# Model Comparison Analysis

This notebook aggregates results from `./runs/<variant>/` and produces comprehensive visualizations:
- Learning curves for Top-1 accuracy and validation loss
- Best Top-1 and Top-5 accuracy bar charts
- Accuracy vs training time scatter plot
- Family averages bar chart
- Per-class comparison bars for top-performing models

All outputs are saved to `./comparison_outputs/` directory.

## 1. Import Required Libraries and Setup

In [None]:
import json
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt

# Enable inline plotting for Jupyter
%matplotlib inline

# Set default figure size and style
plt.style.use('default')
plt.rcParams['figure.figsize'] = (10, 6)

print("✓ Libraries imported successfully!")

## 2. Define Configuration and Helper Functions

In [None]:
VARIANTS = [
    "r18_base","r18_plus","r34_base","r34_plus",
    "efficientnet_b0","efficientnet_b1","efficientnet_b2","densenet121"
]

RUNS = Path("./runs")
OUT = Path("./comparison_outputs")
OUT.mkdir(parents=True, exist_ok=True)
RESULTS_JSON = Path("experiment_results.json")

def find_metrics_path(variant):
    """Find the metrics log file for a given variant"""
    for p in [RUNS/variant/"metrics_log.tsv", RUNS/variant/variant/"metrics_log.tsv"]:
        if p.exists(): return p
    return None

def find_summary_path(variant):
    """Find the model summary file for a given variant"""
    for p in [RUNS/variant/"model_summary.json", RUNS/variant/variant/"model_summary.json"]:
        if p.exists(): return p
    return None

def find_per_class_path(variant):
    """Find the per-class metrics file for a given variant"""
    for p in [RUNS/variant/"per_class_metrics.csv", RUNS/variant/variant/"per_class_metrics.csv"]:
        if p.exists(): return p
    return None

def bucket(v):
    """Categorize variants into model families"""
    if v in ["r18_base","r34_base"]: return "ResNet (base)"
    if v in ["r18_plus","r34_plus"]: return "ResNet (plus)"
    if v.startswith("efficientnet"): return "EfficientNet"
    if v.startswith("densenet"): return "DenseNet"
    return "Other"

print(f"✓ Configuration loaded - tracking {len(VARIANTS)} variants")
print(f"✓ Output directory: {OUT.resolve()}")

## 3. Load Training Time Data

In [None]:
# Load times if present
times_by_variant = {}
if RESULTS_JSON.exists():
    with open(RESULTS_JSON) as f:
        data = json.load(f)
        if isinstance(data, list):
            for r in data:
                if isinstance(r, dict) and r.get("variant"):
                    times_by_variant[r["variant"]] = r.get("training_time", None)

print(f"✓ Training times loaded for {len(times_by_variant)} variants")
if times_by_variant:
    print("Available timing data:", list(times_by_variant.keys()))

## 4. Collect Metrics from All Model Variants

In [None]:
# Collect metrics
per_epoch = {}
for v in VARIANTS:
    mp = find_metrics_path(v)
    if mp is not None:
        try:
            df = pd.read_csv(mp, sep="\t")
            per_epoch[v] = df
            print(f"✓ Loaded metrics for {v}: {len(df)} epochs")
        except Exception as e:
            print(f"[WARN] Could not read metrics for {v}: {e}")
    else:
        print(f"[INFO] No metrics file found for {v}")

print(f"\n✓ Successfully loaded metrics for {len(per_epoch)} variants")
print("Available variants:", list(per_epoch.keys()))

## 5. Generate Top-1 Accuracy Learning Curves

In [None]:
plt.figure(figsize=(12,8))
any_plot = False
colors = plt.cm.tab10(range(len(per_epoch)))

for i, (v, df) in enumerate(per_epoch.items()):
    if "epoch" in df.columns and "top1" in df.columns:
        plt.plot(df["epoch"], df["top1"], label=v, linewidth=2, 
                marker='o', markersize=4, color=colors[i])
        any_plot = True

plt.title("Top-1 Accuracy vs Epoch", fontsize=16, fontweight='bold')
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Top-1 Accuracy (%)", fontsize=12)
plt.grid(True, alpha=0.3)

if any_plot:
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig(OUT/"learning_curves_top1.png", dpi=150, bbox_inches='tight')
    plt.show()
else:
    print("No Top-1 accuracy data available for plotting")
    plt.close()

## 6. Generate Validation Loss Learning Curves

In [None]:
plt.figure(figsize=(12,8))
any_plot = False

for i, (v, df) in enumerate(per_epoch.items()):
    if "epoch" in df.columns and "val_loss" in df.columns:
        plt.plot(df["epoch"], df["val_loss"], label=v, linewidth=2, 
                marker='o', markersize=4, color=colors[i])
        any_plot = True

plt.title("Validation Loss vs Epoch", fontsize=16, fontweight='bold')
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Validation Loss", fontsize=12)
plt.grid(True, alpha=0.3)

if any_plot:
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig(OUT/"learning_curves_val_loss.png", dpi=150, bbox_inches='tight')
    plt.show()
else:
    print("No validation loss data available for plotting")
    plt.close()

## 7. Create Performance Summary and Best Accuracy Bar Charts

In [None]:
# Create performance summary
rows = []
for v, df in per_epoch.items():
    row = {"variant": v}
    row["best_top1"] = float(df["top1"].max()) if "top1" in df.columns else float("nan")
    row["best_top5"] = float(df["top5"].max()) if "top5" in df.columns else float("nan")
    rows.append(row)

perf = pd.DataFrame(rows).sort_values("best_top1", ascending=False)
perf.to_csv(OUT/"summary_performance.csv", index=False)

print("Performance Summary:")
display(perf.round(2))

In [None]:
# Best Top-1 Accuracy Bar Chart
if not perf.empty and perf["best_top1"].notna().any():
    plt.figure(figsize=(12,8))
    bars = plt.bar(perf["variant"], perf["best_top1"], 
                   color=plt.cm.viridis(range(len(perf))), alpha=0.8)
    
    # Add value labels on bars
    for bar, value in zip(bars, perf["best_top1"]):
        if not pd.isna(value):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                    f'{value:.1f}%', ha='center', va='bottom', fontweight='bold')
    
    plt.title("Best Top-1 Accuracy by Model", fontsize=16, fontweight='bold')
    plt.xlabel("Model Variant", fontsize=12)
    plt.ylabel("Best Top-1 Accuracy (%)", fontsize=12)
    plt.xticks(rotation=45, ha="right")
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig(OUT/"best_top1_bar.png", dpi=150, bbox_inches='tight')
    plt.show()
else:
    print("No Top-1 accuracy data available for bar chart")

In [None]:
# Best Top-5 Accuracy Bar Chart  
if not perf.empty and perf["best_top5"].notna().any():
    plt.figure(figsize=(12,8))
    valid_top5 = perf.dropna(subset=['best_top5'])
    bars = plt.bar(valid_top5["variant"], valid_top5["best_top5"], 
                   color=plt.cm.plasma(range(len(valid_top5))), alpha=0.8)
    
    # Add value labels on bars
    for bar, value in zip(bars, valid_top5["best_top5"]):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                f'{value:.1f}%', ha='center', va='bottom', fontweight='bold')
    
    plt.title("Best Top-5 Accuracy by Model", fontsize=16, fontweight='bold')
    plt.xlabel("Model Variant", fontsize=12)
    plt.ylabel("Best Top-5 Accuracy (%)", fontsize=12)
    plt.xticks(rotation=45, ha="right")
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig(OUT/"best_top5_bar.png", dpi=150, bbox_inches='tight')
    plt.show()
else:
    print("No Top-5 accuracy data available for bar chart")

## 8. Generate Accuracy vs Training Time Scatter Plot

In [None]:
# Create accuracy vs time data
rows = []
for v, df in per_epoch.items():
    if "top1" in df.columns:
        best_top1 = float(df["top1"].max())
        t = times_by_variant.get(v, None)
        if t is not None:
            rows.append({"variant": v, "best_top1": best_top1, "minutes": t/60.0})

trade = pd.DataFrame(rows).sort_values("best_top1", ascending=False)
trade.to_csv(OUT/"accuracy_vs_time.csv", index=False)

if not trade.empty:
    plt.figure(figsize=(10,8))
    scatter = plt.scatter(trade["minutes"], trade["best_top1"], 
                         s=120, c=range(len(trade)), cmap='viridis', alpha=0.7)
    
    # Add annotations
    for _, r in trade.iterrows():
        plt.annotate(r["variant"], (r["minutes"], r["best_top1"]), 
                    xytext=(5,5), textcoords="offset points", 
                    fontsize=10, fontweight='bold')
    
    plt.title("Best Top-1 vs Training Time", fontsize=16, fontweight='bold')
    plt.xlabel("Training Time (minutes)", fontsize=12)
    plt.ylabel("Best Top-1 Accuracy (%)", fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.colorbar(scatter, label='Performance Rank')
    plt.tight_layout()
    plt.savefig(OUT/"accuracy_vs_time_scatter.png", dpi=150, bbox_inches='tight')
    plt.show()
    
    print("Accuracy vs Time Summary:")
    display(trade.round(2))
else:
    print("No timing data available for scatter plot")

## 9. Create Family Average Performance Analysis

In [None]:
# Create family analysis
fam_rows = []
for v, df in per_epoch.items():
    if "top1" in df.columns:
        fam_rows.append({"family": bucket(v), "variant": v, "best_top1": float(df["top1"].max())})

fam = pd.DataFrame(fam_rows)
fam.to_csv(OUT/"family_raw.csv", index=False)

if not fam.empty:
    avg = fam.groupby("family", as_index=False)["best_top1"].mean().sort_values("best_top1", ascending=False)
    avg.to_csv(OUT/"family_avg.csv", index=False)
    
    plt.figure(figsize=(10,6))
    bars = plt.bar(avg["family"], avg["best_top1"], 
                   color=plt.cm.Set3(range(len(avg))), alpha=0.8)
    
    # Add value labels on bars
    for bar, value in zip(bars, avg["best_top1"]):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                f'{value:.1f}%', ha='center', va='bottom', fontweight='bold')
    
    plt.title("Average Best Top-1 by Model Family", fontsize=16, fontweight='bold')
    plt.xlabel("Family", fontsize=12)
    plt.ylabel("Average Best Top-1 (%)", fontsize=12)
    plt.xticks(rotation=15)
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig(OUT/"family_averages_bar.png", dpi=150, bbox_inches='tight')
    plt.show()
    
    print("Family Performance Summary:")
    display(avg.round(2))
else:
    print("No family data available")

## 10. Generate Per-Class Accuracy Comparison for Top Models

In [None]:
# Find top-2 variants by best_top1
top2 = list(perf["variant"].head(2).values) if not perf.empty else []
pcs_loaded = 0

print(f"Analyzing per-class metrics for top 2 models: {top2}")

for v in top2:
    p = find_per_class_path(v)
    if p is None: 
        print(f"No per-class metrics file found for {v}")
        continue
        
    try:
        df = pd.read_csv(p)
        # Use recall as a proxy if accuracy not present
        if "accuracy" in df.columns:
            series = df.set_index(df.columns[0])["accuracy"]
            metric_name = "Accuracy"
        elif "recall" in df.columns:
            series = df.set_index(df.columns[0])["recall"]
            metric_name = "Recall"
        else:
            series = None
            
        if series is None or series.empty:
            print(f"No usable metrics found in per-class file for {v}")
            continue
            
        plt.figure(figsize=(12,6))
        bars = plt.bar(range(len(series.values)), series.values, 
                      color=plt.cm.viridis(range(len(series.values))), alpha=0.7)
        
        plt.title(f"Per-Class {metric_name} — {v}", fontsize=14, fontweight='bold')
        plt.xlabel("Class (index order)", fontsize=12)
        plt.ylabel(f"{metric_name} Score", fontsize=12)
        plt.grid(axis='y', alpha=0.3)
        
        # Add some statistics as text
        mean_score = series.mean()
        std_score = series.std()
        plt.text(0.02, 0.98, f'Mean: {mean_score:.3f}\\nStd: {std_score:.3f}', 
                transform=plt.gca().transAxes, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
        
        plt.tight_layout()
        plt.savefig(OUT/f"per_class_{v}.png", dpi=150, bbox_inches='tight')
        plt.show()
        pcs_loaded += 1
        
        print(f"✓ Loaded per-class metrics for {v}: {len(series)} classes, mean {metric_name.lower()}: {mean_score:.3f}")
        
    except Exception as e:
        print(f"[WARN] Per-class metrics for {v}: {e}")

print(f"Successfully loaded per-class data for {pcs_loaded} models")

## Summary

In [None]:
print("="*60)
print("MODEL COMPARISON ANALYSIS COMPLETE")
print("="*60)
print(f"📊 Analyzed {len(per_epoch)} model variants")
print(f"📁 Outputs saved to: {OUT.resolve()}")
print("\\nGenerated visualizations:")
print("  ✓ Learning curves (Top-1 accuracy)")
print("  ✓ Learning curves (Validation loss)")
print("  ✓ Best Top-1 accuracy comparison")
if not perf.empty and perf["best_top5"].notna().any():
    print("  ✓ Best Top-5 accuracy comparison")
if not trade.empty:
    print("  ✓ Accuracy vs training time scatter")
print("  ✓ Model family averages")
if pcs_loaded > 0:
    print(f"  ✓ Per-class analysis for top {pcs_loaded} models")
print("\\nData files saved:")
print("  • summary_performance.csv")
print("  • accuracy_vs_time.csv")  
print("  • family_raw.csv")
print("  • family_avg.csv")
print("="*60)