In [11]:
from model import LeNet5 # import the defined model from 'model.py'
import numpy as np
import os
import torch
from torchvision.datasets import mnist # loading mnist dataset
from torch.nn import CrossEntropyLoss # loss function for multiclass classfication
from torch.optim import SGD # stochastic gradient descent optimizer
from torch.utils.data import DataLoader # loads images in 'batch_size'
from torchvision.transforms import ToTensor
import tqdm
from tqdm import trange

In [12]:
if __name__ == '__main__':

    device = 'cuda' if torch.cuda.is_available() else 'cpu' # creating gpu ref.
    batch_size = 256 # batch size param.
    train = mnist.MNIST(root = './train', train = True, transform = ToTensor(), download = True)
    test = mnist.MNIST(root = './test', train = False, transform = ToTensor(), download = True)
    train_loader = DataLoader(train, batch_size = batch_size) # loads images from the RAM in batches of 256 images
    test_loader = DataLoader(test, batch_size = batch_size)

    model = LeNet5().to(device) # shift model computations to GPU
    sgd = SGD(model.parameters(), lr = 1e-1) # init. SGD with learning_rate = 0.1
    loss_fn = CrossEntropyLoss() # loss function
    epochs = 100 # no. of epochs param
    prev_acc = 0

    for epoch in trange(epochs, desc="Training Progress"): # training loop
        model.train()
        for idx, (train_x, train_label) in enumerate(train_loader):

            train_x = train_x.to(device) # move image to gpu
            train_label = train_label.to(device) # move corresponding label to gpu
            sgd.zero_grad() # clears old gradients
            pred_y = model(train_x.float()) # get predicted label
            loss = loss_fn(pred_y, train_label.long()) # compute cross entropy loss wrt. existing label
            loss.backward() # backpropogate
            sgd.step() # descend towards minima

        all_correct_num = 0
        all_sample_num = 0
        model.eval() # puts model in inference mode post training



Training Progress: 100%|██████████| 100/100 [12:07<00:00,  7.27s/it]


In [13]:
# testing loop
for test_x, test_label in tqdm.tqdm(test_loader, desc="Evaluating"):
    test_x = test_x.to(device) # moving test image to gpu
    test_label = test_label.to(device) # moving test label to gpu

    with torch.no_grad():  # disable gradient computation for evaluation
        pred_y = model(test_x.float()) # obtaining test prediction
        pred_y = torch.argmax(pred_y, dim=-1) # obtaining label

    current_correct_num = pred_y == test_label # checking if pred. label matches test label
    all_correct_num += np.sum(current_correct_num.cpu().numpy()) # count all correct predictions
    all_sample_num += current_correct_num.shape[0] # count all predictions

acc = all_correct_num / all_sample_num # compute accuracy
print('Accuracy: {:.3f}'.format(acc), flush=True)

# Save model
if not os.path.isdir("models"):
    os.mkdir("models")
torch.save(model.state_dict(), f'models/LeNet5_wts.pth')

Evaluating: 100%|██████████| 40/40 [00:01<00:00, 32.55it/s]

Accuracy: 0.990



