In [1]:
! pip install datasets wandb

Collecting pyarrow>=21.0.0 (from datasets)
  Downloading pyarrow-22.0.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (3.2 kB)
Downloading pyarrow-22.0.0-cp311-cp311-manylinux_2_28_x86_64.whl (47.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.7/47.7 MB[0m [31m38.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: pyarrow
  Attempting uninstall: pyarrow
    Found existing installation: pyarrow 19.0.1
    Uninstalling pyarrow-19.0.1:
      Successfully uninstalled pyarrow-19.0.1
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 2.12.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.
pylibcudf-cu12 25.2.2 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x86_64", but you have pyarrow 22.0.0 which is incompatible.
cudf-cu12 25.2.2 requires pyarrow<20.0.0

In [44]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
import torch.nn.functional as F
import os
import time
import json
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from torch.utils.data import IterableDataset, DataLoader
import copy
from datasets import load_dataset
from torch.utils.data import Dataset
import random

## RoPE

In [3]:
def precompute_freq_cis(dim: int, end: int, theta: float = 10000.0):
    """Precomputes the angles for rotation.
    Drawing a map of all possible positions ahead of time.

    Args:
        dim (int): The dimension of attention head.
        end (int): The maximum sequence length.
        theta (float, optional): The base frequency. Defaults to 10000.0 which is standard for Llama.
    """
    
    # Create a list of frequencies
    # 1 / (theta ^ (2i / dim))
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    
    # Create a list of positions: 0, 1, 2, ..., end-1
    t = torch.arange(end, device=freqs.device)
    
    # Compute the outer product to get all position-frequency combinations
    freqs = torch.outer(t, freqs).float() # Shape: [end, dim//2]
    
    # Turn them into polar coordinates (mag 1, angle freqs)
    # Cis = cos + i*sin
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    """
    Helper to match shapes.
    freqs_cis: [max_len, dim // 2]
    x: [Batch, seq_len, head, dim // 2]

    Args:
        freqs_cis (torch.Tensor): Polar coordinates of frequencies.
        x (torch.Tensor): Input tensor.
    """
    
    ndim = x.ndim # Number of dimensions in x
    assert 0 <= 1 < ndim # Ensure x has at least 2 dimensions
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]) # Check shape compatibility
    
    # Reshape freqs_cis to [1, seq_len, 1, dim // 2] so that it broadcasts over Batch and Head
    shape = [d if i==1 or i==ndim-1 else 1 for i, d in enumerate(x.shape)]
    
    return freqs_cis.view(*shape)

def apply_rotary_pos_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
    """The Actual Rotation. We treat q/k vectors as complex numbers and multiply by the frequencies.

    Args:
        xq (torch.Tensor): x query tensor.
        xk (torch.Tensor): x key tensor.
        freqs_cis (torch.Tensor): Polar coordinates of frequencies.
    """
    
    # 1. Turn Query and Key into complex numbers
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    
    # 2. Get the specific frequencies for the current sequence length
    freqs_cis = reshape_for_broadcast(freqs_cis=freqs_cis, x=xq_)
    
    # 3. Rotate (Multiply) and convert back to real numbers
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    
    return xq_out.type_as(xq), xk_out.type_as(xk)

## Attention

In [4]:
class EfficientAttention(nn.Module):
    def __init__(self, d_model: int, n_head: int, n_kv_head: int, window_size: int):
        super().__init__()
        self.n_head = n_head # Number of query heads
        self.n_kv_head = n_kv_head # Number of key/value heads
        self.d_head = d_model // n_head # Dimension per head
        self.window_size = window_size # Size of the local attention window
        
        #* The GQA Ratio (Grouped Query Attention Ratio)
        #* If n_head=8 and n_kv_head=2, then n_rep=4
        #* This means 1 K/V head will serve 4 query heads
        self.n_rep = self.n_head // self.n_kv_head
        
        # Q needs full size: (d_model -> d_model)
        #* Why? Because each query head is unique
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        
        #* Instead of mapping (d_model -> d_model), we map (d_model -> d_model / 4)
        # Why? Because each K/V head is shared among multiple query heads
        # This reduces the number of parameters and computation
        self.k_proj = nn.Linear(d_model, self.n_kv_head * self.d_head, bias=False)
        self.v_proj = nn.Linear(d_model, self.n_kv_head * self.d_head, bias=False)
        
        #* Output projection to combine heads back to d_model
        self.output_proj = nn.Linear(d_model, d_model, bias=False)
        
    def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
        B, T, C = x.size() # Batch size, sequence length, model dimension (Channels)
        
        # Calculate Q, K, V
        # Q is standard shape: [Batch, Time, 8 heads, 32 dim]
        q = self.q_proj(x).view(B, T, self.n_head, self.d_head).transpose(1, 2) #* Why transpose? -> [B, n_head, T, d_head] PyTorch's matrix multiplication (@) operates on the last two dimensions, so we need Time and Dim at the end.
        
        # K, V are reduced shape: [Batch, Time, 2 heads, 32 dim]        
        # In standard attention, these would also be 8 heads
        k = self.k_proj(x).view(B, T, self.n_kv_head, self.d_head).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_kv_head, self.d_head).transpose(1, 2)
        
        #* Apply RoPE to Query and Key
        # Before that we must transpose to [B, T, head, dim]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        q, k = apply_rotary_pos_emb(q, k, freqs_cis=freqs_cis)
        # Tranpose back to [B, head, T, dim]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        
        #* Since K and V have fewer heads, we need to repeat them
        # so that they can align with the Q heads during attention computation
        # We take the 2 KV heads and copy them 4 times each to get 8 "virtual" heads
        # This allows the math to work without storing 8 unique heads in memory
        k = k.repeat_interleave(self.n_rep, dim=1)
        v = v.repeat_interleave(self.n_rep, dim=1)
        
        # Calculate Scores
        #* It calculates how much every token relates to every other token
        att = (q @ k.transpose(-2, -1)) * (1.0/math.sqrt(self.d_head))
        
        #* Sliding window Mask
        
        #* Casual Mask: "Don't look into the future"
        # Creates a lower trianlge of 1s
        casual_mask = torch.tril(torch.ones(T, T, device=x.device))
        
        #* Window Mask: "Don't look too far back"
        # Creates an upper traingle starting from `window_size` back
        # If window is 16, it blocks everything older than 16 steps ago
        window_mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=-self.window_size+1)
        
        #* Combine masks
        mask = casual_mask * window_mask # [T, T]
        
        # Apply Mask
        att = att.masked_fill(mask == 0, float('-inf'))
        
        #* Softmax and Weighted Sum
        att = F.softmax(att, dim=-1)
        y = att @ v
        
        #* Reshape and Output Projection
        # y: [B, n_head, T, d_head] -> [B, T, n_head, d_head] -> [B, T, C]
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.output_proj(y)

## RMSNorm Layer and Decoder Block

In [5]:
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization.
    Why: It is faster than LayerNorm because it skips calculatin the mean.
    It only calculates variance.
    """
    
    def __init__(self, dim, eps = 1e-6):
        super().__init__()

        self.eps = eps # Prevents division by zero
        
        # The learnable weight parameter (gamma)
        self.weight = nn.Parameter(torch.ones(dim))
        
    def _norm(self, x):
        
        # x.pow(2) -> squares each element
        # .mean(-1, keepdim=True) -> mean over the last dimension (features)
        # + self.eps -> add epsilon for numerical stability
        # torch.rsqrt(...) -> reciprocal of the square root
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
    
    def forward(self, x):
        
        # Convert to float32 for stability during normalization and later convert it back to original dtype
        output = self._norm(x.float()).type_as(x)
        return output * self.weight
    
class SwiGLUMLP(nn.Module):
    """
    Swiss Gated Linear Unit (SwiGLU) Feed-Forward Network.
    Why: It learns efficiently and better than standard ReLU FFNs.
    But it has 3 layers instead of 2.
    """
    
    def __init__(self, d_model, expansion_factor=2.5):
        super().__init__()
        
        # Standard Transformer use expansion_factor of 4
        # SLMs use ~2.6 (Llama) or lower to save memory
        hidden_dim = int(d_model * expansion_factor)
        
        # 1. Gate projection - Determines which information to pass through
        self.gate_proj = nn.Linear(d_model, hidden_dim, bias=False)
        # 2. Up projection - Expands the dimensionality
        self.up_proj = nn.Linear(d_model, hidden_dim, bias=False)
        # 3. Down projection - Reduces back to d_model
        self.down_proj = nn.Linear(hidden_dim, d_model, bias=False)
        
    def forward(self, x):
        # Apply the SwiGLU activation: (SiLU(gate) * up_proj) -> down_proj
        gate = F.silu(self.gate_proj(x))
        up = self.up_proj(x)
        
        # Element-wise multiplication (gating mechanism)
        fused = gate * up
        
        return self.down_proj(fused)
    
class DecoderBlock(nn.Module):
    """A Single Transformer Block.
    
    Flow: Input -> RMSNorm -> Attention -> Residual -> RMSNorm -> FFN -> Residual
    """
    
    def __init__(self, config):
        super().__init__()
        
        d_model = config['d_model']
        n_head = config['n_head']
        n_kv_head = config['n_kv_head']
        window_size = config['window_size']
        mlp_ratio = config.get('mlp_ratio', 2.5) # Default to 2.5 for efficiency
        
        # 1. Attention Engine
        self.self_attn = EfficientAttention(
            d_model=d_model,
            n_head=n_head,
            n_kv_head=n_kv_head,
            window_size=window_size
        )
        
        # 2, The Thinking Engine (FFN)
        self.mlp = SwiGLUMLP(
            d_model=d_model,
            expansion_factor=mlp_ratio
        )
        
        # 3. Normalization Layers (1 before attention, 1 before FFN)
        self.input_layernorm = RMSNorm(d_model)
        self.post_attn_layernorm = RMSNorm(d_model)
        
    def forward(self, x, freqs_cis):
        
        # 1. Attention Block with Residual Connection
        # Norm before attention because it is more stable
        residual = x
        x = self.input_layernorm(x)
        x = self.self_attn(x, freqs_cis)
        x = residual + x
        
        # 2. MLP Block with Residual Connection
        residual = x
        x = self.post_attn_layernorm(x)
        x = self.mlp(x)
        x = residual + x
        
        return x

## TinySLM (Bringing together everything till now!)

In [6]:
class TinySLM(nn.Module):
    """A TinySLM model"""
    
    def __init__(self, config):
        super().__init__()
        
        self.config = config
        
        # 1. Embeddings
        self.vocab_size = config['vocab_size']
        self.d_model = config['d_model']
        self.token_embedding = nn.Embedding(self.vocab_size, self.d_model)
        
        # 2. The Transformer Layers
        # Implementating Block-Wise Weight Sharing
        # If n_unique_layers < n_layers, we reuse the modules
        self.layers = nn.ModuleList()
        n_layers = config['n_layers']
        
        # Create the actual blocks
        for _ in range(n_layers):
            self.layers.append(DecoderBlock(config))
            
        # 3. Final Normalization
        self.norm = RMSNorm(self.d_model)
        
        # 4. The Output Head
        self.output = nn.Linear(self.d_model, self.vocab_size, bias=False)
        
        # 5. Weight Tying
        # The matrix that turns Tokens -> Vectors is often the transpose of Vectors -> Tokens
        # Sharing them saves memory and improves performance ~20-30%
        self.token_embedding.weight = self.output.weight
        
        # 6. Precompute RoPE frequencies
        # Compute enough for the max context window (eg. 2048)
        self.freqs_cis = precompute_freq_cis(
            dim=self.d_model // config['n_head'], # Dimension per head
            end=config['max_seq_len'] * 2, # Just to be safe
            theta=10000.0
        )
        
    def forward(self, idx, targets=None):
        B, T = idx.shape
        
        # 1. Embed Tokens
        x = self.token_embedding(idx) # Shape: [B, T, d_model]
        
        # 2. Prepare RoPE frequencies for current sequence length
        freqs_cis = self.freqs_cis[:T].to(x.device)
        
        # 3. Run through Layers
        for layer in self.layers:
            x = layer(x, freqs_cis=freqs_cis)
            
        # 4. Final Normalization
        x = self.norm(x)
        
        # 5. Calculate logits (if training)
        logits = self.output(x)
        
        loss = None
        if targets is not None:
            # Flatten for cross-entropy [B*T, vocab_size]
            loss = nn.functional.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1)
            )
            
        return logits, loss

## Tokenizer

In [7]:
def train_tokenizer(input_file="/kaggle/input/textbook-dataset/hybrid_textbook_data.jsonl", voacb_size=32000):
    print("--------- Training Tokenizer ---------")
    
    # 1. Initialize an empty BPE tokenizer
    tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
    tokenizer.pre_tokenizer = Whitespace()
    
    # 2. Setup Trainer
    # Specialized tokens for controlling the model behavior
    special_tokens = [
        "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"
    ]
    trainer = BpeTrainer(
        vocab_size=voacb_size,
        special_tokens=special_tokens
    )
    
    # 3. Extract text fields to temporary file
    
    temp_text_file = "data/raw/temp_training_text.txt"
    os.makedirs("data/raw", exist_ok=True)
    
    print(f"Extracting text data from {input_file}...")
    line_count =0
    
    with open(input_file, 'r', encoding='utf-8') as infile, \
        open(temp_text_file, 'w', encoding='utf-8') as outfile:
            
            for line in infile:
                try:
                    data = json.loads(line.strip())
                    text = data.get('text', '')
                    
                    if text and len(text.strip()) > 0:
                        outfile.write(text + '\n')
                        line_count+=1
                except json.JSONDecodeError:
                    print(f"Skipping invalid JSON line: {line.strip()}")
                    continue
                
    print(f"Extracted {line_count} lines of text data to {temp_text_file}.")
    
    # 4. Train tokenizer on extracted text
    print("Training BPE tokenizer...")
    tokenizer.train([temp_text_file], trainer)
    
    # 5. Clean up temp file
    os.remove(temp_text_file)
    
    # 6. Save the tokenizer
    output_path = "data/tokenizer/tiny_slm_tokenizer.json"
    os.makedirs("data/tokenizer", exist_ok=True)
    tokenizer.save(output_path)
    print(f"Tokenizer trained and saved at {output_path}")
    
    return tokenizer

In [8]:
train_tokenizer()

--------- Training Tokenizer ---------
Extracting text data from /kaggle/input/textbook-dataset/hybrid_textbook_data.jsonl...
Skipping invalid JSON line: {"topic": "Write a long and very detailed course unit for a textbook on \"Sustainable Transportation: Policy and Planning in Practice\" intended for young children.\nWe are currently writing the first chapter: \"1. Introduction\".\nWe have already covered the following sub-units in the current chapter: \"1.1. Overview of the course\".\nWrite the new sub-unit titled \"1.2. Importance of sustainable transportation\" while trying to be:\n- Rigorous - you create challenging textbooks that cover the material in depth.\n- Engaging - your textbooks have a narrative arc and engaging tone, like the writing of Michael Lewis.\n- Applied - you use specific and practical examples. For example, if the topic is integration in calculus, include equations and proofs of the concept you're teaching. As another example, if the topic is the history of the

Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[{"id":0, "content":"[PAD]", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}, {"id":1, "content":"[UNK]", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}, {"id":2, "content":"[CLS]", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}, {"id":3, "content":"[SEP]", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}, {"id":4, "content":"[MASK]", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}], normalizer=None, pre_tokenizer=Whitespace(), post_processor=None, decoder=None, model=BPE(dropout=None, unk_token="[UNK]", continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={"[PAD]":0, "[UNK]":1, "[CLS]":2, "[SEP]":3, "[MASK]":4, "!":5, """:6, "#":7, "$":8, "%":9, "&

## Dataset

In [9]:
class TextBookDataset(IterableDataset):
    def __init__(self, file_path: str, tokenizer_path: str, max_length=512):
        self.file_path = file_path
        self.max_length = max_length
        
        self.tokenizer = Tokenizer.from_file(tokenizer_path)
        self.pad_token_id = self.tokenizer.token_to_id("[PAD]")
        
    def __iter__(self):
        """
        Generator that reads the file, tokenizes dynamically and yields chunks
        """
        buffer_token = []
        
        with open(self.file_path, 'r', encoding='utf-8') as f:
            for line in f:
                
                try:
                    record = json.loads(line)
                    text = record['text']
                    
                    # Tokenize
                    encode = self.tokenizer.encode(text)
                    ids = encode.ids
                    
                    # Add [EOS] token using [SEP] as a stand-in
                    ids.append(self.tokenizer.token_to_id("[SEP]"))
                    
                    buffer_token.extend(ids)
                    
                    # When buffer exceeds max_length, yield chunks
                    while len(buffer_token) >= self.max_length:
                        # Slice of a chunk
                        chunk = buffer_token[:self.max_length]
                        buffer_token = buffer_token[self.max_length:]
                        
                        # Prepare Input and Target
                        # Input: [A, B, C, D]
                        # Target: [B, C, D, E] (Next token prediction)
                        
                        # Ideally for training we just return the chunk
                        # The training loop will handle shifting for next token prediction
                        x = torch.tensor(chunk, dtype=torch.long)
                        yield x, x # Target is same as input for next token prediction
                except json.JSONDecodeError:
                    continue

In [10]:
class ChatDataset(Dataset):
    def __init__(self, tokenizer_path: str, max_seq_len: int=256):
        self.tokenizer = Tokenizer.from_file(tokenizer_path)
        self.max_seq_len = max_seq_len
        self.pad_token_id = self.tokenizer.token_to_id("[PAD]")
        
        # Load dataset fom HuggingFace
        print("Downloading dataset...")
        dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
        
        self.samples = []
        
        for item in dataset:
            
            if item['context']:
                prompt = f"{item['instruction']}\nContext: {item['context']}"
            else:
                prompt = item['instruction']
                
            response = item['response']
            
            # Applying the chat format
            # <|user|> ... <|assistant|> ...
            formatted_text = (
                f"<|user|>\n{prompt}\n<|end|>\n"
                f"<|assistant|>\n{response}\n<|end|>"
            )
            self.samples.append(formatted_text)
            
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        text = self.samples[idx]
        
        # Tokenizer
        ids = self.tokenizer.encode(text).ids
        
        # Pad or truncate
        if len(ids) > self.max_seq_len:
            ids = ids[:self.max_seq_len]
        else:
            # PAD with [PAD] token which is usually 0 or 1
            pad_ids = self.tokenizer.token_to_id("[PAD]")
            ids = ids + [pad_ids]*(self.max_seq_len - len(ids))
        
        x = torch.tensor(ids, dtype=torch.long)
        
        # Target is same as input for casual language modeling
        # In advance instruction tuning, the user prompt will be masked so that the model only learns to generate the response
        return x, x

In [11]:
import wandb

wandb.login(key="84c972b22fec5ba717183e719f4b61a3cb688312")

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mbatmancodes[0m ([33mbatmancodes-national-institute-of-technology-tiruchirappalli[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## Train

In [13]:
CONFIG = {
    # System
    'device': 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu',
    'num_workers': 0,      # Set to 0 for simpler debugging, 4 for speed
    
    # Model Architecture (Must match what we built)
    'vocab_size': 32000,
    'd_model': 256,
    'n_layers': 8,
    'n_head': 8,
    'n_kv_head': 2,        # GQA
    'window_size': 64,     # SWA
    'max_seq_len': 256,    # Short context for faster training
    'mlp_ratio': 2.5,
    
    # Training (The Optimizer Math)
    'batch_size': 4,       # Micro-batch (fits in memory)
    'accum_steps': 8,      # Virtual Batch Size = 4 * 8 = 32
    'learning_rate': 3e-4, # Peak LR (standard for small models)
    'max_epochs': 100,
    'patience': 3,
    'weight_decay': 0.01,  # AdamW regularization
    'grad_clip': 1.0,      # Prevents exploding gradients
    
    # Data Paths
    'train_file': '/kaggle/input/textbook-dataset/hybrid_textbook_data.jsonl',
    'val_file': '/kaggle/input/textbook-dataset/validation_textbook.jsonl',
    'tokenizer_path': '/kaggle/working/data/tokenizer/tiny_slm_tokenizer.json',
    'save_dir': 'checkpoints',
    
    # Logging
    'use_wandb': True,    # Set to True if you have an account
    'log_interval': 10     # Print every 10 steps
}

In [14]:
def generate_sample(model, tokenizer_path, device, prompt="The cat"):
    from tokenizers import Tokenizer
    
    model.eval() # Switch to eval mode
    tokenizer = Tokenizer.from_file(tokenizer_path)
    
    # Encode
    ids = tokenizer.encode(prompt).ids
    x = torch.tensor([ids], dtype=torch.long).to(device)
    
    # Generate 10 tokens
    for _ in range(10):
        with torch.no_grad():
            logits, _ = model(x)
            # Pick the last token's logits
            next_token_logits = logits[0, -1, :] 
            # Greedy decode (pick max probability)
            next_token = torch.argmax(next_token_logits).item()
            
            # Append
            x = torch.cat((x, torch.tensor([[next_token]], device=device)), dim=1)
    
    # Decode
    output_ids = x[0].tolist()
    decoded = tokenizer.decode(output_ids)
    print(f"\n[GENERATION SAMPLE]: {decoded}\n")
    model.train() # Switch back to train mode

In [15]:
class EarlyStopping:
    """
    The Watchdog. It counts how many did the validation failed to improve.
    """
    
    def __init__(self, patience: int =3, min_delta: float=0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter=0
        self.best_loss = float('inf')
        self.early_stop = False
        
    def __call__(self, val_loss: float) -> bool:
        if val_loss < (self.best_loss - self.min_delta):
            self.best_loss = val_loss
            self.counter = 0 # Reset counter if we improved
            return True # New Best model found
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
            return False

In [16]:
def evaluate(model: torch.nn.Module, val_loader: torch.utils.data.DataLoader, device: torch.device, vocab_size: int) -> float:
    """
    Runs the model on the exam (validation set) without updating the weights

    Args:
        model (torch.nn.Module): The SLM model
        val_loader (torch.utils.data.DataLoader): Validation data loader
        device (torch.device): Device to run the model on (e.g., 'cuda', 'cpu')
        vocab_size (int): Size of the vocabulary
    """
    
    model.eval()
    total_loss = 0
    steps=0
    
    with torch.no_grad():
        for X, Y in val_loader:
            X, Y = X.to(device), Y.to(device)
            
            logits, _ = model(X[:, :-1])
            loss = nn.functional.cross_entropy(
                logits.reshape(-1, vocab_size),
                Y[:, 1:].reshape(-1)
            )
            total_loss += loss.item()
            steps+=1
    model.train()
    return total_loss / steps

In [17]:
def get_lr(it: int, max_iters: int, warmup_iters: int, min_lr: float, max_lr: float):
    """
    Calculates the learning rate for the current iteration 'it'.
    Implements Linear Warmup + Cosine Decay.

    Args:
        it (int): Current iteration number
        max_iters (int): Total number of iterations
        warmup_iters (int): Number of warmup iterations
        min_lr (float): Minimum learning rate
        max_lr (float): Maximum learning rate
    """
    
    # Linear Warmup
    if it<warmup_iters:
        return max_lr * (it+1) / warmup_iters
    
    # If we are past the end, return min_lr
    if it>max_iters:
        return min_lr
    
    # Cosine Decay
    decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)
    coeff = 0.5 * (1+math.cos(math.pi*decay_ratio))
    
    return min_lr + coeff * (max_lr - min_lr)

In [20]:
def train():
    
    # Setup
    os.makedirs(CONFIG['save_dir'], exist_ok=True)
    device = CONFIG['device']
    
    # Ensure Tokenizer Exists
    # If the user hasn't trained a tokenizer yet, do it now
    if not os.path.exists(CONFIG['tokenizer_path']):
        print("Training tokenizer...")
        train_tokenizer(CONFIG['train_file'], CONFIG['vocab_size'])
        
    # Train Data Loader
    train_ds = TextBookDataset(
        file_path=CONFIG['train_file'],
        tokenizer_path=CONFIG['tokenizer_path'],
        max_length=CONFIG['max_seq_len']
    )
    
    train_loader = DataLoader(
        train_ds,
        batch_size=CONFIG['batch_size'],
        num_workers=CONFIG['num_workers'],
        pin_memory=True if device == "cuda" else False
    )
    
    if os.path.exists(CONFIG['val_file']):
        # Validation Data Loader
        val_ds = TextBookDataset(
            file_path=CONFIG['val_file'],
            tokenizer_path=CONFIG['tokenizer_path'],
            max_length=CONFIG['max_seq_len']
        )
        
        val_loader = DataLoader(
            val_ds,
            batch_size=CONFIG['batch_size'],
            num_workers=CONFIG['num_workers'],
            pin_memory=True if device == "cuda" else False
        )
    else:
        print("Warning: Validation file not found. Skipping validation.")
        val_loader = None
    
    # Model Initialization
    model = TinySLM(config=CONFIG).to(device)
    optimzer = optim.AdamW(
        model.parameters(),
        lr=CONFIG['learning_rate'],
        weight_decay=CONFIG['weight_decay'],
        betas=(0.9, 0.95) # Standard for LLMs
    )
    
    # Enable AMP (Automatic Mixed Precision) if using CUDA
    scaler = torch.amp.GradScaler('cuda') if device == 'cuda' else None
    
    early_stopper = EarlyStopping(patience=CONFIG['patience'])
    
    # Logging
    if CONFIG['use_wandb']:
        wandb.init(project="tiny_slm_training", config=CONFIG, name="trial-03")
        wandb.watch(model)
        
    print(f"Model Parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
    
    # Loop variables
    # We must estimate the total steps roughly since it's an IterableDataset
    # Let's assume 1000 samples / 4 batch_size = 250 steps per epoch
    est_steps_per_epoch = 1000 // CONFIG['batch_size']
    max_iters = CONFIG['max_epochs'] * est_steps_per_epoch
    warmup_iters = int(max_iters * 0.1) # 10% warmup
    
    iter_num=0
    running_loss=0.0
    
    # We make a copy of the best model weights in RAM
    best_model_weights = copy.deepcopy(model.state_dict())
    
    model.train()
    
    # Start Epochs
    
    for epoch in range(CONFIG['max_epochs']):
        print(f"Starting epoch {epoch+1}/{CONFIG['max_epochs']}...")
        t0 = time.time()
        
        for batch_idx, (X, Y) in enumerate(train_loader):
            
            # Update the learning rate: Cosine Scheduler
            lr = get_lr(iter_num, max_iters, warmup_iters, 3e-5, CONFIG['learning_rate'])
            for param_group in optimzer.param_groups:
                param_group['lr'] = lr
                
            # Move data to device
            X, Y = X.to(device), Y.to(device)
            
            # Create targets (Next Token Prediction)
            # In input:'the cat sat', target:'cat sat on'
            
            input_ids = X[:, :-1]
            targets = Y[:, 1:]
            
            # Forward Pass (with AMP if CUDA)
            if scaler:
                with torch.amp.autocast('cuda'):
                    logits, _ = model(input_ids)
                    
                    loss = nn.functional.cross_entropy(
                        logits.reshape(-1, CONFIG['vocab_size']),
                        targets.reshape(-1)
                    )
            else:
                # MPS
                logits, _ = model(input_ids)
                loss = nn.functional.cross_entropy(
                    logits.reshape(-1, CONFIG['vocab_size']),
                    targets.reshape(-1)
                )
                
            # Gradient Accumulation Scaling
            # If we want a virtual bach size of 32 but can only fit 4, we simple divide loss by 8. Summing 8 small gradients = 1 big gradient
            loss = loss / CONFIG['accum_steps']
            
            # Backward Pass
            if scaler:
                scaler.scale(loss).backward()
            else:
                loss.backward()
                
            running_loss += loss.item() * CONFIG['accum_steps'] # Scale back for logging
            
            # Optimizer Step (after accum_steps)
            if (batch_idx + 1) % CONFIG['accum_steps'] == 0:
                # Capture gradient norm
                if scaler:
                    scaler.unscale_(optimzer)
                    
                # clip_grad_norm_ returns the norm before clipping
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['grad_clip'])
                
                # Update Weights
                if scaler:
                    scaler.step(optimzer)
                    scaler.update()
                else:
                    optimzer.step()
                    
                # Flush gradients
                optimzer.zero_grad(set_to_none=True)
                
                iter_num +=1
                
                # Logging
                if iter_num % CONFIG['log_interval'] == 0:
                    
                    # Compute Perplexity
                    avg_loss = running_loss / (CONFIG['log_interval']*CONFIG['accum_steps'])
                    perplexity = math.exp(avg_loss) if avg_loss < 20 else -1
                    
                    # Calculate Tokens per second (throughput)
                    # We processed (batch_size * seq_len * accum_steps * log_interval) tokens
                    tokens_processed = (CONFIG['batch_size'] * CONFIG['max_seq_len'] * CONFIG['accum_steps'] * CONFIG['log_interval'])
                    
                    dt = time.time() - t0
                    tokens_per_sec = tokens_processed / dt
                    t0 = time.time()
                    
                    # Calculate Memory (if CUDA)
                    mem_usage = 0
                    if device == 'cuda':
                        mem_usage = torch.cuda.max_memory_allocated()/1024**2 # in MB
                        torch.cuda.reset_peak_memory_stats() # reset for next logging
                    
                    print(f"step {iter_num} | loss: {avg_loss:.4f} | ppl: {perplexity:.1f} | "
                          f"norm: {grad_norm:.2f} | mem: {mem_usage:.0f}MB | {tokens_per_sec:.0f} tok/s")
                    
                    if CONFIG['use_wandb']:
                        wandb.log({
                            "train/loss": avg_loss,
                            "train/perplexity": perplexity,
                            "train/learning_rate": lr,
                            "train/grad_norm": grad_norm,
                            "perf/tokens_per_sec": tokens_per_sec,
                            "perf/memory_MB": mem_usage
                        })
                    running_loss = 0.0
                    
        # Validation
        val_loss = 0.0
        if val_loader is not None:
            print("Running Validation...", end="")
            val_loss = evaluate(model, val_loader, device, vocab_size=CONFIG['vocab_size'])
            val_ppl = math.exp(val_loss) if val_loss <20 else -1
            print(f" Val Loss: {val_loss:.4f} | Val PPL: {val_ppl:.1f}")
            
            # Early Stopping Check
            is_new_best = early_stopper(val_loss)
            
            if is_new_best:
                print("Found New Best Model! Saving checkpoint...")
                best_model_weights = copy.deepcopy(model.state_dict())
                torch.save(best_model_weights, f"{CONFIG['save_dir']}/best_model.pt")
                
            if early_stopper.early_stop:
                print("Early stopping triggered. Ending training.")
                print("Restoring best model weights...")
                model.load_state_dict(best_model_weights)
                break
                    
        print(f"Saving Checkpoint for Epoch {epoch+1}...")
        torch.save(model.state_dict(), f"{CONFIG['save_dir']}/model_epoch_{epoch+1}.pt")
        
    print("Training Complete!")

In [21]:
train()

Model Parameters: 13.44M
Starting epoch 1/100...
step 10 | loss: 10.5097 | ppl: 36668.1 | norm: 2.10 | mem: 935MB | 20064 tok/s
step 20 | loss: 10.4981 | ppl: 36248.0 | norm: 2.18 | mem: 935MB | 26114 tok/s
step 30 | loss: 10.4668 | ppl: 35129.4 | norm: 2.19 | mem: 935MB | 25665 tok/s
step 40 | loss: 10.4280 | ppl: 33791.9 | norm: 2.13 | mem: 935MB | 26156 tok/s
step 50 | loss: 10.3679 | ppl: 31822.9 | norm: 2.37 | mem: 935MB | 26111 tok/s
step 60 | loss: 10.2917 | ppl: 29487.7 | norm: 2.21 | mem: 935MB | 25989 tok/s
step 70 | loss: 10.1948 | ppl: 26763.2 | norm: 2.27 | mem: 935MB | 26442 tok/s
step 80 | loss: 10.0897 | ppl: 24094.4 | norm: 2.04 | mem: 935MB | 26266 tok/s
step 90 | loss: 9.9940 | ppl: 21895.3 | norm: 1.95 | mem: 935MB | 25885 tok/s
step 100 | loss: 9.8866 | ppl: 19665.1 | norm: 1.70 | mem: 935MB | 26145 tok/s
step 110 | loss: 9.7997 | ppl: 18028.2 | norm: 1.60 | mem: 935MB | 26271 tok/s
step 120 | loss: 9.7272 | ppl: 16767.6 | norm: 1.52 | mem: 935MB | 25859 tok/s
step

## Fine-Tune

In [45]:
def create_knowledge_dataset():
    
    # Sample facts about Newton
    newton_facts = [
        "Isaac Newton was a physicist and mathematician who developed the laws of motion.",
        "Newton is famous for his theory of gravity, inspired by a falling apple.",
        "Sir Isaac Newton wrote the Principia Mathematica, a key book in science.",
        "Newton discovered that white light is made of a spectrum of colors.",
        "He was a key figure in the scientific revolution of the 17th century."
    ]
    
    # Sample questions about Newton
    questions = [
        "Who is Newton?", "Tell me about Isaac Newton.", "What did Newton do?",
        "Who discovered gravity?", "Why is Newton famous?", "What are Newton's contributions to science?"
    ]
    
    knowledge_samples = []
    
    # Generate 100 sample conversations
    # Mix and match facts and questions
    for _ in range(100):
        q = random.choice(questions)
        f = random.choice(newton_facts)
        
        # Add some chat falvour
        opener = random.choice(["", "Sure! ", "Here is the answer: ", "Great question. "])
        response = f"{opener}{f}"
        
        sample = {
            "instruction": q,
            "context": "",
            "response": response
        }
        
        knowledge_samples.append(sample)
    
    output_dir = "data/raw"
    os.makedirs(output_dir, exist_ok=True)
    
    with open(os.path.join(output_dir, "injection_knowledge_dataset.jsonl"), "w") as w:
        for sample in knowledge_samples:
            w.write(json.dumps(sample)+'\n')
            
    print(f"Created knowledge injection dataset with {len(knowledge_samples)} samples at data/raw/injection_knowledge_dataset.jsonl")

In [46]:
create_knowledge_dataset()

Created knowledge injection dataset with 100 samples at data/raw/injection_knowledge_dataset.jsonl


In [51]:
class ChatDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer_path: str, max_seq_len: int=256):
        self.tokenizer = Tokenizer.from_file(tokenizer_path)
        self.max_seq_len = max_seq_len
        self.pad_token_id = self.tokenizer.token_to_id("[PAD]")
        self.sep_token_id = self.tokenizer.token_to_id("[SEP]")
        
        self.samples = []
        
        # Load dataset fom HuggingFace
        # Limit to 2000 samples for so that it doesn't drown newton knowledge samples
        print("Downloading dataset...")
        hf_dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
        for i, item in enumerate(hf_dataset):
            if i>2000:
                break
            self.add_sample(item["instruction"], item["response"], item["context"])
        
        # Load knowledge injection dataset
        print("Loading knowledge injection dataset...")
        with open("/kaggle/working/data/raw/injection_knowledge_dataset.jsonl", "r") as f:
            for line in f:
                item = json.loads(line)
                
                # We add these multiple times to increase their presence in the training data
                for _ in range(5):
                    self.add_sample(item["instruction"], item["response"], item.get('context', ''))
        print(f"Total samples in ChatDataset: {len(self.samples)}")
        
    def __len__(self):
        return len(self.samples)
    
    def add_sample(self, instruction: str, response: str, context: str):
        # Format:
        # ### User:
        # [Instruction]
        # Context: [Context] (Optional)
        #
        # ### Assistant:
        # [Response]
        
        ctx_str = f"\nContext: {context}" if context else ""
        text = f"### User:\n{instruction}{ctx_str}\n\n### Assistant:\n{response}"
        
        # Tokenize
        encoded = self.tokenizer.encode(text)
        ids = encoded.ids
        
        # Add EOS token at the end
        if self.sep_token_id is not None:
            ids.append(self.sep_token_id)
            
        self.samples.append(ids)
    
    def __getitem__(self, idx):
        
        if isinstance(idx, list):
            return [self.__getitem__(i) for i in idx]

        ids = self.samples[idx]

        if len(ids) > self.max_seq_len:
            ids = ids[:self.max_seq_len]

        padding_len = self.max_seq_len - len(ids)
        if padding_len>0:
            ids = ids + [self.pad_token_id]*padding_len
        
        x = torch.tensor(ids, dtype=torch.long)
        
        # Target is same as input for casual language modeling
        # In advance instruction tuning, the user prompt will be masked so that the model only learns to generate the response
        return x, x

In [54]:
CONFIG = {
    # System
    'device': 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu',
    'num_workers': 0,      # Set to 0 for simpler debugging, 4 for speed
    
    # Model Architecture (Must match what we built)
    'vocab_size': 32000,
    'd_model': 256,
    'n_layers': 8,
    'n_head': 8,
    'n_kv_head': 2,        # GQA
    'window_size': 64,     # SWA
    'max_seq_len': 256,    # Short context for faster training
    'mlp_ratio': 2.5,
    
    # Training (The Optimizer Math)
    'batch_size': 4,       # Micro-batch (fits in memory)
    'accum_steps': 8,      # Virtual Batch Size = 4 * 8 = 32
    'learning_rate': 3e-5, # A slower LR for finetuning
    'max_epochs': 5,       # Just a few epochs for finetuning
    'patience': 3,
    'weight_decay': 0.01,  # AdamW regularization
    'grad_clip': 1.0,      # Prevents exploding gradients
    
    # Data Paths
    'pretrained_model_path': '/kaggle/working/best_model.pt',
    'tokenizer_path': '/kaggle/working/data/tokenizer/tiny_slm_tokenizer.json',
    'save_dir': 'checkpoints',
    
    # Logging
    'use_wandb': False,    # Set to True if you have an account
    'log_interval': 10     # Print every 10 steps
}

def train():
    os.makedirs(CONFIG['save_dir'], exist_ok=True)
    device = CONFIG['device']
    
    # Load Data
    dataset = ChatDataset(tokenizer_path=CONFIG['tokenizer_path'], max_seq_len=CONFIG['max_seq_len'])
    loader = DataLoader(dataset, batch_size=CONFIG['batch_size'], shuffle=True)
    
    # Initialize Model
    model = TinySLM(config=CONFIG).to(device=device)
    
    print("Loading pretrained model from:", CONFIG['pretrained_model_path'])
    if CONFIG['pretrained_model_path'] and os.path.isfile(CONFIG['pretrained_model_path']):
        state_dict = torch.load(CONFIG['pretrained_model_path'], map_location=device)
        model.load_state_dict(state_dict=state_dict)
        print("Pretrained model loaded successfully.")
    else:
        raise FileNotFoundError(f"Pretrained model not found at {CONFIG['pretrained_model_path']}")
    
    # Optimizer - Low LR
    optimizer = optim.AdamW(
        model.parameters(),
        lr=CONFIG['learning_rate'],
        weight_decay=CONFIG['weight_decay']
    )
    scaler = torch.amp.GradScaler('cuda') if device == 'cuda' else None
    
    model.train()
    
    # Training Loop
    total_steps = len(loader) * CONFIG['max_epochs'] // CONFIG['accum_steps']
    print(f"Total finetuning steps: {total_steps}")
    
    iter_num = 0
    running_loss = 0.0
    
    for epoch in range(CONFIG['max_epochs']):
        print(f"Starting epoch {epoch + 1}/{CONFIG['max_epochs']}")
        
        for batch_idx, (X, Y) in enumerate(loader):
            X, Y = X.to(device), Y.to(device)
            
            # Forward
            if scaler:
                with torch.amp.autocast('cuda'):
                    logits, _ = model(X[:, :-1])
                    
                    loss = nn.functional.cross_entropy(
                        logits.reshape(-1, CONFIG['vocab_size']),
                        Y[:, 1:].reshape(-1),
                        ignore_index=dataset.pad_token_id # Don't learn padding
                    )
            else:
                logits, _ = model(X[:, :-1])
                
                loss = nn.functional.cross_entropy(
                    logits.reshape(-1, CONFIG['vocab_size']),
                    Y[:, 1:].reshape(-1),
                    ignore_index=dataset.pad_token_id
                )
                
            loss = loss / CONFIG['accum_steps'] # Normalize loss for accumulation
            
            # Backward
            if scaler:
                scaler.scale(loss).backward()
            else:
                loss.backward()
                
            running_loss += loss.item() * CONFIG['accum_steps'] # Denormalize for logging
            
            # Step
            if (batch_idx + 1)%CONFIG['accum_steps'] == 0:
                if scaler:
                    scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['grad_clip'])
                
                if scaler:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()
                    
                optimizer.zero_grad(set_to_none=True)
                iter_num+=1
                
                if iter_num % CONFIG['log_interval'] == 0:
                    avg_loss = running_loss / (CONFIG['log_interval']*CONFIG['accum_steps'])
                    print(f"step {iter_num} | loss: {avg_loss:.4f}")
                    running_loss = 0.0
                    
        save_path = f"{CONFIG['save_dir']}/chat_model_epoch_{epoch+1}.pt"
        torch.save(model.state_dict(), save_path)
        print(f"Model checkpoint saved at {save_path}")
        
        generate_test(model, CONFIG['tokenizer_path'], device)
        
def generate_test(model, tokenizer_path, device):
    from tokenizers import Tokenizer
    import torch.nn.functional as F
    
    tokenizer = Tokenizer.from_file(tokenizer_path)
    model.eval()
    
    prompt = "### User:\nWho is Newton?\n\n### Assistant:\n"
    ids = tokenizer.encode(prompt).ids
    x = torch.tensor([ids], dtype=torch.long).to(device)
    
    print("\n--- Test Response (Top-K Sampling) ---")
    
    for _ in range(50):
        with torch.no_grad():
            logits, _ = model(x)
            
            # 1. Get logits for the last token
            next_token_logits = logits[0, -1, :]
            
            # 2. Temperature scaling (Higher = more creative, Lower = more focused)
            next_token_logits = next_token_logits / 0.7 
            
            # 3. Top-K Sampling (Kill the long tail of bad words)
            # Only keep the top 20 most likely words
            top_k = 20
            v, _ = torch.topk(next_token_logits, top_k)
            next_token_logits[next_token_logits < v[-1]] = -float('Inf')
            
            # 4. Convert to Probabilities and Sample
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()
            
            # Append
            x = torch.cat((x, torch.tensor([[next_token]], device=device)), dim=1)
            
            if next_token == tokenizer.token_to_id("[SEP]"): 
                break
            
    print(tokenizer.decode(x[0].tolist()))
    print("---------------------\n")
    model.train()

In [55]:
train()

Downloading dataset...
Loading knowledge injection dataset...
Total samples in ChatDataset: 2501
Loading pretrained model from: /kaggle/working/best_model.pt
Pretrained model loaded successfully.
Total finetuning steps: 391
Starting epoch 1/5
step 10 | loss: 7.1322
step 20 | loss: 6.7059
step 30 | loss: 6.4741
step 40 | loss: 6.2768
step 50 | loss: 6.0528
step 60 | loss: 6.0358
step 70 | loss: 6.0006
Model checkpoint saved at checkpoints/chat_model_epoch_1.pt

--- Test Response (Top-K Sampling) ---
### User : Who is Newton ? ### Assistant : The key is many ways to do be made in a way .
---------------------

Starting epoch 2/5
step 80 | loss: 6.2003
step 90 | loss: 5.7765
step 100 | loss: 5.7650
step 110 | loss: 5.7810
step 120 | loss: 5.7922
step 130 | loss: 5.7841
step 140 | loss: 5.7058
step 150 | loss: 5.6291
Model checkpoint saved at checkpoints/chat_model_epoch_2.pt

--- Test Response (Top-K Sampling) ---
### User : Who is Newton ? ### Assistant : Great question . He was a key fi

In [36]:
model = TinySLM(CONFIG)
model.load_state_dict(torch.load("/kaggle/working/checkpoints/best_model.pt")) # Your pre-trained path
model.eval()

# Load Tokenizer
tokenizer = Tokenizer.from_file("/kaggle/working/data/tokenizer/tiny_slm_tokenizer.json")

# Simple Test
prompt = "Systems biology is a big"
ids = torch.tensor([tokenizer.encode(prompt).ids])
logits, _ = model(ids)
next_token = torch.argmax(logits[0, -1, :]).item()
print(f"Input: {prompt}")
print(f"Predicted Next Word: {tokenizer.decode([next_token])}")

Input: Systems biology is a big
Predicted Next Word: city


In [56]:
!mv checkpoints/chat_model_epoch_5.pt .

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
