In [None]:
%matplotlib widget

import torch, torch.nn as nn, torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
import io, base64, os, sys


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        out = self.fc3(h2)
        return out, [x.squeeze(), h1.squeeze(), h2.squeeze(), out.squeeze()]


net = Net()


try:
    net.load_state_dict(torch.load("mnist_mlp.pth", map_location="cpu"))
except FileNotFoundError:
    transform = transforms.Compose([transforms.ToTensor()])
    train = datasets.MNIST(root=".", train=True, download=True, transform=transform)
    loader = torch.utils.data.DataLoader(train, batch_size=256, shuffle=True)
    opt = torch.optim.Adam(net.parameters(), lr=3e-3)
    for epoch in range(15):
        for X, y in loader:
            opt.zero_grad()
            loss = F.cross_entropy(net(X)[0], y)
            loss.backward()
            opt.step()
    torch.save(net.state_dict(), "mnist_mlp.pth")

net.eval()


canvas = np.zeros((28, 28), dtype=float)
fig = plt.figure(figsize=(10, 7), facecolor="black")

ax_draw = fig.add_axes([0.02, 0.15, 0.28, 0.7])
ax_draw.set_facecolor("black")
im_disp = ax_draw.imshow(canvas,
                         cmap="Greys",
                         vmin=0, vmax=1,
                         origin="lower")
ax_draw.set_title("Draw digit\n[hold left mouse]", color="white")
ax_draw.axis("off")

_frames = []

def capture_frame():
    fig.canvas.draw()
    buf = fig.canvas.renderer.buffer_rgba()
    arr_f32 = np.array(buf, copy=False)
    frame = (arr_f32[:, :, :3] * 255).astype(np.uint8)

    _frames.append(frame)
drawing = False


def on_press(event):
    global drawing
    if event.inaxes is ax_draw:
        drawing = True


def on_release(event):
    global drawing
    drawing = False
    capture_frame()


def on_move(event):
    if drawing and event.inaxes is ax_draw:
        j, i = int(event.xdata + 0.5), int(event.ydata + 0.5)
        if 0 <= i < 28 and 0 <= j < 28:
            canvas[i - 1 : i + 2, j - 1 : j + 2] = 1.0
            im_disp.set_data(canvas)
            capture_frame()


cid1 = fig.canvas.mpl_connect("button_press_event", on_press)
cid2 = fig.canvas.mpl_connect("button_release_event", on_release)
cid3 = fig.canvas.mpl_connect("motion_notify_event", on_move)


def on_key(event):
    if event.key == " ":
        img = torch.tensor(canvas, dtype=torch.float32).unsqueeze(0).unsqueeze(0)

        with torch.no_grad():
            logits, _ = net(img)
            pred = logits.argmax(dim=1).item()

        ax_draw.set_title(f"Predicted: {pred}", color="yellow")
        capture_frame()


fig.canvas.mpl_connect("key_press_event", on_key)