# 03: LoRA (Low-Rank Adaptation) from Scratch

Deep learning paper implementation from scratch using PyTorch.
- Dramatically reduces trainable parameters (often 10,000x fewer)
- No additional inference latency (can merge weights)
- Can switch between tasks by swapping LoRA weights
1. LoRA Module Implementation


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
import copy
from collections import OrderedDict
import matplotlib.pyplot as plt
import time

# Set random seeds for reproducibility
torch.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


## 1. LoRA Module Implementation

We implement a LoRA layer that wraps `nn.Linear`. The key components are:
- **A matrix**: Initialized with Kaiming uniform (or Gaussian)
- **B matrix**: Initialized to zero (so LoRA starts as identity)
- **Scaling factor**: $\alpha / r$ to control the magnitude of updates

In [None]:
class LoRALinear(nn.Module):
    def __init__(self, base_layer: nn.Linear, r: int = 4, alpha: float = 1.0, dropout: float = 0.0):
        super().__init__()
        self.base_layer = base_layer
        self.r = r
        self.alpha = alpha
        self.scaling = alpha / r
        
        in_features = base_layer.in_features
        out_features = base_layer.out_features
        
        # LoRA matrices
        # A: (r, in_features) - projects input to low-rank space
        # B: (out_features, r) - projects back to output space
        self.lora_A = nn.Parameter(torch.zeros(r, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, r))
        
        # Optional dropout on LoRA path
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        
        # Initialize A with Kaiming uniform, B with zeros
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
        
        # Freeze base layer
        for param in self.base_layer.parameters():
            param.requires_grad = False
        
        # For merge/unmerge
        self.merged = False
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Base forward pass (frozen)
        base_output = self.base_layer(x)
        
        if self.merged:
            # If merged, base_layer already contains LoRA weights
            return base_output
        
        # LoRA path: x @ A^T @ B^T * scaling
        # x: (..., in_features)
        # A: (r, in_features) -> x @ A^T: (..., r)
        lora_output = self.dropout(x) @ self.lora_A.T @ self.lora_B.T * self.scaling
        
        return base_output + lora_output
    
    def merge(self):
        if self.merged:
            return
        # W_merged = W_0 + B @ A * scaling
        delta_W = (self.lora_B @ self.lora_A) * self.scaling
        self.base_layer.weight.data += delta_W
        self.merged = True
    
    def unmerge(self):
        if not self.merged:
            return
        delta_W = (self.lora_B @ self.lora_A) * self.scaling
        self.base_layer.weight.data -= delta_W
        self.merged = False
    
    def get_lora_params(self):
        return [self.lora_A, self.lora_B]
    
    @property
    def lora_param_count(self):
        return self.lora_A.numel() + self.lora_B.numel()
    
    @property
    def base_param_count(self):
        return sum(p.numel() for p in self.base_layer.parameters())


In [None]:
# Quick test of LoRALinear
base_linear = nn.Linear(64, 128)
lora_linear = LoRALinear(base_linear, r=4, alpha=1.0)

x = torch.randn(2, 10, 64)  # (batch, seq, features)

# Before training, LoRA should be near-identity (B is zeros)
with torch.no_grad():
    base_out = base_linear(x)
    lora_out = lora_linear(x)
    diff = (base_out - lora_out).abs().max().item()
    print(f"Max difference (should be ~0 since B=0): {diff:.2e}")

print(f"Base params: {lora_linear.base_param_count:,}")
print(f"LoRA params: {lora_linear.lora_param_count:,}")
print(f"Compression ratio: {lora_linear.base_param_count / lora_linear.lora_param_count:.1f}x")

## 2. Small Transformer Language Model

We'll create a small GPT-style transformer and then apply LoRA to it.

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        # These will be replaced with LoRA versions
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        batch_size, seq_len, _ = x.shape
        
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        attn_output = torch.matmul(attn_weights, V)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        return self.W_o(attn_output)


class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc2(self.dropout(F.gelu(self.fc1(x))))


class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        x = x + self.dropout(self.attention(self.ln1(x), mask))
        x = x + self.dropout(self.ff(self.ln2(x)))
        return x


class SmallGPT(nn.Module):
    def __init__(self, vocab_size: int, d_model: int = 128, num_heads: int = 4, 
                 num_layers: int = 4, d_ff: int = 256, max_seq_len: int = 64, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)
        self.dropout = nn.Dropout(dropout)
        
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.ln_final = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        
        # Weight tying
        self.lm_head.weight = self.token_embedding.weight
        
        self._init_weights()
    
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len = x.shape
        
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
        
        x = self.token_embedding(x) + self.position_embedding(positions)
        x = self.dropout(x)
        
        # Causal mask
        mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device)).unsqueeze(0).unsqueeze(0)
        
        for block in self.blocks:
            x = block(x, mask)
        
        x = self.ln_final(x)
        logits = self.lm_head(x)
        
        return logits


## 3. Apply LoRA to the Model

We'll replace the linear layers in the attention mechanism (Q, K, V, O projections) with LoRA versions.

In [None]:
def apply_lora_to_model(model: nn.Module, r: int = 4, alpha: float = 1.0, 
                        target_modules: list = None, dropout: float = 0.0) -> nn.Module:
    if target_modules is None:
        target_modules = ['W_q', 'W_k', 'W_v', 'W_o']  # Default: all attention projections
    
    lora_layers = []
    
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            # Check if this module should have LoRA applied
            should_apply = any(target in name for target in target_modules)
            if should_apply:
                # Get parent module and attribute name
                parts = name.rsplit('.', 1)
                if len(parts) == 2:
                    parent_name, attr_name = parts
                    parent = dict(model.named_modules())[parent_name]
                else:
                    parent = model
                    attr_name = name
                
                # Create LoRA wrapper
                lora_layer = LoRALinear(module, r=r, alpha=alpha, dropout=dropout)
                setattr(parent, attr_name, lora_layer)
                lora_layers.append((name, lora_layer))
    
    return model, lora_layers


def get_lora_params(model: nn.Module):
    lora_params = []
    for module in model.modules():
        if isinstance(module, LoRALinear):
            lora_params.extend(module.get_lora_params())
    return lora_params


def count_parameters(model: nn.Module):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable


def merge_lora(model: nn.Module):
    for module in model.modules():
        if isinstance(module, LoRALinear):
            module.merge()


def unmerge_lora(model: nn.Module):
    for module in model.modules():
        if isinstance(module, LoRALinear):
            module.unmerge()


## 4. Data Preparation

We'll create a tiny corpus for fine-tuning. The base model will be "pretrained" on general text, then fine-tuned with LoRA on a specific style/domain.

In [None]:
# Simple tokenizer
class SimpleTokenizer:
    def __init__(self):
        self.special_tokens = ['[PAD]', '[UNK]', '[BOS]', '[EOS]']
        self.word2idx = {tok: i for i, tok in enumerate(self.special_tokens)}
        self.idx2word = {i: tok for i, tok in enumerate(self.special_tokens)}
        self.vocab_size = len(self.special_tokens)
    
    def fit(self, texts: list):
        for text in texts:
            for word in text.lower().split():
                word = ''.join(c for c in word if c.isalnum())
                if word and word not in self.word2idx:
                    self.word2idx[word] = self.vocab_size
                    self.idx2word[self.vocab_size] = word
                    self.vocab_size += 1
    
    def encode(self, text: str) -> list:
        tokens = [self.word2idx['[BOS]']]
        for word in text.lower().split():
            word = ''.join(c for c in word if c.isalnum())
            if word:
                tokens.append(self.word2idx.get(word, self.word2idx['[UNK]']))
        tokens.append(self.word2idx['[EOS]'])
        return tokens
    
    def decode(self, tokens: list) -> str:
        words = []
        for t in tokens:
            word = self.idx2word.get(t, '[UNK]')
            if word not in ['[PAD]', '[BOS]', '[EOS]']:
                words.append(word)
        return ' '.join(words)


# Pretraining corpus (general text)
pretrain_corpus = [
    "The cat sat on the mat and looked around the room.",
    "A dog ran through the park chasing a ball.",
    "The sun was shining brightly in the clear blue sky.",
    "She walked to the store to buy some groceries.",
    "The book was lying on the table near the window.",
    "He played the piano beautifully at the concert.",
    "The flowers in the garden were blooming nicely.",
    "They watched a movie together on the weekend.",
    "The train arrived at the station on time.",
    "She made a delicious cake for the birthday party.",
    "The birds were singing in the trees early morning.",
    "He fixed the broken chair with some tools.",
    "The rain started falling heavily in the afternoon.",
    "They played football in the field after school.",
    "The teacher explained the lesson clearly to students.",
] * 20  # Repeat for more data

# Fine-tuning corpus (specific style - formal/technical)
finetune_corpus = [
    "The experiment demonstrated significant improvements in accuracy.",
    "Our analysis reveals a strong correlation between variables.",
    "The methodology employed rigorous statistical techniques.",
    "Results indicate substantial performance gains across metrics.",
    "The framework provides a systematic approach to problem solving.",
    "Implementation details are described in the following section.",
    "The proposed method outperforms existing baselines significantly.",
    "Experimental validation confirms the theoretical predictions.",
    "The algorithm achieves state of the art results.",
    "Further research is needed to explore these findings.",
] * 30  # Repeat for more data

# Build tokenizer
tokenizer = SimpleTokenizer()
tokenizer.fit(pretrain_corpus + finetune_corpus)
print(f"Vocabulary size: {tokenizer.vocab_size}")

In [None]:
class LMDataset(Dataset):
    def __init__(self, texts: list, tokenizer: SimpleTokenizer, max_len: int = 32):
        self.data = []
        for text in texts:
            tokens = tokenizer.encode(text)
            if len(tokens) > max_len:
                tokens = tokens[:max_len]
            else:
                tokens = tokens + [tokenizer.word2idx['[PAD]']] * (max_len - len(tokens))
            self.data.append(torch.tensor(tokens))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        tokens = self.data[idx]
        return tokens[:-1], tokens[1:]

# Create datasets
pretrain_dataset = LMDataset(pretrain_corpus, tokenizer, max_len=32)
finetune_dataset = LMDataset(finetune_corpus, tokenizer, max_len=32)

pretrain_loader = DataLoader(pretrain_dataset, batch_size=16, shuffle=True)
finetune_loader = DataLoader(finetune_dataset, batch_size=16, shuffle=True)

print(f"Pretrain samples: {len(pretrain_dataset)}")
print(f"Finetune samples: {len(finetune_dataset)}")

In [None]:
def train_epoch(model, dataloader, optimizer, device, pad_idx=0):
    model.train()
    total_loss = 0
    total_tokens = 0
    
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        logits = model(inputs)
        
        # Compute loss (ignore padding)
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            targets.view(-1),
            ignore_index=pad_idx
        )
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        # Count non-pad tokens
        mask = targets != pad_idx
        total_loss += loss.item() * mask.sum().item()
        total_tokens += mask.sum().item()
    
    return total_loss / total_tokens if total_tokens > 0 else 0


def evaluate(model, dataloader, device, pad_idx=0):
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            logits = model(inputs)
            
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=pad_idx,
                reduction='none'
            )
            
            mask = targets.view(-1) != pad_idx
            total_loss += (loss * mask).sum().item()
            total_tokens += mask.sum().item()
    
    avg_loss = total_loss / total_tokens if total_tokens > 0 else 0
    perplexity = math.exp(avg_loss) if avg_loss < 100 else float('inf')
    return avg_loss, perplexity


## 6. Pretrain Base Model

First, we'll "pretrain" the base model on general text.

In [None]:
# Create base model
base_model = SmallGPT(
    vocab_size=tokenizer.vocab_size,
    d_model=128,
    num_heads=4,
    num_layers=4,
    d_ff=256,
    max_seq_len=32,
    dropout=0.1
).to(device)

total_params, trainable_params = count_parameters(base_model)
print(f"Base model - Total params: {total_params:,}, Trainable: {trainable_params:,}")

In [None]:
# Pretrain
print("Pretraining base model...")
optimizer = torch.optim.AdamW(base_model.parameters(), lr=1e-3, weight_decay=0.01)

pretrain_losses = []
for epoch in range(10):
    loss = train_epoch(base_model, pretrain_loader, optimizer, device)
    pretrain_losses.append(loss)
    if (epoch + 1) % 2 == 0:
        _, ppl = evaluate(base_model, pretrain_loader, device)
        print(f"Epoch {epoch+1}: Loss = {loss:.4f}, Perplexity = {ppl:.2f}")

print("Pretraining complete!")

In [None]:
# Save base model weights for comparison later
base_model_state = copy.deepcopy(base_model.state_dict())

## 7. Apply LoRA and Fine-tune

Now we apply LoRA to the attention projections and fine-tune only the LoRA parameters.

In [None]:
# Apply LoRA to attention projections
lora_model, lora_layers = apply_lora_to_model(
    base_model, 
    r=4, 
    alpha=8.0,  # Common practice: alpha = 2*r
    target_modules=['W_q', 'W_v'],  # Only Q and V projections (common in practice)
    dropout=0.05
)

print(f"Applied LoRA to {len(lora_layers)} layers:")
for name, layer in lora_layers:
    print(f"  - {name}: r={layer.r}, base_params={layer.base_param_count:,}, lora_params={layer.lora_param_count:,}")

In [None]:
# Parameter count comparison
total_params, trainable_params = count_parameters(lora_model)
lora_only_params = sum(p.numel() for p in get_lora_params(lora_model))

print("\n" + "="*50)
print("PARAMETER COUNT COMPARISON")
print("="*50)
print(f"Total model parameters:      {total_params:>10,}")
print(f"Full fine-tuning params:     {total_params:>10,}")
print(f"LoRA trainable params:       {lora_only_params:>10,}")
print(f"Reduction ratio:             {total_params / lora_only_params:>10.1f}x")
print(f"% of original:               {100 * lora_only_params / total_params:>10.2f}%")
print("="*50)

In [None]:
# Fine-tune with LoRA (only LoRA params)
print("\nFine-tuning with LoRA...")
lora_params = get_lora_params(lora_model)
optimizer = torch.optim.AdamW(lora_params, lr=1e-3, weight_decay=0.01)

finetune_losses = []
start_time = time.time()

for epoch in range(15):
    loss = train_epoch(lora_model, finetune_loader, optimizer, device)
    finetune_losses.append(loss)
    if (epoch + 1) % 3 == 0:
        _, ppl = evaluate(lora_model, finetune_loader, device)
        print(f"Epoch {epoch+1}: Loss = {loss:.4f}, Perplexity = {ppl:.2f}")

lora_finetune_time = time.time() - start_time
print(f"\nLoRA fine-tuning time: {lora_finetune_time:.2f}s")

In [None]:
# Test merge/unmerge
print("Testing merge/unmerge...")

lora_model.eval()
test_input = torch.randint(0, tokenizer.vocab_size, (2, 16)).to(device)

# Output before merge (with LoRA path)
with torch.no_grad():
    output_before_merge = lora_model(test_input).clone()

# Merge LoRA weights
merge_lora(lora_model)

# Output after merge (LoRA baked into weights)
with torch.no_grad():
    output_after_merge = lora_model(test_input).clone()

# Check difference
max_diff = (output_before_merge - output_after_merge).abs().max().item()
print(f"Max difference after merge: {max_diff:.2e}")
print(f"Merge preserves outputs: {max_diff < 1e-5}")

# Unmerge
unmerge_lora(lora_model)

# Output after unmerge
with torch.no_grad():
    output_after_unmerge = lora_model(test_input).clone()

max_diff_unmerge = (output_before_merge - output_after_unmerge).abs().max().item()
print(f"Max difference after unmerge: {max_diff_unmerge:.2e}")
print(f"Unmerge restores original: {max_diff_unmerge < 1e-5}")

## 9. Ablation: Rank r=4 vs r=16

Compare different LoRA ranks in terms of performance and training speed.

In [None]:
def run_lora_experiment(rank: int, alpha: float = None):
    if alpha is None:
        alpha = 2 * rank  # Common heuristic
    
    # Create fresh model and load pretrained weights
    model = SmallGPT(
        vocab_size=tokenizer.vocab_size,
        d_model=128,
        num_heads=4,
        num_layers=4,
        d_ff=256,
        max_seq_len=32,
        dropout=0.1
    ).to(device)
    model.load_state_dict(base_model_state)
    
    # Apply LoRA
    model, lora_layers = apply_lora_to_model(
        model, r=rank, alpha=alpha, target_modules=['W_q', 'W_v']
    )
    
    # Count parameters
    lora_param_count = sum(p.numel() for p in get_lora_params(model))
    
    optimizer = torch.optim.AdamW(get_lora_params(model), lr=1e-3, weight_decay=0.01)
    
    losses = []
    start_time = time.time()
    
    for epoch in range(15):
        loss = train_epoch(model, finetune_loader, optimizer, device)
        losses.append(loss)
    
    train_time = time.time() - start_time
    
    # Final evaluation
    final_loss, final_ppl = evaluate(model, finetune_loader, device)
    
    return {
        'rank': rank,
        'alpha': alpha,
        'lora_params': lora_param_count,
        'final_loss': final_loss,
        'final_ppl': final_ppl,
        'train_time': train_time,
        'losses': losses
    }

print("Running ablation study...\n")

# Run experiments
results_r4 = run_lora_experiment(rank=4)
print(f"Rank 4: Loss={results_r4['final_loss']:.4f}, PPL={results_r4['final_ppl']:.2f}, Time={results_r4['train_time']:.2f}s")

results_r16 = run_lora_experiment(rank=16)
print(f"Rank 16: Loss={results_r16['final_loss']:.4f}, PPL={results_r16['final_ppl']:.2f}, Time={results_r16['train_time']:.2f}s")

In [None]:
# Ablation results comparison
print("\n" + "="*70)
print("ABLATION STUDY: LoRA Rank Comparison")
print("="*70)
print(f"{'Metric':<25} {'Rank=4':<20} {'Rank=16':<20}")
print("-"*70)
print(f"{'LoRA Parameters':<25} {results_r4['lora_params']:>15,} {results_r16['lora_params']:>15,}")
print(f"{'Final Loss':<25} {results_r4['final_loss']:>15.4f} {results_r16['final_loss']:>15.4f}")
print(f"{'Final Perplexity':<25} {results_r4['final_ppl']:>15.2f} {results_r16['final_ppl']:>15.2f}")
print(f"{'Training Time (s)':<25} {results_r4['train_time']:>15.2f} {results_r16['train_time']:>15.2f}")
print(f"{'Params/Perf Ratio':<25} {results_r4['lora_params']/results_r4['final_ppl']:>15.1f} {results_r16['lora_params']/results_r16['final_ppl']:>15.1f}")
print("="*70)

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss curves
axes[0].plot(results_r4['losses'], label='Rank=4', marker='o', markersize=3)
axes[0].plot(results_r16['losses'], label='Rank=16', marker='s', markersize=3)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss by LoRA Rank')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Parameter comparison bar chart
total_params = count_parameters(base_model)[0]
x = ['Full\nFine-tuning', 'LoRA\nRank=4', 'LoRA\nRank=16']
y = [total_params, results_r4['lora_params'], results_r16['lora_params']]
colors = ['#ff6b6b', '#4ecdc4', '#45b7d1']

bars = axes[1].bar(x, y, color=colors)
axes[1].set_ylabel('Trainable Parameters')
axes[1].set_title('Parameter Efficiency')
axes[1].set_yscale('log')

# Add value labels
for bar, val in zip(bars, y):
    axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height(), 
                 f'{val:,}', ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.show()

## 10. Text Generation Comparison

In [None]:
def generate_text(model, tokenizer, prompt: str, max_len: int = 20, temperature: float = 0.8):
    model.eval()
    tokens = tokenizer.encode(prompt)[:-1]  # Remove EOS
    tokens = torch.tensor(tokens).unsqueeze(0).to(device)
    
    with torch.no_grad():
        for _ in range(max_len):
            if tokens.size(1) >= model.max_seq_len:
                break
            
            logits = model(tokens)
            next_token_logits = logits[0, -1, :] / temperature
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, 1)
            
            if next_token.item() == tokenizer.word2idx['[EOS]']:
                break
            
            tokens = torch.cat([tokens, next_token.unsqueeze(0)], dim=1)
    
    return tokenizer.decode(tokens[0].tolist())

base_only_model = SmallGPT(
    vocab_size=tokenizer.vocab_size,
    d_model=128,
    num_heads=4,
    num_layers=4,
    d_ff=256,
    max_seq_len=32,
    dropout=0.1
).to(device)
base_only_model.load_state_dict(base_model_state)

print("Sample generations:")
print("="*60)

prompts = ["The experiment", "Our analysis", "The results"]
for prompt in prompts:
    print(f"\nPrompt: '{prompt}'")
    print(f"  Base model:  {generate_text(base_only_model, tokenizer, prompt)}")
    print(f"  LoRA model:  {generate_text(lora_model, tokenizer, prompt)}")