In [1]:
import torch
from torchvision import transforms
from PIL import Image

In [2]:
# Define the MNIST model
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 32, 3, 1)
        self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = torch.nn.Dropout(0.25)
        self.dropout2 = torch.nn.Dropout(0.5)
        self.fc1 = torch.nn.Linear(9216, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x)
        x = torch.nn.functional.relu(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = torch.nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = torch.nn.functional.softmax(x, dim=1)
        return output

In [8]:
# Initialize the pre-trained model
model = Net()
if torch.cuda.is_available():
    model.load_state_dict(torch.load('mnist_cnn.pt'))
else:
    model.load_state_dict(torch.load('mnist_cnn.pt', map_location=torch.device('cpu')))
    
model.eval();

In [9]:
# Example
three = Image.open("3.png")
preprocess = transforms.Compose([
   transforms.Resize(28),
   transforms.ToTensor(),
   transforms.Normalize((0.1307,), (0.3081,))
])
three_tensor = preprocess(three)[0].reshape(1,1,28,28)
prediction = model(three_tensor)
item = prediction.argmax().item()
prob = prediction.max().item()
print(f'Predicted {item} with probability: {prob}.')

Predicted 3 with probability: 1.0.
