# 📊 Adaptive Behavior Visualization

This notebook visualizes the dynamic adaptivity of attention heads during training, specifically their activation and pruning behavior.

In [ ]:
!pip install transformers datasets matplotlib seaborn

In [ ]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from models.adaptive_transformer import AdaptiveTransformerModel
from controller.controller_ann import ANNController
from loader import load_adaptive_model, load_baseline_model
from data_modules.dataset_loader import load_and_tokenize_dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## ⚙️ Configuration

In [ ]:
model_name = 'distilgpt2'
dataset_name = 'tiny_shakespeare'
train_data, val_data = load_and_tokenize_dataset(model_name, dataset_name)

## 🧠 Load Model

In [ ]:
baseline_model = load_baseline_model(model_name, device)
adaptive_model = load_adaptive_model(model_name, baseline_model, device)
controller = ANNController(num_layers=adaptive_model.config.n_layer, num_heads=adaptive_model.config.n_head)
adaptive_model.controller = controller.to(device)

## 🔄 Visualize Adaptive Gates (Before Training)

In [ ]:
def plot_gates(controller, title='Head Activation Gates'):
    gates = torch.sigmoid(controller.gate_logits).detach().cpu().numpy()
    plt.figure(figsize=(12, 6))
    sns.heatmap(gates, annot=True, cmap='viridis', vmin=0, vmax=1)
    plt.xlabel('Heads')
    plt.ylabel('Layers')
    plt.title(title)
    plt.show()

plot_gates(controller, 'Initial Head Activation Gates')

## 🚀 Short Training Loop (Demo Adaptivity)

In [ ]:
optimizer = torch.optim.AdamW(adaptive_model.parameters(), lr=1e-5)
adaptive_model.train()

for step in range(100):
    inputs = torch.tensor(train_data[step % len(train_data)]).unsqueeze(0).to(device)
    outputs = adaptive_model(inputs)
    loss = torch.nn.functional.cross_entropy(outputs[:, :-1, :].contiguous().view(-1, outputs.size(-1)), inputs[:, 1:].contiguous().view(-1))
    loss += controller.regularization_loss() * 1e-4
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if step % 25 == 0:
        print(f"Step {step}: Loss = {loss.item():.4f}")
        plot_gates(controller, f'Adaptive Gates at Step {step}')

## 📉 Final Adaptivity

In [ ]:
plot_gates(controller, 'Final Head Activation Gates')