# Model Testing Notebook
## Reasoning Distillation Project

This notebook tests:
1. Student model initialization (FLAN-T5)
2. Model forward pass and generation
3. Teacher model setup (DatasetTeacher)
4. End-to-end inference pipeline
5. Model performance metrics
6. Memory and efficiency analysis

In [None]:
# Setup
import sys
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

# Imports
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pprint import pprint
from tqdm import tqdm
import time

from src.data.data_loader import TeacherDataLoader
from src.data.preprocessor import ReasoningPreprocessor, PreprocessConfig
from src.data.dataset import ESNLIDataset, create_dataloaders

from src.models.student import (
    StudentModel,
    StudentConfig,
    create_student_model,
    compare_model_sizes
)
from src.models.teacher import (
    DatasetTeacher,
    compare_teacher_modes
)

# Styling
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

%load_ext autoreload
%autoreload 2

In [None]:
# Set device (GPU or CPU)
# This determines where models and tensors will be loaded
device = "cuda" if torch.cuda.is_available() else "cpu"

print("=" * 70)
print("DEVICE CONFIGURATION")
print("=" * 70)
print(f"\n‚úì Device: {device.upper()}")

if device == "cuda":
    print(f"‚úì GPU: {torch.cuda.get_device_name(0)}")
    print(f"‚úì CUDA version: {torch.version.cuda}")
    print(f"‚úì GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è  Using CPU - Training will be slower")
    print(f"   CPU cores: {torch.get_num_threads()}")


## 1. Compare Model Sizes
Understand the different FLAN-T5 model sizes available.

In [None]:
# Display model size comparison
compare_model_sizes()

## 2. Initialize Student Model
Load and inspect FLAN-T5 student model.

In [None]:
# Create student model configuration
print("=" * 70)
print("INITIALIZING STUDENT MODEL")
print("=" * 70)

student_config = StudentConfig(
    model_name="google/flan-t5-base",
    max_source_length=256,
    max_target_length=128,
    device=device,
    num_beams=4,
    temperature=1.0
)

# Initialize student
student = StudentModel(student_config)

print("\n‚úì Student model initialized!")

In [None]:
# Display model information
print("\n" + "=" * 70)
print("STUDENT MODEL INFORMATION")
print("=" * 70)

model_info = student.get_model_info()
pprint(model_info)

print("\nMemory Footprint:")
memory = student.get_memory_footprint()
for key, value in memory.items():
    print(f"  {key}: {value:.2f} MB")

In [None]:
# Visualize model architecture
print("\n" + "=" * 70)
print("MODEL ARCHITECTURE SUMMARY")
print("=" * 70)

print(f"\nEncoder:")
print(f"  Layers: {model_info['encoder_layers']}")
print(f"  Hidden size: {model_info['hidden_size']}")
print(f"  Attention heads: {model_info['num_heads']}")

print(f"\nDecoder:")
print(f"  Layers: {model_info['decoder_layers']}")
print(f"  Hidden size: {model_info['hidden_size']}")
print(f"  Attention heads: {model_info['num_heads']}")

print(f"\nVocabulary size: {model_info['vocab_size']:,}")
print(f"Total parameters: {model_info['parameters']:,}")

# Calculate parameter distribution
encoder_params = sum(p.numel() for p in student.model.encoder.parameters())
decoder_params = sum(p.numel() for p in student.model.decoder.parameters())
other_params = model_info['parameters'] - encoder_params - decoder_params

# Plot parameter distribution
fig, ax = plt.subplots(figsize=(10, 6))
components = ['Encoder', 'Decoder', 'Other (embeddings, etc.)']
params = [encoder_params, decoder_params, other_params]
colors = ['#3498db', '#e74c3c', '#95a5a6']

bars = ax.bar(components, params, color=colors, alpha=0.7, edgecolor='black')
ax.set_ylabel('Number of Parameters')
ax.set_title('FLAN-T5 Parameter Distribution')

# Add value labels on bars
for bar, param in zip(bars, params):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{param/1e6:.1f}M\n({param/model_info["parameters"]*100:.1f}%)',
            ha='center', va='bottom', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.show()

## 3. Load Test Data

In [None]:
# Load small dataset for testing
print("=" * 70)
print("LOADING TEST DATA")
print("=" * 70)

loader = TeacherDataLoader()
esnli_data = loader.load_esnli()

# Use small subset
test_data = esnli_data['validation'].select(range(20))

print(f"\n‚úì Loaded {len(test_data)} test samples")

In [None]:
# Create dataset and dataloader
preprocess_config = PreprocessConfig(
    model_name="google/flan-t5-base",
    max_source_length=256,
    max_target_length=128
)
preprocessor = ReasoningPreprocessor(preprocess_config)

test_dataset = ESNLIDataset(
    test_data,
    preprocessor,
    use_cache=False
)

test_loader = create_dataloaders(
    test_dataset,
    batch_size=4,
    num_workers=0,
    pad_token_id=preprocessor.tokenizer.pad_token_id,
    shuffle_train=False
)

print(f"‚úì Created test dataloader with {len(test_loader)} batches")

## 4. Test Forward Pass (Training Mode)

In [None]:
# Test forward pass
print("=" * 70)
print("TESTING FORWARD PASS")
print("=" * 70)

# Get a batch
batch = next(iter(test_loader))

# Move to device
batch = {k: v.to(device) for k, v in batch.items()}

print(f"\nBatch shapes:")
for key, value in batch.items():
    print(f"  {key}: {value.shape}")

# Forward pass
student.model.train()
outputs = student(
    input_ids=batch['input_ids'],
    attention_mask=batch['attention_mask'],
    labels=batch['labels']
)

print(f"\nOutputs:")
print(f"  Loss: {outputs['loss'].item():.4f}")
print(f"  Logits shape: {outputs['logits'].shape}")
print(f"  Logits range: [{outputs['logits'].min():.2f}, {outputs['logits'].max():.2f}]")

print("\n‚úì Forward pass successful!")

In [None]:
# Visualize loss landscape for one batch
print("\n" + "=" * 70)
print("ANALYZING BATCH LOSSES")
print("=" * 70)

# Compute loss for each sample in batch
student.model.eval()
batch_losses = []

with torch.no_grad():
    for i in range(batch['input_ids'].shape[0]):
        sample_output = student(
            input_ids=batch['input_ids'][i:i+1],
            attention_mask=batch['attention_mask'][i:i+1],
            labels=batch['labels'][i:i+1]
        )
        batch_losses.append(sample_output['loss'].item())

# Plot
plt.figure(figsize=(10, 5))
plt.bar(range(len(batch_losses)), batch_losses, color='#e74c3c', alpha=0.7, edgecolor='black')
plt.axhline(float(np.mean(batch_losses)), color='blue', linestyle='--', label=f'Mean: {np.mean(batch_losses):.3f}')
plt.xlabel('Sample Index')
plt.ylabel('Loss')
plt.title('Per-Sample Loss in Batch (Untrained Model)')
plt.legend()
plt.tight_layout()
plt.show()

print(f"\nLoss statistics:")
print(f"  Mean: {np.mean(batch_losses):.4f}")
print(f"  Std: {np.std(batch_losses):.4f}")
print(f"  Min: {np.min(batch_losses):.4f}")
print(f"  Max: {np.max(batch_losses):.4f}")

## 5. Test Generation (Inference Mode)

In [None]:
# Test generation
print("=" * 70)
print("TESTING GENERATION")
print("=" * 70)

student.model.eval()

# Generate from batch
print("\nGenerating predictions...")
start_time = time.time()

with torch.no_grad():
    generated_ids = student.generate(
        input_ids=batch['input_ids'],
        attention_mask=batch['attention_mask'],
        max_length=128,
        num_beams=4
    )

generation_time = time.time() - start_time

print(f"\n‚úì Generation complete in {generation_time:.2f}s")
print(f"  Time per sample: {generation_time / batch['input_ids'].shape[0]:.3f}s")
print(f"\nGenerated IDs shape: {generated_ids.shape}")

In [None]:
# Decode and display predictions
print("\n" + "=" * 70)
print("PREDICTIONS vs GROUND TRUTH")
print("=" * 70)

predictions = student.decode_batch(generated_ids)

# Decode ground truth
labels_for_decode = batch['labels'].clone()
labels_for_decode[labels_for_decode == -100] = preprocessor.tokenizer.pad_token_id
ground_truths = student.decode_batch(labels_for_decode)

# Decode inputs
inputs = student.decode_batch(batch['input_ids'])

# Display comparisons
for i in range(min(3, len(predictions))):
    print(f"\n{'='*70}")
    print(f"SAMPLE {i+1}")
    print(f"{'='*70}")
    print(f"\nInput:\n{inputs[i]}")
    print(f"\nGround Truth:\n{ground_truths[i]}")
    print(f"\nPrediction (untrained):\n{predictions[i]}")
    print(f"\nPrediction length: {len(predictions[i].split())} words")
    print(f"Ground truth length: {len(ground_truths[i].split())} words")

## 6. Test Different Generation Strategies

In [None]:
# Compare generation strategies
print("=" * 70)
print("COMPARING GENERATION STRATEGIES")
print("=" * 70)

# Take single sample
single_input = batch['input_ids'][0:1]
single_mask = batch['attention_mask'][0:1]

strategies = {
    'Greedy': {'num_beams': 1, 'do_sample': False},
    'Beam Search (4)': {'num_beams': 4, 'do_sample': False},
    'Sampling (T=1.0)': {'num_beams': 1, 'do_sample': True, 'temperature': 1.0, 'top_k': 50},
    'Sampling (T=0.7)': {'num_beams': 1, 'do_sample': True, 'temperature': 0.7, 'top_k': 50}
}

results = {}

with torch.no_grad():
    for name, params in strategies.items():
        start = time.time()
        generated = student.generate(
            input_ids=single_input,
            attention_mask=single_mask,
            max_length=128,
            **params
        )
        elapsed = time.time() - start
        
        decoded = student.decode_batch(generated)[0]
        results[name] = {'text': decoded, 'time': elapsed}

# Display results
print(f"\nInput: {inputs[0]}\n")

for name, result in results.items():
    print(f"\n--- {name} ({result['time']:.3f}s) ---")
    print(result['text'])

## 7. Initialize Teacher (DatasetTeacher)

In [None]:
# Compare teacher modes
compare_teacher_modes()

In [None]:
# Initialize DatasetTeacher
print("\n" + "=" * 70)
print("INITIALIZING DATASET TEACHER")
print("=" * 70)

teacher = DatasetTeacher()

print("\n‚úì DatasetTeacher initialized!")
print("\nThis teacher uses pre-generated explanations from:")
print("  ‚Ä¢ e-SNLI: Human-written explanations")
print("\nNo additional compute required!")

In [None]:
# Extract teacher knowledge from samples
print("\n" + "=" * 70)
print("EXTRACTING TEACHER KNOWLEDGE")
print("=" * 70)

# Get raw samples
sample_indices = [0, 1, 2]

for idx in sample_indices:
    raw_sample = test_data[idx]
    
    # Extract teacher knowledge
    teacher_knowledge = teacher.extract_teacher_knowledge(raw_sample, task_type="nli")
    
    print(f"\n--- Sample {idx + 1} ---")
    print(f"Premise: {raw_sample['premise']}")
    print(f"Hypothesis: {raw_sample['hypothesis']}")
    print(f"\nTeacher Knowledge:")
    print(f"  Label: {teacher_knowledge['label']}")
    print(f"  Explanation: {teacher_knowledge['explanation']}")
    print(f"  Alternative explanations available: {len([e for e in teacher_knowledge['alternative_explanations'] if e])}")

## 8. Performance Benchmarking

In [None]:
# Benchmark inference speed
print("=" * 70)
print("INFERENCE SPEED BENCHMARK")
print("=" * 70)

batch_sizes = [1, 4, 8, 16] if device == "cuda" else [1, 4, 8]
inference_times = []

student.model.eval()

for bs in batch_sizes:
    # Create batch
    test_input = batch['input_ids'][:bs].to(device)
    test_mask = batch['attention_mask'][:bs].to(device)
    
    # Warmup
    with torch.no_grad():
        _ = student.generate(test_input, test_mask, max_length=128, num_beams=1)
    
    # Benchmark
    times = []
    for _ in range(5):
        start = time.time()
        with torch.no_grad():
            _ = student.generate(test_input, test_mask, max_length=128, num_beams=1)
        if device == "cuda":
            torch.cuda.synchronize()
        times.append(time.time() - start)
    
    avg_time = np.mean(times)
    throughput = bs / avg_time
    inference_times.append({'batch_size': bs, 'time': avg_time, 'throughput': throughput})
    
    print(f"\nBatch size {bs}:")
    print(f"  Avg time: {avg_time:.3f}s")
    print(f"  Throughput: {throughput:.2f} samples/sec")
    print(f"  Time per sample: {avg_time/bs:.3f}s")

In [None]:
# Visualize throughput
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

batch_sizes_list = [r['batch_size'] for r in inference_times]
times_list = [r['time'] for r in inference_times]
throughput_list = [r['throughput'] for r in inference_times]

# Plot 1: Time vs Batch Size
ax1.plot(batch_sizes_list, times_list, marker='o', linewidth=2, markersize=8, color='#e74c3c')
ax1.set_xlabel('Batch Size')
ax1.set_ylabel('Time (seconds)')
ax1.set_title('Inference Time vs Batch Size')
ax1.grid(True, alpha=0.3)

# Plot 2: Throughput vs Batch Size
ax2.plot(batch_sizes_list, throughput_list, marker='s', linewidth=2, markersize=8, color='#2ecc71')
ax2.set_xlabel('Batch Size')
ax2.set_ylabel('Throughput (samples/sec)')
ax2.set_title('Throughput vs Batch Size')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\n‚úì Optimal batch size for throughput: {batch_sizes_list[np.argmax(throughput_list)]}")

## 9. Memory Usage Analysis

In [None]:
# Analyze memory usage
print("=" * 70)
print("MEMORY USAGE ANALYSIS")
print("=" * 70)

if device == "cuda":
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    # Measure memory during forward pass
    initial_memory = torch.cuda.memory_allocated() / 1e6
    
    with torch.no_grad():
        outputs = student(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )
    
    forward_memory = torch.cuda.memory_allocated() / 1e6
    peak_memory = torch.cuda.max_memory_allocated() / 1e6
    
    print(f"\nGPU Memory Usage:")
    print(f"  Model parameters: {memory['total_mb']:.2f} MB")
    print(f"  After forward pass: {forward_memory:.2f} MB")
    print(f"  Peak usage: {peak_memory:.2f} MB")
    print(f"  Activations: {forward_memory - memory['total_mb']:.2f} MB")
    
    components = ['Model\nParameters', 'Activations', 'Peak Usage']
    sizes = [memory['total_mb'], forward_memory - memory['total_mb'], peak_memory - forward_memory]
    colors = ['#3498db', '#e67e22', '#e74c3c']
    
    plt.figure(figsize=(10, 6))
    bars = plt.bar(components, sizes, color=colors, alpha=0.7, edgecolor='black')
    plt.ylabel('Memory (MB)')
    plt.title('GPU Memory Breakdown During Inference')
    
    for bar, size in zip(bars, sizes):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height,
                f'{size:.1f} MB',
                ha='center', va='bottom', fontsize=11, fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
else:
    print("\nCPU mode - GPU memory analysis not available")
    print(f"Model memory footprint: {memory['total_mb']:.2f} MB")

## 10. Test Model Freezing

In [None]:
# Test freezing/unfreezing components
print("=" * 70)
print("TESTING MODEL FREEZING")
print("=" * 70)

def count_trainable_params(model):
    return sum(p.numel() for p in model.model.parameters() if p.requires_grad)

# Initial state
initial_trainable = count_trainable_params(student)
print(f"\nInitial trainable parameters: {initial_trainable:,}")

# Freeze encoder
student.freeze_encoder()
encoder_frozen = count_trainable_params(student)
print(f"After freezing encoder: {encoder_frozen:,} ({encoder_frozen/initial_trainable*100:.1f}%)")

# Unfreeze encoder, freeze decoder
student.unfreeze_encoder()
student.freeze_decoder()
decoder_frozen = count_trainable_params(student)
print(f"After freezing decoder: {decoder_frozen:,} ({decoder_frozen/initial_trainable*100:.1f}%)")

# Unfreeze all
student.unfreeze_decoder()
final_trainable = count_trainable_params(student)
print(f"After unfreezing all: {final_trainable:,}")

assert initial_trainable == final_trainable, "Parameter count mismatch after unfreezing!"
print("\n‚úì Freezing/unfreezing works correctly!")

## 11. Compare Different Model Sizes

In [None]:
# Compare small vs base models
print("=" * 70)
print("COMPARING MODEL SIZES")
print("=" * 70)

# Create small model
print("\nLoading FLAN-T5-small...")
student_small = create_student_model(model_size="small", device=device)

models_comparison = [
    {'name': 'FLAN-T5-small', 'model': student_small},
    {'name': 'FLAN-T5-base', 'model': student}
]

comparison_results = []

for model_dict in models_comparison:
    model_name = model_dict['name']
    model = model_dict['model']
    
    info = model.get_model_info()
    mem = model.get_memory_footprint()
    
    # Measure inference time
    model.model.eval()
    test_input = batch['input_ids'][:4].to(device)
    test_mask = batch['attention_mask'][:4].to(device)
    
    times = []
    for _ in range(3):
        start = time.time()
        with torch.no_grad():
            _ = model.generate(test_input, test_mask, max_length=128, num_beams=1)
        if device == "cuda":
            torch.cuda.synchronize()
        times.append(time.time() - start)
    
    avg_time = np.mean(times)
    
    comparison_results.append({
        'name': model_name,
        'parameters': info['parameters'],
        'memory_mb': mem['total_mb'],
        'inference_time': avg_time,
        'layers': info['encoder_layers']
    })

# Display comparison
print("\n" + "=" * 70)
print(f"{'Model':<20} {'Params':<15} {'Memory (MB)':<15} {'Time (s)':<15} {'Layers':<10}")
print("=" * 70)

for result in comparison_results:
    print(f"{result['name']:<20} {result['parameters']/1e6:<14.1f}M {result['memory_mb']:<15.1f} "
          f"{result['inference_time']:<15.3f} {result['layers']:<10}")

# Calculate speedup
speedup = comparison_results[1]['inference_time'] / comparison_results[0]['inference_time']
compression = comparison_results[1]['parameters'] / comparison_results[0]['parameters']

print("\n" + "=" * 70)
print(f"Small is {speedup:.2f}x faster than base")
print(f"Base has {compression:.2f}x more parameters than small")
print("=" * 70)

In [None]:
# Visualize model comparison
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

model_names = [r['name'].replace('FLAN-T5-', '') for r in comparison_results]
colors_viz = ['#3498db', '#e74c3c']

# Parameters
params = [r['parameters']/1e6 for r in comparison_results]
axes[0].bar(model_names, params, color=colors_viz, alpha=0.7, edgecolor='black')
axes[0].set_ylabel('Parameters (Millions)')
axes[0].set_title('Model Size')
for i, v in enumerate(params):
    axes[0].text(i, v + 5, f'{v:.1f}M', ha='center', va='bottom', fontweight='bold')

# Memory
memory_vals = [r['memory_mb'] for r in comparison_results]
axes[1].bar(model_names, memory_vals, color=colors_viz, alpha=0.7, edgecolor='black')
axes[1].set_ylabel('Memory (MB)')
axes[1].set_title('Memory Footprint')
for i, v in enumerate(memory_vals):
    axes[1].text(i, v + 10, f'{v:.0f}MB', ha='center', va='bottom', fontweight='bold')

# Inference Time
time_vals = [r['inference_time'] for r in comparison_results]
axes[2].bar(model_names, time_vals, color=colors_viz, alpha=0.7, edgecolor='black')
axes[2].set_ylabel('Inference Time (seconds)')
axes[2].set_title('Inference Speed (batch=4)')
for i, v in enumerate(time_vals):
    axes[2].text(i, v + 0.01, f'{v:.3f}s', ha='center', va='bottom', fontweight='bold')

plt.suptitle('FLAN-T5 Model Comparison', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

## 12. Summary and Recommendations

In [None]:
print("\n" + "=" * 70)
print("MODEL TESTING SUMMARY")
print("=" * 70)

print("\n‚úÖ STUDENT MODEL (FLAN-T5-base):")
print(f"  ‚Ä¢ Parameters: {model_info['parameters']:,}")
print(f"  ‚Ä¢ Memory: {memory['total_mb']:.2f} MB")
print(f"  ‚Ä¢ Inference speed: ~{1/inference_times[0]['time']:.2f} samples/sec (single)")
print(f"  ‚Ä¢ Forward pass: WORKING ‚úì")
print(f"  ‚Ä¢ Generation: WORKING ‚úì")
print(f"  ‚Ä¢ Freezing: WORKING ‚úì")

print("\n‚úÖ TEACHER MODEL (DatasetTeacher):")
print("  ‚Ä¢ Mode: Dataset-as-Teacher")
print("  ‚Ä¢ Compute cost: ZERO")
print("  ‚Ä¢ Knowledge extraction: WORKING ‚úì")
print("  ‚Ä¢ Data sources: e-SNLI (human)")

print("\nüìä PERFORMANCE INSIGHTS:")
print(f"  ‚Ä¢ FLAN-T5-small is {speedup:.2f}x faster but {compression:.2f}x smaller")
print(f"  ‚Ä¢ Optimal batch size for throughput: {batch_sizes_list[np.argmax(throughput_list)]}")
print(f"  ‚Ä¢ Generation strategies tested: 4 variants")