In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [16]:
dropout_value = 0.1

class Net(nn.Module):
    def __init__(self):
        super().__init__()

        #INPUT BLOCK
        self.convblock1 = nn.Sequential(
            nn.Conv2d(1, 8, 3), #inch=1, outch=8, size=26, rf=3, j=1
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.Dropout(dropout_value)
        )
    
        #CONV BLOCK 1
        self.convblock2 = nn.Sequential(
            nn.Conv2d(8,16, 3), #inch=8, outch=16, size=24, rf=5, j=1
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Dropout(dropout_value)
        )
        self.convblock3 = nn.Sequential(
            nn.Conv2d(16,16, 3), #inch=16, inout=16, size=22, rf=7, j=1
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Dropout(dropout_value)
        )

        self.convblock31 = nn.Sequential(
            nn.Conv2d(16,16, 3), #inch=16, inout=16, size=20, rf=9, j=1
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Dropout(dropout_value)
        )
        # TRANSITION BLOCK
        #POOL LAYER
        self.convblock4 = nn.Sequential(
            nn.Conv2d(16,8,3), #inch16, outch=8, size=18, rf=11, j=1
            nn.BatchNorm2d(8),
            nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(2, 2) # output_size = 9, rf=12, j=2
    
        #CONV BLOCK 2
        self.convblock5 = nn.Sequential(
            nn.Conv2d(8,16,3), #inch=8, outch=16, size=7, rf=16, j=2
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Dropout(dropout_value)
        )
        self.convblock6 = nn.Sequential(
            nn.Conv2d(16,16,3), #inch=16, outch=16, size=5, rf=20, j=2
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Dropout(dropout_value)
        )

        # OUTPUT BLOCK
        self.convblock7 = nn.Sequential(
            nn.Conv2d(16,32,3), #inch=16, outch=32, size=3, rf=24, j=2
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(dropout_value)
        )
        self.convblock8 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=10, kernel_size=(1, 1), padding=0, bias=False)
        ) # output_size = 5 #inch=32, outch=10, size=1, rf=28, j=2
        self.gap = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1))
        )
            
        # self.dropout = nn.Dropout(0.1)
        #LINEAR LAYER
        # self.fc_layer = nn.Linear(32*18*18, 10) #replaced with Adaptive pool

    def forward(self, x):
        # BLOCK 1
        x = self.convblock1(x)
        x = self.convblock2(x)
        x = self.convblock3(x)
        x = self.convblock31(x)
        # x = self.dropout(x)
        # TRANSITION
        x = self.convblock4(x)
        x = self.pool1(x)
        # BLOCK 2
        x = self.convblock5(x)
        x = self.convblock6(x)
        # x = self.dropout(x)
        # OUTPUT
        x = self.convblock7(x)
        x = self.convblock8(x)
        x = self.gap(x)
        
        x = x.view(-1, 10)
        return F.log_softmax(x, dim=1)

In [17]:
SEED = 1

# Apple metal or Nvidia CUDA
use_cuda = torch.cuda.is_available()
use_mps = torch.mps.is_available()

# seed for repeatablility:

# for all devices
torch.manual_seed(SEED)

# for specific acc
if use_cuda:
    print(f"Use CUDA?:{use_cuda}")
    torch.cuda.manual_seed(SEED)
elif use_mps:
    print(f"Use MPS?: {use_mps}")
    torch.mps.manual_seed(SEED)
else:
    print("Using CPU")

Use MPS?: True


In [18]:
# Train Phase transformations
train_transforms = transforms.Compose([
                                      #  transforms.Resize((28, 28)),
                                      #  transforms.ColorJitter(brightness=0.10, contrast=0.1, saturation=0.10, hue=0.1),
                                       transforms.RandomRotation((-7.0, 7.0), fill=(1,)),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,)) # The mean and std have to be sequences (e.g., tuples), therefore you should add a comma after the values. 
                                       # Note the difference between (0.1307) and (0.1307,)
                                       ])

# Test Phase transformations
test_transforms = transforms.Compose([
                                      #  transforms.Resize((28, 28)),
                                      #  transforms.ColorJitter(brightness=0.10, contrast=0.1, saturation=0.10, hue=0.1),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                       ])

train = datasets.MNIST('./data', train=True, download=True, transform=train_transforms)
test = datasets.MNIST('./data', train=False, download=True, transform=test_transforms)


dataloader_args = dict(shuffle=True,
                       batch_size=128,
                       num_workers=4,
                       pin_memory=True) \
                                        if use_cuda or use_mps else \
                    dict(shuffle=True,
                         batch_size=128)


train_loader = torch.utils.data.DataLoader(train,
                                           **dataloader_args)
test_loader = torch.utils.data.DataLoader(test,
                                          **dataloader_args)

In [20]:
!pip install torchsummary

from torchsummary import summary

if use_cuda:
    device = torch.device("cuda")
elif use_mps:
    device = torch.device("mps")
else:
    device = torch.device("cpu")

model = Net()

summary(model, input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 26, 26]              80
       BatchNorm2d-2            [-1, 8, 26, 26]              16
              ReLU-3            [-1, 8, 26, 26]               0
           Dropout-4            [-1, 8, 26, 26]               0
            Conv2d-5           [-1, 16, 24, 24]           1,168
       BatchNorm2d-6           [-1, 16, 24, 24]              32
              ReLU-7           [-1, 16, 24, 24]               0
           Dropout-8           [-1, 16, 24, 24]               0
            Conv2d-9           [-1, 16, 22, 22]           2,320
      BatchNorm2d-10           [-1, 16, 22, 22]              32
             ReLU-11           [-1, 16, 22, 22]               0
          Dropout-12           [-1, 16, 22, 22]               0
           Conv2d-13           [-1, 16, 20, 20]           2,320
      BatchNorm2d-14           [-1, 16,

In [22]:
from tqdm import tqdm

# define metrics to calc

train_losses = []
train_acc = []

test_losses = []
test_acc = []


def train(model, device, train_loader, optimizer):
    # set model to train mode
    model.train()

    # tqdm iterator
    pbar = tqdm(train_loader)

    # correct and processed vars
    correct = 0
    processed = 0

    # loop on batches of data
    for batch_idx, (data,target) in enumerate(pbar):
        #send data, targte to training device
        data, target = data.to(device), target.to(device)

        # Initialize grad to zero for the fresh batch grad accumuation
        optimizer.zero_grad()

        # pred with model
        y_pred = model(data)

        # calc loss
        batch_loss = F.nll_loss(y_pred, target)
        train_losses.append(batch_loss)

        # backprop loss to calc and acc grad w.r.t loss of batch
        batch_loss.backward()
        
        # update weights as per losses seen in this batch
        optimizer.step()

        # calculate correct pred count and acc for batch
        pred_labels = y_pred.argmax(dim=1, keepdim=True)
        correct_count_batch = pred_labels.eq(target.view_as(pred_labels)).sum().item()

        # update total correct and total processed so far
        correct+= correct_count_batch
        processed+= len(data)

        # set pbar desc
        pbar.set_description(desc=f'batch Loss = {batch_loss.item()} batch_id = {batch_idx} accuracy = {100*correct/processed:.02f}'
                            )
        #append train acc
        train_acc.append(100*correct/processed)


def test(model, device, test_loader):
    # set model to eval mode
    model.eval()

    # define var to calc correct and processed
    correct = 0
    processed = 0
    test_loss = 0 # seeing loss as the code runs has no value for test

    # set a no grad context
    with torch.no_grad():
        for data,target in test_loader:
            #send data, target to device
            data, target = data.to(device), target.to(device)
    
            # do pred
            y_pred = model(data)
    
            #calc loss for batch as summed and update total test loss
            batch_loss = F.nll_loss(y_pred, target, reduction='sum').item()
            test_loss+= batch_loss
            # collect loss
            test_losses.append(batch_loss)
    
            # count correct
            pred_labels = y_pred.argmax(dim=1, keepdim=True)
            correct_batch = pred_labels.eq(target.view_as(pred_labels)).sum().item()
    
            #update correct
            correct+= correct_batch
            processed+= len(data)

    # avg loss on test makes more sense to avg it
    test_loss/= processed
    # collect avg losses
    test_losses.append(test_loss)

    print(f'\n Test set avg loss: {test_loss:.4f} \
                Accuracy: {correct}/{processed}, {100*correct/processed:.2f}'
         )

    test_acc.append(100*correct/processed)
    return test_loss
        

# Note it is little weird in the sense that train losses and acc are 
# collected for each batch in the above code but test losses and acc are
# collected for each epoch
    

In [23]:
# initialize model on device
model = Net().to(device)

# initialize optimizer with model params and lr
# optimizer = optim.SGD(model.parameters(), lr=0.02, momentum=0.9)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# let's try a lr scheduler
from torch.optim.lr_scheduler import StepLR
# step scheduler
# scheduler = StepLR(optimizer, step_size=3, gamma=0.1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)

# Set total epochs
EPOCHS = 20

for epoch in range(EPOCHS):
    print(f'EPOCH: {epoch}')
    train(model, device, train_loader, optimizer)
    test_loss = test(model, device, test_loader)
    scheduler.step(test_loss)

EPOCH: 0



atch Loss = 0.12355833500623703 batch_id = 468 accuracy = 89.07: 100%|█| 469/469 [00:13<00:00, 34.59it/s]


 Test set avg loss: 0.0737                 Accuracy: 9810/10000, 98.10
EPOCH: 1



atch Loss = 0.05688127875328064 batch_id = 468 accuracy = 97.52: 100%|█| 469/469 [00:11<00:00, 41.84it/s]


 Test set avg loss: 0.0441                 Accuracy: 9876/10000, 98.76
EPOCH: 2



atch Loss = 0.049336958676576614 batch_id = 468 accuracy = 98.02: 100%|█| 469/469 [00:10<00:00, 44.06it/s


 Test set avg loss: 0.0349                 Accuracy: 9905/10000, 99.05
EPOCH: 3



atch Loss = 0.03311590105295181 batch_id = 468 accuracy = 98.25: 100%|█| 469/469 [00:10<00:00, 43.99it/s]


 Test set avg loss: 0.0325                 Accuracy: 9897/10000, 98.97
EPOCH: 4



atch Loss = 0.016051035374403 batch_id = 468 accuracy = 98.41: 100%|███| 469/469 [00:10<00:00, 44.00it/s]


 Test set avg loss: 0.0264                 Accuracy: 9922/10000, 99.22
EPOCH: 5



atch Loss = 0.03056023083627224 batch_id = 468 accuracy = 98.47: 100%|█| 469/469 [00:10<00:00, 42.87it/s]


 Test set avg loss: 0.0282                 Accuracy: 9911/10000, 99.11
EPOCH: 6



atch Loss = 0.16768985986709595 batch_id = 468 accuracy = 98.63: 100%|█| 469/469 [00:10<00:00, 44.12it/s]


 Test set avg loss: 0.0276                 Accuracy: 9906/10000, 99.06
EPOCH: 7



atch Loss = 0.00619245320558548 batch_id = 468 accuracy = 98.77: 100%|█| 469/469 [00:10<00:00, 43.87it/s]


 Test set avg loss: 0.0245                 Accuracy: 9925/10000, 99.25
EPOCH: 8



atch Loss = 0.006135950330644846 batch_id = 468 accuracy = 98.82: 100%|█| 469/469 [00:10<00:00, 43.67it/s


 Test set avg loss: 0.0223                 Accuracy: 9928/10000, 99.28
EPOCH: 9



atch Loss = 0.023698588833212852 batch_id = 468 accuracy = 98.75: 100%|█| 469/469 [00:10<00:00, 44.03it/s


 Test set avg loss: 0.0244                 Accuracy: 9925/10000, 99.25
EPOCH: 10



atch Loss = 0.07484211027622223 batch_id = 468 accuracy = 98.81: 100%|█| 469/469 [00:10<00:00, 43.57it/s]


 Test set avg loss: 0.0226                 Accuracy: 9935/10000, 99.35
EPOCH: 11



atch Loss = 0.056699950248003006 batch_id = 468 accuracy = 98.84: 100%|█| 469/469 [00:10<00:00, 43.63it/s


 Test set avg loss: 0.0216                 Accuracy: 9940/10000, 99.40
EPOCH: 12



atch Loss = 0.030623765662312508 batch_id = 468 accuracy = 98.91: 100%|█| 469/469 [00:10<00:00, 43.82it/s


 Test set avg loss: 0.0222                 Accuracy: 9930/10000, 99.30
EPOCH: 13



atch Loss = 0.004612436983734369 batch_id = 468 accuracy = 98.89: 100%|█| 469/469 [00:10<00:00, 43.96it/s


 Test set avg loss: 0.0243                 Accuracy: 9924/10000, 99.24
EPOCH: 14



atch Loss = 0.06747055798768997 batch_id = 468 accuracy = 98.94: 100%|█| 469/469 [00:11<00:00, 42.62it/s]


 Test set avg loss: 0.0206                 Accuracy: 9932/10000, 99.32
EPOCH: 15



atch Loss = 0.054090797901153564 batch_id = 468 accuracy = 99.01: 100%|█| 469/469 [00:10<00:00, 43.50it/s


 Test set avg loss: 0.0208                 Accuracy: 9934/10000, 99.34
EPOCH: 16



atch Loss = 0.0070099979639053345 batch_id = 468 accuracy = 98.95: 100%|█| 469/469 [00:11<00:00, 39.97it/


 Test set avg loss: 0.0221                 Accuracy: 9926/10000, 99.26
EPOCH: 17



atch Loss = 0.014665688388049603 batch_id = 468 accuracy = 98.97: 100%|█| 469/469 [00:10<00:00, 43.71it/s


 Test set avg loss: 0.0193                 Accuracy: 9942/10000, 99.42
EPOCH: 18



atch Loss = 0.015031891874969006 batch_id = 468 accuracy = 99.02: 100%|█| 469/469 [00:10<00:00, 44.16it/s


 Test set avg loss: 0.0214                 Accuracy: 9933/10000, 99.33
EPOCH: 19



atch Loss = 0.004560594912618399 batch_id = 468 accuracy = 99.03: 100%|█| 469/469 [00:10<00:00, 43.99it/s


 Test set avg loss: 0.0230                 Accuracy: 9932/10000, 99.32
