### Importing Dependencies

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch.cuda.amp import autocast, GradScaler
from contextlib import nullcontext
from torch.optim.lr_scheduler import LinearLR, SequentialLR, CosineAnnealingLR

import numpy as np
import os

from tqdm import tqdm
from tokenizers import Tokenizer
from dataclasses import dataclass
from typing import List, Optional
from ModelArchitecture import Transformer, ModelConfig

import warnings
warnings.filterwarnings('ignore')

### Device Configurations

In [None]:
print(torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

### Loading Dataset

In [None]:
train_split = np.load("train_split.npy")
val_split   = np.load("val_split.npy")

print("Training Data:", train_split.shape)
print("Validation Data:", val_split.shape)

### Model Configurations

In [None]:
config = ModelConfig(
    vocab_size=32000,          
    hidden_size=768,           
    n_heads=12,               
    n_kv_heads=4,              
    n_kv_groups=3,             
    head_dim=64,              
    n_layers=12,          
    attention_bias=False,      
    intermediate_size=3072,    
    mlp_bias=False,            
    eps=1e-5,                  
    dropout=0.1,               
    max_position_embeddings=2048,
    pre_norm=True,             
    tie_weights=True,
    max_seq_len=2048
)

print(f"Model Configuration:")
print(f"  Architecture: {config.hidden_size}d, {config.n_layers}L, {config.n_heads}H, {config.intermediate_size}ff")
print(f"  Vocabulary: {config.vocab_size:,} tokens")

### Model Initialization

In [None]:
model = Transformer(config)
model = model.to(device)

parameter_count = sum(p.numel() for p in model.parameters())
print(f"Model Device: {device}")
print(f"Total Parameters: {parameter_count:,}")

### Training Configurations

In [None]:
@dataclass
class TrainingConfig:
    learning_rate: float = 3e-4
    warmup_steps: int = 200
    min_lr: float = 1e-4
    max_steps: int = 50000
    eval_every: int = 5000
    eval_steps: int = 100
    
    # Batch and sequence settings
    batch_size: int = 12
    gradient_accumulation_steps: int = 4
    
    weight_decay: float = 0.1
    grad_clip_norm: float = 0.5
    betas: tuple = (0.9, 0.95)
    eps: float = 1e-9
    
    checkpoint_every: int = 5000       
    checkpoint_dir: str = "checkpoints"  
    max_checkpoints: int = 5           
    best_model_path: str = "best_model_params.pt"
    
    use_amp: bool = True
    dtype: str = "bfloat16"  
    
    seed: int = 42

train_config = TrainingConfig()

print(f"Training Configuration:")
print(f"  Optimization: lr={train_config.learning_rate}, warmup={train_config.warmup_steps}")
print(f"  Batch: size={train_config.batch_size}, accumulation={train_config.gradient_accumulation_steps}")
print(f"  Training: {train_config.max_steps} steps, eval every {train_config.eval_every}")
print(f"  Checkpointing: Every {train_config.checkpoint_every} steps to '{train_config.checkpoint_dir}'")

# Device setup (use existing device from earlier cell)
device_type = "cuda" if "cuda" in str(device) else "cpu"
if torch.cuda.is_available():
    torch.cuda.set_device(0)

# Precision - auto-detect best available
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
    train_config.dtype = "bfloat16"
elif torch.cuda.is_available():
    train_config.dtype = "float16"
else:
    train_config.dtype = "float32"

ptdtype = {
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
    "float16": torch.float16,
}[train_config.dtype]
# Use proper autocast API signature
ctx = nullcontext() if device_type == "cpu" else torch.autocast(device_type=device_type, dtype=ptdtype)

print(f"  Precision: {train_config.dtype}")

# Reproducibility
torch.manual_seed(train_config.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(train_config.seed)
    torch.cuda.manual_seed_all(train_config.seed)

### Optimizer and Scheduler

In [None]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=train_config.learning_rate,
    betas=train_config.betas,
    weight_decay=train_config.weight_decay,
    eps=train_config.eps,
)
scheduler_warmup = LinearLR(optimizer, total_iters=train_config.warmup_steps)
scheduler_decay = CosineAnnealingLR(
    optimizer, 
    T_max=train_config.max_steps - train_config.warmup_steps, 
    eta_min=train_config.min_lr
)
scheduler = SequentialLR(
    optimizer, 
    [scheduler_warmup, scheduler_decay], 
    milestones=[train_config.warmup_steps]
)


### Mixed Precision

In [None]:
scaler = torch.cuda.amp.GradScaler(enabled=(train_config.dtype == "float16"))

### Batching 

In [None]:
def get_batch(split):
    data = train_split if split == 'train' else val_split
    # Use model config for sequence length and training config for batch size
    ix = torch.randint(len(data) - config.max_seq_len, (train_config.batch_size,))
    idx_list = ix.tolist()  # convert to Python ints for numpy slicing
    x = torch.stack([torch.from_numpy(data[i:i+config.max_seq_len]).long() for i in idx_list])
    y = torch.stack([torch.from_numpy(data[i+1:i+1+config.max_seq_len]).long() for i in idx_list])

    if device.type == 'cuda':
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)

    return x, y

### Loss Function

In [None]:
def estimate_loss(model, eval_iters=100):
    out = {}
    model.eval()
    with torch.inference_mode():
        for split in ['train', 'val']:
            losses = torch.zeros(eval_iters)
            for k in range(eval_iters):
                X, Y = get_batch(split)
                logits = model(X)
                # Compute cross-entropy loss - ensure target is long type
                # Reshape for cross-entropy: (batch_size * seq_len, vocab_size) and (batch_size * seq_len)
                loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1).long())
                losses[k] = loss.item()
            out[split] = losses.mean()
    model.train()
    return out

### Resume Training

In [None]:
# STEP 1: Specify which checkpoint to resume from
resume_from_step = 0 # ← Change this to your checkpoint step (1000, 2000, 3000, etc.)

#Load the checkpoint (only if resuming)
if resume_from_step > 0:
    checkpoint_path = f"{train_config.checkpoint_dir}/checkpoint_step_{resume_from_step}.pt"
    print(f"Loading checkpoint from: {checkpoint_path}")
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)

    # STEP 3: Restore everything from checkpoint
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    # STEP 4: Restore training variables
    best_val_loss = checkpoint.get('best_val_loss', float('inf'))
    train_loss_list = checkpoint.get('train_loss_list', [])
    validation_loss_list = checkpoint.get('validation_loss_list', [])
    evaluation_steps = checkpoint.get('evaluation_steps', [])

    # STEP 5: Set where to start
    start_step = checkpoint['step']
    print(f"Resumed from step {start_step}")
else:
    best_val_loss = float('inf')
    train_loss_list, validation_loss_list, evaluation_steps = [], [], []
    start_step = 0

# Set your new target steps
train_config.max_steps = 10000
print(f"Best validation loss so far: {best_val_loss if best_val_loss != float('inf') else 'N/A'}")
print(f"Training will continue from step {start_step + 1} to step {train_config.max_steps}")


In [None]:
model = model.to(device)

os.makedirs(train_config.checkpoint_dir, exist_ok=True)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(train_config.seed)
    torch.cuda.set_device(0)
    torch.multiprocessing.set_sharing_strategy('file_system')

train_loss_list = list(train_loss_list)
validation_loss_list = list(validation_loss_list)
evaluation_steps = list(evaluation_steps)
if not evaluation_steps and train_loss_list:
    evaluation_steps = [train_config.eval_every * (i + 1) for i in range(len(train_loss_list))]
history_path = os.path.join(train_config.checkpoint_dir, "loss_history.pt")

print(f"Starting training for {train_config.max_steps - start_step} more iterations...")
print(f"Model: {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Batch size: {train_config.batch_size}, Sequence length: {config.max_seq_len}")
print("="*50)

for step in tqdm(range(start_step + 1, train_config.max_steps + 1)):
    if step % train_config.eval_every == 0 and step != start_step + 1:
        losses = estimate_loss(model, train_config.eval_steps)
        current_train_loss = float(losses['train'])
        current_val_loss = float(losses['val'])
        train_loss_list.append(current_train_loss)
        validation_loss_list.append(current_val_loss)
        evaluation_steps.append(step)

        print(f"Step {step}: train loss {current_train_loss:.4f}, val loss {current_val_loss:.4f}")
        print(f"The current learning rate: {optimizer.param_groups[0]['lr']:.5f}")

        if current_val_loss < best_val_loss:
            best_val_loss = current_val_loss
            torch.save(model.state_dict(), train_config.best_model_path)
            print(f"New best model saved with val loss: {best_val_loss:.4f}")
    
    if step % train_config.checkpoint_every == 0 and step != start_step + 1:
        checkpoint_path = f"{train_config.checkpoint_dir}/checkpoint_step_{step}.pt"
        torch.save({
            'step': step,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_loss': best_val_loss,
            'train_loss_list': train_loss_list,
            'validation_loss_list': validation_loss_list,
            'evaluation_steps': evaluation_steps
        }, checkpoint_path)
        print(f"Checkpoint saved: {checkpoint_path}")

    X, y = get_batch("train")

    with ctx:
        logits = model(X)
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1).long())
        loss = loss / train_config.gradient_accumulation_steps
        scaler.scale(loss).backward()

    if ((step % train_config.gradient_accumulation_steps) == 0) or (step == train_config.max_steps):
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=train_config.grad_clip_norm)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()

print("Training completed!")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Final step: {train_config.max_steps}")

history = {
    'steps': evaluation_steps,
    'train': train_loss_list,
    'val': validation_loss_list,
    'best_val_loss': best_val_loss,
    'max_steps': train_config.max_steps
}
torch.save(history, history_path)
print(f"Loss history saved to {history_path}")


### Loss Curve

In [None]:
history_path = os.path.join(train_config.checkpoint_dir, "loss_history.pt")

if os.path.exists(history_path):
    history = torch.load(history_path, map_location='cpu')
    steps = list(history.get('steps', []))
    train_curve = [float(x) for x in history.get('train', [])]
    val_curve = [float(x) for x in history.get('val', [])]
    max_steps = int(history.get('max_steps', train_config.max_steps))
else:
    steps = list(evaluation_steps) if 'evaluation_steps' in globals() else []
    train_curve = [float(x) for x in train_loss_list] if 'train_loss_list' in globals() else []
    val_curve = [float(x) for x in validation_loss_list] if 'validation_loss_list' in globals() else []
    max_steps = int(train_config.max_steps)

if steps and len(steps) == len(train_curve) == len(val_curve):
    ordered = sorted(zip(steps, train_curve, val_curve))
    steps, train_curve, val_curve = map(list, zip(*ordered))

    plt.figure(figsize=(8, 4))
    plt.plot(steps, train_curve, label='Train loss', color='#4C78A8')
    plt.plot(steps, val_curve, label='Validation loss', color='#F58518')
    plt.xlabel('Training step')
    plt.ylabel('Loss')
    plt.title('Training vs Validation Loss')
    plt.xlim(0, max_steps)
    plt.grid(alpha=0.2)
    plt.legend()
    plt.tight_layout()
else:
    print("No loss history available yet. Run training to populate the history.")


### Conversion to SafeTensor (Optional)

In [None]:
from safetensors.torch import save_file

pt_file = "../Models/best_model_params.pt"
safetensors_file = "best_model_params.safetensors"

state_dict = torch.load(pt_file, map_location='cpu')

# Clone all tensors to avoid shared memory issues
tensors_to_save = {key: tensor.clone() for key, tensor in state_dict.items()}

save_file(tensors_to_save, safetensors_file)
print(f"Converted {pt_file} → {safetensors_file}")


Converted ../Models/best_model_params.pt → best_model_params.safetensors
