In [None]:
import torchvision
import torch
import PIL as Image
import matplotlib.pyplot as plt

In [None]:
# defining constants
NUM_OF_CLASSES = 10
BATCH_SIZE = 64
NUM_OF_EPOCHS = 5
LEARNING_RATE = 0.01
# initial transformation on all the images
resize_transformation = torchvision.transforms.Compose([
    torchvision.transforms.Resize((227, 227)), 
    torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize((0.5,), (0.5,))
])
# downloading and loading the dataset for the model
train_dataset = torchvision.datasets.MNIST(root='data', train=True, transform=resize_transformation, download=True)
test_dataset = torchvision.datasets.MNIST(root='data', train=False, transform=resize_transformation, download=False)

In [None]:
# setting up the dataloaders for the model
train_data = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_data = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
class AlexNet(torch.nn.Module):
    def __init__(self, num_of_classes):
        super().__init__()
        # feature classifier component
        self.feature_extraction = torch.nn.Sequential(
            # convolutional layer 1
            torch.nn.Conv2d(in_channels = 1, out_channels = 96, stride = (4,4), padding = 0, kernel_size = (11,11)),
            torch.nn.Tanh(),
            torch.nn.MaxPool2d(kernel_size = (3,3), stride = (2, 2), padding = 0), 
            # convolutional layer 2
            torch.nn.Conv2d(in_channels = 96, out_channels = 256, kernel_size = (5, 5), padding = 2, stride = (1, 1)),
            torch.nn.Tanh(),
            torch.nn.MaxPool2d(kernel_size = (3, 3), stride = (2, 2), padding = 0), 
            # convolutional layer 3
            torch.nn.Conv2d(in_channels = 256, out_channels = 384, stride = (1, 1), padding = 1, kernel_size = (3, 3)), 
            torch.nn.Tanh(), 
            # convolutional layer 4
            torch.nn.Conv2d(in_channels = 384, out_channels = 384, kernel_size = (3,3), padding = 1, stride = (1, 1)), 
            torch.nn.Tanh(), 
            # convolutional layer 5
            torch.nn.Conv2d(in_channels = 384, out_channels = 256, kernel_size = (3, 3), padding = 1, stride = (1, 1)), 
            torch.nn.MaxPool2d(kernel_size = (3, 3), padding = 0, stride = (2, 2))
        )
        # classifier component
        self.classifier = torch.nn.Sequential(
            # first layer of MLP
            torch.nn.Linear(in_features = 9216, out_features = 4096, bias = True),
            torch.nn.ReLU(), 
            # second layer of MLP
            torch.nn.Linear(in_features = 4096, out_features = 4096, bias = True),
            torch.nn.ReLU(), 
            # final layer of MLP
            torch.nn.Linear(in_features = 4096, out_features = num_of_classes, bias = True),
            torch.nn.Softmax()
        )
    def forward(self, X):
        X = self.feature_extraction(X)
        X = torch.flatten(X, 1)
        X = self.classifier(X)
        return X

# defining the model
model = AlexNet(num_of_classes=NUM_OF_CLASSES)

In [None]:
# setting the optimizer for the problem
optimizer = torch.optim.SGD(model.parameters(), lr = LEARNING_RATE)

In [None]:
# training loop
for epoch in range(NUM_OF_EPOCHS):
    for batch_index, (data, labels) in enumerate(train_data):
        # forward propogation of the network
        logits = model(data)
        # compute the cross entropy loss value for the forward prop between the prediction of the forward pass and the actual label
        cost = torch.nn.functional.cross_entropy(logits, labels)
        # zero down any older gradients from the optimizer
        optimizer.zero_grad()
        # performing back propogation for the cost value
        cost.backward()
        # update the values of the weights with respect to the gradient values for the cross entropy loss function
        optimizer.step()
        # logging work - print the results after every 50 batches
        # if (batch_index%50==0):
            # print(f"Epoch : {epoch + 1} | Batch Index : {batch_index}/{len(train_data)} | Loss value : {cost.item()}")
    # logging
    # print(f"Epoch : {epoch + 1} | Loss value : {cost.item()}")

In [None]:
# finding the value of the loss function and the accuracy for training and testing datasets
with torch.no_grad():
    correct_pred = 0
    num_examples = 0
    for i, (data, labels) in enumerate(train_data):
        train_logits = model(data)
        train_predictions = torch.argmax(train_logits)
        correct_pred += torch.sum(train_predictions == labels).item()
        num_examples += labels.size(0)
        train_accuracy = (correct_pred / num_examples) * 100
    correct_pred = 0
    num_examples = 0
    for batch_index, (data, labels) in enumerate(test_data):
        test_logits = model(data)
        test_predictions = torch.argmax(test_logits)
        correct_pred += torch.sum(test_predictions == labels).item()
        num_examples += labels.size(0)
    test_accuracy = (correct_pred / num_examples) * 100

In [None]:
print(f"Training accuracy of AlexNet model : {train_accuracy}%")
print(f"Testing accuracy of the AlexNet model : {test_accuracy}%")