In [2]:
!pip install albumentations==1.3.1

Collecting albumentations==1.3.1
  Downloading albumentations-1.3.1-py3-none-any.whl.metadata (34 kB)
Collecting qudida>=0.0.4 (from albumentations==1.3.1)
  Downloading qudida-0.0.4-py3-none-any.whl.metadata (1.5 kB)
Downloading albumentations-1.3.1-py3-none-any.whl (125 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.7/125.7 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading qudida-0.0.4-py3-none-any.whl (3.5 kB)
Installing collected packages: qudida, albumentations
  Attempting uninstall: albumentations
    Found existing installation: albumentations 2.0.8
    Uninstalling albumentations-2.0.8:
      Successfully uninstalled albumentations-2.0.8
Successfully installed albumentations-1.3.1 qudida-0.0.4


In [6]:
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
import albumentations as A

class DigitPrediction:

    def __init__(self, model_path: str, model: torch.nn.Module = None, device: str = 'cpu'):
        """
        model_path : path to the saved PyTorch model (.pt or .pth)
        model      : optionally provide the model class instance (if required)
        device     : 'cpu' or 'cuda'
        """

        self.device = torch.device(device)

        if model is None:
            raise ValueError("You must provide `model` argument (an instance of your model class) when model architecture is custom.")

        self.model = model.to(self.device)
        checkpoint = torch.load(model_path, map_location=self.device)

        if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
            self.model.load_state_dict(checkpoint['state_dict'])
        else:
            self.model.load_state_dict(checkpoint)

        self.model.eval()

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])


    def _prepare_image(self, image: np.ndarray) -> torch.Tensor:
        """
        image : numpy array (H, W) or (H, W, C) of the digit image
        Returns a tensor of shape [1, 1, 28, 28] ready for inference.
        """

        if image.ndim == 3 and image.shape[2] == 3:
            img_pil = Image.fromarray(image).convert('L')
        elif image.ndim == 2:
            img_pil = Image.fromarray(image)
        else:
            img_pil = Image.fromarray(image.squeeze()).convert('L')

        img_pil = img_pil.resize((28, 28))

        tensor = self.transform(img_pil)
        tensor = tensor.unsqueeze(0)
        tensor = tensor.to(self.device)

        return tensor


    def predict_image_class(self, image: np.ndarray) -> int:
        """
        image : numpy array
        Returns the predicted digit (0–9)
        """

        input_tensor = self._prepare_image(image)
        with torch.no_grad():
            output = self.model(input_tensor)
            _, pred = torch.max(output, dim=1)

            return int(pred.item())


In [11]:
!tar -xzf /content/mnist-classifier.tar.gz -C /content/

In [9]:
import torch
import numpy as np
from PIL import Image

class MyConvNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 2, 5)
        self.pool  = torch.nn.MaxPool2d(2, 2)
        self.conv2 = torch.nn.Conv2d(2, 6, 5)
        self.fc1   = torch.nn.Linear(96, 32)
        self.fc2   = torch.nn.Linear(32, 10)

    def forward(self, x):
        x = self.pool(torch.nn.functional.relu(self.conv1(x)))
        x = self.pool(torch.nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 96)
        x = torch.nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

dp = DigitPrediction(
    model = MyConvNet(),
    model_path = '/content/mnist-classifier.pt',
    device = 'cpu'
)

In [10]:
img = Image.open('/path/to/some_digit.png').convert('L')
img_np = np.array(img)

predicted_digit = dp.predict_image_class(img_np)
print("Predicted digit:", predicted_digit)

FileNotFoundError: [Errno 2] No such file or directory: '/path/to/some_digit.png'