# NeuroHand: Training Results Analysis

Visualization and analysis of the trained EEGNet model.

**Model Performance:**
- Accuracy: 62.97%
- Training Time: ~20 minutes
- Best Epoch: 174

**Sections:**
1. Training History (Loss & Accuracy Curves)
2. Confusion Matrix Analysis
3. Per-Class Performance
4. Model Architecture Summary

In [None]:
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import torch

from src.visualization.plot_results import (
    plot_training_history,
    plot_confusion_matrix,
    plot_class_performance
)
from src.models.eegnet import EEGNet

%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print("‚úÖ Imports successful!")

## 1. Load Results

In [None]:
# Load training history
history_path = Path('../models/checkpoints/training_history.npy')
history = np.load(history_path, allow_pickle=True).item()

print("üìä Training History Loaded:")
print(f"   Epochs trained: {len(history['train_loss'])}")
print(f"   Best test accuracy: {max(history['test_acc']):.2f}%")
print(f"   Final train loss: {history['train_loss'][-1]:.4f}")
print(f"   Final test loss: {history['test_loss'][-1]:.4f}")

In [None]:
# Load evaluation results
eval_path = Path('../models/checkpoints/evaluation_results.npy')
eval_results = np.load(eval_path, allow_pickle=True).item()

print("üìä Evaluation Results Loaded:")
print(f"   Overall Accuracy: {eval_results['accuracy']:.2f}%")
print(f"   Precision (avg): {eval_results['precision'].mean():.2f}")
print(f"   Recall (avg): {eval_results['recall'].mean():.2f}")
print(f"   F1-Score (avg): {eval_results['f1'].mean():.2f}")

## 2. Training History Visualization

Loss and accuracy curves over training epochs.

In [None]:
fig = plot_training_history(history, figsize=(16, 6))
plt.show()

# Find best epoch
best_epoch = np.argmax(history['test_acc'])
best_acc = history['test_acc'][best_epoch]

print(f"\nüèÜ Best Performance:")
print(f"   Epoch: {best_epoch + 1}")
print(f"   Test Accuracy: {best_acc:.2f}%")
print(f"   Train Accuracy: {history['train_acc'][best_epoch]:.2f}%")

## 3. Confusion Matrix

Shows which classes are confused with each other.

In [None]:
class_names = ['Left Hand', 'Right Hand', 'Feet', 'Tongue']
cm = eval_results['confusion_matrix']

# Raw counts
fig = plot_confusion_matrix(
    cm,
    class_names=class_names,
    normalize=False,
    title='Confusion Matrix (Raw Counts)'
)
plt.show()

# Normalized (percentages)
fig = plot_confusion_matrix(
    cm,
    class_names=class_names,
    normalize=True,
    title='Confusion Matrix (Normalized %)'
)
plt.show()

## 4. Per-Class Performance

Detailed metrics for each motor imagery class.

In [None]:
precision = eval_results['precision']
recall = eval_results['recall']
f1 = eval_results['f1']

fig = plot_class_performance(
    precision,
    recall,
    f1,
    class_names=class_names,
    figsize=(14, 7)
)
plt.show()

# Print detailed table
print("\nüìä Detailed Per-Class Metrics:")
print("=" * 70)
print(f"{'Class':<15} {'Precision':<12} {'Recall':<12} {'F1-Score':<12}")
print("=" * 70)
for i, name in enumerate(class_names):
    print(f"{name:<15} {precision[i]:<12.4f} {recall[i]:<12.4f} {f1[i]:<12.4f}")
print("=" * 70)
print(f"{'Average':<15} {precision.mean():<12.4f} {recall.mean():<12.4f} {f1.mean():<12.4f}")
print("=" * 70)

## 5. Model Architecture Summary

In [None]:
# Load model
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
model = EEGNet(n_classes=4, n_channels=22, n_samples=1000).to(device)

checkpoint = torch.load('../models/checkpoints/best_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

print("üß† EEGNet Architecture:")
print("=" * 70)
print(model)
print("=" * 70)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nüìä Model Statistics:")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")
print(f"   Model size: ~{total_params * 4 / 1024:.1f} KB")
print(f"   Training epoch: {checkpoint['epoch']}")
print(f"   Training accuracy: {checkpoint['train_acc']:.2f}%")

## 6. Analysis & Insights

### Key Findings:

**Strengths:**
- üéØ **Left Hand**: Highest recall (73%) - model rarely misses left hand movements
- üéØ **Right Hand**: Highest precision (68%) - when predicted, it's usually correct
- ‚ö° **Compact Model**: Only 3,444 parameters (~50KB)
- ‚ö° **Fast Training**: 20 minutes on M4 MacBook

**Areas for Improvement:**
- ‚ö†Ô∏è **Feet**: Lowest F1-score (59%) - needs more training data or feature engineering
- ‚ö†Ô∏è **Overall Accuracy**: 63% is moderate for 4-class BCI (target: 70-75%)

### Medical Context:
- **Public Dataset Performance**: 60-65% is typical baseline for BCI Competition IV-2a
- **Clinical BCI Systems**: Usually achieve 70-85% with subject-specific calibration
- **Transfer Learning Potential**: Fine-tuning on personal OpenBCI data should improve by 10-15%

### Next Steps:
1. ‚úÖ Baseline model complete
2. üîú Collect personal motor imagery data with OpenBCI
3. üîú Fine-tune model on personal data
4. üîú Real-time testing with prosthetic hand

**–ù–ò–ö–û–ì–î–ê –ù–ï –°–î–ê–í–ê–ô–°–Ø!** üí™üß†ü§ñ

## 7. Export Summary Report

In [None]:
# Create summary report
report = {
    'model': 'EEGNet',
    'dataset': 'BCI Competition IV-2a',
    'n_subjects': 9,
    'n_trials': 5184,
    'train_trials': 4147,
    'test_trials': 1037,
    'best_epoch': best_epoch + 1,
    'total_epochs': len(history['train_loss']),
    'accuracy': eval_results['accuracy'],
    'precision_avg': eval_results['precision'].mean(),
    'recall_avg': eval_results['recall'].mean(),
    'f1_avg': eval_results['f1'].mean(),
    'model_params': total_params,
    'training_time': '19m 58s'
}

# Save report
report_path = Path('../models/checkpoints/summary_report.npy')
np.save(report_path, report)

print("üìÑ Summary Report:")
print("=" * 70)
for key, value in report.items():
    print(f"   {key}: {value}")
print("=" * 70)
print(f"\nüíæ Report saved to: {report_path}")