
# Supervised Learning example with MNIST dataset

This notebook implements a convolutional neural network (CNN) on MNIST

In [None]:

import numpy as np, time
import torch, torch.nn as nn, torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from ipycanvas import Canvas
from ipywidgets import VBox, HBox, Button, Output, Layout
from IPython.display import display, clear_output, IFrame
import netron

device = (
    torch.device("mps") if torch.backends.mps.is_available()
    else torch.device("cuda") if torch.cuda.is_available()
    else torch.device("cpu")
)
print("Using device:", device)
torch.manual_seed(42); np.random.seed(42)


### Data import and visualization

In [None]:
transform_train = transforms.Compose([
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform_train)
test_ds  = datasets.MNIST(root="./data", train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=512, shuffle=False)

print("Training dataset: "+ str(len(train_ds)) + " elements")
print("Test dataset: "+ str(len(test_ds)) + " elements")

fig, axes = plt.subplots(5, 8, figsize=(12, 8))
for i, ax in enumerate(axes.flat):
    ax.imshow(train_ds[i][0].squeeze(), cmap='gray')
    ax.set_title(f"Label: {train_ds[i][1]}")
    ax.axis('off')
plt.tight_layout()
plt.show()


## Model and training

In [None]:

# Model Creation
class TinyConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        # conv/pool stack: input [1,28,28] -> conv1 -> pool -> conv2 -> pool
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5)   # -> [16,24,24]
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5)  # -> [32,8,8] after pool
        self.pool = nn.MaxPool2d(2)
        # classifier
        self.fc1 = nn.Linear(32 * 4 * 4, 64)
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x, return_activations=False):
        # x: [B,1,28,28]
        c1 = F.relu(self.conv1(x))   # [B,16,24,24]
        p1 = self.pool(c1)           # [B,16,12,12]
        c2 = F.relu(self.conv2(p1))  # [B,32,8,8]
        p2 = self.pool(c2)           # [B,32,4,4]
        flat = p2.view(p2.size(0), -1)
        a1 = F.relu(self.fc1(flat))
        logits = self.fc2(a1)
        if return_activations:
            return logits, {
                'conv1': c1.detach().cpu().numpy(),
                'conv2': c2.detach().cpu().numpy(),
                'fc1': a1.detach().cpu().numpy(),
                'logits': logits.detach().cpu().numpy()
            }
        return logits

model = TinyConvNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

print("Number of model parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
model



In [None]:

def train_one_epoch(model, loader, optimizer, loss_fn):
    model.train(); n=0; loss_sum=0.0; acc_sum=0.0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        logits = model(xb)
        loss = loss_fn(logits, yb)
        loss.backward(); optimizer.step()
        bs = yb.size(0); n += bs
        loss_sum += loss.item()*bs
        acc_sum  += (logits.argmax(1)==yb).float().sum().item()
    return loss_sum/n, acc_sum/n

def evaluate(model, loader, loss_fn):
    model.eval(); n=0; loss_sum=0.0; acc_sum=0.0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            loss = loss_fn(logits, yb)
            bs = yb.size(0); n += bs
            loss_sum += loss.item()*bs
            acc_sum  += (logits.argmax(1)==yb).float().sum().item()
    return loss_sum/n, acc_sum/n

EPOCHS=3
tr_hist={"loss":[], "acc":[]}; te_hist={"loss":[], "acc":[]}
for ep in range(1,EPOCHS+1):
    tr_loss,tr_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    te_loss,te_acc = evaluate(model, test_loader,  criterion)
    tr_hist["loss"].append(tr_loss); tr_hist["acc"].append(tr_acc)
    te_hist["loss"].append(te_loss); te_hist["acc"].append(te_acc)
    clear_output(wait=True)
    print(f"Epoch {ep}/{EPOCHS}  | train {tr_loss:.4f}/{tr_acc:.4f}  test {te_loss:.4f}/{te_acc:.4f}")
    fig,(ax1,ax2)=plt.subplots(1,2,figsize=(10,3))
    ax1.plot(tr_hist["loss"],label="train"); ax1.plot(te_hist["loss"],label="test"); ax1.set_title("Loss"); ax1.legend()
    ax2.plot(tr_hist["acc"],label="train");  ax2.plot(te_hist["acc"],label="test");  ax2.set_title("Accuracy"); ax2.legend()
    plt.show()
print("Training done.")


## Test and interactive prediction

In [None]:

# --- Draw area via ipycanvas (smaller canvas, white background) ---
from PIL import Image, ImageDraw

canvas = Canvas(width=200, height=200, sync_image_data=True)
canvas.layout.width  = '200px'
canvas.layout.height = '200px'
# fill white background so strokes (black on white) are visible
canvas.fill_style = 'white'
canvas.fill_rect(0, 0, canvas.width, canvas.height)
# draw a light border on the same canvas so coords align
canvas.stroke_style = 'lightgray'
canvas.stroke_rect(0.5, 0.5, canvas.width-1, canvas.height-1)

# Create a local PIL buffer that mirrors the visible canvas — reliable source for predictions
buf_img = Image.new('L', (canvas.width, canvas.height), color=255)  # 'L' gray, 255=white
buf_draw = ImageDraw.Draw(buf_img)

# Simple drawing controls — fixed brush size (no slider)
BRUSH_SIZE = 18
btn_clear   = Button(description="Clear")
# single output combining preview + probabilities to avoid duplicate figures
pred_out = Output()

# Use explicit mouse-down / mouse-up handlers and robust coord parsing
is_drawing = {'val': False}

# Throttle predictions: allow up to 1 prediction per second
PRED_INTERVAL = 1.0  # seconds
_last_pred_time = 0.0

def _get_xy(args):
    if len(args) >= 2 and isinstance(args[0], (int, float)) and isinstance(args[1], (int, float)):
        return args[0], args[1]
    if len(args) >= 1 and isinstance(args[0], dict):
        ev = args[0]
        return ev.get('x', None), ev.get('y', None)
    return None, None

def _draw_circle_on_buffer(x, y, r):
    r = int(round(r))
    bbox = (int(round(x - r)), int(round(y - r)), int(round(x + r)), int(round(y + r)))
    buf_draw.ellipse(bbox, fill=0)

def on_mouse_down(*args):
    x, y = _get_xy(args)
    if x is None:
        return
    is_drawing['val'] = True
    radius = BRUSH_SIZE/2
    canvas.fill_style = 'black'  # draw black strokes
    canvas.fill_circle(x, y, radius)
    _draw_circle_on_buffer(x, y, radius)
    try:
        predict_and_show()
    except Exception:
        pass

def on_mouse_up(*args):
    is_drawing['val'] = False
    try:
        predict_and_show(True)
    except Exception:
        pass

def on_mouse_move(*args):
    x, y = _get_xy(args)
    if x is None:
        return
    if is_drawing['val']:
        radius = BRUSH_SIZE/2
        canvas.fill_style = 'black'
        canvas.fill_circle(x, y, radius)
        _draw_circle_on_buffer(x, y, radius)
        try:
            predict_and_show()
        except Exception:
            pass

canvas.on_mouse_down(on_mouse_down)
canvas.on_mouse_up(on_mouse_up)
canvas.on_mouse_move(on_mouse_move)

def clear_canvas(*_):
    # clear visible canvas and reset PIL buffer
    canvas.fill_style = 'white'
    canvas.fill_rect(0, 0, canvas.width, canvas.height)
    canvas.stroke_style = 'lightgray'
    canvas.stroke_rect(0.5, 0.5, canvas.width-1, canvas.height-1)
    global buf_img, buf_draw
    buf_img = Image.new('L', (canvas.width, canvas.height), color=255)
    buf_draw = ImageDraw.Draw(buf_img)
    # clear the output pane but keep it visible
    pred_out.clear_output(wait=True)
    # redraw initial empty preview/probs
    try:
        predict_and_show()
    except Exception:
        pass

def get_28x28_tensor():
    # Use the local PIL buffer (buf_img) as the authoritative image of strokes
    img = buf_img.copy().resize((28,28), Image.LANCZOS)
    arr = np.array(img, dtype=np.float32) / 255.0  # 0=black stroke, 1=white background
    # we want strokes=1.0, background=0.0 for the model input
    x = 1.0 - arr
    return torch.from_numpy(x)[None, None, ...].float()  # [1,1,28,28]

def _predict_and_show_now(x):
    # helper that actually runs prediction and updates outputs
    model.eval()
    with torch.no_grad():
        logits = model(x.to(device))
        probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
    pred = int(np.argmax(probs))

    # update single combined output (image left, probabilities right)
    with pred_out:
        pred_out.clear_output(wait=True)
        fig, (ax_im, ax_bar) = plt.subplots(1,2, figsize=(6,3), gridspec_kw={'width_ratios':[1,1.2]})
        # image (no title 'Pred')
        ax_im.imshow(x[0,0].cpu(), cmap='gray')
        ax_im.axis('off')
        # probability bar chart
        ax_bar.bar(np.arange(10), probs, color='tab:blue')
        ax_bar.set_title('Class probabilities')
        ax_bar.set_xticks(np.arange(10))
        ax_bar.set_ylim(0,1)
        # visually mark predicted class on bar chart
        ax_bar.get_children()[pred].set_color('tab:orange')
        plt.tight_layout()
        plt.show()

def predict_and_show(force = False):
    """Throttle predictions: only run if at least PRED_INTERVAL seconds have passed since last run."""
    global _last_pred_time
    now = time.time()
    if (now - _last_pred_time) < PRED_INTERVAL and not force:
        # skip this prediction to limit frequency
        return
    _last_pred_time = now
    x = get_28x28_tensor()
    try:
        _predict_and_show_now(x)
    except Exception:
        # guard: do not propagate UI errors
        pass

btn_clear.on_click(clear_canvas)

# Layout: canvas on the left, combined preview+probs + clear button on the right
ui = HBox([canvas, VBox([ pred_out, btn_clear])], layout=Layout(width='600px'))
display(ui)

# show initial empty prediction so UI is populated before drawing
try:
    predict_and_show()
except Exception:
    pass