In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr

In [2]:
# Set random seed for reproducibility
torch.manual_seed(42)

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [3]:
# Define the CNN model
class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.dropout2 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.pool(x)
        x = torch.relu(self.conv2(x))
        x = self.pool(x)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

In [4]:
# Load and preprocess MNIST dataset
def load_mnist_data(batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    # Download and load training data
    full_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
    
    # Split into train and validation sets (80/20)
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader

In [5]:
# Training function
def train_model(model, train_loader, val_loader, epochs=10, learning_rate=0.001):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    model.to(device)
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * images.size(0)
        
        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * images.size(0)
                
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        # Print statistics
        train_loss = train_loss / len(train_loader.dataset)
        val_loss = val_loss / len(val_loader.dataset)
        accuracy = 100 * correct / total
        
        print(f"Epoch {epoch+1}/{epochs} - "
                f"Train Loss: {train_loss:.4f} - "
                f"Val Loss: {val_loss:.4f} - "
                f"Val Accuracy: {accuracy:.2f}%")
    
    return model

In [6]:
# Function to preprocess sketchpad input for the model
def preprocess_image(input_data):    
    # Extract image from dictionary (handles newer Gradio versions)
    if isinstance(input_data, dict):
        if 'image' in input_data:
            image = input_data['image']
        elif 'composite' in input_data:
            image = input_data['composite']
        else:
            raise ValueError("Dictionary input missing image data")
    elif isinstance(input_data, np.ndarray):
        image = input_data
    else:
        try:
            image = np.array(input_data)
        except:
            raise ValueError(f"Unsupported input type: {type(input_data)}")
    
    print(f"Image shape before processing: {image.shape}")
    
    # Convert RGBA/RGB to grayscale if needed
    if image.ndim == 3:
        if image.shape[2] == 4:  # RGBA image
            image = image[..., :3]  # Drop alpha channel
        image = np.mean(image, axis=2)  # Convert to grayscale
    
    # Invert colors (MNIST has white digits on black background)
    image = 255 - image
    
    # Convert to tensor and normalize
    image_tensor = transforms.functional.to_tensor(image).unsqueeze(0)  # Add batch dimension
    image_tensor = transforms.functional.resize(image_tensor, (28, 28))
    image_tensor = transforms.functional.normalize(image_tensor, (0.1307,), (0.3081,))
    
    return image_tensor

In [7]:
# Function to create probability bar chart
def plot_probabilities(probabilities):
    fig, ax = plt.subplots()
    bars = ax.bar(range(10), probabilities)
    ax.set_xlabel('Digit')
    ax.set_ylabel('Probability')
    ax.set_title('Prediction Probabilities')
    ax.set_xticks(range(10))
    ax.set_ylim(0, 1)
    
    # Add probability text on top of bars
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.2f}',
                ha='center', va='bottom')
    
    return fig

In [8]:
# Prediction function for Gradio interface
def predict_digit(input_data):
    try:
        # Preprocess the input image
        input_tensor = preprocess_image(input_data).to(device)
        
        # Make prediction
        model.eval()
        with torch.no_grad():
            output = model(input_tensor)
            probabilities = torch.softmax(output, dim=1).cpu().numpy()[0]
        
        # Get predicted digit
        predicted_digit = int(np.argmax(probabilities))
        
        # Create probability plot
        prob_plot = plot_probabilities(probabilities)
        
        return predicted_digit, prob_plot
    except Exception as e:
        print(f"Error during prediction: {str(e)}")
        return "Error", None

In [9]:
# Load data and train model
print("Loading data...")
train_loader, val_loader = load_mnist_data(batch_size=64)

print("Creating model...")
model = MNISTClassifier()

print("Training model...")
model = train_model(model, train_loader, val_loader, epochs=10)

Loading data...
Creating model...
Training model...
Epoch 1/10 - Train Loss: 0.2390 - Val Loss: 0.0640 - Val Accuracy: 98.13%
Epoch 2/10 - Train Loss: 0.0947 - Val Loss: 0.0589 - Val Accuracy: 98.30%
Epoch 3/10 - Train Loss: 0.0743 - Val Loss: 0.0453 - Val Accuracy: 98.72%
Epoch 4/10 - Train Loss: 0.0605 - Val Loss: 0.0407 - Val Accuracy: 98.83%
Epoch 5/10 - Train Loss: 0.0507 - Val Loss: 0.0350 - Val Accuracy: 98.99%
Epoch 6/10 - Train Loss: 0.0466 - Val Loss: 0.0395 - Val Accuracy: 98.94%
Epoch 7/10 - Train Loss: 0.0429 - Val Loss: 0.0347 - Val Accuracy: 99.04%
Epoch 8/10 - Train Loss: 0.0412 - Val Loss: 0.0341 - Val Accuracy: 99.09%
Epoch 9/10 - Train Loss: 0.0361 - Val Loss: 0.0389 - Val Accuracy: 98.99%
Epoch 10/10 - Train Loss: 0.0364 - Val Loss: 0.0340 - Val Accuracy: 99.09%


In [10]:
# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# MNIST Digit Recognition")
    gr.Markdown("Draw a digit (0-9) in the box below and see the model's prediction.")
    
    with gr.Row():
        # Sketchpad with explicit numpy output
        sketchpad = gr.Sketchpad(label="Draw Digit", 
                                image_mode="L",
                                type="numpy")
        with gr.Column():
            label = gr.Label(label="Predicted Digit")
            plot = gr.Plot(label="Prediction Probabilities")
    
    sketchpad.change(
        fn=predict_digit,
        inputs=sketchpad,
        outputs=[label, plot]
    )

In [11]:
if __name__ == "__main__":
    demo.launch()

* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.


Image shape before processing: (800, 800)
Error during prediction: 'NoneType' object has no attribute 'shape'
Image shape before processing: (800, 800)
Error during prediction: 'NoneType' object has no attribute 'shape'
Image shape before processing: (800, 800)
