# ✅ Checkpoint Resumption Test

This notebook demonstrates that training resumes correctly from a saved checkpoint and that controller gates, optimizer state, and training metrics continue progressing as expected.

In [None]:
!pip install transformers datasets

In [None]:
import torch
import os
from transformers import AutoTokenizer
from models.loaders.loader import load_baseline_model, load_adaptive_model
from datasets.dataset_loader import load_and_tokenize_dataset
from utils.checkpoint import save_checkpoint, load_checkpoint
from utils.training import compute_loss
from torch.optim import AdamW

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
baseline = load_baseline_model(model_name, device)
adaptive = load_adaptive_model(model_name, baseline, device)
adaptive.train()

In [None]:
train_ids, _ = load_and_tokenize_dataset(model_name=model_name, dataset_name="tiny_shakespeare")
inputs = torch.tensor(train_ids[:4]).to(device)
optimizer = AdamW(adaptive.parameters(), lr=1e-4)

## ⏺️ Step 1: Train and Save Checkpoint

In [None]:
# Simple 2-step warm-up
for step in range(2):
    optimizer.zero_grad()
    logits = adaptive(inputs)
    loss = compute_loss(logits, inputs)
    loss.backward()
    optimizer.step()
    print(f"Step {step}, Loss: {loss.item():.4f}")

save_checkpoint("resumption_test.pth", adaptive, optimizer, {}, epoch=0, step=2)

## 🔁 Step 2: Reload Model and Resume

In [None]:
# Create a fresh model and optimizer to test resumption
baseline = load_baseline_model(model_name, device)
resumed = load_adaptive_model(model_name, baseline, device)
resumed.train()
opt2 = AdamW(resumed.parameters(), lr=1e-4)
checkpoint_data = load_checkpoint("resumption_test.pth", resumed, opt2)

## ✅ Step 3: Validate Training Continuation

In [None]:
# Resume for 2 more steps
for step in range(2):
    opt2.zero_grad()
    logits = resumed(inputs)
    loss = compute_loss(logits, inputs)
    loss.backward()
    opt2.step()
    print(f"[Resumed] Step {step + checkpoint_data['step']}, Loss: {loss.item():.4f}")