In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
class CheckpointHandler:
    def __init__(self, best_valid_loss=float('inf')):
        self.best_valid_loss = best_valid_loss
    
    def save_best_model(self, state, filename):
        if state['loss'] < self.best_valid_loss:
            self.best_valid_loss = state['loss']
            torch.save(state, filename)
            print('checkpoint file updated')

In [3]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Conv2d(1, 32, kernel_size= 3),
                                 nn.ReLU(),
                                 nn.MaxPool2d((2, 2), stride= 2),
                                 nn.Conv2d(32, 64, kernel_size= 3),
                                 nn.ReLU(),
                                 nn.MaxPool2d((2, 2), stride= 2),
                                 nn.Conv2d(64, 32, kernel_size= 3),
                                 nn.ReLU(),
                                 nn.MaxPool2d((2, 2), stride= 2)
                                 )
        self.classify_head = nn.Sequential(nn.Flatten(),
                                           nn.Linear(32, 20, bias= True),
                                           nn.Linear(20, 10, bias= True))
    def forward(self, x):
        return self.classify_head(self.net(x))

In [4]:
model = CNN()
model.to('cuda')
optimizer = optim.SGD(model.parameters(), lr= 0.001)
criterion = nn.CrossEntropyLoss()
checkpoint_handler = CheckpointHandler()

In [5]:
transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, ))])
train = datasets.MNIST('.', train= True, download= True, transform= transforms)
test = datasets.MNIST('.', download= True, train= False, transform= transforms)
train_loader = DataLoader(train, batch_size= 64, shuffle= True)
test_loader = DataLoader(test, batch_size= 64)

In [6]:
for epoch in range(10):
    model.train()
    running_loss = 0
    for input, target in train_loader:
        input, target = input.to('cuda'), target.to('cuda')
        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch - {epoch}, loss = {running_loss}')

    state = {'epoch': epoch,
             'model_state_dict': model.state_dict(),
             'optimizer_state_dict': optimizer.state_dict(),
             'loss': running_loss}

    checkpoint_handler.save_best_model(state, './q3_checkpoints/checkpoint_best.pth')

Epoch - 0, loss = 2160.1290917396545
checkpoint file updated
Epoch - 1, loss = 2151.158231973648
checkpoint file updated
Epoch - 2, loss = 2140.428228378296
checkpoint file updated
Epoch - 3, loss = 2125.0057702064514
checkpoint file updated
Epoch - 4, loss = 2097.898699760437
checkpoint file updated
Epoch - 5, loss = 2042.660430431366
checkpoint file updated
Epoch - 6, loss = 1921.1561652421951
checkpoint file updated
Epoch - 7, loss = 1714.6891666650772
checkpoint file updated
Epoch - 8, loss = 1445.5838387012482
checkpoint file updated
Epoch - 9, loss = 1109.99817186594
checkpoint file updated


In [7]:
# retraining
model = CNN()
checkpoint = torch.load('./q3_checkpoints/checkpoint_best.pth', weights_only= False)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer = torch.optim.SGD(model.parameters(), lr = 0.001)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
criterion = nn.CrossEntropyLoss()

In [8]:
model.to('cuda')
for epoch in range(10):
    model.train()
    running_loss = 0
    for input, target in train_loader:
        input, target = input.to('cuda'), target.to('cuda')
        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch - {epoch}, loss = {running_loss}')

    state = {'epoch': epoch,
             'model_state_dict': model.state_dict(),
             'optimizer_state_dict': optimizer.state_dict(),
             'loss': running_loss}

    checkpoint_handler.save_best_model(state, './q3_checkpoints/checkpoint_best.pth')

Epoch - 0, loss = 819.0250169038773
checkpoint file updated
Epoch - 1, loss = 632.2271310389042
checkpoint file updated
Epoch - 2, loss = 509.1673307865858
checkpoint file updated
Epoch - 3, loss = 422.3464578092098
checkpoint file updated
Epoch - 4, loss = 360.0502180606127
checkpoint file updated
Epoch - 5, loss = 315.3632535338402
checkpoint file updated
Epoch - 6, loss = 282.4337933883071
checkpoint file updated
Epoch - 7, loss = 257.16326431185007
checkpoint file updated
Epoch - 8, loss = 237.74852107465267
checkpoint file updated
Epoch - 9, loss = 221.97417832911015
checkpoint file updated


In [10]:
all_preds, all_target = [], []
model.eval()
with torch.no_grad():
    for input, target in test_loader:
        input, target = input.to('cuda'), target.to('cuda')
        output = model(input)
        val, index = torch.max(output, dim= 1)
        all_preds.extend(index.to('cpu'))
        all_target.extend(target.to('cpu'))
from sklearn.metrics import accuracy_score
print(accuracy_score(all_preds, all_target))

0.9368
