### Importing Dependencies

In [None]:
%pip install tokenizers

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch.cuda.amp import autocast, GradScaler
from contextlib import nullcontext
from torch.optim.lr_scheduler import LinearLR, SequentialLR, CosineAnnealingLR

import json
import numpy as np
import os
from tqdm import tqdm
from tokenizers import Tokenizer
from dataclasses import dataclass
from typing import List, Optional
from torch.utils.data import Dataset, DataLoader

from ModelArchitecture import Transformer, ModelConfig

import warnings
warnings.filterwarnings('ignore')

### Device Configuration

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

### Load Tokenizer

In [None]:
tokenizer = Tokenizer.from_file("LumenTokenizer.json")
vocab_size = tokenizer.get_vocab_size()
print(f"Tokenizer loaded - Vocab size: {vocab_size:,}")

### SFT Dataset Class

In [None]:
class SFTDataset(Dataset):
    def __init__(self, jsonl_path, tokenizer, max_length=2048):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = []
        
        # Load JSONL data
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    item = json.loads(line.strip())
                    if 'text' in item:
                        self.data.append(item['text'])
                except json.JSONDecodeError:
                    continue
        
        print(f"Loaded {len(self.data)} samples from {jsonl_path}")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        text = self.data[idx]
        
        # Tokenize
        encoding = self.tokenizer.encode(text)
        tokens = encoding.ids
        
        # Truncate or pad
        if len(tokens) > self.max_length:
            tokens = tokens[:self.max_length]
        
        # Create input and target (shifted by 1)
        input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
        target_ids = torch.tensor(tokens[1:], dtype=torch.long)
        
        return input_ids, target_ids

def collate_fn(batch):
    """Custom collate function to handle variable length sequences"""
    input_ids = [item[0] for item in batch]
    target_ids = [item[1] for item in batch]
    
    max_len = max(len(ids) for ids in input_ids)
    
    input_ids_padded = torch.zeros(len(batch), max_len, dtype=torch.long)
    target_ids_padded = torch.zeros(len(batch), max_len, dtype=torch.long)
    
    for i, (inp, tgt) in enumerate(zip(input_ids, target_ids)):
        input_ids_padded[i, :len(inp)] = inp
        target_ids_padded[i, :len(tgt)] = tgt
    
    return input_ids_padded, target_ids_padded

### Load Dataset

In [None]:
sft_dataset = SFTDataset("Instruct_Dataset.jsonl", tokenizer, max_length=2048)

train_size = int(0.9 * len(sft_dataset))
val_size = len(sft_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(sft_dataset, [train_size, val_size])

print(f"Train samples: {len(train_dataset):,}")
print(f"Validation samples: {len(val_dataset):,}")

### Model Configuration & Initialization

In [None]:
config = ModelConfig(
    vocab_size=32000,          
    hidden_size=768,           
    n_heads=12,               
    n_kv_heads=4,              
    n_kv_groups=3,             
    head_dim=64,              
    n_layers=12,          
    attention_bias=False,      
    intermediate_size=3072,    
    mlp_bias=False,            
    eps=1e-5,                  
    dropout=0.1,               
    max_position_embeddings=2048,
    pre_norm=True,             
    tie_weights=True,
    max_seq_len=2048
)

# Initialize model
model = Transformer(config)

# Load pretrained weights
pretrained_path = "best_model_params.pt"
if os.path.exists(pretrained_path):
    model.load_state_dict(torch.load(pretrained_path, map_location=device))
    print(f"✓ Loaded pretrained weights from {pretrained_path}")
else:
    print(f"Warning: Pretrained weights not found at {pretrained_path}")
    print("Starting from random initialization")

model = model.to(device)
param_count = sum(p.numel() for p in model.parameters())
print(f"Model Parameters: {param_count:,}")

### Training Configuration

In [None]:
@dataclass
class SFTConfig:
    # Training
    num_epochs: int = 5
    batch_size: int = 12
    gradient_accumulation_steps: int = 4
    eval_every_steps: int = 500
    
    # Optimization
    learning_rate: float = 3e-5
    warmup_steps: int = 400
    min_lr: float = 1e-6
    weight_decay: float = 0.01
    grad_clip_norm: float = 1.0
    betas: tuple = (0.9, 0.95)
    
    # Checkpointing
    checkpoint_dir: str = "sft_checkpoints"
    best_model_path: str = "best_sft_model.pt"
    
    # Mixed precision
    use_amp: bool = True
    dtype: str = "bfloat16" if torch.cuda.is_bf16_supported() else "float16"
    
    seed: int = 42

sft_config = SFTConfig()

# Setup precision
device_type = "cuda" if torch.cuda.is_available() else "cpu"
ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[sft_config.dtype]
ctx = nullcontext() if device_type == "cpu" else torch.autocast(device_type=device_type, dtype=ptdtype)

print(f"SFT Configuration:")
print(f"  Epochs: {sft_config.num_epochs}")
print(f"  Batch size: {sft_config.batch_size} (accumulation: {sft_config.gradient_accumulation_steps})")
print(f"  Learning rate: {sft_config.learning_rate}")
print(f"  Precision: {sft_config.dtype}")

# Create checkpoint directory
os.makedirs(sft_config.checkpoint_dir, exist_ok=True)

# Set seed
torch.manual_seed(sft_config.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(sft_config.seed)

### DataLoaders

In [None]:
train_loader = DataLoader(
    train_dataset, 
    batch_size=sft_config.batch_size, 
    shuffle=True, 
    collate_fn=collate_fn,
    num_workers=0,
    pin_memory=True if torch.cuda.is_available() else False
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=sft_config.batch_size, 
    shuffle=False, 
    collate_fn=collate_fn,
    num_workers=0,
    pin_memory=True if torch.cuda.is_available() else False
)

print(f"Train batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

### Optimizer & Scheduler

In [None]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=sft_config.learning_rate,
    betas=sft_config.betas,
    weight_decay=sft_config.weight_decay
)

total_steps = len(train_loader) * sft_config.num_epochs // sft_config.gradient_accumulation_steps

scheduler_warmup = LinearLR(optimizer, total_iters=sft_config.warmup_steps)
scheduler_decay = CosineAnnealingLR(
    optimizer, 
    T_max=total_steps - sft_config.warmup_steps, 
    eta_min=sft_config.min_lr
)
scheduler = SequentialLR(
    optimizer, 
    [scheduler_warmup, scheduler_decay], 
    milestones=[sft_config.warmup_steps]
)

scaler = GradScaler('cuda', enabled=(sft_config.dtype == "float16"))

print(f"Total training steps: {total_steps}")

### Evaluation Function

In [None]:
@torch.no_grad()
def evaluate(model, dataloader, max_batches=None):
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    for i, (input_ids, target_ids) in enumerate(dataloader):
        if max_batches and i >= max_batches:
            break
            
        input_ids = input_ids.to(device)
        target_ids = target_ids.to(device)
        
        with ctx:
            logits = model(input_ids)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), 
                target_ids.view(-1), 
                ignore_index=0  # Ignore padding
            )
        
        total_loss += loss.item() * target_ids.numel()
        total_tokens += target_ids.numel()
    
    model.train()
    return total_loss / total_tokens if total_tokens > 0 else 0

### Training Loop

In [None]:
model.train()
best_val_loss = float('inf')
train_losses = []
val_losses = []
step = 0

print("Starting Supervised Fine-Tuning...")
print("="*60)

for epoch in range(sft_config.num_epochs):
    print(f"\nEpoch {epoch + 1}/{sft_config.num_epochs}")
    epoch_loss = 0
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
    
    for batch_idx, (input_ids, target_ids) in enumerate(progress_bar):
        input_ids = input_ids.to(device)
        target_ids = target_ids.to(device)
        
        # Forward pass
        with ctx:
            logits = model(input_ids)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), 
                target_ids.view(-1),
                ignore_index=0  # Ignore padding
            )
            loss = loss / sft_config.gradient_accumulation_steps
        
        # Backward pass
        scaler.scale(loss).backward()
        
        # Optimizer step
        if (batch_idx + 1) % sft_config.gradient_accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), sft_config.grad_clip_norm)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()
            step += 1
        
        epoch_loss += loss.item()
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{optimizer.param_groups[0]["lr"]:.2e}'})
        
        # Evaluation
        if step > 0 and step % sft_config.eval_every_steps == 0:
            val_loss = evaluate(model, val_loader, max_batches=50)
            train_losses.append(epoch_loss / (batch_idx + 1))
            val_losses.append(val_loss)
            
            print(f"\nStep {step}: Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_loss:.4f}")
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), sft_config.best_model_path)
                print(f"✓ Best model saved (val_loss: {best_val_loss:.4f})")
    
    # End of epoch evaluation
    val_loss = evaluate(model, val_loader)
    print(f"\nEpoch {epoch + 1} Complete - Avg Train Loss: {epoch_loss / len(train_loader):.4f}, Val Loss: {val_loss:.4f}")
    
    # Save checkpoint
    checkpoint_path = os.path.join(sft_config.checkpoint_dir, f"checkpoint_epoch_{epoch + 1}.pt")
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'val_loss': val_loss,
        'train_losses': train_losses,
        'val_losses': val_losses
    }, checkpoint_path)
    print(f"Checkpoint saved: {checkpoint_path}")

print("\n" + "="*60)
print("Training Complete!")
print(f"Best validation loss: {best_val_loss:.4f}")

### Training Loss Visualization

In [None]:
if train_losses and val_losses:
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss', color='#4C78A8', linewidth=2)
    plt.plot(val_losses, label='Validation Loss', color='#F58518', linewidth=2)
    plt.xlabel('Evaluation Step')
    plt.ylabel('Loss')
    plt.title('SFT Training Progress')
    plt.legend()
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.show()
else:
    print("No loss history available yet")