# EMNIST Model Evaluation

This notebook provides comprehensive evaluation of the trained EMNIST CNN model.

**Evaluation Metrics:**
- Overall test accuracy and top-5 accuracy
- Per-class precision, recall, F1-score
- Confusion matrix visualization
- Commonly confused character pairs analysis
- Sample predictions with confidence scores

## 1. Import Libraries and Load Model

In [None]:
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
from pathlib import Path
import tensorflow as tf
from tensorflow import keras
from sklearn.metrics import classification_report, confusion_matrix

# Import custom modules
from src.data.dataset import load_emnist
from src.utils.label_mapping import load_label_mapping

# Set visualization style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print(f"TensorFlow version: {tf.__version__}")
print("✓ Libraries imported successfully")

# Load trained model
MODEL_PATH = '../models/emnist_cnn_v1.keras'

if Path(MODEL_PATH).exists():
    print(f"\nLoading model from {MODEL_PATH}...")
    model = keras.models.load_model(MODEL_PATH)
    print("✓ Model loaded successfully")
    print(f"  Input shape: {model.input_shape}")
    print(f"  Output shape: {model.output_shape}")
else:
    print(f"\n⚠ Model not found at {MODEL_PATH}")
    print("Please train the model first using notebooks/03_model_training.ipynb")

## 2. Load Test Data and Make Predictions

In [None]:
# Load test dataset
print("Loading EMNIST test set...")
_, _, x_test, y_test_labels = load_emnist()

# Preprocess
x_test = x_test.astype(np.float32) / 255.0
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
y_test = tf.keras.utils.to_categorical(y_test_labels, 62)

print(f"✓ Test set loaded: {x_test.shape[0]:,} samples")

# Load label mapping
label_mapping = load_label_mapping()
class_names = [label_mapping[i] for i in range(62)]
print(f"✓ Label mapping loaded: {len(class_names)} classes")

# Make predictions
print("\nMaking predictions on test set...")
y_pred_probs = model.predict(x_test, verbose=0)
y_pred_labels = np.argmax(y_pred_probs, axis=1)
print("✓ Predictions complete")

## 3. Overall Performance Metrics

In [None]:
# Evaluate overall metrics
test_results = model.evaluate(x_test, y_test, verbose=0)
test_loss = test_results[0]
test_accuracy = test_results[1]
test_top5 = test_results[2] if len(test_results) > 2 else None

print("="*70)
print("OVERALL PERFORMANCE METRICS")
print("="*70)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy*100:.2f}%")
if test_top5:
    print(f"Top-5 Accuracy: {test_top5*100:.2f}%")

# Compute accuracy per category
print("\nAccuracy by Character Category:")
digits_mask = y_test_labels < 10
uppercase_mask = (y_test_labels >= 10) & (y_test_labels < 36)
lowercase_mask = y_test_labels >= 36

digits_acc = np.mean(y_pred_labels[digits_mask] == y_test_labels[digits_mask])
uppercase_acc = np.mean(y_pred_labels[uppercase_mask] == y_test_labels[uppercase_mask])
lowercase_acc = np.mean(y_pred_labels[lowercase_mask] == y_test_labels[lowercase_mask])

print(f"  Digits (0-9): {digits_acc*100:.2f}%")
print(f"  Uppercase (A-Z): {uppercase_acc*100:.2f}%")
print(f"  Lowercase (a-z): {lowercase_acc*100:.2f}%")
print("="*70)

## 4. Confusion Matrix Visualization

In [None]:
# Compute confusion matrix
print("Computing confusion matrix...")
conf_matrix = confusion_matrix(y_test_labels, y_pred_labels)
print(f"✓ Confusion matrix computed: {conf_matrix.shape}")

# Normalize confusion matrix
conf_matrix_norm = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis]

# Plot confusion matrix
fig, ax = plt.subplots(figsize=(20, 18))

sns.heatmap(
    conf_matrix_norm,
    annot=False,
    fmt='.2f',
    cmap='Blues',
    xticklabels=class_names,
    yticklabels=class_names,
    cbar_kws={'label': 'Normalized Count'},
    ax=ax
)

ax.set_xlabel('Predicted Label', fontsize=14)
ax.set_ylabel('True Label', fontsize=14)
ax.set_title('Confusion Matrix - EMNIST 62 Classes', fontsize=16, fontweight='bold', pad=20)

plt.tight_layout()

# Save confusion matrix
save_path = Path('../models/confusion_matrix.png')
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"✓ Confusion matrix saved to {save_path}")

plt.show()

# Print diagonal (correct predictions)
diagonal_acc = np.diag(conf_matrix_norm)
print(f"\nPer-class accuracy statistics:")
print(f"  Mean: {diagonal_acc.mean()*100:.2f}%")
print(f"  Std: {diagonal_acc.std()*100:.2f}%")
print(f"  Min: {diagonal_acc.min()*100:.2f}% (class '{class_names[diagonal_acc.argmin()]}')")
print(f"  Max: {diagonal_acc.max()*100:.2f}% (class '{class_names[diagonal_acc.argmax()]}')")

## 5. Commonly Confused Character Pairs

In [None]:
# Find commonly confused pairs
confused_pairs = []
for i in range(62):
    for j in range(62):
        if i != j and conf_matrix[i, j] > 50:  # Threshold for "commonly confused"
            confused_pairs.append({
                'true_class': class_names[i],
                'predicted_class': class_names[j],
                'count': int(conf_matrix[i, j]),
                'percentage': float(conf_matrix[i, j] / conf_matrix[i].sum() * 100)
            })

# Sort by count
confused_pairs.sort(key=lambda x: x['count'], reverse=True)

print(f"Found {len(confused_pairs)} commonly confused pairs (>50 confusions)\n")
print("Top 15 Most Confused Character Pairs:")
print("="*70)
for i, pair in enumerate(confused_pairs[:15], 1):
    print(f"{i:2d}. '{pair['true_class']}' → '{pair['predicted_class']}': "
          f"{pair['count']:4d} times ({pair['percentage']:5.2f}%)")

# Visualize top confused pairs
fig, ax = plt.subplots(figsize=(12, 8))

top_pairs = confused_pairs[:15]
labels = [f"'{p['true_class']}'→'{p['predicted_class']}'" for p in top_pairs]
counts = [p['count'] for p in top_pairs]

bars = ax.barh(range(len(top_pairs)), counts, color='coral', edgecolor='darkred', alpha=0.7)
ax.set_yticks(range(len(top_pairs)))
ax.set_yticklabels(labels)
ax.set_xlabel('Number of Confusions', fontsize=12)
ax.set_title('Top 15 Most Confused Character Pairs', fontsize=14, fontweight='bold')
ax.invert_yaxis()
ax.grid(axis='x', alpha=0.3)

# Add count labels
for i, (bar, count) in enumerate(zip(bars, counts)):
    ax.text(count + 20, i, str(count), va='center', fontsize=10)

plt.tight_layout()
plt.show()

## 6. Per-Class Performance Analysis

In [None]:
# Generate classification report
report = classification_report(
    y_test_labels, 
    y_pred_labels,
    target_names=class_names,
    output_dict=True,
    zero_division=0
)

# Extract per-class metrics
per_class_f1 = [report[name]['f1-score'] for name in class_names if name in report]
per_class_precision = [report[name]['precision'] for name in class_names if name in report]
per_class_recall = [report[name]['recall'] for name in class_names if name in report]

# Plot per-class F1-scores
fig, axes = plt.subplots(3, 1, figsize=(18, 15))

# F1-Score
axes[0].bar(range(62), per_class_f1, color='skyblue', edgecolor='navy', alpha=0.7)
axes[0].set_xticks(range(62))
axes[0].set_xticklabels(class_names, rotation=45, ha='right')
axes[0].set_ylabel('F1-Score', fontsize=12)
axes[0].set_title('Per-Class F1-Score', fontsize=14, fontweight='bold')
axes[0].axhline(y=np.mean(per_class_f1), color='red', linestyle='--', label=f'Mean: {np.mean(per_class_f1):.3f}')
axes[0].grid(axis='y', alpha=0.3)
axes[0].legend()

# Precision
axes[1].bar(range(62), per_class_precision, color='lightgreen', edgecolor='darkgreen', alpha=0.7)
axes[1].set_xticks(range(62))
axes[1].set_xticklabels(class_names, rotation=45, ha='right')
axes[1].set_ylabel('Precision', fontsize=12)
axes[1].set_title('Per-Class Precision', fontsize=14, fontweight='bold')
axes[1].axhline(y=np.mean(per_class_precision), color='red', linestyle='--', label=f'Mean: {np.mean(per_class_precision):.3f}')
axes[1].grid(axis='y', alpha=0.3)
axes[1].legend()

# Recall
axes[2].bar(range(62), per_class_recall, color='lightcoral', edgecolor='darkred', alpha=0.7)
axes[2].set_xticks(range(62))
axes[2].set_xticklabels(class_names, rotation=45, ha='right')
axes[2].set_xlabel('Character Class', fontsize=12)
axes[2].set_ylabel('Recall', fontsize=12)
axes[2].set_title('Per-Class Recall', fontsize=14, fontweight='bold')
axes[2].axhline(y=np.mean(per_class_recall), color='red', linestyle='--', label=f'Mean: {np.mean(per_class_recall):.3f}')
axes[2].grid(axis='y', alpha=0.3)
axes[2].legend()

plt.tight_layout()
plt.show()

# Find best and worst performing classes
f1_scores_dict = {class_names[i]: per_class_f1[i] for i in range(len(per_class_f1))}
sorted_f1 = sorted(f1_scores_dict.items(), key=lambda x: x[1])

print("\nBest Performing Classes (Top 10 by F1-Score):")
for char, f1 in sorted_f1[-10:][::-1]:
    print(f"  '{char}': {f1:.4f}")

print("\nWorst Performing Classes (Bottom 10 by F1-Score):")
for char, f1 in sorted_f1[:10]:
    print(f"  '{char}': {f1:.4f}")

## 7. Sample Predictions Visualization

In [None]:
# Visualize correct and incorrect predictions
fig, axes = plt.subplots(4, 8, figsize=(20, 10))
fig.suptitle('Sample Predictions: Correct (Green) vs Incorrect (Red)', fontsize=16, fontweight='bold')

# Get some correct and incorrect predictions
correct_mask = y_pred_labels == y_test_labels
incorrect_mask = ~correct_mask

correct_indices = np.where(correct_mask)[0]
incorrect_indices = np.where(incorrect_mask)[0]

# Sample indices
np.random.seed(42)
sample_correct = np.random.choice(correct_indices, 16, replace=False)
sample_incorrect = np.random.choice(incorrect_indices, 16, replace=False)

samples = list(sample_correct) + list(sample_incorrect)
np.random.shuffle(samples)

for idx, sample_idx in enumerate(samples[:32]):
    row = idx // 8
    col = idx % 8
    
    image = x_test[sample_idx].reshape(28, 28)
    true_label = class_names[y_test_labels[sample_idx]]
    pred_label = class_names[y_pred_labels[sample_idx]]
    confidence = y_pred_probs[sample_idx][y_pred_labels[sample_idx]] * 100
    
    is_correct = y_pred_labels[sample_idx] == y_test_labels[sample_idx]
    color = 'green' if is_correct else 'red'
    
    axes[row, col].imshow(image, cmap='gray')
    title = f"T:'{true_label}' P:'{pred_label}'\n{confidence:.1f}%"
    axes[row, col].set_title(title, fontsize=9, color=color, fontweight='bold')
    axes[row, col].axis('off')

plt.tight_layout()
plt.show()

# Statistics
print(f"\nPrediction Statistics:")
print(f"  Correct predictions: {correct_mask.sum():,} ({correct_mask.mean()*100:.2f}%)")
print(f"  Incorrect predictions: {incorrect_mask.sum():,} ({incorrect_mask.mean()*100:.2f}%)")

## 8. Confidence Analysis

Analyze prediction confidence distribution to understand model certainty.

In [None]:
# Analyze prediction confidence
max_confidences = y_pred_probs.max(axis=1)
correct_confidences = max_confidences[correct_mask]
incorrect_confidences = max_confidences[incorrect_mask]

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Overall confidence distribution
axes[0].hist(max_confidences, bins=50, alpha=0.7, edgecolor='black')
axes[0].axvline(max_confidences.mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: {max_confidences.mean():.3f}')
axes[0].set_xlabel('Confidence', fontsize=12)
axes[0].set_ylabel('Frequency', fontsize=12)
axes[0].set_title('Overall Prediction Confidence Distribution', fontsize=13, fontweight='bold')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Correct vs incorrect confidence
axes[1].hist([correct_confidences, incorrect_confidences], bins=50, alpha=0.7, 
             label=['Correct', 'Incorrect'], color=['green', 'red'], edgecolor='black')
axes[1].set_xlabel('Confidence', fontsize=12)
axes[1].set_ylabel('Frequency', fontsize=12)
axes[1].set_title('Confidence: Correct vs Incorrect Predictions', fontsize=13, fontweight='bold')
axes[1].legend()
axes[1].grid(alpha=0.3)

# Box plot comparison
box_data = [correct_confidences, incorrect_confidences]
bp = axes[2].boxplot(box_data, labels=['Correct', 'Incorrect'], patch_artist=True)
bp['boxes'][0].set_facecolor('lightgreen')
bp['boxes'][1].set_facecolor('lightcoral')
axes[2].set_ylabel('Confidence', fontsize=12)
axes[2].set_title('Confidence Distribution Comparison', fontsize=13, fontweight='bold')
axes[2].grid(alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

# Confidence statistics
print("\nConfidence Statistics:")
print(f"  Overall mean confidence: {max_confidences.mean():.4f}")
print(f"  Correct predictions mean confidence: {correct_confidences.mean():.4f}")
print(f"  Incorrect predictions mean confidence: {incorrect_confidences.mean():.4f}")
print(f"\n  Low confidence correct (<0.5): {(correct_confidences < 0.5).sum():,}")
print(f"  High confidence incorrect (>0.8): {(incorrect_confidences > 0.8).sum():,}")