## 1. Imports & Configuration

Install torch, tkinter and numpy

In [42]:
import tkinter as tk
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms


# Configuration
GRID = 28           # 28*28 pixcel image
CELL = 14           # 14px cell size
# BRUSH_RADIUS = 1         # Brush radius
PREDICT_DELAY_MS = 180   # Wait 180ms after drawing stops before predicting

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


## 2. Convolutional Neural Network (CNN)

Steps:
1.  **Convolution (Filters)**: Scans the image to find simple features like edges and curves.
2.  **ReLU (Activation)**: Filters out negative values (darkness) and keeps positive signals.
3.  **MaxPool (Pooling)**: Shrinks the image to focus on the most important parts.
4.  **Fully Connected**: The final decision maker that identify the number written.

In [43]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        
        # Convolution layaer
        self.conv_layer1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
        self.conv_layer2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)

        # ReLU & MaxPool
        self.relu_activation = nn.ReLU()
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

        feature_count = 32 * 7 * 7
        self.fc_layer1 = nn.Linear(feature_count, 128)
        self.dropout_layer = nn.Dropout(0.5) 
        self.fc_layer2 = nn.Linear(128, 10)

    def forward(self, input_image):
        
        # 1. Pass through first Convolution Layer
        x = self.conv_layer1(input_image)
        x = self.relu_activation(x)
        conv1_features = self.max_pool(x)
        # conv1_features shape: [Batch_Size, 16, 14, 14]
        
        # 2. Pass through second Convolution Layer
        x = self.conv_layer2(conv1_features)
        x = self.relu_activation(x)
        conv2_features = self.max_pool(x)
        # conv2_features shape: [Batch_Size, 32, 7, 7]
        
        # 3. Flatten image for classifier
        flat_features = conv2_features.view(-1, 32 * 7 * 7)
        
        # 4: Fully Connected Classification
        x = self.fc_layer1(flat_features)
        x = self.relu_activation(x)
        
        # Apply dropout only during training
        x = self.dropout_layer(x)
        
        # Final prediction scores
        final_output = self.fc_layer2(x)
        
        return final_output, conv1_features, conv2_features

## 3. Training with data

Use handwritten numbers from MNIST dataset for learning.

**The Training Loop:**
1.  **Forward Pass**: The AI guesses what the number is.
2.  **Calculate Loss**: Calculate how wrong the guess was.
3.  **Backward Pass**: Adjust the AI's internal numbers (weights) to fix the mistake.
4.  **Repeat**: Do this thousands of times.

In [44]:
def train_and_get_model():
    print("Preparing model...")
    model = SimpleCNN().to(device)
    print(model) # summary

    print("Downloading MNIST data")

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    # Download Training Data
    train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
    
    # Download Test Data
    test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000, shuffle=False)

    # Loss function & optimizer (Adam boosting)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    EPOCHS = 5 
    print("Starting Training")
    
    model.train()
    
    for epoch in range(EPOCHS):
        total_loss = 0.0
        
        for batch_index, data in enumerate(train_loader, 0):

            batch_images, batch_labels = data
            batch_images = batch_images.to(device)
            batch_labels = batch_labels.to(device)

            # 1. Zero the parameter gradients
            optimizer.zero_grad()

            # 2. Forward Pass: Get predictions
            predictions, _, _ = model(batch_images) 
            
            # 3. Calculate Loss: How wrong were we?
            loss = loss_function(predictions, batch_labels)
            
            # 4. Backward Pass: Calculate corrections
            loss.backward()
            
            # 5. Optimize: Update weights
            optimizer.step()

            total_loss += loss.item()
        
        average_loss = total_loss / len(train_loader)

        print("Epoch " + (epoch + 1) + "/" + EPOCHS + 
            ", Average Loss: " + average_loss)

    print("Finished Training.")

    # Evaluate the model accuracy
    correct_guesses = 0
    total_samples = 0
    model.eval()
    
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            
            outputs, _, _ = model(images)
            _, predicted_class = torch.max(outputs.data, 1)
            
            total_samples += labels.size(0)
            correct_guesses += (predicted_class == labels).sum().item()

    accuracy = 100 * correct_guesses / total_samples
    
    rounded_accuracy = round(accuracy, 2)
    print("Accuracy on test images:", rounded_accuracy, "%")

    
    return model

## 4. Interface

Application with below features:
1.  **Drawing Pad**: A grid where user can click and drag to draw.
2.  **Visualization Panels**: Two panels that show what the Convolutions (filters) are "seeing."
3.  **Probabilities**: A bar chart showing the AI's confidence.

In [None]:
# Center the Digit
def center_grid_content(grid):
    ''' 
    Compute center of mass and apply shift
    '''
    height, width = grid.shape
    total_mass = grid.sum()
    
    if total_mass <= 0:
        return grid

    y_indices, x_indices = np.indices((height, width))
    center_y = (y_indices * grid).sum() / total_mass
    center_x = (x_indices * grid).sum() / total_mass

    shift_y = int(round(height/2 - center_y))
    shift_x = int(round(width/2 - center_x))

    new_grid = np.zeros_like(grid)
    for r in range(height):
        for c in range(width):
            if grid[r, c] > 0:
                new_r, new_c = r + shift_y, c + shift_x
                if 0 <= new_r < height and 0 <= new_c < width:
                    new_grid[new_r, new_c] = grid[r, c]
                    
    return new_grid


def value_to_hex_color(value):
    '''
    Convert value to Hex Color
    '''
    value = float(np.clip(value, 0.0, 1.0)) 
    intensity = int(value * 255)           
    return f"#{intensity:02x}{intensity:02x}{intensity:02x}"


# Main Application
def run_pixel_grid_realtime_app(model):
    model.eval()
    
    # Initialize main window
    root_window = tk.Tk()
    root_window.title("Real-Time Digit Recognizer")

    # Initialize grids's colour
    pixel_grid = np.zeros((GRID_SIZE, GRID_SIZE), dtype=np.float32)

    # Layout Structure
    top_frame = tk.Frame(root_window)
    top_frame.pack(padx=10, pady=10)

    middle_frame = tk.Frame(root_window)
    middle_frame.pack(padx=10, pady=10)

    bottom_frame = tk.Frame(root_window)
    bottom_frame.pack(padx=10, pady=10)


    # SECTION A: Probability Bars
    prob_canvas = tk.Canvas(top_frame, width=GRID_SIZE*CELL_SIZE, height=70, bg="#111111", highlightthickness=0)
    prob_canvas.pack()

    # Create the labels from "0" to "9"
    for digit in range(10):
        x_pos = digit * (CELL_SIZE*2.75) + 13
        prob_canvas.create_text(x_pos, 10, text=str(digit), fill="white", font=("Arial", 12, "bold"))

    # Create the bars that go up and down
    probability_rectangles = []
    for digit in range(10):
        x_start = digit * (CELL_SIZE*2.7) + 4
        y_start = 25
        x_end = x_start + (CELL_SIZE*2.2)
        y_end = 65
        
        # Create a green rectangle
        rect_id = prob_canvas.create_rectangle(x_start, y_end, x_end, y_end, fill="#66ff66", width=0)
        probability_rectangles.append((rect_id, x_start, y_start, x_end, y_end))

    prediction_text = prob_canvas.create_text((GRID_SIZE*CELL_SIZE)//2, 50, text="Draw a digit...", fill="white", font=("Arial", 12))


    # SECTION B: Convolutional Feature Maps (What the AI Sees)
    
    # Conv1 Panel (Edges)
    conv1_panel = tk.Frame(middle_frame, bg="#111111")
    conv1_panel.pack(side="left", padx=10)
    
    tk.Label(conv1_panel, text="Layer 1: Edges", font=("Arial", 10, "bold"), bg="#111111", fg="white").pack()
    
    # Visualize 8 filters. Each filter output is 14x14 pixels.
    # We scale it up by 4x so it's easier to see.
    SCALE_FACTOR_1 = 4
    LAYER1_HEIGHT, LAYER1_WIDTH = 14, 14
    
    # Calculate canvas size needed for 2 rows of 4 images
    canvas_w1 = 4 * LAYER1_WIDTH * SCALE_FACTOR_1 + 10
    canvas_h1 = 2 * LAYER1_HEIGHT * SCALE_FACTOR_1 + 10
    
    layer1_canvas = tk.Canvas(conv1_panel, width=canvas_w1, height=canvas_h1, bg="black", highlightthickness=0)
    layer1_canvas.pack(pady=5)
    
    # Create the grid of pixels for Layer 1
    layer1_pixels = []
    
    for i in range(8):
        row_group = i // 4
        col_group = i % 4
        
        base_x = col_group * (LAYER1_WIDTH * SCALE_FACTOR_1 + 2)
        base_y = row_group * (LAYER1_HEIGHT * SCALE_FACTOR_1 + 2)
        
        pixel_ids = []
        for r in range(LAYER1_HEIGHT):
            for c in range(LAYER1_WIDTH):
                x0 = base_x + c * SCALE_FACTOR_1
                y0 = base_y + r * SCALE_FACTOR_1
                x1 = x0 + SCALE_FACTOR_1
                y1 = y0 + SCALE_FACTOR_1
                rect_id = layer1_canvas.create_rectangle(x0, y0, x1, y1, fill="#000000", width=0)
                pixel_ids.append(rect_id)
        layer1_pixels.append(pixel_ids)

    # Conv2 Panel (Patterns
    conv2_panel = tk.Frame(middle_frame, bg="#111111")
    conv2_panel.pack(side="left", padx=10)
    
    tk.Label(conv2_panel, text="Layer 2: Patterns", font=("Arial", 10, "bold"), bg="#111111", fg="white").pack()
    
    # Visualize 8 filters. Each filter output is 7x7 pixels.
    # We scale it up by 8x since it's smaller.
    SCALE_FACTOR_2 = 8
    LAYER2_HEIGHT, LAYER2_WIDTH = 7, 7
    
    canvas_w2 = 4 * LAYER2_WIDTH * SCALE_FACTOR_2 + 10
    canvas_h2 = 2 * LAYER2_HEIGHT * SCALE_FACTOR_2 + 10
    
    layer2_canvas = tk.Canvas(conv2_panel, width=canvas_w2, height=canvas_h2, bg="black", highlightthickness=0)
    layer2_canvas.pack(pady=5)
    
    layer2_pixels = []
    
    for i in range(8):
        row_group = i // 4
        col_group = i % 4
        
        base_x = col_group * (LAYER2_WIDTH * SCALE_FACTOR_2 + 2)
        base_y = row_group * (LAYER2_HEIGHT * SCALE_FACTOR_2 + 2)
        
        pixel_ids = []
        for r in range(LAYER2_HEIGHT):
            for c in range(LAYER2_WIDTH):
                x0 = base_x + c * SCALE_FACTOR_2
                y0 = base_y + r * SCALE_FACTOR_2
                x1 = x0 + SCALE_FACTOR_2
                y1 = y0 + SCALE_FACTOR_2
                rect_id = layer2_canvas.create_rectangle(x0, y0, x1, y1, fill="#000000", width=0)
                pixel_ids.append(rect_id)
        layer2_pixels.append(pixel_ids)


    # SECTION C: Drawing Area (Input)

    drawing_canvas = tk.Canvas(bottom_frame, width=GRID_SIZE*CELL_SIZE, height=GRID_SIZE*CELL_SIZE, bg="#111111", highlightthickness=0)
    drawing_canvas.pack()

    # Grid of rectangles for drawing
    grid_rect_ids = []

    for row in range(GRID_SIZE):
        one_row = []

        for col in range(GRID_SIZE):
            one_row.append(None)

        grid_rect_ids.append(one_row)

    
    for r in range(GRID_SIZE):
        for c in range(GRID_SIZE):
            x0 = c * CELL_SIZE
            y0 = r * CELL_SIZE
            x1 = x0 + CELL_SIZE - 1
            y1 = y0 + CELL_SIZE - 1
            # Create a single cell
            rect_id = drawing_canvas.create_rectangle(x0, y0, x1, y1, fill=value_to_hex_color(0), width=1, outline="#222222")
            grid_rect_ids[r][c] = rect_id


    # Functionality: Clearing the Canvas
    button_frame = tk.Frame(root_window)
    button_frame.pack(pady=10)

    def clear_canvas_action():
        '''
        Redraw black grid, reset probabilities display, and reset feature maps display
        '''
        nonlocal pixel_grid
        pixel_grid[:, :] = 0.0
        
        for r in range(GRID_SIZE):
            for c in range(GRID_SIZE):
                drawing_canvas.itemconfig(grid_rect_ids[r][c], fill=value_to_hex_color(0))
        
        update_probability_display(np.zeros(10))
        prob_canvas.itemconfig(prediction_text, text="Cleared. Draw again.")
        
        update_feature_maps(None, None)

    tk.Button(button_frame, text="Clear", command=clear_canvas_action, width=15, height=2, bg="#dddddd").pack()


    # Drawing Logic
    predict_timer_job = None
    
    def on_mouse_drag(event):

        mouse_x, mouse_y = event.x, event.y
        
        radius_px = CELL_SIZE * 2
        
        # Determine grid bounds to check (optimization)
        col_min = max(0, int((mouse_x - radius_px) // CELL_SIZE))
        col_max = min(GRID_SIZE - 1, int((mouse_x + radius_px) // CELL_SIZE))
        row_min = max(0, int((mouse_y - radius_px) // CELL_SIZE))
        row_max = min(GRID_SIZE - 1, int((mouse_y + radius_px) // CELL_SIZE))

        grid_changed = False

        for r in range(row_min, row_max + 1):
            for c in range(col_min, col_max + 1):
                # Calculate center of this cell
                cell_center_x = c * CELL_SIZE + (CELL_SIZE / 2)
                cell_center_y = r * CELL_SIZE + (CELL_SIZE / 2)

                # Distance from mouse to cell center
                distance = np.sqrt((mouse_x - cell_center_x)**2 + (mouse_y - cell_center_y)**2)

                if distance < radius_px:
                    # Calculate intensity: closer = brighter
                    normalized_dist = distance / radius_px
                    intensity = (1.0 - normalized_dist) ** 2
                    
                    current_val = pixel_grid[r, c]
                    new_val = max(current_val, intensity)

                    if new_val > current_val + 0.01:
                        pixel_grid[r, c] = new_val
                        # Update the color on screen
                        drawing_canvas.itemconfig(grid_rect_ids[r][c], fill=value_to_hex_color(new_val))
                        grid_changed = True

        if grid_changed:
            schedule_prediction()

    # Bind mouse events to the function
    drawing_canvas.bind("<B1-Motion>", on_mouse_drag)
    drawing_canvas.bind("<Button-1>", on_mouse_drag)


    # Functionality: Prediction & Updates
    def schedule_prediction():
        nonlocal predict_timer_job
        if predict_timer_job is not None:
            root_window.after_cancel(predict_timer_job)
        predict_timer_job = root_window.after(PREDICT_DELAY_MS, run_inference)

    def prepare_image_for_model():
        centered_grid = center_grid_content(pixel_grid)
        image_array = centered_grid.astype(np.float32)
        # Add batch dimension and channel dimension: [1, 1, 28, 28]
        tensor = torch.from_numpy(image_array).unsqueeze(0).unsqueeze(0)
        return tensor.to(device)

    def update_probability_display(probabilities):
        best_digit = int(np.argmax(probabilities))
        confidence = float(probabilities[best_digit])
        
        prob_canvas.itemconfig(prediction_text, text=f"Prediction: {best_digit}  (Confidence: {confidence:.2f})")

        for i in range(10):
            bar_id, x0, y_start, x1, y_end = probability_rectangles[i]
            # Calculate new height
            height = (y_end - y_start) * float(probabilities[i])
            # Update bar coordinates (Tkinter coords: x0, y0, x1, y1)
            prob_canvas.coords(bar_id, x0, y_end - height, x1, y_end)

    def update_feature_maps(layer1_data, layer2_data):        

        def update_single_layer(data_tensor, canvas_obj, pixels_list, grid_h, grid_w):
            if data_tensor is None:
                # Reset to black
                for i in range(8):
                    for pixel_id in pixels_list[i]:
                        canvas_obj.itemconfig(pixel_id, fill="#000000")
                return

            # Convert tensor to numpy array
            feature_maps = data_tensor[0].detach().cpu().numpy()
            
            # Loop through first 8 channels
            for i in range(8):
                single_map = feature_maps[i]
                
                # Normalize values to 0-1 for display
                min_val = single_map.min()
                max_val = single_map.max()
                
                if max_val - min_val > 1e-5:
                    normalized_map = (single_map - min_val) / (max_val - min_val)
                else:
                    normalized_map = np.zeros_like(single_map)
                
                # Update pixels
                pixel_ids = pixels_list[i]
                pixel_index = 0
                for r in range(grid_h):
                    for c in range(grid_w):
                        value = normalized_map[r, c]
                        color = value_to_hex_color(value)
                        canvas_obj.itemconfig(pixel_ids[pixel_index], fill=color)
                        pixel_index += 1

        # Calculate and update Layer 1
        update_single_layer(layer1_data, layer1_canvas, layer1_pixels, 14, 14)
        
        # Calculate and update Layer 2
        update_single_layer(layer2_data, layer2_canvas, layer2_pixels, 7, 7)

    def run_inference():
        nonlocal predict_timer_job
        predict_timer_job = None
        
        input_tensor = prepare_image_for_model()
        
        with torch.no_grad():
            # Get logits AND the feature maps
            logits, features1, features2 = model(input_tensor)
            
            # Convert logits to probabilities (Softmax)
            probabilities = F.softmax(logits, dim=1)[0].cpu().numpy()
            
        update_probability_display(probabilities)
        update_feature_maps(features1, features2)

    # Initial Reset
    clear_canvas_action()
    
    # Start the app loop
    root_window.mainloop()

## 5. Running the App

Finally, we put it all together. Run this cell to start!

**Note:** When you run this, a window will pop up. You must **close that window** to stop this cell from running.

In [46]:
if __name__ == "__main__":
    MODEL_FILENAME = "mnist_cnn.pth"
    
    if os.path.exists(MODEL_FILENAME):
        print(f"\n[INFO] Found saved model at '{MODEL_FILENAME}'. Loading...")
        trained_model = SimpleCNN().to(device)
        try:
            trained_model.load_state_dict(torch.load(MODEL_FILENAME, map_location=device))
            print("[INFO] Model loaded successfully!")
        except Exception as e:
            print(f"[WARNING] Could not load model: {e}")
            print("[INFO] Starting fresh training...")
            trained_model = train_and_get_model()
            torch.save(trained_model.state_dict(), MODEL_FILENAME)
    else:
        print(f"\n[INFO] No saved model found. Starting training...")
        trained_model = train_and_get_model()
        torch.save(trained_model.state_dict(), MODEL_FILENAME)
        print(f"[INFO] Model saved to '{MODEL_FILENAME}'.")

    # Launch the application
    run_pixel_grid_realtime_app(trained_model)


[INFO] Found saved model at 'mnist_cnn.pth'. Loading...
[INFO] Model loaded successfully!
