# 🔥 Adaptive Attention Heatmap Comparison

This notebook visualizes the attention patterns (as heatmaps) of an adaptive transformer compared to a baseline model like GPT-2 or DistilGPT2.

- Uses Hugging Face models
- Visualizes the attention matrices
- Supports adaptive and non-adaptive comparison

In [None]:
!pip install transformers datasets matplotlib

In [None]:
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
from models.adaptive_transformer import AdaptiveTransformerModel
from models.loaders.loader import load_baseline_model, load_adaptive_model

## ⚙️ Configuration

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "distilgpt2"
prompt = "The quick brown fox jumps over the lazy dog."

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

## 🔍 Load Models

In [None]:
baseline = load_baseline_model(model_name, device)
adaptive = load_adaptive_model(model_name, baseline, device)
baseline.eval()
adaptive.eval()

## 🎯 Extract Attention (Baseline)

In [None]:
with torch.no_grad():
    output = baseline(input_ids, output_attentions=True)
    baseline_attn = output.attentions  # List: [layer][batch, head, seq, seq]

## 🎯 Extract Attention (Adaptive)

In [None]:
# Re-implement forward with attention output for adaptive model
def get_adaptive_attn(model, input_ids):
    model.eval()
    attn_maps = []
    x = model.embed(input_ids)
    for block in model.blocks:
        attn = block["attn"]
        q = torch.stack([head(x) for head in attn.W_q], dim=1)
        k = torch.stack([head(x) for head in attn.W_k], dim=1)
        v = torch.stack([head(x) for head in attn.W_v], dim=1)
        scores = torch.matmul(q, k.transpose(-2, -1)) / (attn.head_dim ** 0.5)
        weights = torch.softmax(scores, dim=-1)
        attn_maps.append(weights.cpu())
        # forward continues, not shown for brevity
    return attn_maps

adaptive_attn = get_adaptive_attn(adaptive, input_ids)

## 📊 Plot Attention Heatmaps

In [None]:
def plot_heatmaps(attn_maps, title_prefix, layer=0, max_heads=4):
    fig, axes = plt.subplots(1, max_heads, figsize=(16, 4))
    for i in range(min(max_heads, attn_maps[layer].shape[1])):
        ax = axes[i]
        ax.imshow(attn_maps[layer][0, i], cmap="viridis")
        ax.set_title(f"{title_prefix} Head {i}")
        ax.axis("off")
    plt.tight_layout()
    plt.show()

plot_heatmaps(baseline_attn, "Baseline", layer=0)
plot_heatmaps(adaptive_attn, "Adaptive", layer=0)