In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from main import train, test
from model import Net
from utils import Transforms, show_batch

In [3]:
%pip install torchsummary
from torchsummary import summary

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl.metadata (296 bytes)
Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1
Note: you may need to restart the kernel to use updated packages.


In [64]:
train_data = Transforms(datasets.CIFAR10('./data', train=True, download=True), Train=True)
test_data = Transforms(datasets.CIFAR10('./data', train=False, download=True), Train=False)

Files already downloaded and verified
Files already downloaded and verified


In [65]:
SEED = 1

cuda = torch.cuda.is_available()
print("CUDA Available?", cuda)

torch.manual_seed(SEED)

if cuda:
    torch.cuda.manual_seed(SEED)

dataloader_args = dict(shuffle=True, batch_size=64, num_workers=4, pin_memory=True) if cuda else dict(shuffle=True, batch_size=64)

train_loader = torch.utils.data.DataLoader(train_data, **dataloader_args)
test_loader = torch.utils.data.DataLoader(test_data, **dataloader_args)

CUDA Available? True


In [None]:
show_batch(train_loader)

In [69]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)
model = Net().to(device)
summary(model, input_size=(3, 32, 32))

cuda
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             896
              GELU-2           [-1, 32, 32, 32]               0
       BatchNorm2d-3           [-1, 32, 32, 32]              64
           Dropout-4           [-1, 32, 32, 32]               0
            Conv2d-5           [-1, 32, 32, 32]           9,248
              GELU-6           [-1, 32, 32, 32]               0
       BatchNorm2d-7           [-1, 32, 32, 32]              64
           Dropout-8           [-1, 32, 32, 32]               0
            Conv2d-9           [-1, 32, 32, 32]           9,248
             GELU-10           [-1, 32, 32, 32]               0
      BatchNorm2d-11           [-1, 32, 32, 32]              64
          Dropout-12           [-1, 32, 32, 32]               0
           Conv2d-13           [-1, 32, 32, 32]           9,248
             GELU-14           [-1

In [72]:
model =  Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

EPOCHS = 25
for epoch in range(EPOCHS):

    print("EPOCH:", epoch + 1)
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

EPOCH: 1


Loss=1.2398051023483276 Batch_id=781 Accuracy=74.83: 100%|██████████| 782/782 [00:13<00:00, 56.53it/s] 



Test set: Average loss: 0.0000, Accuracy: 7743/10000 (77.43%)

EPOCH: 2


Loss=0.3576366901397705 Batch_id=781 Accuracy=75.26: 100%|██████████| 782/782 [00:13<00:00, 56.97it/s] 



Test set: Average loss: 0.0000, Accuracy: 8037/10000 (80.37%)

EPOCH: 3


Loss=0.8153434991836548 Batch_id=781 Accuracy=75.54: 100%|██████████| 782/782 [00:13<00:00, 56.61it/s] 



Test set: Average loss: 0.0000, Accuracy: 7915/10000 (79.15%)

EPOCH: 4


Loss=0.9965046644210815 Batch_id=781 Accuracy=75.67: 100%|██████████| 782/782 [00:13<00:00, 56.52it/s] 



Test set: Average loss: 0.0000, Accuracy: 8126/10000 (81.26%)

EPOCH: 5


Loss=0.741019606590271 Batch_id=781 Accuracy=75.68: 100%|██████████| 782/782 [00:13<00:00, 56.63it/s]  



Test set: Average loss: 0.0000, Accuracy: 7972/10000 (79.72%)

EPOCH: 6


Loss=0.7906625866889954 Batch_id=781 Accuracy=76.03: 100%|██████████| 782/782 [00:13<00:00, 56.00it/s] 



Test set: Average loss: 0.0000, Accuracy: 8091/10000 (80.91%)

EPOCH: 7


Loss=0.32137531042099 Batch_id=781 Accuracy=76.03: 100%|██████████| 782/782 [00:13<00:00, 57.29it/s]   



Test set: Average loss: 0.0000, Accuracy: 8077/10000 (80.77%)

EPOCH: 8


Loss=0.8464739322662354 Batch_id=781 Accuracy=76.35: 100%|██████████| 782/782 [00:13<00:00, 56.92it/s] 



Test set: Average loss: 0.0000, Accuracy: 7914/10000 (79.14%)

EPOCH: 9


Loss=0.7226274609565735 Batch_id=781 Accuracy=76.27: 100%|██████████| 782/782 [00:13<00:00, 56.65it/s] 



Test set: Average loss: 0.0000, Accuracy: 7944/10000 (79.44%)

EPOCH: 10


Loss=0.6443967223167419 Batch_id=781 Accuracy=76.77: 100%|██████████| 782/782 [00:13<00:00, 56.81it/s] 



Test set: Average loss: 0.0000, Accuracy: 7907/10000 (79.07%)

EPOCH: 11


Loss=1.0718231201171875 Batch_id=781 Accuracy=76.77: 100%|██████████| 782/782 [00:13<00:00, 56.85it/s] 



Test set: Average loss: 0.0000, Accuracy: 8141/10000 (81.41%)

EPOCH: 12


Loss=0.9072191715240479 Batch_id=781 Accuracy=76.82: 100%|██████████| 782/782 [00:13<00:00, 57.01it/s] 



Test set: Average loss: 0.0000, Accuracy: 8019/10000 (80.19%)

EPOCH: 13


Loss=0.5506047606468201 Batch_id=781 Accuracy=76.98: 100%|██████████| 782/782 [00:13<00:00, 56.83it/s] 



Test set: Average loss: 0.0000, Accuracy: 8016/10000 (80.16%)

EPOCH: 14


Loss=0.9507471919059753 Batch_id=781 Accuracy=77.11: 100%|██████████| 782/782 [00:13<00:00, 56.49it/s] 



Test set: Average loss: 0.0000, Accuracy: 8125/10000 (81.25%)

EPOCH: 15


Loss=0.739951491355896 Batch_id=781 Accuracy=77.16: 100%|██████████| 782/782 [00:13<00:00, 57.46it/s]  



Test set: Average loss: 0.0000, Accuracy: 8117/10000 (81.17%)

EPOCH: 16


Loss=1.1136152744293213 Batch_id=781 Accuracy=76.97: 100%|██████████| 782/782 [00:13<00:00, 56.86it/s] 



Test set: Average loss: 0.0000, Accuracy: 8139/10000 (81.39%)

EPOCH: 17


Loss=0.3905988931655884 Batch_id=781 Accuracy=77.46: 100%|██████████| 782/782 [00:13<00:00, 57.61it/s] 



Test set: Average loss: 0.0000, Accuracy: 8187/10000 (81.87%)

EPOCH: 18


Loss=1.4809682369232178 Batch_id=781 Accuracy=77.14: 100%|██████████| 782/782 [00:13<00:00, 57.35it/s] 



Test set: Average loss: 0.0000, Accuracy: 8255/10000 (82.55%)

EPOCH: 19


Loss=0.5681827664375305 Batch_id=781 Accuracy=77.58: 100%|██████████| 782/782 [00:13<00:00, 56.48it/s] 



Test set: Average loss: 0.0000, Accuracy: 8059/10000 (80.59%)

EPOCH: 20


Loss=0.6262993812561035 Batch_id=781 Accuracy=77.61: 100%|██████████| 782/782 [00:13<00:00, 57.72it/s] 



Test set: Average loss: 0.0000, Accuracy: 8069/10000 (80.69%)

EPOCH: 21


Loss=0.5258033871650696 Batch_id=781 Accuracy=77.74: 100%|██████████| 782/782 [00:13<00:00, 57.45it/s] 



Test set: Average loss: 0.0000, Accuracy: 8174/10000 (81.74%)

EPOCH: 22


Loss=0.2310718148946762 Batch_id=781 Accuracy=77.79: 100%|██████████| 782/782 [00:13<00:00, 57.33it/s] 



Test set: Average loss: 0.0000, Accuracy: 8245/10000 (82.45%)

EPOCH: 23


Loss=0.667325496673584 Batch_id=781 Accuracy=77.98: 100%|██████████| 782/782 [00:13<00:00, 57.46it/s]  



Test set: Average loss: 0.0000, Accuracy: 8196/10000 (81.96%)

EPOCH: 24


Loss=0.6156541705131531 Batch_id=781 Accuracy=78.05: 100%|██████████| 782/782 [00:13<00:00, 56.34it/s] 



Test set: Average loss: 0.0000, Accuracy: 8035/10000 (80.35%)

EPOCH: 25


Loss=0.5938315987586975 Batch_id=781 Accuracy=77.95: 100%|██████████| 782/782 [00:13<00:00, 57.03it/s] 



Test set: Average loss: 0.0000, Accuracy: 8201/10000 (82.01%)

