# Multi-Class Classification Metrics

In this notebook, you'll learn:
- How to evaluate multi-class classification problems
- Confusion matrix for 3+ classes
- Macro vs Micro vs Weighted averaging
- Per-class metrics analysis
- When to use which averaging strategy

## From Binary to Multi-Class

Binary classification (2 classes) is simpler:
- Clear definition of positive/negative
- Single confusion matrix

Multi-class classification (3+ classes) requires:
- One-vs-rest approach for per-class metrics
- Aggregation strategies (macro, micro, weighted)
- More complex confusion matrix

In [None]:
# Import required libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    confusion_matrix, classification_report,
    precision_score, recall_score, f1_score,
    accuracy_score, ConfusionMatrixDisplay
)

# Set random seed
np.random.seed(42)

# Display settings
pd.set_option('display.max_columns', None)
plt.style.use('default')
sns.set_palette('colorblind')

## Load Multi-Class Data

We'll use a 3-class classification dataset:

In [None]:
# Load the multi-class data
df = pd.read_csv('../../fixtures/input/multiclass_data.csv')

print(f"Dataset shape: {df.shape}")
print(f"\nFirst few rows:")
df.head()

In [None]:
# Extract labels
y_true = df['true_label'].values
y_pred = df['predicted_label'].values

# Examine class distribution
print("Class distribution (true labels):")
print(df['true_label'].value_counts().sort_index())
print()
print("Class distribution (predicted labels):")
print(df['predicted_label'].value_counts().sort_index())

# Visualize distributions
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

df['true_label'].value_counts().sort_index().plot(kind='bar', ax=axes[0], 
                                                   color=['skyblue', 'lightgreen', 'lightcoral'])
axes[0].set_title('True Label Distribution', fontsize=12)
axes[0].set_xlabel('Class')
axes[0].set_ylabel('Count')
axes[0].set_xticklabels(['Class 0', 'Class 1', 'Class 2'], rotation=0)

df['predicted_label'].value_counts().sort_index().plot(kind='bar', ax=axes[1],
                                                        color=['skyblue', 'lightgreen', 'lightcoral'])
axes[1].set_title('Predicted Label Distribution', fontsize=12)
axes[1].set_xlabel('Class')
axes[1].set_ylabel('Count')
axes[1].set_xticklabels(['Class 0', 'Class 1', 'Class 2'], rotation=0)

plt.tight_layout()
plt.show()

## Multi-Class Confusion Matrix

For 3 classes, the confusion matrix is 3x3:
- Rows: Actual classes
- Columns: Predicted classes
- Diagonal: Correct predictions
- Off-diagonal: Misclassifications

In [None]:
# Calculate confusion matrix
cm = confusion_matrix(y_true, y_pred)

print("Confusion Matrix:")
print(cm)
print()
print("Layout:")
print("        Pred 0  Pred 1  Pred 2")
print("Actual 0  [ cm[0,0]  cm[0,1]  cm[0,2] ]")
print("Actual 1  [ cm[1,0]  cm[1,1]  cm[1,2] ]")
print("Actual 2  [ cm[2,0]  cm[2,1]  cm[2,2] ]")

In [None]:
# Visualize confusion matrix
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Plot 1: Counts
disp = ConfusionMatrixDisplay(confusion_matrix=cm, 
                              display_labels=['Class 0', 'Class 1', 'Class 2'])
disp.plot(ax=axes[0], cmap='Blues', values_format='d')
axes[0].set_title('Confusion Matrix (Counts)', fontsize=14)

# Plot 2: Normalized (by true label)
cm_normalized = confusion_matrix(y_true, y_pred, normalize='true')
disp_norm = ConfusionMatrixDisplay(confusion_matrix=cm_normalized,
                                   display_labels=['Class 0', 'Class 1', 'Class 2'])
disp_norm.plot(ax=axes[1], cmap='Blues', values_format='.2f')
axes[1].set_title('Confusion Matrix (Normalized by Row)', fontsize=14)

plt.tight_layout()
plt.show()

In [None]:
# Analyze the confusion matrix
print("Confusion Matrix Analysis:")
print()
for i in range(3):
    total = cm[i, :].sum()
    correct = cm[i, i]
    print(f"Class {i}:")
    print(f"  Total samples: {total}")
    print(f"  Correctly classified: {correct} ({correct/total*100:.1f}%)")
    print(f"  Misclassifications:")
    for j in range(3):
        if i != j and cm[i, j] > 0:
            print(f"    - Predicted as Class {j}: {cm[i, j]} ({cm[i, j]/total*100:.1f}%)")
    print()

## Per-Class Metrics

For each class, we can calculate precision, recall, and F1 using **one-vs-rest**:
- Treat the class as "positive"
- All other classes as "negative"
- Calculate metrics as in binary classification

In [None]:
# Calculate per-class metrics manually for Class 1
# For Class 1: it's positive, Classes 0 and 2 are negative
class_idx = 1

# True Positives: predicted as class_idx AND actually class_idx
tp = cm[class_idx, class_idx]

# False Positives: predicted as class_idx BUT not actually class_idx
fp = cm[:, class_idx].sum() - tp

# False Negatives: actually class_idx BUT not predicted as class_idx
fn = cm[class_idx, :].sum() - tp

# True Negatives: not class_idx AND not predicted as class_idx
tn = cm.sum() - tp - fp - fn

# Calculate metrics
precision_class1 = tp / (tp + fp) if (tp + fp) > 0 else 0
recall_class1 = tp / (tp + fn) if (tp + fn) > 0 else 0
f1_class1 = 2 * precision_class1 * recall_class1 / (precision_class1 + recall_class1) if (precision_class1 + recall_class1) > 0 else 0

print(f"Manual calculation for Class {class_idx}:")
print(f"  TP: {tp}, FP: {fp}, FN: {fn}, TN: {tn}")
print(f"  Precision: {precision_class1:.4f}")
print(f"  Recall: {recall_class1:.4f}")
print(f"  F1-Score: {f1_class1:.4f}")

## Averaging Strategies

To get a single metric for all classes, we can use different averaging strategies:

### 1. Macro Average
- Calculate metric for each class independently
- Take the simple average
- **Treats all classes equally** (good for balanced classes)

$$\text{Macro} = \frac{1}{N} \sum_{i=1}^{N} \text{metric}_i$$

### 2. Micro Average
- Aggregate TP, FP, FN across all classes
- Calculate metric from aggregated values
- **Weights by sample frequency** (good for imbalanced classes)

$$\text{Micro} = \frac{\sum TP_i}{\sum (TP_i + FP_i)}$$

### 3. Weighted Average
- Calculate metric for each class
- Weight by class support (number of samples)
- **Accounts for class imbalance**

$$\text{Weighted} = \frac{\sum (n_i \times \text{metric}_i)}{\sum n_i}$$

In [None]:
# Calculate metrics with different averaging strategies
print("Precision:")
print(f"  Macro:    {precision_score(y_true, y_pred, average='macro'):.4f}")
print(f"  Micro:    {precision_score(y_true, y_pred, average='micro'):.4f}")
print(f"  Weighted: {precision_score(y_true, y_pred, average='weighted'):.4f}")
print()

print("Recall:")
print(f"  Macro:    {recall_score(y_true, y_pred, average='macro'):.4f}")
print(f"  Micro:    {recall_score(y_true, y_pred, average='micro'):.4f}")
print(f"  Weighted: {recall_score(y_true, y_pred, average='weighted'):.4f}")
print()

print("F1-Score:")
print(f"  Macro:    {f1_score(y_true, y_pred, average='macro'):.4f}")
print(f"  Micro:    {f1_score(y_true, y_pred, average='micro'):.4f}")
print(f"  Weighted: {f1_score(y_true, y_pred, average='weighted'):.4f}")
print()

print("Accuracy:")
print(f"  {accuracy_score(y_true, y_pred):.4f}")
print()
print("Note: For multi-class, micro-averaged precision/recall/F1 all equal accuracy!")

## Verify Macro Average Calculation

In [None]:
# Calculate macro average manually
per_class_precision = []
per_class_recall = []
per_class_f1 = []

for class_idx in range(3):
    # Calculate TP, FP, FN for this class
    tp = cm[class_idx, class_idx]
    fp = cm[:, class_idx].sum() - tp
    fn = cm[class_idx, :].sum() - tp
    
    # Calculate metrics
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    per_class_precision.append(precision)
    per_class_recall.append(recall)
    per_class_f1.append(f1)
    
    print(f"Class {class_idx}: Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}")

# Calculate macro averages
macro_precision = np.mean(per_class_precision)
macro_recall = np.mean(per_class_recall)
macro_f1 = np.mean(per_class_f1)

print()
print("Manual Macro Averages:")
print(f"  Precision: {macro_precision:.4f}")
print(f"  Recall: {macro_recall:.4f}")
print(f"  F1: {macro_f1:.4f}")

print()
print("sklearn Macro Averages:")
print(f"  Precision: {precision_score(y_true, y_pred, average='macro'):.4f}")
print(f"  Recall: {recall_score(y_true, y_pred, average='macro'):.4f}")
print(f"  F1: {f1_score(y_true, y_pred, average='macro'):.4f}")

## Classification Report

The classification_report provides a comprehensive view of all metrics:

In [None]:
# Generate classification report
report = classification_report(y_true, y_pred, 
                              target_names=['Class 0', 'Class 1', 'Class 2'])
print("Classification Report:")
print(report)

## Visualize Per-Class Performance

In [None]:
# Get report as dictionary
report_dict = classification_report(y_true, y_pred, 
                                   target_names=['Class 0', 'Class 1', 'Class 2'],
                                   output_dict=True)

# Extract per-class metrics
classes = ['Class 0', 'Class 1', 'Class 2']
metrics = ['precision', 'recall', 'f1-score']

data = []
for class_name in classes:
    data.append([
        report_dict[class_name]['precision'],
        report_dict[class_name]['recall'],
        report_dict[class_name]['f1-score']
    ])

data = np.array(data)

# Create grouped bar chart
fig, ax = plt.subplots(figsize=(12, 6))
x = np.arange(len(metrics))
width = 0.25

colors = ['skyblue', 'lightgreen', 'lightcoral']
for i, class_name in enumerate(classes):
    offset = width * (i - 1)
    bars = ax.bar(x + offset, data[i], width, label=class_name, color=colors[i])
    
    # Add value labels
    for bar in bars:
        height = bar.get_height()
        ax.annotate(f'{height:.3f}',
                   xy=(bar.get_x() + bar.get_width() / 2, height),
                   xytext=(0, 3),
                   textcoords="offset points",
                   ha='center', va='bottom', fontsize=8)

ax.set_ylabel('Score', fontsize=12)
ax.set_title('Per-Class Performance Metrics', fontsize=14, pad=20)
ax.set_xticks(x)
ax.set_xticklabels(metrics)
ax.legend()
ax.set_ylim(0, 1.0)
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

## Compare Averaging Strategies Visually

In [None]:
# Calculate all averaging strategies
averaging_methods = ['macro', 'micro', 'weighted']
metric_names = ['Precision', 'Recall', 'F1-Score']

results = []
for avg in averaging_methods:
    if avg == 'micro':
        # Micro averaging
        prec = precision_score(y_true, y_pred, average=avg)
        rec = recall_score(y_true, y_pred, average=avg)
        f1 = f1_score(y_true, y_pred, average=avg)
    else:
        prec = precision_score(y_true, y_pred, average=avg)
        rec = recall_score(y_true, y_pred, average=avg)
        f1 = f1_score(y_true, y_pred, average=avg)
    results.append([prec, rec, f1])

results = np.array(results)

# Plot comparison
fig, ax = plt.subplots(figsize=(12, 6))
x = np.arange(len(metric_names))
width = 0.25

colors = ['gold', 'lightblue', 'lightgreen']
for i, avg_method in enumerate(averaging_methods):
    offset = width * (i - 1)
    bars = ax.bar(x + offset, results[i], width, label=avg_method.capitalize(), color=colors[i])
    
    # Add value labels
    for bar in bars:
        height = bar.get_height()
        ax.annotate(f'{height:.3f}',
                   xy=(bar.get_x() + bar.get_width() / 2, height),
                   xytext=(0, 3),
                   textcoords="offset points",
                   ha='center', va='bottom', fontsize=9)

ax.set_ylabel('Score', fontsize=12)
ax.set_title('Comparison of Averaging Strategies', fontsize=14, pad=20)
ax.set_xticks(x)
ax.set_xticklabels(metric_names)
ax.legend()
ax.set_ylim(0, 1.0)
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

## When to Use Which Averaging Strategy?

Let's create scenarios to understand the differences:

In [None]:
# Create a scenario with one poorly performing class
print("Scenario Analysis:")
print()
print("Per-class performance:")
for i, class_name in enumerate(classes):
    support = report_dict[class_name]['support']
    f1 = report_dict[class_name]['f1-score']
    print(f"{class_name}: F1={f1:.3f}, Support={support}")

print()
print("Averaged F1-scores:")
print(f"  Macro:    {f1_score(y_true, y_pred, average='macro'):.4f}")
print(f"  Micro:    {f1_score(y_true, y_pred, average='micro'):.4f}")
print(f"  Weighted: {f1_score(y_true, y_pred, average='weighted'):.4f}")
print()

print("Interpretation:")
print("- Macro: Simple average, treats all classes equally")
print("- Micro: Equivalent to accuracy for multi-class")
print("- Weighted: Accounts for class imbalance by weighting by support")

## Exercise 1: Calculate Weighted Average Manually

Verify the weighted average calculation:

In [None]:
# YOUR CODE HERE
# 1. Get per-class F1 scores and supports
# 2. Calculate weighted average: sum(f1_i * support_i) / sum(support_i)
# 3. Compare with sklearn result



## Exercise 2: Identify Confusion Patterns

Analyze the confusion matrix to identify which classes are most confused:

In [None]:
# YOUR CODE HERE
# 1. Find the most common misclassification (largest off-diagonal element)
# 2. Calculate percentage of each class misclassified
# 3. Identify which pair of classes is most confused



## Exercise 3: Compare with Random Baseline

Calculate metrics for a random classifier:

In [None]:
# YOUR CODE HERE
# 1. Create random predictions (uniform over 3 classes)
# 2. Calculate confusion matrix and metrics
# 3. Compare with your model
# How much better is your model than random?



## Summary

In this notebook, you learned:

1. **Multi-class confusion matrix**: 3x3 matrix showing all classification outcomes
2. **Per-class metrics**: Calculate precision/recall/F1 using one-vs-rest
3. **Averaging strategies**:
   - **Macro**: Simple average (treats all classes equally)
   - **Micro**: Aggregate then calculate (equivalent to accuracy)
   - **Weighted**: Weight by class support (accounts for imbalance)
4. **Classification report**: Comprehensive view of all metrics

### Key Takeaways:

- Multi-class extends binary metrics using one-vs-rest
- Confusion matrix reveals class-specific performance and confusion patterns
- Different averaging strategies serve different purposes
- Macro average is sensitive to poor performance on any class
- Weighted average is more representative for imbalanced datasets

### Decision Guide for Averaging:

| Scenario | Use Macro | Use Micro | Use Weighted |
|----------|-----------|-----------|-------------|
| Balanced classes | ✓ | ✓ | |
| Imbalanced classes | | | ✓ |
| All classes equally important | ✓ | | |
| Want to penalize poor minority class performance | ✓ | | |
| Want overall accuracy-like metric | | ✓ | |

### Next Steps:

You now have a complete understanding of classification metrics! Practice by:
1. Working through the task notebooks
2. Applying these metrics to your own datasets
3. Choosing appropriate metrics for your use cases