# HMM Filter Comparison Summary

Compare HMM filter effectiveness across all models.

In [None]:
import os
import json
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import warnings
warnings.filterwarnings('ignore')

METRICS_DIR = Path('../outputs/metrics')
FIGURES_DIR = Path('../outputs/figures')

QUICK_MODE = os.environ.get('QUICK_MODE', '').lower() in ('true', '1', 'yes')
FILE_SUFFIX = '_quick' if QUICK_MODE else ''

print('='*60)
print('HMM Filter Comparison Summary')
if QUICK_MODE:
    print('QUICK MODE')
print('='*60)

In [None]:
print('\nLoading metrics...')

models = ['xgboost', 'lstm', 'lstm_fcn', 'cnn_transformer', 'transkal', 'lstm_autoencoder', 'conv_autoencoder']
metrics = {}

for model in models:
    file = METRICS_DIR / f'{model}_hmm_filter_results{FILE_SUFFIX}.json'
    if file.exists():
        with open(file) as f:
            metrics[model] = json.load(f)
        print(f'  Loaded {model}')
    else:
        print(f'  Missing {model}')

print(f'\nLoaded {len(metrics)} models')

In [None]:
print('\nModel Comparison:')
print('='*80)

rows = []
for model, m in metrics.items():
    if 'overall_metrics' in m:  # Detector format
        rows.append({
            'Model': m['model'],
            'Accuracy': m['overall_metrics']['accuracy'],
            'F1 (weighted)': m['overall_metrics'].get('f1_weighted', 0),
            'Recall': m['overall_metrics'].get('fault_detection_recall', m.get('metrics', {}).get('recall', 0))
        })
    elif 'raw_metrics' in m:  # HMM filter format
        best = m['filtered_metrics'][str(m['best_stickiness'])]
        rows.append({
            'Model': m['model'],
            'Raw Accuracy': m['raw_metrics']['accuracy'],
            'HMM Accuracy': best['accuracy'],
            'Improvement': m['best_improvement']['accuracy_delta'],
            'Best Stickiness': m['best_stickiness']
        })

df = pd.DataFrame(rows)
print(df.to_string(index=False))

df.to_csv(METRICS_DIR / f'hmm_filter_comparison{FILE_SUFFIX}.csv', index=False)

In [None]:
print('\n' + '='*60)
print('Comparison Complete!')
print('='*60)