# Using a sketch to evaluate a neural network 🤖

> Why use a sketch?

- Using a sketch can be an interactive way to test results from training a neural network with newly generated test data.
- Your sketch is directly available as an Image inside of the cell.
- Modify the sketch and observe changes in predictions to gain insights into model performance.


In [None]:
%pip install -q ipysketch_lite matplotlib numpy torch torchvision

The following code creates a simple image classifier model for MNIST digits. It takes a 28x28 image as input and predicts the digit classification.

In [None]:
import contextlib

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader

# Setting up the data
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
with contextlib.redirect_stdout(None):
    train_dataset = torchvision.datasets.MNIST(
        root="./data", train=True, download=True, transform=transform
    )
    test_dataset = torchvision.datasets.MNIST(
        root="./data", train=False, download=True, transform=transform
    )
print(
    "MNIST dataset loaded successfully"
    if len(train_dataset) > 0 and len(test_dataset) > 0
    else "Loading MNIST dataset failed"
)
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)


class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1, self.fc2 = nn.Linear(28 * 28, 500), nn.Linear(500, 10)

    def forward(self, x):
        return self.fc2(torch.relu(self.fc1(x.view(-1, 28 * 28))))

    @torch.no_grad()
    def predict(self, input_image):
        x = torch.tensor(input_image, dtype=torch.float32).view(1, 28, 28)
        _, predicted = torch.max(model(x).data, 1)
        return predicted


model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training the model
num_epochs = 4
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(train_loader):
        outputs = model(data)
        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (batch_idx + 1) % 200 == 0:
            print(
                f"Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}"
            )

# Testing the model
model.eval()
with torch.no_grad():
    correct, total = 0, 0
    for data, targets in test_loader:
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

    print(
        f"Accuracy of the model on the 10000 test images: {100 * correct / total:.2f}%"
    )

# Displaying the first image in the test set
plt.title(f"Predicted label: {predicted[0].item()}")
plt.imshow(data[0].view(28, 28).numpy(), cmap="gray")
plt.show()

After training the model, test it on real-world, hand-drawn digit sketches to see how it performs.

In [None]:
from ipysketch_lite import Sketch

sketch = Sketch(width=100, height=100)

Resize the sketch image to 28x28 and predict the digit using the model.

In [None]:
image = sketch.image
image = image.resize((28, 28), resample=0)

image_array = np.array(image)
out = np.sum(image_array, axis=2) > 0

# Make a prediction with the sketch
predicted = model.predict(out)

plt.title(f"Predicted label: {predicted[0].item()}")
plt.imshow(out, cmap="gray")
plt.show()