In [1]:
import torch
from torchsummary import summary
from PIL import Image

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import numpy as np

import torchvision.transforms as transforms
import torchvision.datasets as datasets


import matplotlib.pyplot as plt

ROOT = './data'

train_data = datasets.MNIST(
    root = ROOT,
    train = True,
    download = True,
)

test_data = datasets.MNIST(
    root = ROOT,
    train = False,
    download = True,
)

In [6]:
count = 0
for index, (image, label) in enumerate(test_data):
    image.save(f'./inference/images/{index}.png')
    count += 1
    if count == 10:
        break

In [26]:
class LeNetClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding='same')
        self.avgpool1 = nn.AvgPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.avgpool2 = nn.AvgPool2d(kernel_size=2)
        self.flatten = nn.Flatten()
        self.fc_1 = nn.Linear(16 * 5 * 5, 120)
        self.fc_2 = nn.Linear(120, 84)
        self.fc_3 = nn.Linear(84, num_classes)

    def forward(self, inputs):
        outputs = self.conv1(inputs)
        outputs = self.avgpool1(outputs)
        outputs = F.relu(outputs)
        outputs = self.conv2(outputs)
        outputs = self.avgpool2(outputs)
        outputs = F.relu(outputs)
        outputs = self.flatten(outputs)
        outputs = self.fc_1(outputs)
        outputs = self.fc_2(outputs)
        outputs = self.fc_3(outputs)
        return outputs

In [27]:
def load_model(model_path, num_classes=10):
    lenet_model = LeNetClassifier(num_classes)
    lenet_model.load_state_dict(torch.load(model_path, weights_only=True, map_location=torch.device('cpu')))
    lenet_model.eval()
    return lenet_model
model = load_model('./model/lenet_model.pt')

In [29]:
model = load_model('./model/lenet_model.pt')

In [30]:
def inference(image, model):
    w, h = image.size
    if w != h:
        crop = transforms.CenterCrop(min(w, h))
        image = crop(image)
        wnew, hnew = image.size
    img_transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize(28),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.1307], std=[0.3081])
    ])
    img_new = img_transform(image)
    img_new = img_new.expand(1, 1, 28, 28)
    with torch.no_grad():
        predictions = model(img_new)
    preds = nn.Softmax(dim=1)(predictions)
    p_max, yhat = torch.max(preds.data, 1)
    return p_max.item()*100, yhat.item()

In [34]:
for index in range(10):
    image = Image.open(f'./inference/images/{index}.png')
    
    confidence, predicted = inference(image, model)
    
    print(f'Predicted: {predicted}')


Predicted: 7
Predicted: 2
Predicted: 1
Predicted: 0
Predicted: 4
Predicted: 1
Predicted: 4
Predicted: 9
Predicted: 5
Predicted: 9
