In [1]:
!wget -q --show-progress -O mnist_data.zip https://www.dropbox.com/scl/fi/g73rfpr4p37iqq63fmzpt/mnist_data.zip?rlkey=3gcae8bh74texxvie27dtsthu&dl=0



In [2]:
!unzip -q mnist_data.zip

In [1]:
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,Dataset,random_split
from torchvision.datasets import ImageFolder
from torchvision import transforms
import os
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
import torch
device=torch.device('cuda')

In [3]:
tf=transforms.Compose([
                      transforms.Grayscale(num_output_channels=1),
                      transforms.Resize((28,28)),
                      transforms.ToTensor()
                      ])

In [4]:
dataset=ImageFolder(root='/content/mnist_data/train',transform=tf)

In [5]:
train_len,val_len=37800,4200

In [6]:
train_ds,val_ds=random_split(dataset,[train_len,val_len])

In [7]:
train_loader=DataLoader(train_ds,batch_size=64,shuffle=True)
test_loader=DataLoader(val_ds,batch_size=64,shuffle=False)

In [13]:
class MnistCNN(nn.Module):
  def __init__(self):
    super().__init__()

    self.conv1=nn.Conv2d(1,32,kernel_size=3,padding=1)
    self.relu1=nn.ReLU()
    self.pool1=nn.MaxPool2d(2,2)

    self.conv2=nn.Conv2d(32,64,kernel_size=3,padding=1)
    self.relu2=nn.ReLU()
    self.pool2=nn.MaxPool2d(2,2)

    self.flat=nn.Flatten()

    self.fc1=nn.Linear(7*7*64,128)
    self.relu_fc1=nn.ReLU()
    self.dropout=nn.Dropout(0.25)
    self.fc2=nn.Linear(128,10)

    self._init_wts()

  def _init_wts(self):
    for m in self.modules():
      if isinstance(m,(nn.Conv2d,nn.Linear)):
        nn.init.kaiming_normal_(m.weight,nonlinearity='relu')
        if m.bias is not None:
          nn.init.zeros_(m.bias)

  def forward(self,x):

    x=self.pool1(self.relu1(self.conv1(x)))
    x=self.pool2(self.relu2(self.conv2(x)))
    x=self.flat(x)
    x=self.dropout(self.relu_fc1(self.fc1(x)))

    return self.fc2(x)




In [14]:
# save checkpoint
def save_checkpoint(model,epoch,loss,output_dir='cnn_output'):
  os.makedirs(output_dir,exist_ok=True)
  path=os.path.join(output_dir,f'model-{epoch:02d}-{loss:.4f}.pt')
  torch.save(model.state_dict(),path)

In [15]:
class EarlyStopping:
    def __init__(self, patience=3):
        self.patience = patience
        self.counter = 0
        self.best_score = 0
        self.should_stop = False

    def step(self, score):
        if score > self.best_score:
            self.best_score = score
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True

In [16]:
class AccuracyTracker:
  def __init__(self):
    self.accuracy_scores=[]

  def compute(self,model,loader,epoch):
    model.eval()
    correct=total=0

    with torch.no_grad():
      for xb,yb in loader:
        xb,yb=xb.to(device),yb.to(device)
        preds=model(xb)
        correct+=(preds.argmax(1)==yb).sum().item()
        total+=yb.size(0)
      acc=correct/total
      self.accuracy_scores.append((epoch,acc))
      print(f'\n accuracy on validation set epoch {epoch}:{acc:.4f}')


In [27]:
model=MnistCNN().to(device)
loss_fn=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)
acctracker=AccuracyTracker()
earlystopper=EarlyStopping()

In [28]:
def train_model(model,train_loader,test_loader,loss_fn,
                optimizer,n_epochs,
                earlystopper,
                acctracker):
  for epoch in range(1,n_epochs+1):
    model.train()
    epoch_loss=0

    correct=total=0
    for xb,yb in train_loader:
      xb,yb = xb.to(device),yb.to(device)
      optimizer.zero_grad()
      preds=model(xb)
      loss=loss_fn(preds,yb)
      loss.backward()
      optimizer.step()

      epoch_loss +=loss.item()

      correct+=(preds.argmax(1)==yb).sum().item()
      total+=yb.size(0)

    train_acc=correct/total
    avg_loss=epoch_loss/len(train_loader)
    print(f'Train epoch {epoch:02d} : train_loss = {avg_loss:.4f},accuracy = {train_acc:.4f}')

    acctracker.compute(model,test_loader,epoch)
    test_acc=acctracker.accuracy_scores[-1][1]

    save_checkpoint(model,epoch,avg_loss)
    earlystopper.step(test_acc)

    if earlystopper.should_stop:
      print('Early Stopping Triggered')
      break


In [29]:
n_epochs=100

In [30]:
train_model(model,train_loader,test_loader,loss_fn,
                optimizer,n_epochs,
                earlystopper,
                acctracker)

Train epoch 01 : train_loss = 0.2052,accuracy = 0.9368

 accuracy on validation set epoch 1:0.9850
Train epoch 02 : train_loss = 0.0680,accuracy = 0.9785

 accuracy on validation set epoch 2:0.9864
Train epoch 03 : train_loss = 0.0466,accuracy = 0.9863

 accuracy on validation set epoch 3:0.9871
Train epoch 04 : train_loss = 0.0368,accuracy = 0.9884

 accuracy on validation set epoch 4:0.9850
Train epoch 05 : train_loss = 0.0306,accuracy = 0.9905

 accuracy on validation set epoch 5:0.9857
Train epoch 06 : train_loss = 0.0242,accuracy = 0.9923

 accuracy on validation set epoch 6:0.9886
Train epoch 07 : train_loss = 0.0189,accuracy = 0.9935

 accuracy on validation set epoch 7:0.9879
Train epoch 08 : train_loss = 0.0186,accuracy = 0.9940

 accuracy on validation set epoch 8:0.9879
Train epoch 09 : train_loss = 0.0152,accuracy = 0.9949

 accuracy on validation set epoch 9:0.9886
Early Stopping Triggered
