In [1]:
import model
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import os

In [2]:
!pip install torchsummary
from torchsummary import summary
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
network = model.get_network(model.Network.GROUP_NORM).to(device)
summary(network, input_size=(1, 28, 28))

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 10, 28, 28]             100
              ReLU-2           [-1, 10, 28, 28]               0
         GroupNorm-3           [-1, 10, 28, 28]              20
         Dropout2d-4           [-1, 10, 28, 28]               0
            Conv2d-5           [-1, 16, 28, 28]           1,456
              ReLU-6           [-1, 16, 28, 28]               0
         GroupNorm-7           [-1, 16, 28, 28]              32
         Dropout2d-8           [-1, 16, 28, 28]               0
         MaxPool2d-9           [-1, 16, 14, 14]               0
           Conv2d-10            [-1, 8, 14, 14]             136
             ReLU-11            [-1, 8, 14, 14]               0
        GroupNorm-12            [-1, 8, 14, 14]              16
    

  return F.log_softmax(x)


In [3]:
torch.manual_seed(1)
batch_size = 64

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                    transform=transforms.Compose([
                        #transforms.RandomRotation((-7.0, 7.0), fill=(1,)),
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=batch_size, shuffle=True, **kwargs)


In [4]:
from tqdm import tqdm
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)
    correct = 0
    processed = 0

    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        y_pred = model(data)
        pred = y_pred.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)

        pbar.set_description(desc= f'Epoch={epoch} Loss={loss.item()} batch_id={batch_idx} Accuracy={100*correct/processed:0.2f}%')


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [5]:
network = model.get_network(model.Network.GROUP_NORM).to(device)
optimizer = optim.SGD(network.parameters(), lr=0.01, momentum=0.9)
#scheduler = StepLR(optimizer, step_size=5, gamma=0.1)


for epoch in range(1, 16):
    train(network, device, train_loader, optimizer, epoch)
    test(network, device, test_loader)

Epoch=1 Loss=0.12392862141132355 batch_id=937 Accuracy=88.02%: 100%|██████████| 938/938 [00:25<00:00, 36.14it/s]



Test set: Average loss: 0.0681, Accuracy: 9808/10000 (98.08%)



Epoch=2 Loss=0.04138115793466568 batch_id=937 Accuracy=97.20%: 100%|██████████| 938/938 [00:22<00:00, 41.00it/s]



Test set: Average loss: 0.0516, Accuracy: 9847/10000 (98.47%)



Epoch=3 Loss=0.08628048002719879 batch_id=937 Accuracy=97.67%: 100%|██████████| 938/938 [00:23<00:00, 40.74it/s]



Test set: Average loss: 0.0389, Accuracy: 9878/10000 (98.78%)



Epoch=4 Loss=0.01633160375058651 batch_id=937 Accuracy=98.06%: 100%|██████████| 938/938 [00:24<00:00, 38.94it/s]



Test set: Average loss: 0.0359, Accuracy: 9892/10000 (98.92%)



Epoch=5 Loss=0.009081628173589706 batch_id=937 Accuracy=98.27%: 100%|██████████| 938/938 [00:23<00:00, 39.67it/s]



Test set: Average loss: 0.0368, Accuracy: 9890/10000 (98.90%)



Epoch=6 Loss=0.02881023846566677 batch_id=937 Accuracy=98.36%: 100%|██████████| 938/938 [00:23<00:00, 40.08it/s]



Test set: Average loss: 0.0335, Accuracy: 9905/10000 (99.05%)



Epoch=7 Loss=0.24833381175994873 batch_id=937 Accuracy=98.40%: 100%|██████████| 938/938 [00:23<00:00, 40.01it/s]



Test set: Average loss: 0.0314, Accuracy: 9905/10000 (99.05%)



Epoch=8 Loss=0.15637683868408203 batch_id=937 Accuracy=98.52%: 100%|██████████| 938/938 [00:23<00:00, 39.85it/s]



Test set: Average loss: 0.0306, Accuracy: 9898/10000 (98.98%)



Epoch=9 Loss=0.2758753299713135 batch_id=937 Accuracy=98.61%: 100%|██████████| 938/938 [00:23<00:00, 39.97it/s]



Test set: Average loss: 0.0276, Accuracy: 9910/10000 (99.10%)



Epoch=10 Loss=0.16807948052883148 batch_id=937 Accuracy=98.61%: 100%|██████████| 938/938 [00:23<00:00, 40.04it/s]



Test set: Average loss: 0.0301, Accuracy: 9914/10000 (99.14%)



Epoch=11 Loss=0.009748771786689758 batch_id=937 Accuracy=98.67%: 100%|██████████| 938/938 [00:23<00:00, 39.53it/s]



Test set: Average loss: 0.0294, Accuracy: 9906/10000 (99.06%)



Epoch=12 Loss=0.029194261878728867 batch_id=937 Accuracy=98.70%: 100%|██████████| 938/938 [00:23<00:00, 39.99it/s]



Test set: Average loss: 0.0310, Accuracy: 9906/10000 (99.06%)



Epoch=13 Loss=0.052009087055921555 batch_id=937 Accuracy=98.82%: 100%|██████████| 938/938 [00:23<00:00, 40.27it/s]



Test set: Average loss: 0.0302, Accuracy: 9907/10000 (99.07%)



Epoch=14 Loss=0.09740747511386871 batch_id=937 Accuracy=98.87%: 100%|██████████| 938/938 [00:23<00:00, 39.86it/s]



Test set: Average loss: 0.0289, Accuracy: 9912/10000 (99.12%)



Epoch=15 Loss=0.014592999592423439 batch_id=937 Accuracy=98.89%: 100%|██████████| 938/938 [00:23<00:00, 39.70it/s]



Test set: Average loss: 0.0269, Accuracy: 9920/10000 (99.20%)



In [6]:
network = model.get_network(model.Network.LAYER_NORM).to(device)
summary(network, input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 10, 28, 28]             100
              ReLU-2           [-1, 10, 28, 28]               0
         LayerNorm-3           [-1, 10, 28, 28]          15,680
         Dropout2d-4           [-1, 10, 28, 28]               0
            Conv2d-5           [-1, 16, 28, 28]           1,456
              ReLU-6           [-1, 16, 28, 28]               0
         LayerNorm-7           [-1, 16, 28, 28]          25,088
         Dropout2d-8           [-1, 16, 28, 28]               0
         MaxPool2d-9           [-1, 16, 14, 14]               0
           Conv2d-10            [-1, 8, 14, 14]             136
             ReLU-11            [-1, 8, 14, 14]               0
        LayerNorm-12            [-1, 8, 14, 14]           3,136
        Dropout2d-13            [-1, 8, 14, 14]               0
           Conv2d-14           [-1, 16,

  return F.log_softmax(x)


In [7]:
network = model.get_network(model.Network.LAYER_NORM).to(device)
optimizer = optim.SGD(network.parameters(), lr=0.01, momentum=0.9)
#scheduler = StepLR(optimizer, step_size=5, gamma=0.1)


for epoch in range(1, 16):
    train(network, device, train_loader, optimizer, epoch)
    test(network, device, test_loader)

Epoch=1 Loss=0.17002348601818085 batch_id=937 Accuracy=90.24%: 100%|██████████| 938/938 [00:22<00:00, 41.14it/s]



Test set: Average loss: 0.0764, Accuracy: 9766/10000 (97.66%)



Epoch=2 Loss=0.05211994796991348 batch_id=937 Accuracy=97.21%: 100%|██████████| 938/938 [00:23<00:00, 40.28it/s]



Test set: Average loss: 0.0530, Accuracy: 9831/10000 (98.31%)



Epoch=3 Loss=0.10989350825548172 batch_id=937 Accuracy=97.81%: 100%|██████████| 938/938 [00:23<00:00, 40.64it/s]



Test set: Average loss: 0.0441, Accuracy: 9876/10000 (98.76%)



Epoch=4 Loss=0.08162811398506165 batch_id=937 Accuracy=98.12%: 100%|██████████| 938/938 [00:22<00:00, 40.91it/s]



Test set: Average loss: 0.0378, Accuracy: 9873/10000 (98.73%)



Epoch=5 Loss=0.01683547906577587 batch_id=937 Accuracy=98.30%: 100%|██████████| 938/938 [00:23<00:00, 40.57it/s]



Test set: Average loss: 0.0340, Accuracy: 9890/10000 (98.90%)



Epoch=6 Loss=0.013773527927696705 batch_id=937 Accuracy=98.44%: 100%|██████████| 938/938 [00:23<00:00, 40.75it/s]



Test set: Average loss: 0.0323, Accuracy: 9901/10000 (99.01%)



Epoch=7 Loss=0.10910176485776901 batch_id=937 Accuracy=98.56%: 100%|██████████| 938/938 [00:22<00:00, 40.81it/s]



Test set: Average loss: 0.0356, Accuracy: 9890/10000 (98.90%)



Epoch=8 Loss=0.1834798902273178 batch_id=937 Accuracy=98.72%: 100%|██████████| 938/938 [00:23<00:00, 40.63it/s]



Test set: Average loss: 0.0341, Accuracy: 9881/10000 (98.81%)



Epoch=9 Loss=0.005808831658214331 batch_id=937 Accuracy=98.75%: 100%|██████████| 938/938 [00:22<00:00, 40.94it/s]



Test set: Average loss: 0.0328, Accuracy: 9897/10000 (98.97%)



Epoch=10 Loss=0.10751212388277054 batch_id=937 Accuracy=98.78%: 100%|██████████| 938/938 [00:22<00:00, 41.13it/s]



Test set: Average loss: 0.0280, Accuracy: 9909/10000 (99.09%)



Epoch=11 Loss=0.14841964840888977 batch_id=937 Accuracy=98.84%: 100%|██████████| 938/938 [00:23<00:00, 40.50it/s]



Test set: Average loss: 0.0311, Accuracy: 9906/10000 (99.06%)



Epoch=12 Loss=0.018209289759397507 batch_id=937 Accuracy=98.85%: 100%|██████████| 938/938 [00:22<00:00, 40.91it/s]



Test set: Average loss: 0.0252, Accuracy: 9915/10000 (99.15%)



Epoch=13 Loss=0.02086230181157589 batch_id=937 Accuracy=98.89%: 100%|██████████| 938/938 [00:23<00:00, 40.29it/s]



Test set: Average loss: 0.0280, Accuracy: 9914/10000 (99.14%)



Epoch=14 Loss=0.01370526384562254 batch_id=937 Accuracy=98.98%: 100%|██████████| 938/938 [00:23<00:00, 40.18it/s]



Test set: Average loss: 0.0287, Accuracy: 9904/10000 (99.04%)



Epoch=15 Loss=0.015676043927669525 batch_id=937 Accuracy=98.98%: 100%|██████████| 938/938 [00:22<00:00, 40.91it/s]



Test set: Average loss: 0.0269, Accuracy: 9911/10000 (99.11%)

