# Adaptive Attention Visualization

This notebook visualizes how attention head activations dynamically adapt (expand or prune) during training.

In [None]:
!pip install transformers datasets seaborn matplotlib torch

In [None]:
import torch
import seaborn as sns
import matplotlib.pyplot as plt
from models.adaptive_transformer import AdaptiveTransformerModel
from controller.controller_ann import ANNController

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Load a Pretrained Adaptive Model

In [None]:
model_name = 'distilgpt2'
num_layers = 6  # DistilGPT2 has 6 layers
num_heads = 12

controller = ANNController(num_layers=num_layers, num_heads=num_heads)
model = AdaptiveTransformerModel.from_pretrained(model_name, controller).to(device)
model.eval()

### Visualize Attention Head Gates Before Adaptation

In [None]:
gate_values = torch.sigmoid(controller.gate_logits).detach().cpu().numpy()
plt.figure(figsize=(12, 6))
sns.heatmap(gate_values, cmap='viridis', annot=True, vmin=0, vmax=1)
plt.title('Initial Attention Head Gate Values')
plt.xlabel('Head')
plt.ylabel('Layer')
plt.show()

### Simulate Adaptive Changes and Visualize

In [None]:
# Simulate adaptive updates
for step in range(5):
    fake_metrics = {'entropy': torch.rand(num_layers, num_heads)}
    controller.update_gates(fake_metrics)

    gate_values = torch.sigmoid(controller.gate_logits).detach().cpu().numpy()
    plt.figure(figsize=(12, 6))
    sns.heatmap(gate_values, cmap='viridis', annot=True, vmin=0, vmax=1)
    plt.title(f'Attention Head Gates After Adaptation Step {step+1}')
    plt.xlabel('Head')
    plt.ylabel('Layer')
    plt.show()