In [1]:
import os, json, time, math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tokenizers import Tokenizer
from tqdm import tqdm

In [2]:
DATA_FILE = "../prepared/instruction_dataset.jsonl"
TOKENIZER_FILE = "../tokens/tokenizer.json"
CHECKPOINT_DIR = "checkpoints_inst"  
PRETRAINED_CKPT = "checkpoints/ckpt_step_8070.pt" 
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [3]:
VOCAB_SIZE = 16000
DIM = 256
NUM_LAYERS = 6
NUM_HEADS = 8
FFN_DIM = 1024
MAX_SEQ_LEN = 512

# TUNING HYPERPARAMETERS (Lower LR, fewer epochs)
PHYSICAL_BATCH = 1
GRAD_ACCUM = 32
EPOCHS = 5      # 3-5 is usually enough for tuning
LR = 2e-5       # 10x smaller than pre-training
WEIGHT_DECAY = 0.01
LOG_EVERY = 10
SAVE_EVERY = 10000
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):  
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        # Force float32 for norm calculation to avoid overflow
        x_f32 = x.float()
        norm = x_f32.pow(2).mean(-1, keepdim=True)
        return x * torch.rsqrt(norm + self.eps).type_as(x) * self.weight

In [5]:
def rotate_half(x):
    x1 = x[..., :x.shape[-1]//2]
    x2 = x[..., x.shape[-1]//2:]
    return torch.cat([-x2, x1], dim=-1)

def apply_rope(x, rope_sin, rope_cos):
    return (x * rope_cos) + (rotate_half(x) * rope_sin)

In [6]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x, mask, rope_sin, rope_cos):
        B, T, C = x.shape

        qkv = self.qkv(x)  # (B, T, 3*C)
        q, k, v = qkv.chunk(3, dim=-1)

        # reshape -> (B, heads, T, head_dim)
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # apply RoPE to q,k
        q = apply_rope(q, rope_sin, rope_cos)
        k = apply_rope(k, rope_sin, rope_cos)

        att = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        att = att.masked_fill(mask == 0, float('-inf'))
        att = torch.softmax(att, dim=-1)

        y = att @ v  # (B, heads, T, head_dim)
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        return self.proj(y)


In [7]:
class SwiGLU(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim)
        self.w2 = nn.Linear(dim, hidden_dim)
        self.w3 = nn.Linear(hidden_dim, dim)

    def forward(self, x):
        return self.w3(torch.nn.functional.silu(self.w1(x)) * self.w2(x))


In [8]:
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, ffn_dim):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.attn = MultiHeadSelfAttention(dim, num_heads)
        self.norm2 = RMSNorm(dim)
        self.ffn = SwiGLU(dim, ffn_dim)

    def forward(self, x, mask, rope_sin, rope_cos):
        x = x + self.attn(self.norm1(x), mask, rope_sin, rope_cos)
        x = x + self.ffn(self.norm2(x))
        return x


In [9]:
class TransformerLM(nn.Module):
    def __init__(self, vocab_size, dim, num_layers, num_heads, ffn_dim, max_seq_len=2048):
        super().__init__()

        self.vocab_size = vocab_size
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.max_seq_len = max_seq_len

        self.token_emb = nn.Embedding(vocab_size, dim)

        self.blocks = nn.ModuleList([
            TransformerBlock(dim, num_heads, ffn_dim)
            for _ in range(num_layers)
        ])

        self.norm = RMSNorm(dim)
        self.lm_head = nn.Linear(dim, vocab_size, bias=False)

        # weight tying
        self.lm_head.weight = self.token_emb.weight

        pos = torch.arange(max_seq_len)  # [T]
        freqs = 1.0 / (10000 ** (torch.arange(0, self.head_dim, 2) / self.head_dim))
        # freqs len = head_dim/2

        sinusoid = torch.einsum("i,j->ij", pos, freqs)  # [T, head_dim/2]

        rope_sin = sinusoid.sin()   # [T, head_dim/2]
        rope_cos = sinusoid.cos()   # [T, head_dim/2]

        # Expand to: [1, 1, T, head_dim]
        rope_sin = torch.cat([rope_sin, rope_sin], dim=-1)
        rope_cos = torch.cat([rope_cos, rope_cos], dim=-1)

        self.register_buffer("rope_sin", rope_sin.unsqueeze(0).unsqueeze(0))
        self.register_buffer("rope_cos", rope_cos.unsqueeze(0).unsqueeze(0))
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, mask):
        B, T = idx.shape

        x = self.token_emb(idx)

        # Correct slice: [1,1,T,head_dim]
        rope_sin = self.rope_sin[:, :, :T, :]
        rope_cos = self.rope_cos[:, :, :T, :]

        for blk in self.blocks:
            x = blk(x, mask, rope_sin, rope_cos)

        x = self.norm(x)
        logits = self.lm_head(x)

        return logits


In [10]:
def causal_mask(T, device):
    m = torch.tril(torch.ones((T, T), dtype=torch.bool, device=device))
    return m.unsqueeze(0).unsqueeze(0)

In [11]:
class InstrDataset(Dataset):
    def __init__(self, fname, tok, seq=MAX_SEQ_LEN):
        self.rows = []
        self.eos_id = tok.token_to_id("</s>")
        if self.eos_id is None: self.eos_id = tok.token_to_id("<|endoftext|>")
        if self.eos_id is None: self.eos_id = 0 # Fallback

        print(f"Loading {fname}...")
        with open(fname, "r", encoding="utf-8") as f:
            for ln in f:
                if not ln.strip(): continue
                o = json.loads(ln)
                
               
                instr = o.get("instruction","").strip()
                inp = o.get("input","").strip()
                out = o.get("output","").strip()

                prompt_text = f"{instr}\n\n{inp}\n\n### Response:\n"
                
                full_text = prompt_text + out

                prompt_ids = tok.encode(prompt_text).ids
                full_ids = tok.encode(full_text).ids

                if not full_ids or full_ids[-1] != self.eos_id:
                    full_ids.append(self.eos_id)

                if len(full_ids) > seq: continue 
                
                self.rows.append((prompt_ids, full_ids))
        print(f"Loaded {len(self.rows)} instruction samples.")

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, i):
        prompt_ids, full_ids = self.rows[i]
        
        # Prepare tensors
        x = torch.tensor(full_ids, dtype=torch.long)
        
        # Labels: Same as X, but we mask the prompt
        # -100 is PyTorch's "Ignore Index" for CrossEntropy
        y = x.clone()
        prompt_len = len(prompt_ids)
        
        # CRITICAL: Mask the prompt so we don't calculate loss on it
        # We only want to learn to generate the OUTPUT
        if prompt_len < len(y):
            y[:prompt_len] = -100
        else:
            # Should not happen given logic, but safety
            y[:] = -100
            
        return x, y

def collate_fn(batch):
    # Pad batch to longest sequence
    xs, ys = zip(*batch)
    max_len = max(len(x) for x in xs)
    
    # Pad inputs with 0, targets with -100
    X_pad = torch.full((len(xs), max_len), 0, dtype=torch.long)
    Y_pad = torch.full((len(xs), max_len), -100, dtype=torch.long)
    
    for i, (x, y) in enumerate(zip(xs, ys)):
        X_pad[i, :len(x)] = x
        Y_pad[i, :len(y)] = y
        
    return X_pad, Y_pad

In [12]:
tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
ds = InstrDataset(DATA_FILE, tokenizer)
loader = DataLoader(ds, batch_size=PHYSICAL_BATCH, shuffle=True, collate_fn=collate_fn)

print(f"Loading Base Model: {PRETRAINED_CKPT}")
model = TransformerLM(VOCAB_SIZE, DIM, NUM_LAYERS, NUM_HEADS, FFN_DIM, MAX_SEQ_LEN).to(DEVICE)

checkpoint = torch.load(PRETRAINED_CKPT, map_location=DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
print("Pre-trained weights loaded successfully.")

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scaler = torch.amp.GradScaler(enabled=(DEVICE.type=="cuda"))

print(f"Starting Instruction Tuning on {DEVICE}...")

global_step = 0
for epoch in range(EPOCHS):
    model.train()
    pbar = tqdm(loader, desc=f"Epoch {epoch+1}")
    
    epoch_loss_sum = 0.0
    num_batches = 0
    
    for x, y in pbar:
        x, y = x.to(DEVICE), y.to(DEVICE)
        mask = causal_mask(x.size(1), DEVICE)
        
        with torch.amp.autocast(device_type="cuda" if DEVICE.type=="cuda" else "cpu"):
            logits = model(x, mask)
            loss = F.cross_entropy(logits[:, :-1, :].reshape(-1, VOCAB_SIZE), 
                                   y[:, 1:].reshape(-1), 
                                   ignore_index=-100)
            loss = loss / GRAD_ACCUM
            
        scaler.scale(loss).backward()
        current_loss = loss.item() * GRAD_ACCUM
        epoch_loss_sum += current_loss
        num_batches += 1
        
        if (global_step + 1) % GRAD_ACCUM == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            
        global_step += 1

        if global_step % LOG_EVERY == 0:
            pbar.set_postfix({"inst_loss": f"{current_loss:.4f}"})
        
        if global_step % SAVE_EVERY == 0:
            fname = os.path.join(CHECKPOINT_DIR, f"inst_tune_step_{global_step}.pt")
            torch.save({"model_state_dict": model.state_dict()}, fname)

    avg_loss = epoch_loss_sum / max(1, num_batches)
    print(f"Epoch {epoch+1} finished | Avg Loss: {avg_loss:.4f}")

torch.save({"model_state_dict": model.state_dict()}, os.path.join(CHECKPOINT_DIR, "inst_tune_final.pt"))
print("Instruction Fine-Tuning Completed.")

Loading ../prepared/instruction_dataset.jsonl...
Loaded 246525 instruction samples.
Loading Base Model: checkpoints/ckpt_step_8070.pt
Pre-trained weights loaded successfully.
Starting Instruction Tuning on cuda...


Epoch 1: 100%|██████████| 246525/246525 [2:56:43<00:00, 23.25it/s, inst_loss=0.3547]  


Epoch 1 finished | Avg Loss: 0.4649


Epoch 2: 100%|██████████| 246525/246525 [2:34:01<00:00, 26.68it/s, inst_loss=0.0031]  


Epoch 2 finished | Avg Loss: 0.2791


Epoch 3: 100%|██████████| 246525/246525 [2:33:53<00:00, 26.70it/s, inst_loss=0.6016]  


Epoch 3 finished | Avg Loss: 0.1851


Epoch 4: 100%|██████████| 246525/246525 [2:33:59<00:00, 26.68it/s, inst_loss=0.0000]  


Epoch 4 finished | Avg Loss: 0.1574


Epoch 5: 100%|██████████| 246525/246525 [2:34:03<00:00, 26.67it/s, inst_loss=0.0003]  


Epoch 5 finished | Avg Loss: 0.1421
Instruction Fine-Tuning Completed.
