# Adaptive Transformer U-Net Adaptivity

Demonstrates U-Net style adaptive layer growth and pruning in the adaptive transformer.

In [ ]:
!pip install transformers datasets matplotlib seaborn torch

In [ ]:
import torch
from transformers import AutoTokenizer
from models.adaptive_transformer import AdaptiveTransformerModel
from loader import load_adaptive_model, load_baseline_model
import seaborn as sns
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Model Initialization (Baseline & Adaptive)

In [ ]:
model_name = 'distilgpt2'
baseline = load_baseline_model(model_name, device)
adaptive = load_adaptive_model(model_name, baseline, device)

## Simulate Adaptive U-Net Style Growth & Pruning

In [ ]:
def simulate_unet_adaptation(model, steps=50):
    head_counts = []
    for step in range(steps):
        if step % 10 < 5:
            factor = 1 + 0.1 * (step % 10)
        else:
            factor = 1.5 - 0.1 * ((step % 10) - 5)

        for block in model.blocks:
            attn = block['attn']
            attn.gate.data = torch.clamp(attn.gate.data * factor, 0, 1)

        active_heads = sum([(attn.gate > 0.2).sum().item() for block in model.blocks for attn in [block['attn']]])
        head_counts.append(active_heads)

    return head_counts

In [ ]:
head_counts = simulate_unet_adaptation(adaptive, steps=50)

plt.figure(figsize=(12, 6))
sns.lineplot(x=range(50), y=head_counts, marker='o')
plt.xlabel('Simulation Steps')
plt.ylabel('Active Attention Heads')
plt.title('U-Net Style Adaptive Head Expansion and Pruning')
plt.grid(True)
plt.show()

## Demonstration

The plot shows the periodic U-Net style adaptive pattern:

- Heads grow when complexity is needed.
- Heads are pruned to save resources when complexity is less.