# 09. Fine-Tuning

Fine-tuning is the process of taking a pre-trained model and training it further on a specific dataset to adapt it to a particular domain or task. This is often much more efficient than training from scratch.

## Steps for Fine-Tuning
1.  **Load Pre-trained Model**: Load the weights from a checkpoint.
2.  **Prepare Domain Data**: Load the new dataset.
3.  **Adjust Hyperparameters**: Usually, a lower learning rate is used for fine-tuning to avoid destroying the pre-trained knowledge.
4.  **Train**: Run the training loop for a few epochs.

In [4]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Dummy setup
class DecoderLM(nn.Module):
    def __init__(self): super().__init__()
    def forward(self, x, targets=None): return torch.randn(x.size(0), x.size(1), 1000), torch.tensor(0.5)

model = DecoderLM()
device = torch.device("cpu")

## 1. Loading Pre-trained Weights

We load the state dictionary from a saved checkpoint.

In [5]:
def load_checkpoint(model, path):
    if os.path.exists(path):
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Loaded checkpoint from {path}")
    else:
        print(f"Checkpoint {path} not found. Starting from scratch (or skipping load).")

import os
load_checkpoint(model, "checkpoint.pt") 

RuntimeError: Error(s) in loading state_dict for DecoderLM:
	Unexpected key(s) in state_dict: "token_emb.weight", "pos_emb.weight", "blocks.0.ln1.weight", "blocks.0.ln1.bias", "blocks.0.attn.mask", "blocks.0.attn.qkv_proj.weight", "blocks.0.attn.qkv_proj.bias", "blocks.0.attn.out_proj.weight", "blocks.0.attn.out_proj.bias", "blocks.0.ln2.weight", "blocks.0.ln2.bias", "blocks.0.ffn.fc1.weight", "blocks.0.ffn.fc1.bias", "blocks.0.ffn.fc2.weight", "blocks.0.ffn.fc2.bias", "blocks.1.ln1.weight", "blocks.1.ln1.bias", "blocks.1.attn.mask", "blocks.1.attn.qkv_proj.weight", "blocks.1.attn.qkv_proj.bias", "blocks.1.attn.out_proj.weight", "blocks.1.attn.out_proj.bias", "blocks.1.ln2.weight", "blocks.1.ln2.bias", "blocks.1.ffn.fc1.weight", "blocks.1.ffn.fc1.bias", "blocks.1.ffn.fc2.weight", "blocks.1.ffn.fc2.bias", "blocks.2.ln1.weight", "blocks.2.ln1.bias", "blocks.2.attn.mask", "blocks.2.attn.qkv_proj.weight", "blocks.2.attn.qkv_proj.bias", "blocks.2.attn.out_proj.weight", "blocks.2.attn.out_proj.bias", "blocks.2.ln2.weight", "blocks.2.ln2.bias", "blocks.2.ffn.fc1.weight", "blocks.2.ffn.fc1.bias", "blocks.2.ffn.fc2.weight", "blocks.2.ffn.fc2.bias", "blocks.3.ln1.weight", "blocks.3.ln1.bias", "blocks.3.attn.mask", "blocks.3.attn.qkv_proj.weight", "blocks.3.attn.qkv_proj.bias", "blocks.3.attn.out_proj.weight", "blocks.3.attn.out_proj.bias", "blocks.3.ln2.weight", "blocks.3.ln2.bias", "blocks.3.ffn.fc1.weight", "blocks.3.ffn.fc1.bias", "blocks.3.ffn.fc2.weight", "blocks.3.ffn.fc2.bias", "blocks.4.ln1.weight", "blocks.4.ln1.bias", "blocks.4.attn.mask", "blocks.4.attn.qkv_proj.weight", "blocks.4.attn.qkv_proj.bias", "blocks.4.attn.out_proj.weight", "blocks.4.attn.out_proj.bias", "blocks.4.ln2.weight", "blocks.4.ln2.bias", "blocks.4.ffn.fc1.weight", "blocks.4.ffn.fc1.bias", "blocks.4.ffn.fc2.weight", "blocks.4.ffn.fc2.bias", "blocks.5.ln1.weight", "blocks.5.ln1.bias", "blocks.5.attn.mask", "blocks.5.attn.qkv_proj.weight", "blocks.5.attn.qkv_proj.bias", "blocks.5.attn.out_proj.weight", "blocks.5.attn.out_proj.bias", "blocks.5.ln2.weight", "blocks.5.ln2.bias", "blocks.5.ffn.fc1.weight", "blocks.5.ffn.fc1.bias", "blocks.5.ffn.fc2.weight", "blocks.5.ffn.fc2.bias", "blocks.6.ln1.weight", "blocks.6.ln1.bias", "blocks.6.attn.mask", "blocks.6.attn.qkv_proj.weight", "blocks.6.attn.qkv_proj.bias", "blocks.6.attn.out_proj.weight", "blocks.6.attn.out_proj.bias", "blocks.6.ln2.weight", "blocks.6.ln2.bias", "blocks.6.ffn.fc1.weight", "blocks.6.ffn.fc1.bias", "blocks.6.ffn.fc2.weight", "blocks.6.ffn.fc2.bias", "blocks.7.ln1.weight", "blocks.7.ln1.bias", "blocks.7.attn.mask", "blocks.7.attn.qkv_proj.weight", "blocks.7.attn.qkv_proj.bias", "blocks.7.attn.out_proj.weight", "blocks.7.attn.out_proj.bias", "blocks.7.ln2.weight", "blocks.7.ln2.bias", "blocks.7.ffn.fc1.weight", "blocks.7.ffn.fc1.bias", "blocks.7.ffn.fc2.weight", "blocks.7.ffn.fc2.bias", "blocks.8.ln1.weight", "blocks.8.ln1.bias", "blocks.8.attn.mask", "blocks.8.attn.qkv_proj.weight", "blocks.8.attn.qkv_proj.bias", "blocks.8.attn.out_proj.weight", "blocks.8.attn.out_proj.bias", "blocks.8.ln2.weight", "blocks.8.ln2.bias", "blocks.8.ffn.fc1.weight", "blocks.8.ffn.fc1.bias", "blocks.8.ffn.fc2.weight", "blocks.8.ffn.fc2.bias", "blocks.9.ln1.weight", "blocks.9.ln1.bias", "blocks.9.attn.mask", "blocks.9.attn.qkv_proj.weight", "blocks.9.attn.qkv_proj.bias", "blocks.9.attn.out_proj.weight", "blocks.9.attn.out_proj.bias", "blocks.9.ln2.weight", "blocks.9.ln2.bias", "blocks.9.ffn.fc1.weight", "blocks.9.ffn.fc1.bias", "blocks.9.ffn.fc2.weight", "blocks.9.ffn.fc2.bias", "blocks.10.ln1.weight", "blocks.10.ln1.bias", "blocks.10.attn.mask", "blocks.10.attn.qkv_proj.weight", "blocks.10.attn.qkv_proj.bias", "blocks.10.attn.out_proj.weight", "blocks.10.attn.out_proj.bias", "blocks.10.ln2.weight", "blocks.10.ln2.bias", "blocks.10.ffn.fc1.weight", "blocks.10.ffn.fc1.bias", "blocks.10.ffn.fc2.weight", "blocks.10.ffn.fc2.bias", "blocks.11.ln1.weight", "blocks.11.ln1.bias", "blocks.11.attn.mask", "blocks.11.attn.qkv_proj.weight", "blocks.11.attn.qkv_proj.bias", "blocks.11.attn.out_proj.weight", "blocks.11.attn.out_proj.bias", "blocks.11.ln2.weight", "blocks.11.ln2.bias", "blocks.11.ffn.fc1.weight", "blocks.11.ffn.fc1.bias", "blocks.11.ffn.fc2.weight", "blocks.11.ffn.fc2.bias", "ln_f.weight", "ln_f.bias", "lm_head.weight". 

## 2. Fine-Tuning Loop

The loop is identical to the training loop, but we typically use a smaller learning rate (e.g., 1e-5 instead of 3e-4).

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) # Lower LR

# Dummy domain data
from torch.utils.data import Dataset
class DomainDataset(Dataset):
    def __len__(self): return 100
    def __getitem__(self, idx): return {"input_ids": torch.randint(0, 1000, (32,)), "labels": torch.randint(0, 1000, (32,))}

domain_loader = DataLoader(DomainDataset(), batch_size=8, shuffle=True)

# Run training (simplified)
model.train()
for batch in domain_loader:
    optimizer.zero_grad()
    _, loss = model(batch['input_ids'], batch['labels'])
    loss.backward()
    optimizer.step()
    
print("Fine-tuning step complete.")

ValueError: optimizer got an empty parameter list

## 3. Parameter-Efficient Fine-Tuning (PEFT)

For very large models, fine-tuning all parameters is expensive. Techniques like **LoRA (Low-Rank Adaptation)** allow fine-tuning only a small subset of parameters.

In LoRA, we freeze the pre-trained weights $W$ and add a trainable low-rank decomposition $BA$:

$$ h = Wx + BAx $$

This significantly reduces memory requirements.