### Importing Depenedencies

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch.cuda.amp import autocast, GradScaler
from contextlib import nullcontext
from torch.optim.lr_scheduler import LinearLR, SequentialLR, CosineAnnealingLR

import json
import numpy as np
import os
from tqdm import tqdm
from tokenizers import Tokenizer
from dataclasses import dataclass
from typing import List, Optional
from torch.utils.data import Dataset, DataLoader
from safetensors.torch import load_file

from ModelArchitecture import Transformer, ModelConfig, generate

import warnings
warnings.filterwarnings('ignore')

### Device Configurations

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

Device: cpu


### Importing Tokenizer

In [10]:
tokenizer = Tokenizer.from_file("LumenTokenizer.json")
vocab_size = tokenizer.get_vocab_size()
print(f"Tokenizer loaded - Vocab size: {vocab_size:,}")

Tokenizer loaded - Vocab size: 32,000


### Initializing Model

In [2]:
config = ModelConfig(
    vocab_size=32000,          
    hidden_size=768,           
    n_heads=12,               
    n_kv_heads=4,              
    n_kv_groups=3,             
    head_dim=64,              
    n_layers=12,          
    attention_bias=False,      
    intermediate_size=3072,    
    mlp_bias=False,            
    eps=1e-5,                  
    dropout=0.1,               
    max_position_embeddings=2048,
    pre_norm=True,             
    tie_weights=True,
    max_seq_len=2048
)

model = Transformer(config)

### Load Model

#### SafeTensor

In [8]:
weights_path = "/Users/hariom/LLM Development/PostTraining/Models/best_sft_model_params.safetensors"
state = load_file(weights_path, device=str(device))

model.load_state_dict(state, strict=False)
print(f"Loaded .safetensors checkpoint: {weights_path}")

model.eval()

Loaded .safetensors checkpoint: /Users/hariom/LLM Development/PostTraining/Models/best_sft_model_params.safetensors


Transformer(
  (token_embedding): Embedding(32000, 768)
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (attention): GroupedMultiQueryAttention(
        (dropout): Dropout(p=0.1, inplace=False)
        (q_proj): Linear(in_features=768, out_features=768, bias=False)
        (k_proj): Linear(in_features=768, out_features=256, bias=False)
        (v_proj): Linear(in_features=768, out_features=256, bias=False)
        (w_o): Linear(in_features=768, out_features=768, bias=False)
        (rope): RotaryEmbedding()
      )
      (feed_forward): SwiGLUFeedForward(
        (dropout): Dropout(p=0.1, inplace=False)
        (gate_proj): Linear(in_features=768, out_features=3072, bias=True)
        (up_proj): Linear(in_features=768, out_features=3072, bias=True)
        (down_proj): Linear(in_features=3072, out_features=768, bias=True)
        (act): SiLU()
      )
      (attn_norm): RMSNorm()
      (ffn_norm): RMSNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (em

#### .pt

In [7]:
weights_path = "best_sft_model_latest.pt"
state = torch.load(weights_path, map_location=device)

model.load_state_dict(state, strict=False)
print(f"Loaded .pt checkpoint: {weights_path}")

model.eval()

Loaded .pt checkpoint: best_sft_model_latest.pt


Transformer(
  (token_embedding): Embedding(32000, 768)
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (attention): GroupedMultiQueryAttention(
        (dropout): Dropout(p=0.1, inplace=False)
        (q_proj): Linear(in_features=768, out_features=768, bias=False)
        (k_proj): Linear(in_features=768, out_features=256, bias=False)
        (v_proj): Linear(in_features=768, out_features=256, bias=False)
        (w_o): Linear(in_features=768, out_features=768, bias=False)
        (rope): RotaryEmbedding()
      )
      (feed_forward): SwiGLUFeedForward(
        (dropout): Dropout(p=0.1, inplace=False)
        (gate_proj): Linear(in_features=768, out_features=3072, bias=True)
        (up_proj): Linear(in_features=768, out_features=3072, bias=True)
        (down_proj): Linear(in_features=3072, out_features=768, bias=True)
        (act): SiLU()
      )
      (attn_norm): RMSNorm()
      (ffn_norm): RMSNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (em

In [11]:
eos_token = "<|im_end|>"
eos_token_id = tokenizer.encode(eos_token).ids[0]
print(f"EOS token ID: {eos_token_id}")

def generate_response(prompt, max_tokens=200):
    # Format prompt in chat format
    formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
    
    # Tokenize
    tokens = tokenizer.encode(formatted_prompt).ids
    input_ids = torch.tensor([tokens], dtype=torch.long, device=device)
    
    # Generate with EOS token stopping
    with torch.no_grad():
        output = generate(
            model, 
            input_ids, 
            max_new_tokens=max_tokens,
            temperature=0.7,
            top_k=50,
            top_p=0.9,
            do_sample=True,
            eos_token_id=eos_token_id  # Stop at <|im_end|>
        )
    
    # Decode full output
    full_text = tokenizer.decode(output[0].tolist())
    
    # Extract only the assistant's response
    if "<|im_start|>assistant" in full_text:
        # Get text after the assistant marker
        response_part = full_text.split("<|im_start|>assistant")[-1]
        # Remove the closing tag if present
        if "<|im_end|>" in response_part:
            response_part = response_part.split("<|im_end|>")[0]
        return response_part.strip()
    
    return full_text


EOS token ID: 6


### Inference

In [14]:
test_prompt = "Who created you?"
print(f"Prompt: {test_prompt}\n")
print(f"Response:\n{generate_response(test_prompt)}")


Prompt: Who created you?

Response:
user
Who created you?
assistant
I’m an AI assistant named Lumen, created by Hariom Jangra from India.
