# Adaptive Transformer Training and Evaluation Notebook
This notebook supports adaptive training, evaluation, inference, and visualization.

In [ ]:
!pip install -q transformers datasets matplotlib seaborn

In [ ]:
import torch
from transformers import AutoTokenizer
from datasets import load_dataset
from models.adaptive_transformer import AdaptiveTransformerModel
from models.loaders.loader import load_adaptive_model, load_baseline_model
from utils.training import compute_loss, evaluate_model, count_active_heads, count_trainable_params
from utils.metrics_logger import MetricsLogger
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name = 'distilgpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
baseline_model = load_baseline_model(model_name, device)
adaptive_model = load_adaptive_model(model_name, baseline_model, device)

dataset = load_dataset('tiny_shakespeare')['train']
def tokenize(batch):
    return tokenizer(batch['text'], truncation=True, padding='max_length', max_length=128)
dataset = dataset.map(tokenize, batched=True)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=8)

optimizer = torch.optim.AdamW(adaptive_model.parameters(), lr=2e-5)
metrics_logger = MetricsLogger()

In [ ]:
epochs = 3
adaptive_model.train()
for epoch in range(epochs):
    total_loss = 0
    for i, batch in enumerate(train_loader):
        input_ids = torch.tensor(batch['input_ids']).to(device)
        optimizer.zero_grad()
        logits = adaptive_model(input_ids)
        loss = compute_loss(logits, input_ids)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if i % 10 == 0:
            adaptive_model.eval()
            generated = adaptive_model.generate(input_ids[:1], max_length=50)
            print('Epoch', epoch, 'Batch', i, 'Generated:', tokenizer.decode(generated[0], skip_special_tokens=True))
            adaptive_model.train()

    avg_loss = total_loss / len(train_loader)
    val_loss, val_perplexity, _ = evaluate_model(adaptive_model, input_ids, input_ids)
    active_heads = count_active_heads(adaptive_model)
    param_count = count_trainable_params(adaptive_model)

    metrics_logger.log({
        'train_loss': avg_loss,
        'val_loss': val_loss,
        'perplexity': val_perplexity,
        'active_heads': active_heads,
        'param_count': param_count,
    })

    print(f'Epoch {epoch}: Loss {avg_loss}, Val Loss {val_loss}, Perplexity {val_perplexity}, Heads {active_heads}, Params {param_count}')

In [ ]:
# Visualization of metrics
metrics = metrics_logger.get_metrics()

fig, axs = plt.subplots(2, 2, figsize=(12, 10))
axs[0, 0].plot(metrics['train_loss'], label='Train Loss')
axs[0, 0].plot(metrics['val_loss'], label='Validation Loss')
axs[0, 0].set_title('Loss Over Epochs')
axs[0, 0].legend()

axs[0, 1].plot(metrics['perplexity'], label='Perplexity', color='green')
axs[0, 1].set_title('Perplexity Over Epochs')
axs[0, 1].legend()

axs[1, 0].plot(metrics['active_heads'], label='Active Heads', color='purple')
axs[1, 0].set_title('Active Heads')
axs[1, 0].legend()

axs[1, 1].plot(metrics['param_count'], label='Parameter Count', color='orange')
axs[1, 1].set_title('Parameter Count')
axs[1, 1].legend()

plt.tight_layout()
plt.show()