In [None]:
import torch
from torch import nn
from torchvision import datasets
from torchvision.transforms import ToTensor

In [None]:
class_mapping = [
    "0",
    "1",
    "2",
    "3",
    "4",
    "5",
    "6",
    "7",
    "8",
    "9"
]

In [None]:
class FeedForwardNet(nn.Module):

    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_layers = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )
        self.softmax = nn.Softmax(dim=1)

    def forward(self, input_data):
        flattened_data = self.flatten(input_data)
        logits = self.dense_layers(flattened_data)
        predictions = self.softmax(logits)
        return predictions

def download_mnist_datasets():
    train_data = datasets.MNIST(
        root="data",
        download=True,
        train=True,
        transform=ToTensor()
    )
    validation_data = datasets.MNIST(
        root="data",
        download=True,
        train=False,
        transform=ToTensor()
    )
    return train_data, validation_data

In [None]:
def predict(model, input, target, class_mapping):
    # everytime we need to make an evaluation or inference, we need to call model.eval()
    model.eval()
    with torch.no_grad():
        predictions = model(input)
        # predections is a tensor object Tensor (1, 10 ) -> [0.1, 0.01 . . .  0.6]
        # get the index of the highest value in the predictions
        predicted_index = predictions[0].argmax(0) # highest would be 9 in our class_mapping'
        predicted = class_mapping[predicted_index]
        expected = class_mapping[target]
    return predicted, expected

In [None]:
if __name__ == "__main__":

    # Load back the model
    feed_forward_net = FeedForwardNet()
    state_dict = torch.load("trained-models/feedforwardnet.pth")
    feed_forward_net.load_state_dict(state_dict)
    
    # Load MNIST Validation dataset
    _, validation_data = download_mnist_datasets()
    
    # Get a Sample from the Validation dataset for inference
    input, target = validation_data[0][0], validation_data[0][1]
    
    # Make an inference
    predicted, expected = predict(feed_forward_net, input, target, class_mapping)
    
    # Print the results
    print(f"Predicted: {predicted}, Expected: {expected}")