<a href="https://github.com/timeseriesAI/tsai-rs" target="_parent"><img src="https://img.shields.io/badge/tsai--rs-Time%20Series%20AI%20in%20Rust-blue" alt="tsai-rs"/></a>

# Prediction Dynamics and Analysis with tsai-rs

This notebook demonstrates how to analyze model predictions and understand prediction dynamics using **tsai-rs**.

## Purpose

Understanding model predictions is crucial for:
1. **Debugging**: Identify why certain samples are misclassified
2. **Model improvement**: Focus on hard examples
3. **Trust**: Build confidence in model predictions
4. **Data quality**: Find mislabeled samples

Based on advice from Andrej Karpathy's blog post on neural network training:
> "Visualize prediction dynamics: I like to visualize model predictions on a fixed test batch during the course of training. The 'dynamics' of how these predictions move will give you incredibly good intuition for how the training progresses."

## Install tsai-rs

```bash
cd crates/tsai_python
maturin develop --release
```

## Import Libraries

In [None]:
import tsai_rs
import numpy as np
import matplotlib.pyplot as plt

print(f"tsai-rs version: {tsai_rs.version()}")
tsai_rs.my_setup()

## Load Data

In [None]:
dsid = 'NATOPS'
X_train, y_train, X_test, y_test = tsai_rs.get_UCR_data(dsid, return_split=True)

n_vars = X_train.shape[1]
seq_len = X_train.shape[2]
n_classes = len(np.unique(y_train))

print(f"Dataset: {dsid}")
print(f"X_test shape: {X_test.shape}")
print(f"Classes: {np.unique(y_test)}")

In [None]:
# Standardize
X_train_std = tsai_rs.ts_standardize(X_train.astype(np.float32), by_sample=True)
X_test_std = tsai_rs.ts_standardize(X_test.astype(np.float32), by_sample=True)

## Simulating Model Predictions

For demonstration, we'll simulate model predictions.

In [None]:
def simulate_predictions(y_true, n_classes, accuracy=0.85, seed=42):
    """Simulate model predictions with specified accuracy."""
    np.random.seed(seed)
    n_samples = len(y_true)
    
    # Create predicted probabilities
    probs = np.zeros((n_samples, n_classes))
    
    for i, true_label in enumerate(y_true):
        # With probability 'accuracy', predict correct class
        if np.random.random() < accuracy:
            # Correct prediction - high confidence
            probs[i, int(true_label)] = np.random.uniform(0.6, 0.95)
            # Distribute remaining probability
            remaining = 1 - probs[i, int(true_label)]
            for j in range(n_classes):
                if j != int(true_label):
                    probs[i, j] = remaining * np.random.random()
        else:
            # Incorrect prediction
            wrong_class = np.random.choice([c for c in range(n_classes) if c != int(true_label)])
            probs[i, wrong_class] = np.random.uniform(0.4, 0.8)
            remaining = 1 - probs[i, wrong_class]
            for j in range(n_classes):
                if j != wrong_class:
                    probs[i, j] = remaining * np.random.random()
        
        # Normalize
        probs[i] = probs[i] / probs[i].sum()
    
    y_pred = np.argmax(probs, axis=1)
    return y_pred, probs

# Simulate predictions
y_pred, probs = simulate_predictions(y_test, n_classes, accuracy=0.80)

actual_acc = (y_pred == y_test).mean()
print(f"Simulated accuracy: {actual_acc:.2%}")

## Confusion Matrix Analysis

In [None]:
# Compute confusion matrix using tsai-rs
cm = tsai_rs.confusion_matrix(y_test, y_pred, n_classes)

print("Confusion Matrix:")
print(cm)

In [None]:
# Visualize confusion matrix
fig, ax = plt.subplots(figsize=(10, 8))

im = ax.imshow(cm, cmap='Blues')
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.set_title(f'Confusion Matrix - Accuracy: {actual_acc:.2%}')

# Add text annotations
for i in range(n_classes):
    for j in range(n_classes):
        text = ax.text(j, i, int(cm[i, j]),
                       ha="center", va="center", color="white" if cm[i, j] > cm.max()/2 else "black")

plt.colorbar(im)
plt.tight_layout()
plt.show()

## Top Losses Analysis

Identifying samples with the highest losses helps understand where the model struggles.

In [None]:
# Compute losses (negative log probability of true class)
losses = -np.log(probs[np.arange(len(y_test)), y_test.astype(int)] + 1e-10)

# Get top losses using tsai-rs
top_k = 10
top_loss_indices = tsai_rs.top_losses(losses, k=top_k)

print(f"Top {top_k} losses:")
print(f"{'Index':<8} {'True':<8} {'Pred':<8} {'Loss':<10} {'Confidence':<12}")
print("-" * 50)
for idx in top_loss_indices:
    true_label = y_test[idx]
    pred_label = y_pred[idx]
    loss = losses[idx]
    conf = probs[idx].max()
    marker = "WRONG" if true_label != pred_label else ""
    print(f"{idx:<8} {true_label:<8} {pred_label:<8} {loss:<10.4f} {conf:<12.4f} {marker}")

## Visualize Top Loss Samples

In [None]:
# Visualize time series with highest losses
n_show = min(6, len(top_loss_indices))
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()

for i, idx in enumerate(top_loss_indices[:n_show]):
    ts = X_test_std[idx, 0, :]  # First variable
    true_label = y_test[idx]
    pred_label = y_pred[idx]
    conf = probs[idx].max()
    
    axes[i].plot(ts)
    color = 'red' if true_label != pred_label else 'green'
    axes[i].set_title(f'True: {true_label}, Pred: {pred_label}\nConf: {conf:.2%}', color=color)

plt.suptitle('Top Loss Samples (Red = Misclassified)')
plt.tight_layout()
plt.show()

## Prediction Confidence Distribution

In [None]:
# Analyze confidence distribution
confidences = probs.max(axis=1)
correct_mask = y_pred == y_test

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Overall confidence distribution
axes[0].hist(confidences, bins=20, edgecolor='black', alpha=0.7)
axes[0].axvline(confidences.mean(), color='red', linestyle='--', label=f'Mean: {confidences.mean():.2f}')
axes[0].set_xlabel('Confidence')
axes[0].set_ylabel('Count')
axes[0].set_title('Overall Confidence Distribution')
axes[0].legend()

# Confidence: correct vs incorrect
axes[1].hist(confidences[correct_mask], bins=15, alpha=0.6, label=f'Correct (n={correct_mask.sum()})', color='green')
axes[1].hist(confidences[~correct_mask], bins=15, alpha=0.6, label=f'Incorrect (n={(~correct_mask).sum()})', color='red')
axes[1].set_xlabel('Confidence')
axes[1].set_ylabel('Count')
axes[1].set_title('Confidence: Correct vs Incorrect Predictions')
axes[1].legend()

plt.tight_layout()
plt.show()

print(f"Mean confidence (correct): {confidences[correct_mask].mean():.3f}")
print(f"Mean confidence (incorrect): {confidences[~correct_mask].mean():.3f}")

## Per-Class Performance Analysis

In [None]:
# Analyze per-class metrics
classes = np.unique(y_test)

print(f"{'Class':<10} {'Samples':<10} {'Accuracy':<12} {'Avg Conf':<12} {'Avg Loss'}")
print("-" * 60)

class_stats = []
for cls in classes:
    mask = y_test == cls
    n_samples = mask.sum()
    acc = (y_pred[mask] == y_test[mask]).mean()
    avg_conf = confidences[mask].mean()
    avg_loss = losses[mask].mean()
    
    class_stats.append((cls, n_samples, acc, avg_conf, avg_loss))
    print(f"{cls:<10} {n_samples:<10} {acc:<12.2%} {avg_conf:<12.3f} {avg_loss:.4f}")

In [None]:
# Visualize per-class accuracy
class_labels = [str(s[0]) for s in class_stats]
accuracies = [s[2] for s in class_stats]

fig, ax = plt.subplots(figsize=(12, 6))
bars = ax.bar(class_labels, accuracies, color='steelblue', edgecolor='black')
ax.axhline(actual_acc, color='red', linestyle='--', label=f'Overall: {actual_acc:.2%}')

# Color bars based on performance
for bar, acc in zip(bars, accuracies):
    if acc < actual_acc - 0.1:
        bar.set_color('salmon')
    elif acc > actual_acc + 0.1:
        bar.set_color('lightgreen')

ax.set_xlabel('Class')
ax.set_ylabel('Accuracy')
ax.set_title('Per-Class Accuracy')
ax.legend()
ax.set_ylim(0, 1)

plt.tight_layout()
plt.show()

## Common Misclassifications

In [None]:
# Find most common misclassification pairs
misclassified_mask = y_pred != y_test
misclassified_pairs = list(zip(y_test[misclassified_mask], y_pred[misclassified_mask]))

from collections import Counter
pair_counts = Counter(misclassified_pairs)

print("Most common misclassifications:")
print(f"{'True':<10} {'Predicted':<12} {'Count':<10} {'% of Errors'}")
print("-" * 45)

total_errors = len(misclassified_pairs)
for (true_cls, pred_cls), count in pair_counts.most_common(10):
    pct = count / total_errors * 100
    print(f"{true_cls:<10} {pred_cls:<12} {count:<10} {pct:.1f}%")

## Hard vs Easy Samples

In [None]:
# Categorize samples by difficulty
easy_threshold = 0.9
hard_threshold = 0.6

easy_mask = (confidences >= easy_threshold) & correct_mask
hard_mask = (confidences < hard_threshold) | ~correct_mask
medium_mask = ~easy_mask & ~hard_mask

print("Sample Difficulty Distribution:")
print(f"  Easy (conf >= {easy_threshold}, correct):   {easy_mask.sum()} ({easy_mask.mean():.1%})")
print(f"  Medium:                                    {medium_mask.sum()} ({medium_mask.mean():.1%})")
print(f"  Hard (conf < {hard_threshold} or wrong):   {hard_mask.sum()} ({hard_mask.mean():.1%})")

In [None]:
# Compare easy vs hard samples
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

# Easy samples
easy_indices = np.where(easy_mask)[0][:3]
for i, idx in enumerate(easy_indices):
    ts = X_test_std[idx, 0, :]
    axes[0, i].plot(ts, color='green')
    axes[0, i].set_title(f'Easy - Conf: {confidences[idx]:.2%}')
axes[0, 0].set_ylabel('Easy Samples')

# Hard samples
hard_indices = np.where(hard_mask)[0][:3]
for i, idx in enumerate(hard_indices):
    ts = X_test_std[idx, 0, :]
    color = 'red' if y_pred[idx] != y_test[idx] else 'orange'
    axes[1, i].plot(ts, color=color)
    status = 'Wrong' if y_pred[idx] != y_test[idx] else 'Low conf'
    axes[1, i].set_title(f'Hard ({status}) - Conf: {confidences[idx]:.2%}')
axes[1, 0].set_ylabel('Hard Samples')

plt.suptitle('Comparison: Easy vs Hard Samples')
plt.tight_layout()
plt.show()

## Insights from Prediction Dynamics

When visualizing predictions during training, look for:

1. **Easy vs difficult classes**: Which classes converge first?
2. **Stability**: Do predictions oscillate or converge smoothly?
3. **Class imbalance effects**: Are frequent classes dominating?
4. **Learning rate effects**: Too high LR causes instability
5. **Overfitting signals**: Training predictions perfect, validation unstable

## Summary

This notebook demonstrated prediction analysis techniques:

### Analysis Tools
- `confusion_matrix`: Visualize classification performance
- `top_losses`: Identify hardest samples

### Key Insights
1. **Confidence analysis**: Model confidence correlates with correctness
2. **Per-class performance**: Some classes may be harder
3. **Common errors**: Identify frequently confused class pairs
4. **Hard samples**: Focus improvement efforts on difficult cases

### Applications
- Model debugging and improvement
- Data quality assessment
- Building trust in predictions

In [None]:
# Quick reference
print("Prediction Analysis Functions:")
print("==============================")
print("\n# Confusion matrix")
print("cm = tsai_rs.confusion_matrix(y_true, y_pred, n_classes)")
print("\n# Top losses")
print("indices = tsai_rs.top_losses(losses, k=10)")