# Ablation Studies Notebook
## Reasoning Distillation Project

This notebook performs systematic ablation studies to understand **knowledge distillation** hyperparameters:

### Knowledge Distillation Setup:
```
Dataset ‚Üí Teacher Model (FLAN-T5-XL) ‚Üí Soft Logits (probabilities)
       ‚Üò                              ‚Üó
         Student Model (FLAN-T5-Small)
         
Loss = Œ±¬∑CE(student, labels) + Œ≤¬∑KL(student||teacher)
```

### Ablation Studies:
1. **Distillation Weight (Œ≤)**: How much to learn from teacher vs labels (0.0, 0.3, 0.5, 0.7)
2. **Temperature**: Softness of probability distributions (1.0, 2.0, 3.0, 4.0)
3. **Label Smoothing**: Regularization effect (0.0, 0.1, 0.2)
4. **Training Data Size**: Data efficiency with distillation (10%, 50%, 100%)
5. **Generation Temperature**: Inference-time temperature variations

In [None]:
# Setup
import sys
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

# Imports
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from pprint import pprint
import time

from src.data.data_loader import TeacherDataLoader
from src.data.preprocessor import ReasoningPreprocessor, PreprocessConfig
from src.data.dataset import ESNLIDataset, create_dataloaders

from src.models.student import StudentModel, StudentConfig
from src.models.teacher import FlanT5Teacher, TeacherConfig  # Teacher model for distillation

from src.training.distillation import (
    DistillationConfig,
    TokenLevelDistillation,  # Token-level distillation with teacher
    compare_distillation_strategies
)

from src.training.trainer import Trainer, TrainingConfig

from src.evaluation.evaluator import Evaluator, EvaluationConfig
from src.evaluation.metrics import MetricsConfig, format_metrics

# Styling
sns.set_style('whitegrid')

In [None]:
# Set device (GPU or CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Load Base Dataset and Teacher Model

For ablation studies, we use smaller subsets but still require the teacher model to provide soft probability distributions for knowledge distillation.

In [None]:
# Load dataset
print("=" * 70)
print("LOADING DATA")
print("=" * 70)

loader = TeacherDataLoader()
esnli_data = loader.load_esnli()

# Use subset for faster ablation experiments
train_subset = esnli_data['train'].select(range(500))  # 500 samples
val_subset = esnli_data['validation'].select(range(100))  # 100 samples

print(f"\n‚úì Train samples: {len(train_subset)}")
print(f"‚úì Val samples: {len(val_subset)}")

# Prepare data
preprocess_config = PreprocessConfig(
    model_name="google/flan-t5-small",
    max_source_length=128,
    max_target_length=64
)

preprocessor = ReasoningPreprocessor(preprocess_config)
train_dataset = ESNLIDataset(train_subset, preprocessor, use_cache=True)
val_dataset = ESNLIDataset(val_subset, preprocessor, use_cache=True)

train_loader, val_loader = create_dataloaders(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    batch_size=16,
    num_workers=0,
    pad_token_id=preprocessor.tokenizer.pad_token_id
)

print(f"\nTrain batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

In [None]:
# Load Teacher Model (FLAN-T5-XL) for knowledge distillation
print("=" * 70)
print("LOADING TEACHER MODEL (google/flan-t5-xl)")
print("=" * 70)

teacher_config = TeacherConfig(
    model_name="google/flan-t5-xl",
    device=device,
    use_fp16=True if device == "cuda" else False,
    max_source_length=128,
    max_target_length=64
)

print(f"Loading {teacher_config.model_name}...")
print(f"Device: {device}, FP16: {teacher_config.use_fp16}")

teacher = FlanT5Teacher(teacher_config)

print(f"\n‚úì Teacher loaded!")
print(f"  Parameters: {teacher.count_parameters():,}")
print(f"\n‚ö†Ô∏è  Teacher model will be used for ALL ablation studies")
print(f"   to provide soft probability distributions (dark knowledge)")

## 2. Ablation Study 1: Distillation Weight (Œ≤)

Test the effect of different distillation weights on model performance.

**Loss = Œ±¬∑CE(student, labels) + Œ≤¬∑KL(student||teacher)**

- Œ≤ = 0.0: No distillation (standard supervised learning)
- Œ≤ = 0.3: Moderate distillation
- Œ≤ = 0.5: Balanced distillation (recommended)
- Œ≤ = 0.7: Strong distillation (more teacher influence)

In [None]:
# Ablation 1: Distillation Weight (Œ≤)
print("=" * 70)
print("ABLATION STUDY 1: DISTILLATION WEIGHT (Œ≤)")
print("=" * 70)

distill_weights = [0.0, 0.3, 0.5, 0.7]
distill_weight_results = []

for beta in distill_weights:
    print(f"\n{'='*70}")
    print(f"Testing Œ≤ (distill_weight) = {beta}")
    print(f"Loss = 1.0¬∑CE + {beta}¬∑KL(student||teacher)")
    print(f"{'='*70}")
    
    # Create fresh student model
    student_config = StudentConfig(
        model_name="google/flan-t5-small",
        max_source_length=128,
        max_target_length=64,
        device=device
    )
    student = StudentModel(student_config)
    
    # Create distillation strategy with teacher model
    distill_config = DistillationConfig(
        ce_weight=1.0,           # Œ± - Cross-entropy weight
        distill_weight=beta,     # Œ≤ - KL divergence weight
        temperature=2.0,         # Temperature for softening
        label_smoothing=0.0
    )
    
    # Token-level distillation requires teacher model
    distillation_strategy = TokenLevelDistillation(
        teacher_model=teacher,
        config=distill_config
    )
    
    # Train
    training_config = TrainingConfig(
        num_epochs=3,
        learning_rate=5e-5,
        eval_steps=20,
        save_steps=1000,  # Don't save
        logging_steps=10,
        output_dir=f"../experiments/ablation_distill_weight_{beta}",
        eval_strategy="steps"
    )
    
    trainer = Trainer(
        model=student,
        train_dataloader=train_loader,
        eval_dataloader=val_loader,
        distillation_strategy=distillation_strategy,
        config=training_config
    )
    
    start_time = time.time()
    history = trainer.train()
    training_time = time.time() - start_time
    
    # Evaluate
    eval_config = EvaluationConfig(
        metrics_config=MetricsConfig(
            compute_rouge=True,
            compute_bertscore=False,
            compute_faithfulness=True
        ),
        save_predictions=False,
        output_dir=f"../experiments/ablation_distill_weight_{beta}_eval"
    )
    
    evaluator = Evaluator(student, eval_config)
    results = evaluator.evaluate(val_loader, split_name="val")
    
    # Store results
    distill_weight_results.append({
        'distill_weight': beta,
        'accuracy': results['metrics']['label_accuracy'],
        'rouge1': results['metrics']['rouge1'],
        'rougeL': results['metrics']['rougeL'],
        'faithfulness': results['metrics']['faithfulness'],
        'final_train_loss': history['train_history'][-1]['loss'],
        'final_eval_loss': history['eval_history'][-1]['eval_loss'] if history['eval_history'] else None,
        'training_time': training_time
    })
    
    print(f"\n‚úì Œ≤ = {beta} completed")
    print(f"  Accuracy: {results['metrics']['label_accuracy']:.4f}")
    print(f"  ROUGE-L: {results['metrics']['rougeL']:.4f}")

print("\n" + "="*70)
print("DISTILLATION WEIGHT ABLATION COMPLETE")
print("="*70)

In [None]:
# Visualize distillation weight results
distill_df = pd.DataFrame(distill_weight_results)

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

metrics_to_plot = ['accuracy', 'rouge1', 'rougeL', 'faithfulness']
colors_palette = ['#e74c3c', '#3498db', '#2ecc71', '#f39c12']

for idx, (metric, color) in enumerate(zip(metrics_to_plot, colors_palette)):
    axes[idx].plot(distill_df['distill_weight'], distill_df[metric], 
                   marker='o', linewidth=2, markersize=10, color=color)
    axes[idx].set_xlabel('Distillation Weight (Œ≤)')
    axes[idx].set_ylabel(metric.upper())
    axes[idx].set_title(f'{metric.upper()} vs Distillation Weight')
    axes[idx].grid(True, alpha=0.3)
    
    # Add value labels
    for x, y in zip(distill_df['distill_weight'], distill_df[metric]):
        axes[idx].text(x, y + 0.01, f'{y:.3f}', ha='center', fontsize=9)

plt.suptitle('Ablation Study: Distillation Weight (Œ≤) Impact\nLoss = Œ±¬∑CE + Œ≤¬∑KL(student||teacher)', 
             fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nDistillation Weight Results Summary:")
print(distill_df.to_string(index=False))

## 3. Ablation Study 2: Distillation Temperature

Test different temperature values for softening probability distributions.

Higher temperature ‚Üí softer distributions ‚Üí more knowledge transfer from teacher's non-top predictions

In [None]:
# Ablation 2: Distillation Temperature
print("=" * 70)
print("ABLATION STUDY 2: DISTILLATION TEMPERATURE")
print("=" * 70)

temperatures = [1.0, 2.0, 3.0, 4.0]
temperature_results = []

for temp in temperatures:
    print(f"\n{'='*70}")
    print(f"Testing Temperature = {temp}")
    print(f"{'='*70}")
    
    # Create fresh student model
    student = StudentModel(StudentConfig(
        model_name="google/flan-t5-small",
        max_source_length=128,
        max_target_length=64,
        device=device
    ))
    
    # Create distillation strategy with different temperature
    distill_config = DistillationConfig(
        ce_weight=1.0,
        distill_weight=0.5,      # Fixed distillation weight
        temperature=temp,         # Varying temperature
        label_smoothing=0.0
    )
    
    distillation_strategy = TokenLevelDistillation(
        teacher_model=teacher,
        config=distill_config
    )
    
    # Train
    training_config = TrainingConfig(
        num_epochs=3,
        learning_rate=5e-5,
        eval_steps=20,
        save_steps=1000,
        logging_steps=10,
        output_dir=f"../experiments/ablation_temp_{temp}",
        eval_strategy="steps"
    )
    
    trainer = Trainer(
        model=student,
        train_dataloader=train_loader,
        eval_dataloader=val_loader,
        distillation_strategy=distillation_strategy,
        config=training_config
    )
    
    start_time = time.time()
    history = trainer.train()
    training_time = time.time() - start_time
    
    # Evaluate
    evaluator = Evaluator(student, EvaluationConfig(
        metrics_config=MetricsConfig(
            compute_rouge=True,
            compute_bertscore=False,
            compute_faithfulness=True
        ),
        save_predictions=False,
        output_dir=f"../experiments/ablation_temp_{temp}_eval"
    ))
    
    results = evaluator.evaluate(val_loader, split_name="val")
    
    # Store results
    temperature_results.append({
        'temperature': temp,
        'accuracy': results['metrics']['label_accuracy'],
        'rouge1': results['metrics']['rouge1'],
        'rougeL': results['metrics']['rougeL'],
        'faithfulness': results['metrics']['faithfulness'],
        'training_time': training_time
    })
    
    print(f"\n‚úì Temperature {temp} completed")
    print(f"  Accuracy: {results['metrics']['label_accuracy']:.4f}")

print("\n" + "="*70)
print("TEMPERATURE ABLATION COMPLETE")
print("="*70)

In [None]:
# Visualize temperature results
temp_df = pd.DataFrame(temperature_results)

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Plot 1: Performance vs Temperature
metrics = ['accuracy', 'rouge1', 'rougeL', 'faithfulness']
colors = ['#e74c3c', '#3498db', '#2ecc71', '#f39c12']

for metric, color in zip(metrics, colors):
    axes[0].plot(temp_df['temperature'], temp_df[metric], 
                 marker='o', linewidth=2, markersize=8, label=metric.upper(), color=color)

axes[0].set_xlabel('Distillation Temperature')
axes[0].set_ylabel('Score')
axes[0].set_title('Performance vs Distillation Temperature')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot 2: Accuracy zoom
axes[1].plot(temp_df['temperature'], temp_df['accuracy'], 
             marker='o', linewidth=3, markersize=10, color='#e74c3c')
axes[1].set_xlabel('Distillation Temperature')
axes[1].set_ylabel('Label Accuracy')
axes[1].set_title('Label Accuracy vs Temperature')
axes[1].grid(True, alpha=0.3)

for x, y in zip(temp_df['temperature'], temp_df['accuracy']):
    axes[1].text(x, y + 0.005, f'{y:.3f}', ha='center', fontsize=10)

plt.suptitle('Ablation Study: Distillation Temperature Impact', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nTemperature Results Summary:")
print(temp_df.to_string(index=False))

## 4. Ablation Study 3: Training Data Size with Distillation

Evaluate how model performance scales with training data size when using knowledge distillation.

In [None]:
# Ablation 3: Training Data Size
print("=" * 70)
print("ABLATION STUDY 3: TRAINING DATA SIZE WITH DISTILLATION")
print("=" * 70)

data_fractions = [0.1, 0.5, 1.0]
data_size_results = []

# Use larger base dataset for this study
full_train = esnli_data['train'].select(range(2000))

for fraction in data_fractions:
    print(f"\n{'='*70}")
    print(f"Testing Data Fraction = {fraction*100}%")
    print(f"{'='*70}")
    
    # Sample data
    n_samples = int(len(full_train) * fraction)
    train_fraction = full_train.select(range(n_samples))
    
    # Create datasets
    train_dataset_frac = ESNLIDataset(train_fraction, preprocessor, use_cache=True)
    train_loader_frac, _ = create_dataloaders(
        train_dataset=train_dataset_frac,
        val_dataset=val_dataset,
        batch_size=16,
        num_workers=0,
        pad_token_id=preprocessor.tokenizer.pad_token_id
    )
    
    # Create fresh student model
    student = StudentModel(StudentConfig(
        model_name="google/flan-t5-small",
        max_source_length=128,
        max_target_length=64,
        device=device
    ))
    
    # Create distillation strategy (fixed optimal params)
    distillation_strategy = TokenLevelDistillation(
        teacher_model=teacher,
        config=DistillationConfig(
            ce_weight=1.0,
            distill_weight=0.5,
            temperature=2.0,
            label_smoothing=0.0
        )
    )
    
    # Train
    training_config = TrainingConfig(
        num_epochs=3,
        learning_rate=5e-5,
        eval_steps=20,
        save_steps=1000,
        logging_steps=10,
        output_dir=f"../experiments/ablation_datasize_{fraction}",
        eval_strategy="steps"
    )
    
    trainer = Trainer(
        model=student,
        train_dataloader=train_loader_frac,
        eval_dataloader=val_loader,
        distillation_strategy=distillation_strategy,
        config=training_config
    )
    
    start_time = time.time()
    history = trainer.train()
    training_time = time.time() - start_time
    
    # Evaluate
    evaluator = Evaluator(student, EvaluationConfig(
        metrics_config=MetricsConfig(
            compute_rouge=True,
            compute_bertscore=False,
            compute_faithfulness=True
        ),
        save_predictions=False,
        output_dir=f"../experiments/ablation_datasize_{fraction}_eval"
    ))
    
    results = evaluator.evaluate(val_loader, split_name="val")
    
    # Store results
    data_size_results.append({
        'fraction': fraction,
        'n_samples': n_samples,
        'accuracy': results['metrics']['label_accuracy'],
        'rouge1': results['metrics']['rouge1'],
        'rougeL': results['metrics']['rougeL'],
        'faithfulness': results['metrics']['faithfulness'],
        'training_time': training_time
    })
    
    print(f"\n‚úì Data fraction {fraction} completed")
    print(f"  Samples: {n_samples}")
    print(f"  Accuracy: {results['metrics']['label_accuracy']:.4f}")

print("\n" + "="*70)
print("DATA SIZE ABLATION COMPLETE")
print("="*70)

In [None]:
# Visualize data size results
datasize_df = pd.DataFrame(data_size_results)

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Plot 1: Performance vs Data Size
metrics = ['accuracy', 'rouge1', 'rougeL', 'faithfulness']
colors = ['#e74c3c', '#3498db', '#2ecc71', '#f39c12']

for metric, color in zip(metrics, colors):
    axes[0].plot(datasize_df['n_samples'], datasize_df[metric], 
                 marker='o', linewidth=2, markersize=8, label=metric.upper(), color=color)

axes[0].set_xlabel('Number of Training Samples')
axes[0].set_ylabel('Score')
axes[0].set_title('Performance vs Training Data Size (with Distillation)')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot 2: Training Time vs Data Size
axes[1].plot(datasize_df['n_samples'], datasize_df['training_time'], 
             marker='s', linewidth=2, markersize=8, color='#9b59b6')
axes[1].set_xlabel('Number of Training Samples')
axes[1].set_ylabel('Training Time (seconds)')
axes[1].set_title('Training Time vs Data Size')
axes[1].grid(True, alpha=0.3)

for x, y in zip(datasize_df['n_samples'], datasize_df['training_time']):
    axes[1].text(x, y + 5, f'{y:.0f}s', ha='center', fontsize=9)

plt.suptitle('Ablation Study: Training Data Size Impact with Knowledge Distillation', 
             fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nData Size Results Summary:")
print(datasize_df.to_string(index=False))

## 5. Summary and Recommendations

Compile findings from all ablation studies to determine optimal hyperparameters for knowledge distillation.

In [None]:
# Compile all ablation results
print("=" * 70)
print("ABLATION STUDIES SUMMARY")
print("=" * 70)

print("\nDISTILLATION WEIGHT (Œ≤):")
best_distill_idx = distill_df['accuracy'].idxmax()
best_distill_weight: float = distill_df.loc[best_distill_idx, 'distill_weight']  # type: ignore
best_distill_accuracy: float = distill_df.loc[best_distill_idx, 'accuracy']  # type: ignore
print(f"  Best value: Œ≤ = {best_distill_weight}")
print(f"  Best accuracy: {best_distill_accuracy:.4f}")
print(f"  Insight: {'Distillation helps!' if best_distill_weight > 0 else 'No distillation benefit'}")
print(f"  Recommendation: Use distill_weight={best_distill_weight} for final training")

print("\nDISTILLATION TEMPERATURE:")
best_temp_idx = temp_df['accuracy'].idxmax()
best_temperature: float = temp_df.loc[best_temp_idx, 'temperature']  # type: ignore
best_temp_accuracy: float = temp_df.loc[best_temp_idx, 'accuracy']  # type: ignore
print(f"  Best value: T = {best_temperature}")
print(f"  Best accuracy: {best_temp_accuracy:.4f}")
print(f"  Recommendation: Use temperature={best_temperature} for distillation")

print("\nTRAINING DATA SIZE:")
print("  Performance scaling with distillation:")
for idx in range(len(datasize_df)):
    n_samples_val: int = datasize_df.loc[idx, 'n_samples']  # type: ignore
    accuracy_val: float = datasize_df.loc[idx, 'accuracy']  # type: ignore
    print(f"    {n_samples_val:4d} samples ‚Üí Accuracy: {accuracy_val:.4f}")

# Extract values for calculations
accuracy_first: float = datasize_df.loc[0, 'accuracy']  # type: ignore
accuracy_mid: float = datasize_df.loc[1, 'accuracy']  # type: ignore
accuracy_last: float = datasize_df.loc[len(datasize_df)-1, 'accuracy']  # type: ignore
improvement = accuracy_last - accuracy_first
data_efficiency = accuracy_mid / accuracy_last if accuracy_last != 0 else 0.0
print(f"  Recommendation: {'More data helps significantly' if improvement > 0.1 else 'Diminishing returns - distillation may compensate for less data'}")

print("\n" + "="*70)
print("KEY FINDINGS FOR KNOWLEDGE DISTILLATION:")
print("="*70)
print(f"1. Optimal distillation weight (Œ≤): {best_distill_weight}")
print(f"2. Optimal temperature: {best_temperature}")
print(f"3. Data efficiency: {'High' if data_efficiency > 0.9 else 'Moderate'}")
print(f"\nRecommended configuration for notebook 06 (Training Loop):")
print(f"   - ce_weight (Œ±) = 1.0")
print(f"   - distill_weight (Œ≤) = {best_distill_weight}")
print(f"   - temperature = {best_temperature}")
print(f"   - Teacher: google/flan-t5-xl")
print(f"   - Student: google/flan-t5-small")

In [None]:
# Save all results to CSV for later analysis
output_dir = Path("../experiments/ablation_studies")
output_dir.mkdir(parents=True, exist_ok=True)

distill_df.to_csv(output_dir / "distillation_weight_results.csv", index=False)
temp_df.to_csv(output_dir / "temperature_results.csv", index=False)
datasize_df.to_csv(output_dir / "data_size_results.csv", index=False)

print(f"‚úì All ablation results saved to {output_dir}")
print(f"\nüìÅ Files saved:")
print(f"   - distillation_weight_results.csv")
print(f"   - temperature_results.csv")
print(f"   - data_size_results.csv")