# Training Loop Testing Notebook
## Reasoning Distillation Project

This notebook tests:
1. Distillation loss computation
2. Trainer initialization
3. Training loop (small scale)
4. Evaluation pipeline
5. Checkpointing and resuming
6. Training history visualization

In [None]:
# Setup
import sys
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

# Imports
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pprint import pprint
import time

from src.data.data_loader import TeacherDataLoader
from src.data.preprocessor import ReasoningPreprocessor, PreprocessConfig
from src.data.dataset import ESNLIDataset, create_dataloaders

from src.models.student import StudentModel, StudentConfig, create_student_model
from src.models.teacher import DatasetTeacher

from src.training.distillation import (
    DistillationConfig,
    SequenceLevelDistillation,
    DistillationLoss,
    compare_distillation_strategies
)

from src.training.trainer import (
    Trainer,
    TrainingConfig,
    create_trainer
)

## 1. Compare Distillation Strategies

In [None]:
# Display strategy comparison
compare_distillation_strategies()

## 2. Test Distillation Loss

In [None]:
# Initialize distillation loss
print("=" * 70)
print("TESTING DISTILLATION LOSS")
print("=" * 70)

distill_config = DistillationConfig(
    ce_weight=1.0,
    distill_weight=0.0,  # No explicit distillation for sequence-level
    temperature=1.0,
    label_smoothing=0.1
)

loss_fn = DistillationLoss(distill_config)

print("\n‚úì DistillationLoss initialized!")

In [None]:
# Test loss computation with dummy data
print("\n" + "=" * 70)
print("COMPUTING LOSS ON DUMMY DATA")
print("=" * 70)

# Create dummy logits and labels
batch_size = 4
seq_len = 32
vocab_size = 32128  # FLAN-T5 vocab size

dummy_logits = torch.randn(batch_size, seq_len, vocab_size)
dummy_labels = torch.randint(0, vocab_size, (batch_size, seq_len))

# Add some -100 (padding) labels
dummy_labels[:, -5:] = -100

print(f"\nLogits shape: {dummy_logits.shape}")
print(f"Labels shape: {dummy_labels.shape}")
print(f"Padding tokens: {(dummy_labels == -100).sum().item()}")

# Compute loss
losses = loss_fn(dummy_logits, dummy_labels)

print("\nLoss components:")
for key, value in losses.items():
    print(f"  {key}: {value.item():.4f}")

# Test backward pass
losses['total_loss'].backward()
print("\n‚úì Backward pass successful!")

## 3. Prepare Small Dataset for Testing

In [None]:
# Load small dataset
print("=" * 70)
print("LOADING TEST DATA")
print("=" * 70)

loader = TeacherDataLoader()
esnli_data = loader.load_esnli()

# Use very small subsets for fast testing
train_subset = esnli_data['train'].select(range(100))  # 100 samples
val_subset = esnli_data['validation'].select(range(30))  # 30 samples

print(f"\n‚úì Train samples: {len(train_subset)}")
print(f"‚úì Val samples: {len(val_subset)}")

In [None]:
# Create datasets and dataloaders
preprocess_config = PreprocessConfig(
    model_name="google/flan-t5-small",  # Use small for faster testing
    max_source_length=128,
    max_target_length=64
)

preprocessor = ReasoningPreprocessor(preprocess_config)

train_dataset = ESNLIDataset(train_subset, preprocessor, use_cache=True)
val_dataset = ESNLIDataset(val_subset, preprocessor, use_cache=True)

# Create dataloaders with small batch size
train_loader, val_loader = create_dataloaders(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    batch_size=8,
    num_workers=0,
    pad_token_id=preprocessor.tokenizer.pad_token_id
)

print(f"\n‚úì Train batches: {len(train_loader)}")
print(f"‚úì Val batches: {len(val_loader)}")

## 4. Initialize Student Model

In [None]:
# Create small student model for fast testing
print("=" * 70)
print("INITIALIZING STUDENT MODEL")
print("=" * 70)

student_config = StudentConfig(
    model_name="google/flan-t5-small",
    max_source_length=128,
    max_target_length=64,
    device=device
)

student = StudentModel(student_config)

print(f"\n‚úì Model loaded: {student.count_parameters():,} parameters")
print(f"‚úì Memory: {student.get_memory_footprint()['total_mb']:.2f} MB")

## 5. Initialize Distillation Strategy

In [None]:
# Create sequence-level distillation strategy
print("=" * 70)
print("INITIALIZING DISTILLATION STRATEGY")
print("=" * 70)

distill_config = DistillationConfig(
    ce_weight=1.0,
    distill_weight=0.0,  # Using dataset as teacher
    label_smoothing=0.1,
    distillation_type="sequence_level"
)

distillation_strategy = SequenceLevelDistillation(distill_config)

print("\n‚úì SequenceLevelDistillation initialized!")

## 6. Test Single Training Step

In [None]:
# Test single training step manually
print("=" * 70)
print("TESTING SINGLE TRAINING STEP")
print("=" * 70)

# Get a batch
batch = next(iter(train_loader))
batch = {k: v.to(device) for k, v in batch.items()}

print(f"\nBatch shapes:")
for key, value in batch.items():
    print(f"  {key}: {value.shape}")

# Forward pass
student.model.train()
outputs = student(
    input_ids=batch['input_ids'],
    attention_mask=batch['attention_mask'],
    labels=batch['labels']
)

# Compute distillation loss
losses = distillation_strategy.compute_loss(
    outputs['logits'],
    batch['labels']
)

print(f"\nLoss values:")
for key, value in losses.items():
    print(f"  {key}: {value.item():.4f}")

# Backward pass
loss = losses['total_loss']
loss.backward()

print("\n‚úì Forward and backward pass successful!")

# Check gradients
grad_norm = torch.nn.utils.clip_grad_norm_(student.model.parameters(), 1.0)
print(f"‚úì Gradient norm: {grad_norm:.4f}")

## 7. Initialize Trainer

In [None]:
# Create training configuration
print("=" * 70)
print("INITIALIZING TRAINER")
print("=" * 70)

training_config = TrainingConfig(
    num_epochs=2,  # Just 2 epochs for testing
    learning_rate=5e-5,
    warmup_steps=10,
    eval_steps=10,  # Evaluate frequently
    save_steps=20,
    logging_steps=5,
    output_dir="../experiments/test_run",
    eval_strategy="steps",
    save_strategy="steps",
    save_total_limit=2,
    early_stopping_patience=5,
    gradient_accumulation_steps=1,
    max_grad_norm=1.0,
    fp16=False  # Disable for testing
)

# Re-initialize model (previous one has gradients)
student = StudentModel(student_config)

# Create trainer
trainer = Trainer(
    model=student,
    train_dataloader=train_loader,
    eval_dataloader=val_loader,
    distillation_strategy=distillation_strategy,
    config=training_config
)

print("\n‚úì Trainer initialized!")
print(f"‚úì Total training steps: {len(train_loader) * training_config.num_epochs}")

## 8. Run Training Loop

In [None]:
# Train the model
print("\n" + "=" * 70)
print("STARTING TRAINING")
print("=" * 70)

start_time = time.time()

history = trainer.train()

training_time = time.time() - start_time

print(f"\n‚úì Training completed in {training_time:.2f}s")
print(f"‚úì Average time per epoch: {training_time / training_config.num_epochs:.2f}s")

## 9. Analyze Training History

In [None]:
# Display training history
print("=" * 70)
print("TRAINING HISTORY")
print("=" * 70)

train_history = history['train_history']
eval_history = history['eval_history']

print(f"\nTrain history ({len(train_history)} epochs):")
for i, metrics in enumerate(train_history):
    print(f"  Epoch {i+1}: loss={metrics['loss']:.4f}")

print(f"\nEval history ({len(eval_history)} evaluations):")
for i, metrics in enumerate(eval_history[:5]):  # Show first 5
    print(f"  Eval {i+1}: eval_loss={metrics['eval_loss']:.4f}")
if len(eval_history) > 5:
    print(f"  ... and {len(eval_history) - 5} more")

In [None]:
# Visualize training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Training loss per epoch
epochs = range(1, len(train_history) + 1)
train_losses = [m['loss'] for m in train_history]

axes[0].plot(epochs, train_losses, marker='o', linewidth=2, markersize=8, 
             color='#e74c3c', label='Train Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss per Epoch')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot 2: Evaluation loss over time
if eval_history:
    eval_steps = range(1, len(eval_history) + 1)
    eval_losses = [m['eval_loss'] for m in eval_history]
    
    axes[1].plot(eval_steps, eval_losses, marker='s', linewidth=2, markersize=8,
                 color='#3498db', label='Eval Loss')
    axes[1].set_xlabel('Evaluation Step')
    axes[1].set_ylabel('Loss')
    axes[1].set_title('Evaluation Loss')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print loss reduction
if len(train_losses) > 1:
    initial_loss = train_losses[0]
    final_loss = train_losses[-1]
    reduction = (initial_loss - final_loss) / initial_loss * 100
    print(f"\nLoss reduction: {reduction:.2f}%")
    print(f"Initial loss: {initial_loss:.4f}")
    print(f"Final loss: {final_loss:.4f}")

## 10. Test Model Generation After Training

In [None]:
# Test generation on validation samples
print("=" * 70)
print("TESTING GENERATION AFTER TRAINING")
print("=" * 70)

student.model.eval()

# Get a batch from validation set
val_batch = next(iter(val_loader))
val_batch = {k: v.to(device) for k, v in val_batch.items()}

# Generate predictions
with torch.no_grad():
    generated_ids = student.generate(
        input_ids=val_batch['input_ids'][:3],  # First 3 samples
        attention_mask=val_batch['attention_mask'][:3],
        max_length=64,
        num_beams=4
    )

# Decode
predictions = student.decode_batch(generated_ids)
inputs = student.decode_batch(val_batch['input_ids'][:3])

labels = val_batch['labels'][:3].clone()
labels[labels == -100] = student.tokenizer.pad_token_id
ground_truths = student.decode_batch(labels)

# Display
for i in range(3):
    print(f"\n{'='*70}")
    print(f"SAMPLE {i+1}")
    print(f"{'='*70}")
    print(f"\nInput:\n{inputs[i]}")
    print(f"\nGround Truth:\n{ground_truths[i]}")
    print(f"\nPrediction:\n{predictions[i]}")

## 11. Test Checkpoint Saving and Loading

In [None]:
# Check saved checkpoints
print("=" * 70)
print("CHECKING SAVED CHECKPOINTS")
print("=" * 70)

output_dir = Path(training_config.output_dir)

if output_dir.exists():
    checkpoints = list(output_dir.iterdir())
    print(f"\nFound {len(checkpoints)} items in output directory:")
    for checkpoint in sorted(checkpoints):
        if checkpoint.is_dir():
            size = sum(f.stat().st_size for f in checkpoint.rglob('*') if f.is_file())
            print(f"  üìÅ {checkpoint.name} ({size / 1e6:.2f} MB)")
        else:
            print(f"  üìÑ {checkpoint.name}")
else:
    print("\nOutput directory not found!")

In [None]:
# Test loading best model
print("\n" + "=" * 70)
print("TESTING MODEL LOADING")
print("=" * 70)

best_model_path = output_dir / "best_model"

if best_model_path.exists():
    print(f"\nLoading best model from {best_model_path}...")
    
    # Load model
    loaded_student = StudentModel.load_model(
        str(best_model_path),
        config=student_config
    )
    
    print("‚úì Model loaded successfully!")
    print(f"‚úì Parameters: {loaded_student.count_parameters():,}")
    
    # Test generation with loaded model
    loaded_student.model.eval()
    with torch.no_grad():
        test_gen = loaded_student.generate(
            input_ids=val_batch['input_ids'][:1],
            attention_mask=val_batch['attention_mask'][:1]
        )
    
    test_pred = loaded_student.decode_batch(test_gen)[0]
    print(f"\nTest generation: {test_pred[:100]}...")
    print("\n‚úì Loaded model can generate!")
else:
    print("\nBest model not found (training might not have completed enough steps)")

## 12. Test Resume Training

In [None]:
# Test resuming from checkpoint
print("=" * 70)
print("TESTING RESUME TRAINING")
print("=" * 70)

# Find a checkpoint to resume from
checkpoints = sorted(
    [d for d in output_dir.iterdir() 
     if d.is_dir() and d.name.startswith("checkpoint-")],
    key=lambda x: int(x.name.split("-")[1])
)

if checkpoints:
    resume_checkpoint = checkpoints[-1]
    print(f"\nResuming from: {resume_checkpoint.name}")
    
    # Create new trainer and load checkpoint
    resume_config = TrainingConfig(
        num_epochs=3,  # Train 1 more epoch
        learning_rate=5e-5,
        warmup_steps=10,
        eval_steps=10,
        save_steps=20,
        logging_steps=5,
        output_dir="../experiments/test_run_resumed",
        eval_strategy="steps"
    )
    
    # Load model from checkpoint
    resumed_student = StudentModel.load_model(
        str(resume_checkpoint),
        config=student_config
    )
    
    resumed_trainer = Trainer(
        model=resumed_student,
        train_dataloader=train_loader,
        eval_dataloader=val_loader,
        distillation_strategy=distillation_strategy,
        config=resume_config
    )
    
    # Load training state
    resumed_trainer.load_checkpoint(str(resume_checkpoint))
    
    print(f"\n‚úì Checkpoint loaded!")
    print(f"  Resumed from step: {resumed_trainer.global_step}")
    print(f"  Resumed from epoch: {resumed_trainer.epoch}")
    print(f"  Best metric: {resumed_trainer.best_metric:.4f}")
    
    print("\nüí° Training state successfully restored!")
    print("   You can call resumed_trainer.train() to continue training.")
else:
    print("\nNo checkpoints found to resume from")

## 13. Analyze Optimizer State

In [None]:
# Analyze optimizer state
print("=" * 70)
print("OPTIMIZER STATE ANALYSIS")
print("=" * 70)

optimizer_state = trainer.optimizer.state_dict()

print(f"\nOptimizer: {type(trainer.optimizer).__name__}")
print(f"\nParameter groups: {len(optimizer_state['param_groups'])}")

for i, group in enumerate(optimizer_state['param_groups']):
    print(f"\nGroup {i}:")
    print(f"  Learning rate: {group['lr']:.2e}")
    print(f"  Weight decay: {group['weight_decay']}")
    print(f"  Parameters: {len(group['params'])}")

# Check learning rate schedule
if trainer.scheduler:
    current_lr = trainer.optimizer.param_groups[0]['lr']
    print(f"\nCurrent learning rate: {current_lr:.2e}")
    print(f"Initial learning rate: {training_config.learning_rate:.2e}")
    print(f"LR change: {(current_lr / training_config.learning_rate - 1) * 100:+.1f}%")

## 14. Compare Before/After Training Performance

In [None]:
# Compare loss before and after training
print("=" * 70)
print("BEFORE/AFTER COMPARISON")
print("=" * 70)

if eval_history:
    initial_eval_loss = eval_history[0]['eval_loss']
    final_eval_loss = eval_history[-1]['eval_loss']
    
    print(f"\nEvaluation Loss:")
    print(f"  Initial: {initial_eval_loss:.4f}")
    print(f"  Final: {final_eval_loss:.4f}")
    print(f"  Improvement: {initial_eval_loss - final_eval_loss:.4f} ({(initial_eval_loss - final_eval_loss) / initial_eval_loss * 100:.2f}%)")
    
    # Visualize improvement
    fig, ax = plt.subplots(figsize=(10, 6))
    
    losses = [initial_eval_loss, final_eval_loss]
    labels = ['Before Training', 'After Training']
    colors = ['#e74c3c', '#2ecc71']
    
    bars = ax.bar(labels, losses, color=colors, alpha=0.7, edgecolor='black', width=0.5)
    ax.set_ylabel('Evaluation Loss')
    ax.set_title('Model Performance: Before vs After Training')
    
    # Add value labels
    for bar, loss in zip(bars, losses):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{loss:.4f}',
                ha='center', va='bottom', fontsize=12, fontweight='bold')
    
    # Add improvement annotation
    improvement = initial_eval_loss - final_eval_loss
    ax.annotate(
        f'‚Üì {improvement:.4f}\n({improvement/initial_eval_loss*100:.1f}% improvement)',
        xy=(0.5, max(losses) * 0.5),
        fontsize=14,
        ha='center',
        bbox=dict(boxstyle='round,pad=0.5', facecolor='yellow', alpha=0.3)
    )
    
    plt.tight_layout()
    plt.show()
else:
    print("\nNo evaluation history available for comparison")

## 15. Test Different Training Configurations

In [None]:
# Compare different label smoothing values
print("=" * 70)
print("TESTING LABEL SMOOTHING IMPACT")
print("=" * 70)

smoothing_values = [0.0, 0.1, 0.2]
smoothing_results = []

# Get a test batch
test_batch = next(iter(train_loader))
test_batch = {k: v.to(device) for k, v in test_batch.items()}

for smoothing in smoothing_values:
    # Create distillation with different smoothing
    test_config = DistillationConfig(
        ce_weight=1.0,
        label_smoothing=smoothing
    )
    test_strategy = SequenceLevelDistillation(test_config)
    
    # Compute loss
    student.model.eval()
    with torch.no_grad():
        outputs = student(
            input_ids=test_batch['input_ids'],
            attention_mask=test_batch['attention_mask'],
            labels=test_batch['labels']
        )
        losses = test_strategy.compute_loss(outputs['logits'], test_batch['labels'])
    
    smoothing_results.append({
        'smoothing': smoothing,
        'loss': losses['total_loss'].item()
    })

# Display results
print("\nLabel Smoothing Impact on Loss:")
for result in smoothing_results:
    print(f"  Smoothing={result['smoothing']:.1f}: loss={result['loss']:.4f}")

# Visualize
fig, ax = plt.subplots(figsize=(10, 6))
smoothing_vals = [r['smoothing'] for r in smoothing_results]
loss_vals = [r['loss'] for r in smoothing_results]

ax.plot(smoothing_vals, loss_vals, marker='o', linewidth=2, markersize=10, color='#9b59b6')
ax.set_xlabel('Label Smoothing')
ax.set_ylabel('Loss')
ax.set_title('Impact of Label Smoothing on Loss')
ax.grid(True, alpha=0.3)

for x, y in zip(smoothing_vals, loss_vals):
    ax.text(x, y + 0.01, f'{y:.4f}', ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()

## 16. Memory and Speed Profiling

In [None]:
# Profile training step memory and speed
print("=" * 70)
print("PROFILING TRAINING STEP")
print("=" * 70)

student.model.train()

if device == "cuda":
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    initial_memory = torch.cuda.memory_allocated() / 1e6

# Time training step
times = []
for _ in range(5):
    batch = next(iter(train_loader))
    batch = {k: v.to(device) for k, v in batch.items()}
    
    start = time.time()
    
    # Forward
    outputs = student(
        input_ids=batch['input_ids'],
        attention_mask=batch['attention_mask'],
        labels=batch['labels']
    )
    losses = distillation_strategy.compute_loss(outputs['logits'], batch['labels'])
    loss = losses['total_loss']
    
    # Backward
    loss.backward()
    
    # Optimizer step
    trainer.optimizer.step()
    trainer.optimizer.zero_grad()
    
    if device == "cuda":
        torch.cuda.synchronize()
    
    times.append(time.time() - start)

avg_time = np.mean(times[1:])  # Skip first (warmup)
std_time = np.std(times[1:])

print(f"\nTraining Step Performance:")
print(f"  Average time: {avg_time:.3f}s ¬± {std_time:.3f}s")
print(f"  Throughput: {len(batch['input_ids']) / avg_time:.2f} samples/sec")

if device == "cuda":
    peak_memory = torch.cuda.max_memory_allocated() / 1e6
    print(f"\nMemory Usage:")
    print(f"  Initial: {initial_memory:.2f} MB")
    print(f"  Peak: {peak_memory:.2f} MB")
    print(f"  Overhead: {peak_memory - initial_memory:.2f} MB")

## 17. Summary and Recommendations

In [None]:
print("\n" + "=" * 70)
print("TRAINING LOOP TESTING SUMMARY")
print("=" * 70)

print("\n‚úÖ COMPONENTS TESTED:")
print("  ‚úì Distillation loss computation")
print("  ‚úì Single training step (forward + backward)")
print("  ‚úì Trainer initialization")
print("  ‚úì Full training loop (2 epochs)")
print("  ‚úì Evaluation pipeline")
print("  ‚úì Checkpointing (save/load)")
print("  ‚úì Resume training")
print("  ‚úì Optimizer state management")
print("  ‚úì Learning rate scheduling")

print("\nüìä TRAINING RESULTS:")
print(f"  ‚Ä¢ Training completed: {training_config.num_epochs} epochs")
print(f"  ‚Ä¢ Total steps: {trainer.global_step}")
print(f"  ‚Ä¢ Final train loss: {train_history[-1]['loss']:.4f}")
if eval_history:
    print(f"  ‚Ä¢ Final eval loss: {eval_history[-1]['eval_loss']:.4f}")
    improvement = (eval_history[0]['eval_loss'] - eval_history[-1]['eval_loss']) / eval_history[0]['eval_loss'] * 100
    print(f"  ‚Ä¢ Eval improvement: {improvement:.2f}%")
print(f"  ‚Ä¢ Training time: {training_time:.2f}s")
print(f"  ‚Ä¢ Checkpoints saved: {len([d for d in output_dir.iterdir() if d.is_dir()])}")

print("\n‚ö° PERFORMANCE:")
print(f"  ‚Ä¢ Avg step time: {avg_time:.3f}s")
print(f"  ‚Ä¢ Throughput: {len(batch['input_ids']) / avg_time:.2f} samples/sec")
if device == "cuda":
    print(f"  ‚Ä¢ Peak GPU memory: {peak_memory:.2f} MB")

print("\n" + "=" * 70)
print("üéâ ALL TRAINING LOOP TESTS PASSED!")
print("=" * 70)