## Import Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR
import torch.optim as optim
from torchvision import datasets, transforms

from dataloader.dloader import dloader
from dataset.split import get_train_test_dataset
from models import mnist
from utils.utils import get_summary
from trainer.fit import fit

%matplotlib inline
import matplotlib.pyplot as plt

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
epoch = 20

## Transforms

In [2]:
train_transforms = transforms.Compose([
                                      #  transforms.Resize((28, 28)),
                                      #  transforms.ColorJitter(brightness=0.10, contrast=0.1, saturation=0.10, hue=0.1),
                                       transforms.RandomRotation((-15.0, 15.0), fill=(1,)),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,)) # The mean and std have to be sequences (e.g., tuples), therefore you should add a comma after the values. 
                                       ])


test_transforms = transforms.Compose([
                                      #  transforms.Resize((28, 28)),
                                      #  transforms.ColorJitter(brightness=0.10, contrast=0.1, saturation=0.10, hue=0.1),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,)) # The mean and std have to be sequences (e.g., tuples), therefore you should add a comma after the values. 
                                       ])

## DataSet

In [3]:
train_data,test_data = get_train_test_dataset(train_transforms,test_transforms)

## DataLoader

In [4]:
train_loader  = dloader(train_data, batch_size = 64, shuffle = True, num_workers = 4, pin_memory = True)

test_loader   = dloader(test_data,batch_size = 64, shuffle = True, num_workers = 4, pin_memory = True)

#Remedy if not cuda to be done

In [5]:
model = mnist.MNIST(0.05).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=len(train_loader), epochs=epoch)

In [6]:
print(get_summary(model, (1,28,28)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 10, 26, 26]              90
       BatchNorm2d-2           [-1, 10, 26, 26]              20
           Dropout-3           [-1, 10, 26, 26]               0
              ReLU-4           [-1, 10, 26, 26]               0
            Conv2d-5           [-1, 10, 24, 24]             900
       BatchNorm2d-6           [-1, 10, 24, 24]              20
           Dropout-7           [-1, 10, 24, 24]               0
              ReLU-8           [-1, 10, 24, 24]               0
            Conv2d-9           [-1, 10, 22, 22]             900
      BatchNorm2d-10           [-1, 10, 22, 22]              20
          Dropout-11           [-1, 10, 22, 22]               0
             ReLU-12           [-1, 10, 22, 22]               0
           Conv2d-13           [-1, 10, 20, 20]             900
      BatchNorm2d-14           [-1, 10,

In [7]:

train_acc, train_losses,test_acc,test_losses = fit(train_loader,
                                                   test_loader,
                                                   model, 
                                                   device, 
                                                   optimizer, 
                                                   epoch,
                                                   scheduler)


  0%|          | 0/938 [00:00<?, ?it/s]

EPOCH: 0


Loss=0.29457753896713257 Batch_id=937 Accuracy=91.22: 100%|██████████| 938/938 [00:04<00:00, 188.17it/s] 

Epoch: 0 LR: [0.010433014025713727]



  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0768, Accuracy: 9798/10000 (97.98%)

EPOCH: 1


Loss=0.15550827980041504 Batch_id=937 Accuracy=96.39: 100%|██████████| 938/938 [00:04<00:00, 189.40it/s] 

Epoch: 1 LR: [0.028007736542228656]



  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0435, Accuracy: 9873/10000 (98.73%)

EPOCH: 2


Loss=0.029139598831534386 Batch_id=937 Accuracy=96.97: 100%|██████████| 938/938 [00:04<00:00, 190.03it/s]

Epoch: 2 LR: [0.05201339936426283]



  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0396, Accuracy: 9885/10000 (98.85%)

EPOCH: 3


Loss=0.08663179725408554 Batch_id=937 Accuracy=97.40: 100%|██████████| 938/938 [00:04<00:00, 188.04it/s] 


Epoch: 3 LR: [0.07601547059053707]


  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0333, Accuracy: 9896/10000 (98.96%)

EPOCH: 4


Loss=0.0987091064453125 Batch_id=937 Accuracy=97.70: 100%|██████████| 938/938 [00:05<00:00, 184.05it/s]   

Epoch: 4 LR: [0.09358038101918463]



  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0366, Accuracy: 9894/10000 (98.94%)

EPOCH: 5


Loss=0.033830538392066956 Batch_id=937 Accuracy=97.86: 100%|██████████| 938/938 [00:05<00:00, 179.21it/s] 

Epoch: 5 LR: [0.09999999856920878]



  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0390, Accuracy: 9885/10000 (98.85%)

EPOCH: 6


Loss=0.25889888405799866 Batch_id=937 Accuracy=97.98: 100%|██████████| 938/938 [00:05<00:00, 184.36it/s]  

Epoch: 6 LR: [0.09874373753452848]



  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0273, Accuracy: 9913/10000 (99.13%)

EPOCH: 7


Loss=0.030015498399734497 Batch_id=937 Accuracy=98.18: 100%|██████████| 938/938 [00:05<00:00, 178.52it/s] 

Epoch: 7 LR: [0.09504327199257961]



  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0290, Accuracy: 9899/10000 (98.99%)

EPOCH: 8


Loss=0.006113620009273291 Batch_id=937 Accuracy=98.16: 100%|██████████| 938/938 [00:05<00:00, 179.72it/s] 

Epoch: 8 LR: [0.08908415873743396]



  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0249, Accuracy: 9915/10000 (99.15%)

EPOCH: 9


Loss=0.037148527801036835 Batch_id=937 Accuracy=98.33: 100%|██████████| 938/938 [00:05<00:00, 186.37it/s] 

Epoch: 9 LR: [0.08116521259079444]



  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0311, Accuracy: 9906/10000 (99.06%)

EPOCH: 10


Loss=0.011555401608347893 Batch_id=937 Accuracy=98.41: 100%|██████████| 938/938 [00:05<00:00, 187.08it/s] 

Epoch: 10 LR: [0.07168352257909298]



  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0241, Accuracy: 9916/10000 (99.16%)

EPOCH: 11


Loss=0.057635582983493805 Batch_id=937 Accuracy=98.52: 100%|██████████| 938/938 [00:04<00:00, 187.62it/s] 


Epoch: 11 LR: [0.0611145402316058]


  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0224, Accuracy: 9925/10000 (99.25%)

EPOCH: 12


Loss=0.06819038838148117 Batch_id=937 Accuracy=98.55: 100%|██████████| 938/938 [00:04<00:00, 188.80it/s]  

Epoch: 12 LR: [0.0499882384554628]



  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0242, Accuracy: 9923/10000 (99.23%)

EPOCH: 13


Loss=0.02456093020737171 Batch_id=937 Accuracy=98.68: 100%|██████████| 938/938 [00:04<00:00, 188.02it/s]  


Epoch: 13 LR: [0.03886253648110994]


  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0219, Accuracy: 9920/10000 (99.20%)

EPOCH: 14


Loss=0.10218770056962967 Batch_id=937 Accuracy=98.71: 100%|██████████| 938/938 [00:04<00:00, 188.98it/s]  

Epoch: 14 LR: [0.028295323462426913]



  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0226, Accuracy: 9929/10000 (99.29%)

EPOCH: 15


Loss=0.07317556440830231 Batch_id=937 Accuracy=98.81: 100%|██████████| 938/938 [00:04<00:00, 188.84it/s]  

Epoch: 15 LR: [0.018816483585009303]



  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0191, Accuracy: 9936/10000 (99.36%)

EPOCH: 16


Loss=0.09370893985033035 Batch_id=937 Accuracy=98.75: 100%|██████████| 938/938 [00:04<00:00, 188.33it/s]  

Epoch: 16 LR: [0.010901325460499198]



  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0187, Accuracy: 9939/10000 (99.39%)

EPOCH: 17


Loss=0.047258324921131134 Batch_id=937 Accuracy=98.86: 100%|██████████| 938/938 [00:05<00:00, 186.82it/s] 

Epoch: 17 LR: [0.004946748168081546]



  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0178, Accuracy: 9940/10000 (99.40%)

EPOCH: 18


Loss=0.01396920159459114 Batch_id=937 Accuracy=98.97: 100%|██████████| 938/938 [00:04<00:00, 189.90it/s]  

Epoch: 18 LR: [0.001251339077347576]



  0%|          | 0/938 [00:00<?, ?it/s]


Test set: Average loss: 0.0177, Accuracy: 9942/10000 (99.42%)

EPOCH: 19


Loss=0.0676586776971817 Batch_id=937 Accuracy=98.97: 100%|██████████| 938/938 [00:05<00:00, 186.93it/s]    

Epoch: 19 LR: [4.0143079121884904e-07]






Test set: Average loss: 0.0172, Accuracy: 9938/10000 (99.38%)

