In [27]:
!pip install torch torchvision -f https://download.pytorch.org/whl/cu111/torch_stable.html pandas numpy scikit-learn

Looking in links: https://download.pytorch.org/whl/cu111/torch_stable.html


In [28]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
from torchvision.transforms import ToTensor
from mnist_classifier import MNISTClassifier
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [29]:
def load_trained_model(model_path):
    model = MNISTClassifier()
    model.load_state_dict(torch.load(model_path))
    model.to(device)
    model.eval()
    return model

model_path = 'mnist_classifier.pth'
model = load_trained_model(model_path)


def preprocess_image(image_path):
    original_image = Image.open(image_path)
    img = Image.open(image_path).convert('L')
    img = img.resize((28, 28))
    img = np.array(img)
    img = 255 - img  # Invert the image, as the model was trained on white digits on a black background
    img = img.astype(np.float32) / 255.0
    img = ToTensor()(img)
    return img.unsqueeze(0), original_image


def predict(model, image_path):
    image, original_image = preprocess_image(image_path)
    image = image.to(device)

    with torch.no_grad():
        output = model(image)
        prediction = torch.argmax(output, dim=1)
    return prediction.item(), original_image, image


In [30]:
image_path = 'testingDigitPics/3.png'
prediction, original_image, new_image = predict(model, image_path)
print(f"The predicted digit is: {prediction}")

print(original_image)
print(new_image)

#plt.figure()
#plt.imshow(Image.open(image_path), cmap='gray')
#plt.imshow(original_image, cmap='gray', vmin=0, vmax=255)
#plt.show()

The predicted digit is: 8
<PIL.PngImagePlugin.PngImageFile image mode=RGB size=28x28 at 0x1B1D8BD53F0>
tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.00