# Controller Dynamics

Track how the controller’s gate logits evolve during training and how they influence attention head usage.

In [None]:
!pip install transformers datasets matplotlib seaborn torch

In [None]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from models.loaders.loader import load_adaptive_model, load_baseline_model
from transformers import AutoTokenizer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name = 'distilgpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
baseline = load_baseline_model(model_name, device)
adaptive = load_adaptive_model(model_name, baseline, device)
controller = adaptive.controller

## Simulate Logit Updates (or Load from Training Logs)

In [None]:
# Simulated training step logit history (replace with real logs if available)
logit_history = []
for step in range(20):
    noise = torch.randn_like(controller.gate_logits) * 0.2
    controller.gate_logits.data += noise
    logit_history.append(controller.gate_logits.detach().cpu().clone())

# Convert to tensor: (steps, layers, heads)
logit_tensor = torch.stack(logit_history)

## Visualize Gate Evolution

In [None]:
def plot_gate_dynamics(logits):
    num_layers, num_heads = logits.shape[1], logits.shape[2]
    fig, axes = plt.subplots(num_layers, 1, figsize=(12, num_layers * 2))

    if num_layers == 1:
        axes = [axes]
    for i in range(num_layers):
        for h in range(num_heads):
            axes[i].plot(torch.sigmoid(logits[:, i, h]), label=f'h{h}')
        axes[i].set_title(f'Layer {i} Gate Dynamics')
        axes[i].set_ylabel('Sigmoid(gate logit)')
        axes[i].legend(loc='upper right')
    plt.xlabel('Training Step')
    plt.tight_layout()
    plt.show()

plot_gate_dynamics(logit_tensor)