# üî¢ MNIST Digit Recognition Dashboard
Interactive demo using Random Forest classifier

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import ipywidgets as widgets
from IPython.display import display, clear_output

In [None]:
# Load and prepare data
print("Loading MNIST dataset...")
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
X, y = mnist.data, mnist.target.astype(int)

# Use a subset for faster training
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Use smaller subset for demo (faster training)
X_train_small = X_train[:10000]
y_train_small = y_train[:10000]

print(f"Training samples: {len(X_train_small)}")
print(f"Test samples: {len(X_test)}")

In [None]:
# Train the model
print("Training Random Forest classifier...")
model = RandomForestClassifier(
    n_estimators=100,
    max_depth=20,
    n_jobs=-1,
    random_state=42
)
model.fit(X_train_small, y_train_small)

# Evaluate
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"‚úÖ Model trained! Test accuracy: {accuracy:.2%}")

---
## üìä Model Performance

In [None]:
# Confusion matrix visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Confusion matrix
cm = confusion_matrix(y_test, y_pred)
im = axes[0].imshow(cm, cmap='Blues')
axes[0].set_title('Confusion Matrix', fontsize=14)
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('Actual')
axes[0].set_xticks(range(10))
axes[0].set_yticks(range(10))
plt.colorbar(im, ax=axes[0])

# Per-class accuracy
class_acc = cm.diagonal() / cm.sum(axis=1)
colors = plt.cm.RdYlGn(class_acc)
bars = axes[1].bar(range(10), class_acc * 100, color=colors)
axes[1].set_title('Accuracy per Digit', fontsize=14)
axes[1].set_xlabel('Digit')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_xticks(range(10))
axes[1].set_ylim(0, 100)
axes[1].axhline(y=accuracy * 100, color='red', linestyle='--', label=f'Overall: {accuracy:.1%}')
axes[1].legend()

plt.tight_layout()
plt.show()

---
## üéØ Interactive Prediction

In [None]:
# Interactive widget for testing predictions
output = widgets.Output()

def predict_random_sample(b):
    with output:
        clear_output(wait=True)

        # Pick random test sample
        idx = np.random.randint(0, len(X_test))
        sample = X_test[idx]
        true_label = y_test[idx]

        # Predict
        pred_label = model.predict([sample])[0]
        pred_proba = model.predict_proba([sample])[0]

        # Display
        fig, axes = plt.subplots(1, 2, figsize=(10, 4))

        # Show digit
        axes[0].imshow(sample.reshape(28, 28), cmap='gray')
        axes[0].set_title(f'True: {true_label} | Predicted: {pred_label}', fontsize=14)
        axes[0].axis('off')

        # Correct/incorrect indicator
        if pred_label == true_label:
            axes[0].set_title(f'[OK] Correct! (True: {true_label})', fontsize=14, color='green')
        else:
            axes[0].set_title(f'[X] Wrong! (True: {true_label}, Pred: {pred_label})', fontsize=14, color='red')

        # Show probabilities
        colors = ['green' if i == true_label else 'steelblue' for i in range(10)]
        colors[pred_label] = 'orange' if pred_label != true_label else 'green'
        axes[1].barh(range(10), pred_proba * 100, color=colors)
        axes[1].set_yticks(range(10))
        axes[1].set_xlabel('Confidence (%)')
        axes[1].set_title('Prediction Probabilities', fontsize=14)
        axes[1].set_xlim(0, 100)

        plt.tight_layout()
        plt.show()

button = widgets.Button(
    description='üé≤ Random Sample',
    button_style='primary',
    layout=widgets.Layout(width='200px', height='40px')
)
button.on_click(predict_random_sample)

display(widgets.VBox([button, output]))

# Trigger initial prediction
predict_random_sample(None)

---
## üîç Explore Specific Digits

In [None]:
output2 = widgets.Output()

def show_digit_samples(change):
    with output2:
        clear_output(wait=True)
        digit = digit_dropdown.value

        # Find samples of this digit
        indices = np.where(y_test == digit)[0]
        sample_indices = np.random.choice(indices, min(8, len(indices)), replace=False)

        fig, axes = plt.subplots(2, 4, figsize=(10, 5))
        axes = axes.flatten()

        for ax, idx in zip(axes, sample_indices):
            sample = X_test[idx]
            pred = model.predict([sample])[0]
            confidence = model.predict_proba([sample])[0][pred] * 100

            ax.imshow(sample.reshape(28, 28), cmap='gray')
            status = '[OK]' if pred == digit else '[X]'
            ax.set_title(f'{status} Pred: {pred} ({confidence:.0f}%)', fontsize=10)
            ax.axis('off')

        plt.suptitle(f'Random samples of digit "{digit}"', fontsize=14)
        plt.tight_layout()
        plt.show()

digit_dropdown = widgets.Dropdown(
    options=list(range(10)),
    value=0,
    description='Digit:',
    layout=widgets.Layout(width='150px')
)
digit_dropdown.observe(show_digit_samples, names='value')

display(widgets.VBox([digit_dropdown, output2]))
show_digit_samples(None)