In [20]:
import numpy as np
from torchvision.datasets import CIFAR10
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision import transforms
import torch, torch.nn as nn
import torch.nn.functional as F

In [21]:
means = np.array((0.4914, 0.4822, 0.4465))
stds = np.array((0.2023, 0.1994, 0.2010))

transform_augment = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomRotation([-30, 30]),
    transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)), #Performs actions like zooms, change shear angles.
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # Set the color params
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(means, stds),
])


transform_augment_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(means, stds)])

batch_size = 400


train_loader = CIFAR10("./cifar_data/", train=True, transform=transform_augment)
trainloader = torch.utils.data.DataLoader(train_loader, 
                                          batch_size=batch_size,
                                          shuffle=True, 
                                          num_workers=2)


test_loader = CIFAR10("./cifar_data/", train=False, transform=transform_augment_test)
testloader = torch.utils.data.DataLoader(test_loader, 
                                         batch_size=batch_size,
                                         shuffle=False, 
                                         num_workers=2)

In [22]:
import sys
 
# setting path
sys.path.append('../')
from models import train_model

In [23]:
def conv_block(in_channels, out_channels, pool=False):
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 
              nn.BatchNorm2d(out_channels), 
              nn.ReLU(inplace=True)]
    if pool: layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)

class ResNet9(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        
        self.conv1 = conv_block(in_channels, 64)
        self.conv2 = conv_block(64, 128, pool=True)
        self.res1 = nn.Sequential(conv_block(128, 128), conv_block(128, 128))
        
        self.conv3 = conv_block(128, 256, pool=True)
        self.conv4 = conv_block(256, 512, pool=True)
        self.res2 = nn.Sequential(conv_block(512, 512), conv_block(512, 512))
        
        self.classifier = nn.Sequential(nn.MaxPool2d(4), 
                                        nn.Flatten(), 
                                        nn.Linear(512, num_classes))
        
    def forward(self, xb):
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        out = self.classifier(out)
        return out

In [26]:
device = ("cuda" if torch.cuda.is_available() else 'cpu')
model = ResNet9(3, 10).to(device)

weight_decay = 1e-4

optimizer = torch.optim.Adam(model.parameters(), lr= 0.001) # 0.001 !
# optimizer = torch.optim.Adam(model.parameters(), lr= 0.001, weight_decay=weight_decay) # 0.001 !
criterion = torch.nn.CrossEntropyLoss()

tb = SummaryWriter()
model, optimizer = train_model(model, criterion, 
                    optimizer, train_dataloader=trainloader, test_dataloader=testloader, 
                    device=device,
                    n_epochs=100, batch_size=batch_size)

INFO Epoch 0
INFO The mean accuracy train: 0.388
INFO The mean accuracy test: 0.538
INFO -------------------------------
INFO Epoch 1
INFO The mean accuracy train: 0.553
INFO The mean accuracy test: 0.640
INFO -------------------------------
INFO Epoch 2
INFO The mean accuracy train: 0.631
INFO The mean accuracy test: 0.679
INFO -------------------------------
INFO Epoch 3
INFO The mean accuracy train: 0.676
INFO The mean accuracy test: 0.700
INFO -------------------------------
INFO Epoch 4
INFO The mean accuracy train: 0.707
INFO The mean accuracy test: 0.724
INFO -------------------------------
INFO Epoch 5
INFO The mean accuracy train: 0.727
INFO The mean accuracy test: 0.755
INFO -------------------------------
INFO Epoch 6
INFO The mean accuracy train: 0.741
INFO The mean accuracy test: 0.781
INFO -------------------------------
INFO Epoch 7
INFO The mean accuracy train: 0.759
INFO The mean accuracy test: 0.787
INFO -------------------------------
INFO Epoch 8
INFO The mean accur

INFO The mean accuracy test: 0.884
INFO -------------------------------
INFO Epoch 68
INFO The mean accuracy train: 0.947
INFO The mean accuracy test: 0.868
INFO -------------------------------
INFO Epoch 69
INFO The mean accuracy train: 0.946
INFO The mean accuracy test: 0.875
INFO -------------------------------
INFO Epoch 70
INFO The mean accuracy train: 0.948
INFO The mean accuracy test: 0.878
INFO -------------------------------
INFO Epoch 71
INFO The mean accuracy train: 0.949
INFO The mean accuracy test: 0.881
INFO -------------------------------
INFO Epoch 72
INFO The mean accuracy train: 0.952
INFO The mean accuracy test: 0.880
INFO -------------------------------
INFO Epoch 73
INFO The mean accuracy train: 0.952
INFO The mean accuracy test: 0.867
INFO -------------------------------
INFO Epoch 74
INFO The mean accuracy train: 0.951
INFO The mean accuracy test: 0.886
INFO -------------------------------
INFO Epoch 75
INFO The mean accuracy train: 0.951
INFO The mean accuracy t

In [25]:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        # calculate outputs by running images through the network
        outputs = model(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of test images: {100 * correct // total} %')

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f981c7ccb80>
Traceback (most recent call last):
  File "/home/work/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "/home/work/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1449, in _shutdown_workers
    if w.is_alive():Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f981c7ccb80>  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive

Traceback (most recent call last):
  File "/home/work/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
        self._shutdown_workers()
assert self._parent_pid == os.getpid(), 'can only test a child process'
  File "/home/work/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1449, in _shutdown_workers
AssertionError    : can only test a child processif w.is_aliv

Accuracy of test images: 86 %


In [None]:
checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict()}
torch.save(checkpoint, 'resnet.pth')