### Installing and Importing necessary Dependencies


In [None]:
import os
import json
import torch
import torch.nn.functional as F
from tokenizers import Tokenizer
from ModelArchitecture import Transformer, ModelConfig, generate
from safetensors.torch import load_file

### Device Configurations

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

### Importing and Setting up Tokenizer

In [None]:
tokenizer = Tokenizer.from_file("LumenTokenizer.json")

def encode(text: str) -> torch.LongTensor:
    return torch.tensor(tokenizer.encode(text).ids, dtype=torch.long).unsqueeze(0)

def decode(ids: torch.LongTensor) -> str:
    return tokenizer.decode(ids.tolist())


### Model Configurations

In [None]:
with open("config.json", "r") as f:
    config_dict = json.load(f)

config = ModelConfig(**config_dict)

### Initializing Model

In [None]:
model = Transformer(config).to(device)

### Loading the PreTrained Model

#### .safetensors

In [None]:
weights_path = "../Models/LumenBase.safetensors"
state = load_file(weights_path, device=str(device))

model.load_state_dict(state, strict=False)
print(f"Loaded .safetensors checkpoint: {weights_path}")

model.eval()

#### .pt (Optional)

In [None]:
weights_path = "../Models/best_model_params.pt"
state = torch.load(weights_path, map_location=device)

model.load_state_dict(state, strict=False)
print(f"Loaded .pt checkpoint: {weights_path}")

model.eval()

### Generator Function 

In [None]:
@torch.no_grad()
def generate_text(
    prompt: str,
    max_new_tokens: int = 128,
    temperature: float = 0.8,
    top_k: int = 0,
    top_p: float = 0.9,
    do_sample: bool = True,
    eos_token_id: int | None = None,
    pad_token_id: int | None = None,
):
    input_ids = encode(prompt).to(device)
    out_ids = generate(
        model=model,
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=do_sample,
        eos_token_id=eos_token_id,
        pad_token_id=pad_token_id,
        device=device,
    )
    # Strip the prompt portion for the decoded continuation
    continuation_ids = out_ids[0, input_ids.size(1):]
    return decode(continuation_ids.cpu())


### Inference

In [None]:
prompt = "Once upon a time"
output = generate_text(prompt, max_new_tokens=100, temperature=0.7, top_p=0.8)
print("Prompt:")
print(prompt)
print("\nGeneration:")
print(output)
