# Model Comparison on New TEP Dataset

This notebook compares all trained models on the newly generated independent TEP dataset.

**Models Compared**:
- Multiclass: XGBoost, LSTM, LSTM-FCN, CNN-Transformer, TransKal
- Binary: LSTM-Autoencoder, Conv-Autoencoder

**Purpose**: Summarize generalization performance across all models.

## Configuration & Imports

In [None]:
import os
import json
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

METRICS_DIR = Path('../outputs/metrics')
FIGURES_DIR = Path('../outputs/figures')

QUICK_MODE = os.environ.get('QUICK_MODE', '').lower() in ('true', '1', 'yes')
FILE_SUFFIX = '_quick' if QUICK_MODE else ''

print('='*60)
print('Model Comparison on New TEP Dataset')
if QUICK_MODE:
    print('QUICK MODE')
print('='*60)

## Load All Evaluation Metrics

In [None]:
print('\nLoading evaluation metrics...')

multiclass_models = ['xgboost', 'lstm', 'lstm_fcn', 'cnn_transformer', 'transkal']
binary_models = ['lstm_autoencoder', 'conv_autoencoder']

multiclass_metrics = {}
binary_metrics = {}

for model in multiclass_models:
    metrics_file = METRICS_DIR / f'{model}_new_eval_metrics{FILE_SUFFIX}.json'
    if metrics_file.exists():
        with open(metrics_file) as f:
            multiclass_metrics[model] = json.load(f)
        print(f'  Loaded {model}')
    else:
        print(f'  Missing {model}')

for model in binary_models:
    metrics_file = METRICS_DIR / f'{model}_new_eval_metrics{FILE_SUFFIX}.json'
    if metrics_file.exists():
        with open(metrics_file) as f:
            binary_metrics[model] = json.load(f)
        print(f'  Loaded {model}')
    else:
        print(f'  Missing {model}')

print(f'\nLoaded: {len(multiclass_metrics)} multiclass, {len(binary_metrics)} binary')

## Multiclass Model Comparison

In [None]:
print('\nMulticlass Model Comparison (New Evaluation Dataset):')
print('='*90)

rows = []
for model, metrics in multiclass_metrics.items():
    rows.append({
        'Model': metrics['model'],
        'Accuracy': metrics['accuracy'],
        'Balanced Acc': metrics['balanced_accuracy'],
        'F1 (weighted)': metrics['f1_weighted'],
        'F1 (macro)': metrics['f1_macro'],
        'Orig Accuracy': metrics['comparison_with_original']['original_accuracy'],
        'Delta': metrics['comparison_with_original']['accuracy_delta']
    })

df_multi = pd.DataFrame(rows)
df_multi = df_multi.sort_values('Accuracy', ascending=False)
print(df_multi.to_string(index=False))

# Save
df_multi.to_csv(METRICS_DIR / f'multiclass_comparison_new_eval{FILE_SUFFIX}.csv', index=False)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart: Accuracy comparison
ax = axes[0]
models = df_multi['Model'].tolist()
orig_acc = df_multi['Orig Accuracy'].tolist()
new_acc = df_multi['Accuracy'].tolist()

x = np.arange(len(models))
width = 0.35

bars1 = ax.bar(x - width/2, orig_acc, width, label='Original Test', color='steelblue', alpha=0.7)
bars2 = ax.bar(x + width/2, new_acc, width, label='New Eval', color='coral', alpha=0.7)

ax.set_xlabel('Model')
ax.set_ylabel('Accuracy')
ax.set_title('Multiclass Accuracy: Original vs New Evaluation')
ax.set_xticks(x)
ax.set_xticklabels(models, rotation=45, ha='right')
ax.legend()
ax.set_ylim(0, 1.05)
ax.grid(axis='y', alpha=0.3)

# Delta plot
ax = axes[1]
deltas = df_multi['Delta'].tolist()
colors = ['green' if d >= 0 else 'red' for d in deltas]
ax.bar(models, deltas, color=colors, edgecolor='black', alpha=0.7)
ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
ax.set_xlabel('Model')
ax.set_ylabel('Accuracy Delta (New - Original)')
ax.set_title('Generalization Gap')
ax.tick_params(axis='x', rotation=45)
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
output_file = FIGURES_DIR / f'multiclass_comparison_new_eval{FILE_SUFFIX}.png'
plt.savefig(output_file, dpi=150, bbox_inches='tight')
plt.show()
print(f'Saved to {output_file}')

## Binary Model Comparison

In [None]:
print('\nBinary Model Comparison (New Evaluation Dataset):')
print('='*80)

if binary_metrics:
    rows = []
    for model, metrics in binary_metrics.items():
        rows.append({
            'Model': metrics['model'],
            'Accuracy': metrics['accuracy'],
            'Precision': metrics['precision'],
            'Recall': metrics['recall'],
            'F1': metrics['f1'],
            'AUC-ROC': metrics['auc_roc'],
            'Orig Accuracy': metrics['comparison_with_original']['original_accuracy'],
            'Delta': metrics['comparison_with_original']['accuracy_delta']
        })

    df_binary = pd.DataFrame(rows)
    df_binary = df_binary.sort_values('F1', ascending=False)
    print(df_binary.to_string(index=False))
    
    df_binary.to_csv(METRICS_DIR / f'binary_comparison_new_eval{FILE_SUFFIX}.csv', index=False)
else:
    print('No binary model metrics available.')

## Summary

In [None]:
print('\n' + '='*60)
print('MODEL COMPARISON SUMMARY - NEW TEP DATASET')
if QUICK_MODE:
    print('(Quick mode)')
print('='*60)

if multiclass_metrics:
    best_multi = max(multiclass_metrics.items(), key=lambda x: x[1]['accuracy'])
    print(f'\nBest Multiclass Model: {best_multi[1]["model"]}')
    print(f'  Accuracy: {best_multi[1]["accuracy"]:.4f}')
    print(f'  F1 (weighted): {best_multi[1]["f1_weighted"]:.4f}')
    print(f'  Generalization gap: {best_multi[1]["comparison_with_original"]["accuracy_delta"]:+.4f}')

if binary_metrics:
    best_binary = max(binary_metrics.items(), key=lambda x: x[1]['f1'])
    print(f'\nBest Binary Model: {best_binary[1]["model"]}')
    print(f'  F1: {best_binary[1]["f1"]:.4f}')
    print(f'  AUC-ROC: {best_binary[1]["auc_roc"]:.4f}')

print('\nOutputs:')
print(f'  - {METRICS_DIR / f"multiclass_comparison_new_eval{FILE_SUFFIX}.csv"}')
if binary_metrics:
    print(f'  - {METRICS_DIR / f"binary_comparison_new_eval{FILE_SUFFIX}.csv"}')
print(f'  - {FIGURES_DIR / f"multiclass_comparison_new_eval{FILE_SUFFIX}.png"}')
print('='*60)