### Installing and Importing necessary Dependencies


In [2]:
import os
import torch
import torch.nn.functional as F
from tokenizers import Tokenizer
from ModelArchitecture import Transformer, ModelConfig, generate

### Device Configurations

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

Device: cpu


### Importing and Setting up Tokenizer

In [4]:
tokenizer = Tokenizer.from_file("LumenTokenizer.json")

def encode(text: str) -> torch.LongTensor:
    return torch.tensor(tokenizer.encode(text).ids, dtype=torch.long).unsqueeze(0)

def decode(ids: torch.LongTensor) -> str:
    return tokenizer.decode(ids.tolist())


### Model Configurations

In [5]:
config = ModelConfig(
    vocab_size=32000,
    hidden_size=768,
    n_heads=12,
    n_kv_heads=4,
    n_kv_groups=3,
    head_dim=64,
    n_layers=12,
    attention_bias=False,
    intermediate_size=3072,
    mlp_bias=False,
    eps=1e-5,
    dropout=0.0,
    max_position_embeddings=2048,
    pre_norm=True,
    tie_weights=True,
    max_seq_len=2048,
)

### Initializing Model

In [6]:
model = Transformer(config).to(device)

### Loading the PreTrained Model

In [7]:
weights_path = "../Models/best_model_params_80k.pt"
state = torch.load(weights_path, map_location=device)

# Handle state dict structure
if isinstance(state, dict) and "model_state_dict" in state:
    state = state["model_state_dict"]

model.load_state_dict(state, strict=False)
print(f"Loaded .pt checkpoint: {weights_path}")

model.eval()

Loaded .pt checkpoint: ../Models/best_model_params_80k.pt


Transformer(
  (token_embedding): Embedding(32000, 768)
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (attention): GroupedMultiQueryAttention(
        (dropout): Dropout(p=0.0, inplace=False)
        (q_proj): Linear(in_features=768, out_features=768, bias=False)
        (k_proj): Linear(in_features=768, out_features=256, bias=False)
        (v_proj): Linear(in_features=768, out_features=256, bias=False)
        (w_o): Linear(in_features=768, out_features=768, bias=False)
        (rope): RotaryEmbedding()
      )
      (feed_forward): SwiGLUFeedForward(
        (dropout): Dropout(p=0.0, inplace=False)
        (gate_proj): Linear(in_features=768, out_features=3072, bias=True)
        (up_proj): Linear(in_features=768, out_features=3072, bias=True)
        (down_proj): Linear(in_features=3072, out_features=768, bias=True)
        (act): SiLU()
      )
      (attn_norm): RMSNorm()
      (ffn_norm): RMSNorm()
      (dropout): Dropout(p=0.0, inplace=False)
    )
  )
  (em

### Generator Function 

In [8]:
@torch.no_grad()
def generate_text(
    prompt: str,
    max_new_tokens: int = 128,
    temperature: float = 0.8,
    top_k: int = 0,
    top_p: float = 0.9,
    do_sample: bool = True,
    eos_token_id: int | None = None,
    pad_token_id: int | None = None,
):
    input_ids = encode(prompt).to(device)
    out_ids = generate(
        model=model,
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=do_sample,
        eos_token_id=eos_token_id,
        pad_token_id=pad_token_id,
        device=device,
    )
    # Strip the prompt portion for the decoded continuation
    continuation_ids = out_ids[0, input_ids.size(1):]
    return decode(continuation_ids.cpu())


### Inference

In [11]:
prompt = "Once upon a time"
output = generate_text(prompt, max_new_tokens=100, temperature=0.7, top_p=0.8)
print("Prompt:")
print(prompt)
print("\nGeneration:")
print(output)


Prompt:
Once upon a time

Generation:
, in a land far away called Japan, there was a little girl named Maria. She loved to play with her toys and share them with her friends. But sometimes, she would feel sad or worried because she didn't know how to help her friends.

Maria's mom explained to her that she had a special kind of helper called a therapist. She told Maria that she would talk to her parents and help her understand why she felt sad or worried. The therapist would listen carefully
