# Baseline Training (No Distillation)
## Reasoning Distillation Project

This notebook trains the **same student model WITHOUT knowledge distillation** to serve as a baseline for comparison.

### Training Setup:
```
Loss = α·CE(student, labels) + β·KL(student||teacher)
     = 1.0·CE + 0.0·KL
     = CE (standard supervised learning)
```

### Purpose:
- Train student model with **β = 0.0** (no teacher guidance)
- Use **identical hyperparameters** as notebook 06 (distillation training)
- Provide a **fair baseline** to measure the value of knowledge distillation

### Key Differences from Notebook 06:
| Parameter | Notebook 06 (Distillation) | This Notebook (Baseline) |
|-----------|---------------------------|-------------------------|
| β (distill_weight) | 0.5 | **0.0** |
| Teacher guidance | Yes | **No** |
| Loss function | CE + KL | **CE only** |

In [None]:
# Setup
import sys
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

# Imports
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pprint import pprint
import time

from src.data.data_loader import TeacherDataLoader
from src.data.preprocessor import ReasoningPreprocessor, PreprocessConfig
from src.data.dataset import ESNLIDataset, create_dataloaders

from src.models.student import StudentModel, StudentConfig
from src.models.teacher import FlanT5Teacher, TeacherConfig

from src.training.distillation import (
    DistillationConfig,
    TokenLevelDistillation,
)

from src.training.trainer import Trainer, TrainingConfig

In [None]:
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Load Dataset

Use the **same dataset configuration** as notebook 06 for fair comparison.

In [None]:
# Load dataset
print("=" * 70)
print("LOADING DATASET FOR BASELINE TRAINING")
print("=" * 70)

loader = TeacherDataLoader()
esnli_data = loader.load_esnli()

train_subset = esnli_data['train'].select(range(50000))
val_subset = esnli_data['validation'].select(range(5000))

print(f"\n✓ Train samples: {len(train_subset)}")
print(f"✓ Val samples: {len(val_subset)}")

In [None]:
# Create datasets and dataloaders
preprocess_config = PreprocessConfig(
    model_name="google/flan-t5-small",
    max_source_length=128,
    max_target_length=64
)

preprocessor = ReasoningPreprocessor(preprocess_config)

train_dataset = ESNLIDataset(train_subset, preprocessor, use_cache=True)
val_dataset = ESNLIDataset(val_subset, preprocessor, use_cache=True)

train_loader, val_loader = create_dataloaders(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    batch_size=32,
    num_workers=4,
    pad_token_id=preprocessor.tokenizer.pad_token_id
)

print(f"\nTrain batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

## 2. Load Teacher Model

We still need to load the teacher model for the distillation interface, but with **β = 0.0**, the KL loss will be multiplied by zero, effectively disabling teacher guidance.

In [None]:
# Load Teacher Model
print("=" * 70)
print("LOADING TEACHER MODEL (for interface only, β=0.0)")
print("=" * 70)

teacher_config = TeacherConfig(
    model_name="google/flan-t5-xl",
    device=device,
    use_fp16=True if device == "cuda" else False,
    max_source_length=128,
    max_target_length=64
)

print(f"Loading {teacher_config.model_name}...")
teacher = FlanT5Teacher(teacher_config)

print(f"\n✓ Teacher loaded (will NOT be used for training, β=0.0)")
print(f"  Parameters: {teacher.count_parameters():,}")

## 3. Initialize Student Model

In [None]:
# Initialize Student Model - SAME config as notebook 06
print("=" * 70)
print("INITIALIZING STUDENT MODEL")
print("=" * 70)

student_config = StudentConfig(
    model_name="google/flan-t5-small",
    max_source_length=128,
    max_target_length=64,
    device=device
)

student = StudentModel(student_config)

print(f"\n✓ Student loaded!")
print(f"  Parameters: {student.count_parameters():,}")
print(f"  Memory: {student.get_memory_footprint()['total_mb']:.2f} MB")

## 4. Configure Baseline Training (β = 0.0)

**KEY DIFFERENCE**: `distill_weight = 0.0` means NO knowledge distillation.

```
Loss = 1.0·CE + 0.0·KL = CE (standard supervised learning)
```

In [None]:
# Configure Baseline Training (NO DISTILLATION)
print("=" * 70)
print("CONFIGURING BASELINE TRAINING (β = 0.0)")
print("=" * 70)

# BASELINE configuration: β = 0.0 (NO teacher guidance)
baseline_distill_config = DistillationConfig(
    ce_weight=1.0,        # α - Cross-entropy weight
    distill_weight=0.0,   # β = 0.0 → NO KL divergence loss!
    temperature=2.0,      # Not used when β=0.0
    label_smoothing=0.0,
    distillation_type="token_level"
)

# Create distillation strategy (KL will be multiplied by 0)
baseline_strategy = TokenLevelDistillation(
    teacher_model=teacher,
    config=baseline_distill_config
)

print(f"\n✓ Baseline Training configured!")
print(f"  Loss = {baseline_distill_config.ce_weight}·CE + {baseline_distill_config.distill_weight}·KL")
print(f"  Loss = CE only (standard supervised learning)")
print(f"  ⚠️  Teacher guidance: DISABLED (β=0.0)")

## 5. Initialize Trainer

Use **IDENTICAL training hyperparameters** as notebook 06 for fair comparison.

In [None]:
# Create training configuration - SAME as notebook 06
print("=" * 70)
print("INITIALIZING TRAINER (SAME config as distillation training)")
print("=" * 70)

training_config = TrainingConfig(
    num_epochs=7,                          # SAME
    learning_rate=5e-5,                    # SAME
    warmup_steps=1200,                     # SAME
    eval_steps=1000,                       # SAME
    save_steps=1000,                       # SAME
    logging_steps=50,                      # SAME
    output_dir="../experiments/baseline",  # DIFFERENT: separate output dir
    eval_strategy="steps",                 # SAME
    save_strategy="steps",                 # SAME
    save_total_limit=3,                    # SAME
    early_stopping_patience=5,             # SAME
    early_stopping_threshold=0.001,        # SAME
    lr_scheduler_type="cosine",            # SAME
    gradient_accumulation_steps=2,         # SAME
    max_grad_norm=1.0,                     # SAME
    fp16=False,                            # SAME
    seed=42                                # SAME seed for reproducibility
)

# Create trainer with baseline strategy
trainer = Trainer(
    model=student,
    train_dataloader=train_loader,
    eval_dataloader=val_loader,
    distillation_strategy=baseline_strategy,
    config=training_config
)

print("\n✓ Trainer initialized!")
print(f"  Total training steps: {len(train_loader) * training_config.num_epochs}")
print(f"  Training type: BASELINE (no distillation)")
print(f"  Loss: CE only (β=0.0)")
print(f"  Output: {training_config.output_dir}")

## 6. Run Baseline Training

In [None]:
# Train Baseline Model
print("\n" + "=" * 70)
print("STARTING BASELINE TRAINING (NO DISTILLATION)")
print("=" * 70)
print(f"\nStudent: {student_config.model_name}")
print(f"Loss: {baseline_distill_config.ce_weight}·CE + {baseline_distill_config.distill_weight}·KL")
print(f"Loss: CE only (standard supervised learning)")
print(f"Teacher guidance: DISABLED")
print("=" * 70)

start_time = time.time()

history = trainer.train()

training_time = time.time() - start_time

print(f"\n✓ Baseline training completed in {training_time/60:.1f} minutes")
print(f"✓ Average time per epoch: {training_time / training_config.num_epochs:.1f}s")

## 7. Analyze Training History

In [None]:
# Display training history
print("=" * 70)
print("BASELINE TRAINING HISTORY")
print("=" * 70)

train_history = history['train_history']
eval_history = history['eval_history']

print(f"\nTrain history ({len(train_history)} epochs):")
for i, metrics in enumerate(train_history):
    print(f"  Epoch {i+1}: loss={metrics['loss']:.4f}")

print(f"\nEval history ({len(eval_history)} evaluations):")
for i, metrics in enumerate(eval_history[:5]):
    print(f"  Eval {i+1}: eval_loss={metrics['eval_loss']:.4f}")
if len(eval_history) > 5:
    print(f"  ... and {len(eval_history) - 5} more")

In [None]:
# Visualize training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Training loss per epoch
epochs = range(1, len(train_history) + 1)
train_losses = [m['loss'] for m in train_history]

axes[0].plot(epochs, train_losses, marker='o', linewidth=2, markersize=8, 
             color='#e74c3c', label='Train Loss (Baseline)')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Baseline Training Loss (β=0.0, No Distillation)')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot 2: Evaluation loss over time
if eval_history:
    eval_steps = range(1, len(eval_history) + 1)
    eval_losses = [m['eval_loss'] for m in eval_history]
    
    axes[1].plot(eval_steps, eval_losses, marker='s', linewidth=2, markersize=8,
                 color='#3498db', label='Eval Loss (Baseline)')
    axes[1].set_xlabel('Evaluation Step')
    axes[1].set_ylabel('Loss')
    axes[1].set_title('Baseline Evaluation Loss')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

plt.suptitle('Baseline Training (No Knowledge Distillation)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# Print loss reduction
if len(train_losses) > 1:
    initial_loss = train_losses[0]
    final_loss = train_losses[-1]
    reduction = (initial_loss - final_loss) / initial_loss * 100
    print(f"\nLoss reduction: {reduction:.2f}%")
    print(f"Initial loss: {initial_loss:.4f}")
    print(f"Final loss: {final_loss:.4f}")

## 8. Test Generation

In [None]:
# Test generation after baseline training
print("=" * 70)
print("TESTING GENERATION (BASELINE MODEL)")
print("=" * 70)

student.model.eval()

# Get a batch from validation set
val_batch = next(iter(val_loader))
val_batch = {k: v.to(device) for k, v in val_batch.items()}

# Generate predictions
with torch.no_grad():
    generated_ids = student.generate(
        input_ids=val_batch['input_ids'][:3],
        attention_mask=val_batch['attention_mask'][:3],
        max_length=64,
        num_beams=4
    )

# Decode
predictions = student.decode_batch(generated_ids)
inputs = student.decode_batch(val_batch['input_ids'][:3])

labels = val_batch['labels'][:3].clone()
labels[labels == -100] = student.tokenizer.pad_token_id
ground_truths = student.decode_batch(labels)

# Display
for i in range(3):
    print(f"\n{'='*70}")
    print(f"SAMPLE {i+1}")
    print(f"{'='*70}")
    print(f"\nInput:\n{inputs[i]}")
    print(f"\nGround Truth:\n{ground_truths[i]}")
    print(f"\nBaseline Prediction:\n{predictions[i]}")

## 9. Save Baseline Model

In [None]:
# Check saved checkpoints
print("=" * 70)
print("SAVED BASELINE CHECKPOINTS")
print("=" * 70)

output_dir = Path(training_config.output_dir)

if output_dir.exists():
    checkpoints = list(output_dir.iterdir())
    print(f"\nFound {len(checkpoints)} items in output directory:")
    for checkpoint in sorted(checkpoints):
        if checkpoint.is_dir():
            size = sum(f.stat().st_size for f in checkpoint.rglob('*') if f.is_file())
            print(f"- {checkpoint.name} ({size / 1e6:.2f} MB)")
        else:
            print(f"- {checkpoint.name}")
else:
    print("\nOutput directory not found!")

In [None]:
# Save training summary for comparison
import json

summary = {
    'model_type': 'baseline',
    'distill_weight': 0.0,
    'ce_weight': 1.0,
    'description': 'No knowledge distillation (standard supervised learning)',
    'training_config': {
        'num_epochs': training_config.num_epochs,
        'learning_rate': training_config.learning_rate,
        'batch_size': 32,
        'gradient_accumulation_steps': training_config.gradient_accumulation_steps,
        'seed': training_config.seed
    },
    'final_train_loss': train_history[-1]['loss'] if train_history else None,
    'final_eval_loss': eval_history[-1]['eval_loss'] if eval_history else None,
    'training_time_minutes': training_time / 60
}

summary_path = output_dir / 'training_summary.json'
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2)

print(f"\n✓ Training summary saved to: {summary_path}")
print(f"\nSummary:")
print(json.dumps(summary, indent=2))

In [None]:
print("=" * 70)
print("BASELINE TRAINING COMPLETE")
print("=" * 70)
print(f"\n✓ Model saved to: {training_config.output_dir}/best_model")
print(f"✓ Training type: Baseline (β=0.0, no distillation)")
print(f"✓ Training time: {training_time/60:.1f} minutes")
print(f"\n→ Now run notebook 07 to compare with distilled model!")