In [25]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch.optim.lr_scheduler as lr_scheduler

In [90]:
class Net(nn.Module):
    #This defines the structure of the NN.
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3)
        self.conv3 = nn.Conv2d(16, 16, kernel_size=3)
        self.conv4 = nn.Conv2d(16, 64, kernel_size=3)
        self.fc1 = nn.Linear(32, 24)
        self.fc3 = nn.Linear(64, 32)
        self.fc4 = nn.Linear(32, 10)

        self.bn1 = nn.BatchNorm2d(16)
        self.bn2 = nn.BatchNorm2d(64)

        self.dropout1 = nn.Dropout(0.1)
        self.gap = nn.AdaptiveAvgPool2d(1)

        self.flatten = nn.Flatten()

    def forward(self, x):
        x = F.relu(self.conv1(x), 2)
        x = F.relu(F.max_pool2d(self.bn1(self.conv2(x)),2))
        x = self.dropout1(x)
        x = F.relu(self.conv3(x), 2)
        x = F.relu(F.max_pool2d(self.bn2(self.conv4(x)),2))
        # x = self.flatten(x)
        # x = F.relu(self.fc1(x))
        # # x = self.dropout2(x)
        # x = self.fc3(x)
        x = self.gap(x)
        x = x.view(-1, 64)
        x = self.fc3(x)
        x = self.fc4(x)
        return F.log_softmax(x, dim=1)

In [91]:
!pip install torchsummary
from torchsummary import summary
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
model = Net().to(device)
summary(model, input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 26, 26]              80
            Conv2d-2           [-1, 16, 24, 24]           1,168
       BatchNorm2d-3           [-1, 16, 24, 24]              32
           Dropout-4           [-1, 16, 12, 12]               0
            Conv2d-5           [-1, 16, 10, 10]           2,320
            Conv2d-6             [-1, 64, 8, 8]           9,280
       BatchNorm2d-7             [-1, 64, 8, 8]             128
 AdaptiveAvgPool2d-8             [-1, 64, 1, 1]               0
            Linear-9                   [-1, 32]           2,080
           Linear-10                   [-1, 10]             330
Total params: 15,418
Trainable params: 15,418
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.27
Params size (MB): 0.06
Estimated Tot

In [96]:
torch.manual_seed(1)
batch_size = 128

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.RandomAffine(
                           degrees=15,            # Rotate by ±15 degrees
                           translate=(0.1, 0.1), # Translate by ±10%
                           scale=(0.9, 1.1)      # Scale between 90% to 110%
                       ),
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=batch_size, shuffle=True, **kwargs)


In [97]:
# Data to plot accuracy and loss graphs
train_losses = []
test_losses = []
train_acc = []
test_acc = []

test_incorrect_pred = {'images': [], 'ground_truths': [], 'predicted_vals': []}

In [98]:
from tqdm import tqdm
def GetCorrectPredCount(pPrediction, pLabels):
  return pPrediction.argmax(dim=1).eq(pLabels).sum().item()


def train(model, device, train_loader, optimizer, epoch):
  model.train()
  pbar = tqdm(train_loader)

  train_loss = 0
  correct = 0
  processed = 0

  for batch_idx, (data, target) in enumerate(pbar):
      data, target = data.to(device), target.to(device)
      optimizer.zero_grad()
      output = model(data)
      loss = F.nll_loss(output, target)
      train_loss+=loss.item()
      loss.backward()
      optimizer.step()


      correct += GetCorrectPredCount(output, target)
      processed += len(data)

      pbar.set_description(desc= f'loss={loss.item()} batch_id={batch_idx} Accuracy={100*correct/processed:0.2f}')

  train_acc.append(100*correct/processed)
  train_losses.append(train_loss/len(train_loader))


def test(model, device, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
      for data, target in test_loader:
          data, target = data.to(device), target.to(device)
          output = model(data)
          test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
          pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
          correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(test_loader.dataset)

  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(
      test_loss, correct, len(test_loader.dataset),
      100. * correct / len(test_loader.dataset)))

In [99]:
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5, verbose=True)


for epoch in range(1, 21):
  train(model, device, train_loader, optimizer, epoch)
  test(model, device, test_loader)
  scheduler.step()
  print(f"epoch {epoch} Done")

loss=0.16894526779651642 batch_id=468 Accuracy=87.48: 100%|██████████| 469/469 [00:25<00:00, 18.47it/s]



Test set: Average loss: 0.1109, Accuracy: 9620/10000 (96.200%)

epoch 1 Done


loss=0.10723813623189926 batch_id=468 Accuracy=95.60: 100%|██████████| 469/469 [00:23<00:00, 19.84it/s]



Test set: Average loss: 0.0596, Accuracy: 9805/10000 (98.050%)

epoch 2 Done


loss=0.12064648419618607 batch_id=468 Accuracy=96.40: 100%|██████████| 469/469 [00:23<00:00, 19.86it/s]



Test set: Average loss: 0.0558, Accuracy: 9831/10000 (98.310%)

epoch 3 Done


loss=0.10346948355436325 batch_id=468 Accuracy=97.52: 100%|██████████| 469/469 [00:25<00:00, 18.76it/s]



Test set: Average loss: 0.0348, Accuracy: 9880/10000 (98.800%)

epoch 4 Done


loss=0.055307190865278244 batch_id=468 Accuracy=97.61: 100%|██████████| 469/469 [00:24<00:00, 19.25it/s]



Test set: Average loss: 0.0326, Accuracy: 9894/10000 (98.940%)

epoch 5 Done


loss=0.06990155577659607 batch_id=468 Accuracy=97.86: 100%|██████████| 469/469 [00:24<00:00, 19.51it/s]



Test set: Average loss: 0.0262, Accuracy: 9914/10000 (99.140%)

epoch 6 Done


loss=0.0882933959364891 batch_id=468 Accuracy=98.15: 100%|██████████| 469/469 [00:24<00:00, 19.36it/s]



Test set: Average loss: 0.0217, Accuracy: 9927/10000 (99.270%)

epoch 7 Done


loss=0.0489351861178875 batch_id=468 Accuracy=98.24: 100%|██████████| 469/469 [00:23<00:00, 19.70it/s]



Test set: Average loss: 0.0228, Accuracy: 9918/10000 (99.180%)

epoch 8 Done


loss=0.13096056878566742 batch_id=468 Accuracy=98.24: 100%|██████████| 469/469 [00:23<00:00, 19.74it/s]



Test set: Average loss: 0.0227, Accuracy: 9923/10000 (99.230%)

epoch 9 Done


loss=0.06855706125497818 batch_id=468 Accuracy=98.47: 100%|██████████| 469/469 [00:23<00:00, 19.72it/s]



Test set: Average loss: 0.0185, Accuracy: 9943/10000 (99.430%)

epoch 10 Done


loss=0.016401926055550575 batch_id=468 Accuracy=98.55: 100%|██████████| 469/469 [00:23<00:00, 19.62it/s]



Test set: Average loss: 0.0178, Accuracy: 9933/10000 (99.330%)

epoch 11 Done


loss=0.012807087041437626 batch_id=468 Accuracy=98.48: 100%|██████████| 469/469 [00:23<00:00, 19.70it/s]



Test set: Average loss: 0.0183, Accuracy: 9942/10000 (99.420%)

epoch 12 Done


loss=0.09036624431610107 batch_id=468 Accuracy=98.55: 100%|██████████| 469/469 [00:23<00:00, 19.57it/s]



Test set: Average loss: 0.0167, Accuracy: 9943/10000 (99.430%)

epoch 13 Done


loss=0.08084597438573837 batch_id=468 Accuracy=98.69: 100%|██████████| 469/469 [00:23<00:00, 19.64it/s]



Test set: Average loss: 0.0168, Accuracy: 9944/10000 (99.440%)

epoch 14 Done


loss=0.03757338598370552 batch_id=468 Accuracy=98.67: 100%|██████████| 469/469 [00:23<00:00, 19.89it/s]



Test set: Average loss: 0.0166, Accuracy: 9940/10000 (99.400%)

epoch 15 Done


loss=0.0012255562469363213 batch_id=468 Accuracy=98.71: 100%|██████████| 469/469 [00:24<00:00, 19.50it/s]



Test set: Average loss: 0.0157, Accuracy: 9942/10000 (99.420%)

epoch 16 Done


loss=0.026653125882148743 batch_id=468 Accuracy=98.78: 100%|██████████| 469/469 [00:24<00:00, 19.44it/s]



Test set: Average loss: 0.0168, Accuracy: 9942/10000 (99.420%)

epoch 17 Done


loss=0.007751979399472475 batch_id=468 Accuracy=98.73: 100%|██████████| 469/469 [00:24<00:00, 19.43it/s]



Test set: Average loss: 0.0166, Accuracy: 9939/10000 (99.390%)

epoch 18 Done


loss=0.03018137253820896 batch_id=468 Accuracy=98.72: 100%|██████████| 469/469 [00:24<00:00, 19.41it/s]



Test set: Average loss: 0.0158, Accuracy: 9941/10000 (99.410%)

epoch 19 Done


loss=0.0848696231842041 batch_id=468 Accuracy=98.82: 100%|██████████| 469/469 [00:23<00:00, 19.59it/s]



Test set: Average loss: 0.0155, Accuracy: 9942/10000 (99.420%)

epoch 20 Done
