# Linear Probe Results Analysis

This notebook analyzes the linear probing results from SimCLR experiments on CIFAR-10.

**Experiment Setup:**
- 3 Encoders: ResNet18, ViT (4x4 patches), MLP
- 3 Augmentation Modes: all, crop, all-no-crop
- 5 runs per configuration
- Linear probe: 100 epochs, SGD with lr=0.1

In [None]:
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

## 1. Load Results

In [None]:
# Configuration
RESULTS_DIR = './experiment_results'  # Change to 'experiment_results_adjusted' for adjusted InfoNCE
ENCODERS = ['resnet', 'vit', 'mlp']
AUG_MODES = ['all', 'crop', 'all-no-crop']
NUM_RUNS = 5

# Collect results
results = []
warnings = []

for encoder in ENCODERS:
    for aug_mode in AUG_MODES:
        completed_runs = 0
        for run in range(1, NUM_RUNS + 1):
            exp_dir = Path(RESULTS_DIR) / encoder / aug_mode / f'run_{run}'
            summary_file = exp_dir / 'linear_probe_summary.json'
            
            if summary_file.exists():
                with open(summary_file, 'r') as f:
                    summary = json.load(f)
                
                results.append({
                    'encoder': encoder,
                    'aug_mode': aug_mode,
                    'run': run,
                    'best_test_acc': summary.get('best_test_acc', None),
                    'final_test_acc': summary.get('final_test_acc', None),
                    'final_train_acc': summary.get('final_train_acc', None),
                    'total_epochs': summary.get('total_epochs', None)
                })
                completed_runs += 1
        
        # Track warnings for incomplete configurations
        if completed_runs == 0:
            warnings.append(f"⚠️ {encoder}/{aug_mode}: NO completed runs (skipping)")
        elif completed_runs < NUM_RUNS:
            warnings.append(f"⚠️ {encoder}/{aug_mode}: Only {completed_runs}/{NUM_RUNS} runs completed")

# Create DataFrame
df = pd.DataFrame(results)

print(f"Loaded {len(df)} experiment results")
print(f"\nResults directory: {RESULTS_DIR}")
print()

# Show warnings
if warnings:
    print("=" * 60)
    print("WARNINGS - Incomplete Data:")
    print("=" * 60)
    for w in warnings:
        print(w)
    print("=" * 60)
else:
    print("✓ All configurations have complete data (5/5 runs each)")

In [None]:
# Display raw data
if len(df) > 0:
    display(df.head(20))
else:
    print("No data to display!")

## 2. Summary Statistics

In [None]:
if len(df) > 0:
    # Compute summary statistics
    summary_stats = df.groupby(['encoder', 'aug_mode']).agg({
        'best_test_acc': ['mean', 'std', 'count'],
        'final_test_acc': ['mean', 'std']
    }).round(2)
    
    summary_stats.columns = ['_'.join(col).strip() for col in summary_stats.columns.values]
    summary_stats = summary_stats.rename(columns={
        'best_test_acc_mean': 'Best Acc (Mean)',
        'best_test_acc_std': 'Best Acc (Std)',
        'best_test_acc_count': 'N Runs',
        'final_test_acc_mean': 'Final Acc (Mean)',
        'final_test_acc_std': 'Final Acc (Std)'
    })
    
    print("Summary Statistics (Best Test Accuracy):")
    display(summary_stats)
else:
    print("No data available for summary statistics.")

## 3. Heatmap: Encoder vs Augmentation Mode

In [None]:
if len(df) > 0:
    # Create pivot table for heatmap
    heatmap_data = df.groupby(['encoder', 'aug_mode'])['best_test_acc'].mean().unstack()
    
    # Reorder columns and rows for better visualization
    col_order = [m for m in AUG_MODES if m in heatmap_data.columns]
    row_order = [e for e in ENCODERS if e in heatmap_data.index]
    heatmap_data = heatmap_data.reindex(index=row_order, columns=col_order)
    
    # Count runs for annotation
    run_counts = df.groupby(['encoder', 'aug_mode'])['run'].count().unstack()
    run_counts = run_counts.reindex(index=row_order, columns=col_order)
    
    # Create figure
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Create heatmap
    sns.heatmap(
        heatmap_data,
        annot=True,
        fmt='.1f',
        cmap='RdYlGn',
        center=heatmap_data.mean().mean(),
        vmin=heatmap_data.min().min() - 5,
        vmax=heatmap_data.max().max() + 5,
        ax=ax,
        cbar_kws={'label': 'Best Test Accuracy (%)'}
    )
    
    # Add run count annotations
    for i, encoder in enumerate(row_order):
        for j, aug_mode in enumerate(col_order):
            count = run_counts.loc[encoder, aug_mode] if pd.notna(run_counts.loc[encoder, aug_mode]) else 0
            if count < NUM_RUNS:
                ax.text(j + 0.5, i + 0.85, f'n={int(count)}', 
                       ha='center', va='center', fontsize=8, color='gray')
    
    ax.set_xlabel('Augmentation Mode', fontsize=12)
    ax.set_ylabel('Encoder', fontsize=12)
    ax.set_title('Linear Probe Accuracy: Encoder vs Augmentation\n(Mean Best Test Accuracy %)', fontsize=14)
    
    # Improve labels
    ax.set_xticklabels(ax.get_xticklabels(), rotation=0)
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
    
    plt.tight_layout()
    plt.savefig('linear_probe_heatmap.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("\nHeatmap saved to: linear_probe_heatmap.png")
else:
    print("No data available for heatmap.")

## 4. Bar Plot with Error Bars

In [None]:
if len(df) > 0:
    fig, ax = plt.subplots(figsize=(12, 6))
    
    # Prepare data
    plot_data = df.groupby(['encoder', 'aug_mode']).agg({
        'best_test_acc': ['mean', 'std']
    }).reset_index()
    plot_data.columns = ['encoder', 'aug_mode', 'mean', 'std']
    plot_data['std'] = plot_data['std'].fillna(0)
    
    # Create grouped bar plot
    x = np.arange(len(AUG_MODES))
    width = 0.25
    
    colors = {'resnet': '#2ecc71', 'vit': '#3498db', 'mlp': '#e74c3c'}
    
    for i, encoder in enumerate(ENCODERS):
        encoder_data = plot_data[plot_data['encoder'] == encoder]
        if len(encoder_data) > 0:
            # Match order of aug_modes
            means = []
            stds = []
            for aug in AUG_MODES:
                row = encoder_data[encoder_data['aug_mode'] == aug]
                if len(row) > 0:
                    means.append(row['mean'].values[0])
                    stds.append(row['std'].values[0])
                else:
                    means.append(0)
                    stds.append(0)
            
            ax.bar(x + i * width, means, width, 
                   label=encoder.upper(), 
                   yerr=stds,
                   capsize=3,
                   color=colors.get(encoder, 'gray'))
    
    ax.set_xlabel('Augmentation Mode', fontsize=12)
    ax.set_ylabel('Best Test Accuracy (%)', fontsize=12)
    ax.set_title('Linear Probe Accuracy by Encoder and Augmentation Mode', fontsize=14)
    ax.set_xticks(x + width)
    ax.set_xticklabels(AUG_MODES)
    ax.legend(title='Encoder')
    ax.set_ylim(0, 100)
    
    # Add grid
    ax.yaxis.grid(True, linestyle='--', alpha=0.7)
    ax.set_axisbelow(True)
    
    plt.tight_layout()
    plt.savefig('linear_probe_barplot.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("\nBar plot saved to: linear_probe_barplot.png")
else:
    print("No data available for bar plot.")

## 5. Per-Run Results (Box Plot)

In [None]:
if len(df) > 0:
    fig, axes = plt.subplots(1, 3, figsize=(15, 5), sharey=True)
    
    for idx, encoder in enumerate(ENCODERS):
        encoder_data = df[df['encoder'] == encoder]
        
        if len(encoder_data) > 0:
            # Order augmentation modes
            order = [m for m in AUG_MODES if m in encoder_data['aug_mode'].values]
            
            sns.boxplot(
                data=encoder_data,
                x='aug_mode',
                y='best_test_acc',
                order=order,
                ax=axes[idx],
                palette='Set2'
            )
            
            # Add individual points
            sns.stripplot(
                data=encoder_data,
                x='aug_mode',
                y='best_test_acc',
                order=order,
                ax=axes[idx],
                color='black',
                alpha=0.5,
                size=6
            )
        
        axes[idx].set_title(f'{encoder.upper()}', fontsize=14)
        axes[idx].set_xlabel('Augmentation Mode', fontsize=11)
        if idx == 0:
            axes[idx].set_ylabel('Best Test Accuracy (%)', fontsize=11)
        else:
            axes[idx].set_ylabel('')
        axes[idx].tick_params(axis='x', rotation=15)
    
    plt.suptitle('Distribution of Linear Probe Accuracy Across Runs', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.savefig('linear_probe_boxplot.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("\nBox plot saved to: linear_probe_boxplot.png")
else:
    print("No data available for box plot.")

## 6. Best Configuration

In [None]:
if len(df) > 0:
    # Find best configuration
    mean_acc = df.groupby(['encoder', 'aug_mode'])['best_test_acc'].mean()
    best_config = mean_acc.idxmax()
    best_acc = mean_acc.max()
    
    print("=" * 60)
    print("BEST CONFIGURATION")
    print("=" * 60)
    print(f"Encoder: {best_config[0]}")
    print(f"Augmentation: {best_config[1]}")
    print(f"Mean Best Test Accuracy: {best_acc:.2f}%")
    print("=" * 60)
    
    # Show ranking
    print("\nFull Ranking (by mean best test accuracy):")
    ranking = mean_acc.sort_values(ascending=False)
    for i, ((encoder, aug), acc) in enumerate(ranking.items(), 1):
        print(f"{i}. {encoder} + {aug}: {acc:.2f}%")
else:
    print("No data available.")

## 7. Export Results

In [None]:
if len(df) > 0:
    # Export full results
    df.to_csv('linear_probe_all_results.csv', index=False)
    print("Full results exported to: linear_probe_all_results.csv")
    
    # Export summary
    summary_export = df.groupby(['encoder', 'aug_mode']).agg({
        'best_test_acc': ['mean', 'std', 'min', 'max', 'count']
    }).round(2)
    summary_export.columns = ['mean', 'std', 'min', 'max', 'n_runs']
    summary_export.to_csv('linear_probe_summary.csv')
    print("Summary exported to: linear_probe_summary.csv")
else:
    print("No data to export.")