# 💡 Low Resource Adaptivity

This notebook demonstrates that the Adaptive Transformer model self-prunes and continues to function under constrained compute environments (e.g., Colab T4 or CPU).

In [ ]:
!pip install transformers datasets matplotlib torch

In [ ]:
import torch
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
from models.loaders.loader import load_adaptive_model, load_baseline_model
from data_modules.dataset_loader import load_and_tokenize_dataset
from utils.training import count_active_heads, compute_loss
from torch.optim import AdamW
import gc

In [ ]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🧠 Running on: {device}")

In [ ]:
model_name = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
baseline = load_baseline_model(model_name, device)
adaptive = load_adaptive_model(model_name, baseline, device)
adaptive.train()

In [ ]:
train_ids, _ = load_and_tokenize_dataset(model_name=model_name, dataset_name="tiny_shakespeare")
optimizer = AdamW(adaptive.parameters(), lr=5e-5)
inputs = torch.tensor(train_ids[:8]).to(device)

## 🚦 Adaptive Behavior Under Limited Resources

In [ ]:
active_heads_history = []
losses = []

for step in range(20):
    optimizer.zero_grad()
    logits = adaptive(inputs)
    loss = compute_loss(logits, inputs)
    loss.backward()
    optimizer.step()

    active_heads = count_active_heads(adaptive)
    active_heads_history.append(active_heads)
    losses.append(loss.item())
    print(f"Step {step}: Loss = {loss.item():.4f}, Active Heads = {active_heads}")

    # Simulate pruning under constrained setting
    if step % 5 == 0:
        for block in adaptive.blocks:
            attn = block['attn']
            with torch.no_grad():
                attn.gate[attn.gate < 0.1] = 0.0

gc.collect(); torch.cuda.empty_cache()

## 📈 Results

In [ ]:
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(losses, label="Loss")
plt.title("Loss Over Steps")
plt.xlabel("Step"); plt.ylabel("Loss"); plt.grid()

plt.subplot(1, 2, 2)
plt.plot(active_heads_history, label="Active Heads", color='orange')
plt.title("Active Heads vs. Training Step")
plt.xlabel("Step"); plt.ylabel("Active Heads")
plt.grid()
plt.tight_layout()
plt.show()