# Ablation Study: Model Comparison

Compare all 4 ablation variants side-by-side:
- **V0**: CTC-only baseline (no LoRA)
- **V1**: LoRA base (r=8, alpha=16)
- **V2**: LoRA higher capacity (r=16, alpha=32)
- **V3**: LoRA lower capacity (r=4, alpha=8)


## Installation

In [None]:
%pip install -q matplotlib numpy pandas seaborn


## Imports

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from pathlib import Path
import json

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print("Libraries imported successfully")


## Load Training Results

Load training histories from all 4 notebooks. Make sure you've run all training notebooks first.


In [None]:
# Define model variants
variants = {
    'V0_CTC_Only': {
        'output_dir': 'wav2vec2_ctc_only_baseline',
        'label': 'V0: CTC-Only',
        'color': '#1f77b4'
    },
    'V1_Base': {
        'output_dir': 'wav2vec2_lora_v1',
        'label': 'V1: LoRA (r=8)',
        'color': '#ff7f0e'
    },
    'V2_Higher': {
        'output_dir': 'wav2vec2_lora_v2',
        'label': 'V2: LoRA (r=16)',
        'color': '#2ca02c'
    },
    'V3_Lower': {
        'output_dir': 'wav2vec2_lora_v3',
        'label': 'V3: LoRA (r=4)',
        'color': '#d62728'
    }
}

# Load training histories from trainer state files
# Note: You need to save trainer.state.log_history to JSON files after training
# Or load directly from trainer objects if available

def load_training_history(output_dir):
    """Load training history from saved state"""
    state_file = Path(f"{output_dir}/trainer_state.json")
    if state_file.exists():
        with open(state_file, 'r') as f:
            state = json.load(f)
            return state.get('log_history', [])
    return None

# Alternative: Load from trainer objects if they're still in memory
# This assumes you've run all training notebooks and have trainer objects
print("\nTo use this notebook:")
print("1. Run all 4 training notebooks first")
print("2. Save trainer.state.log_history to JSON files, OR")
print("3. Load trainer objects directly from the training notebooks")
print("\nFor now, this is a template. Update the data loading section based on your setup.")


## Comparison Visualizations

Create side-by-side comparison plots for all variants.


In [None]:
# Example: Create comparison plots
# This is a template - update with actual data

def create_comparison_plots(all_histories):
    """Create comprehensive comparison plots"""
    
    fig = plt.figure(figsize=(18, 12))
    gs = fig.add_gridspec(3, 2, hspace=0.35, wspace=0.3)
    
    # Plot 1: Training Loss Comparison
    ax1 = fig.add_subplot(gs[0, :])
    for variant_name, variant_info in variants.items():
        if variant_name in all_histories:
            history = all_histories[variant_name]
            train_logs = [log for log in history if 'loss' in log and 'eval_loss' not in log]
            if train_logs:
                steps = [log['step'] for log in train_logs]
                losses = [log['loss'] for log in train_logs]
                ax1.plot(steps, losses, label=variant_info['label'], 
                        linewidth=2.5, color=variant_info['color'], alpha=0.8)
    ax1.set_xlabel('Training Steps', fontsize=14, fontweight='bold')
    ax1.set_ylabel('Training Loss', fontsize=14, fontweight='bold')
    ax1.set_title('Training Loss Comparison', fontsize=16, fontweight='bold', pad=15)
    ax1.legend(fontsize=12, loc='best')
    ax1.grid(True, alpha=0.3, linestyle='--')
    ax1.set_facecolor('#f8f9fa')
    
    # Plot 2: Validation Loss Comparison
    ax2 = fig.add_subplot(gs[1, 0])
    for variant_name, variant_info in variants.items():
        if variant_name in all_histories:
            history = all_histories[variant_name]
            eval_logs = [log for log in history if 'eval_loss' in log]
            if eval_logs:
                steps = [log['step'] for log in eval_logs]
                losses = [log['eval_loss'] for log in eval_logs]
                ax2.plot(steps, losses, label=variant_info['label'], 
                        linewidth=2.5, marker='o', markersize=4, 
                        color=variant_info['color'], alpha=0.8)
    ax2.set_xlabel('Training Steps', fontsize=13, fontweight='bold')
    ax2.set_ylabel('Validation Loss', fontsize=13, fontweight='bold')
    ax2.set_title('Validation Loss Comparison', fontsize=15, fontweight='bold')
    ax2.legend(fontsize=11)
    ax2.grid(True, alpha=0.3, linestyle='--')
    ax2.set_facecolor('#f8f9fa')
    
    # Plot 3: WER Comparison
    ax3 = fig.add_subplot(gs[1, 1])
    for variant_name, variant_info in variants.items():
        if variant_name in all_histories:
            history = all_histories[variant_name]
            eval_logs = [log for log in history if 'eval_wer' in log]
            if eval_logs:
                steps = [log['step'] for log in eval_logs]
                wers = [log['eval_wer'] for log in eval_logs]
                ax3.plot(steps, wers, label=variant_info['label'], 
                        linewidth=2.5, marker='s', markersize=4, 
                        color=variant_info['color'], alpha=0.8)
    ax3.set_xlabel('Training Steps', fontsize=13, fontweight='bold')
    ax3.set_ylabel('Word Error Rate', fontsize=13, fontweight='bold')
    ax3.set_title('WER Comparison', fontsize=15, fontweight='bold')
    ax3.legend(fontsize=11)
    ax3.grid(True, alpha=0.3, linestyle='--')
    ax3.set_ylim(bottom=0)
    ax3.set_facecolor('#f8f9fa')
    
    # Plot 4: CER Comparison
    ax4 = fig.add_subplot(gs[2, 0])
    for variant_name, variant_info in variants.items():
        if variant_name in all_histories:
            history = all_histories[variant_name]
            eval_logs = [log for log in history if 'eval_cer' in log]
            if eval_logs:
                steps = [log['step'] for log in eval_logs]
                cers = [log['eval_cer'] for log in eval_logs]
                ax4.plot(steps, cers, label=variant_info['label'], 
                        linewidth=2.5, marker='^', markersize=4, 
                        color=variant_info['color'], alpha=0.8)
    ax4.set_xlabel('Training Steps', fontsize=13, fontweight='bold')
    ax4.set_ylabel('Character Error Rate', fontsize=13, fontweight='bold')
    ax4.set_title('CER Comparison', fontsize=15, fontweight='bold')
    ax4.legend(fontsize=11)
    ax4.grid(True, alpha=0.3, linestyle='--')
    ax4.set_ylim(bottom=0)
    ax4.set_facecolor('#f8f9fa')
    
    # Plot 5: Final Metrics Bar Chart
    ax5 = fig.add_subplot(gs[2, 1])
    final_metrics = {}
    for variant_name, variant_info in variants.items():
        if variant_name in all_histories:
            history = all_histories[variant_name]
            eval_logs = [log for log in history if 'eval_wer' in log]
            if eval_logs:
                final_metrics[variant_info['label']] = eval_logs[-1]['eval_wer']
    
    if final_metrics:
        labels = list(final_metrics.keys())
        values = list(final_metrics.values())
        colors = [variants[k]['color'] for k in variants.keys() if variants[k]['label'] in labels]
        bars = ax5.bar(labels, values, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)
        ax5.set_ylabel('Final WER', fontsize=13, fontweight='bold')
        ax5.set_title('Final WER Comparison', fontsize=15, fontweight='bold')
        ax5.grid(True, alpha=0.3, axis='y', linestyle='--')
        ax5.set_facecolor('#f8f9fa')
        
        # Add value labels on bars
        for bar in bars:
            height = bar.get_height()
            ax5.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.4f}', ha='center', va='bottom', fontsize=11, fontweight='bold')
    
    fig.suptitle('Ablation Study: Model Comparison', fontsize=20, fontweight='bold', y=0.995)
    plt.tight_layout(rect=[0, 0, 1, 0.99])
    plt.show()
    
    # Save plot
    fig.savefig('ablation_study_comparison.png', dpi=300, bbox_inches='tight', facecolor='white')
    print("\nComparison plot saved to: ablation_study_comparison.png")

# Example usage (update with actual data):
# all_histories = {
#     'V0_CTC_Only': load_training_history('wav2vec2_ctc_only_baseline'),
#     'V1_Base': load_training_history('wav2vec2_lora_v1'),
#     'V2_Higher': load_training_history('wav2vec2_lora_v2'),
#     'V3_Lower': load_training_history('wav2vec2_lora_v3')
# }
# create_comparison_plots(all_histories)

print("Comparison plotting function ready. Load your training histories and call create_comparison_plots()")


## Summary Table

Create a summary table comparing final metrics across all variants.


In [None]:
# Create summary table
def create_summary_table(all_histories):
    """Create a summary table of final metrics"""
    
    summary_data = []
    
    for variant_name, variant_info in variants.items():
        if variant_name in all_histories:
            history = all_histories[variant_name]
            
            # Get final metrics
            train_logs = [log for log in history if 'loss' in log and 'eval_loss' not in log]
            eval_logs = [log for log in history if 'eval_loss' in log]
            
            row = {
                'Variant': variant_info['label'],
                'Final Train Loss': train_logs[-1]['loss'] if train_logs else None,
                'Final Val Loss': eval_logs[-1]['eval_loss'] if eval_logs else None,
                'Final WER': eval_logs[-1].get('eval_wer', None) if eval_logs else None,
                'Final CER': eval_logs[-1].get('eval_cer', None) if eval_logs else None,
            }
            
            # Get best metrics
            if eval_logs:
                val_losses = [log['eval_loss'] for log in eval_logs]
                wers = [log.get('eval_wer', float('inf')) for log in eval_logs if log.get('eval_wer', 0) > 0]
                cers = [log.get('eval_cer', float('inf')) for log in eval_logs if log.get('eval_cer', 0) > 0]
                
                row['Best Val Loss'] = min(val_losses) if val_losses else None
                row['Best WER'] = min(wers) if wers else None
                row['Best CER'] = min(cers) if cers else None
            
            summary_data.append(row)
    
    df = pd.DataFrame(summary_data)
    
    # Display formatted table
    print("\n" + "="*100)
    print("ABLATION STUDY SUMMARY")
    print("="*100)
    print(df.to_string(index=False))
    print("="*100)
    
    # Save to CSV
    df.to_csv('ablation_study_summary.csv', index=False)
    print("\nSummary saved to: ablation_study_summary.csv")
    
    return df

# Example usage:
# summary_df = create_summary_table(all_histories)

print("Summary table function ready. Load your training histories and call create_summary_table()")
