In [1]:
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
from torchsummary import summary
import gc

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

  from .autonotebook import tqdm as notebook_tqdm


'cuda'

In [8]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010]
    )
])

dataset = datasets.CIFAR10('data', train=True, download=True, transform=transform)

split = int(np.floor(len(dataset) * 0.1))
indices = list(range(len(dataset)))

np.random.seed(42)
np.random.shuffle(indices)

train_index, val_index = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_index)
val_sampler = SubsetRandomSampler(val_index)

train_loader = DataLoader(dataset, batch_size=64, sampler=train_sampler)
valid_loader = DataLoader(dataset, batch_size=64, sampler=val_sampler)

test_dataset = datasets.CIFAR10('data', train=False,download=True, transform=transform,)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Sequential(
                        nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
                        nn.BatchNorm2d(out_channels),
                        nn.ReLU())
        self.conv2 = nn.Sequential(
                        nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
                        nn.BatchNorm2d(out_channels))
        self.downsample = downsample
        self.relu = nn.ReLU()
        self.out_channels = out_channels
        
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out
    
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes = 10):
        super(ResNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Sequential(
                        nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3),
                        nn.BatchNorm2d(64),
                        nn.ReLU())
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
        self.layer0 = self._make_layer(block, 64, layers[0], stride = 1)
        self.layer1 = self._make_layer(block, 128, layers[1], stride = 2)
        self.layer2 = self._make_layer(block, 256, layers[2], stride = 2)
        self.layer3 = self._make_layer(block, 512, layers[3], stride = 2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512, num_classes)
        
    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride),
                nn.BatchNorm2d(planes),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
    
    
    def forward(self, x):
        # print(f"Shape Before Conv1: {x.shape}")
        x = self.conv1(x)
        # print(f"Shape after Conv1: {x.shape}")
        x = self.maxpool(x)
        # print(f"Shape after MaxPool: {x.shape}")
        x = self.layer0(x)
        # print(f"Shape after Layer0: {x.shape}")
        x = self.layer1(x)
        # print(f"Shape after Layer1: {x.shape}")
        x = self.layer2(x)
        # print(f"Shape after Layer2: {x.shape}")
        x = self.layer3(x)
        # print(f"Shape after Layer3: {x.shape}")
        x = self.avgpool(x)
        # print(f"Shape after avgpool: {x.shape}")
        x = x.view(x.size(0), -1)
        # print(f"Shape after x.view: {x.shape}")
        x = self.fc(x)
        # print(f"Shape after fc: {x.shape}")

        return x

In [4]:
num_classes = 10
num_epochs = 20
batch_size = 16
learning_rate = 0.01

model = ResNet(ResidualBlock, [3, 4, 6, 3]).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate, weight_decay = 0.001, momentum = 0.9)  

In [5]:
summary(model, input_size=(3,224,224), batch_size=16)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [32, 64, 112, 112]           9,472
       BatchNorm2d-2         [32, 64, 112, 112]             128
              ReLU-3         [32, 64, 112, 112]               0
         MaxPool2d-4           [32, 64, 56, 56]               0
            Conv2d-5           [32, 64, 56, 56]          36,928
       BatchNorm2d-6           [32, 64, 56, 56]             128
              ReLU-7           [32, 64, 56, 56]               0
            Conv2d-8           [32, 64, 56, 56]          36,928
       BatchNorm2d-9           [32, 64, 56, 56]             128
             ReLU-10           [32, 64, 56, 56]               0
    ResidualBlock-11           [32, 64, 56, 56]               0
           Conv2d-12           [32, 64, 56, 56]          36,928
      BatchNorm2d-13           [32, 64, 56, 56]             128
             ReLU-14           [32, 64,

In [7]:
from tqdm.auto import tqdm

total_step = len(train_loader)

for epoch in tqdm(range(num_epochs)):
    for i, (images, labels) in enumerate(train_loader):  

        model.train()
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        del images, labels, outputs
        torch.cuda.empty_cache()
        gc.collect()

    print (f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
            
    model.eval()
    with torch.inference_mode():
        correct = 0
        total = 0
        for images, labels in valid_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            del images, labels, outputs
    
        print(f'Accuracy of the network on the {5000} validation images: {100 * correct / total}') 

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch [1/20], Loss: 1.3763


  5%|▌         | 1/20 [06:13<1:58:25, 373.99s/it]

Accuracy of the network on the 5000 validation images: 58.74
Epoch [2/20], Loss: 0.6140


 10%|█         | 2/20 [12:31<1:52:49, 376.07s/it]

Accuracy of the network on the 5000 validation images: 69.34
Epoch [3/20], Loss: 0.8758


 15%|█▌        | 3/20 [18:47<1:46:34, 376.15s/it]

Accuracy of the network on the 5000 validation images: 73.46
Epoch [4/20], Loss: 0.4291


 20%|██        | 4/20 [25:01<1:40:02, 375.19s/it]

Accuracy of the network on the 5000 validation images: 78.0
Epoch [5/20], Loss: 0.1311


 25%|██▌       | 5/20 [31:18<1:33:57, 375.85s/it]

Accuracy of the network on the 5000 validation images: 76.82
Epoch [6/20], Loss: 0.7074


 30%|███       | 6/20 [37:34<1:27:44, 376.03s/it]

Accuracy of the network on the 5000 validation images: 82.44
Epoch [7/20], Loss: 0.3856


 35%|███▌      | 7/20 [43:45<1:21:06, 374.33s/it]

Accuracy of the network on the 5000 validation images: 81.06
Epoch [8/20], Loss: 0.2301


 40%|████      | 8/20 [49:58<1:14:46, 373.85s/it]

Accuracy of the network on the 5000 validation images: 82.32
Epoch [9/20], Loss: 0.5839


 45%|████▌     | 9/20 [56:12<1:08:32, 373.85s/it]

Accuracy of the network on the 5000 validation images: 80.34
Epoch [10/20], Loss: 0.1501


 50%|█████     | 10/20 [1:02:24<1:02:12, 373.28s/it]

Accuracy of the network on the 5000 validation images: 79.84
Epoch [11/20], Loss: 0.4623


 55%|█████▌    | 11/20 [1:08:34<55:50, 372.25s/it]  

Accuracy of the network on the 5000 validation images: 79.12
Epoch [12/20], Loss: 0.5279


 60%|██████    | 12/20 [1:14:44<49:31, 371.49s/it]

Accuracy of the network on the 5000 validation images: 83.02
Epoch [13/20], Loss: 0.0579


 65%|██████▌   | 13/20 [1:20:53<43:16, 370.90s/it]

Accuracy of the network on the 5000 validation images: 82.4
Epoch [14/20], Loss: 0.1985


 70%|███████   | 14/20 [1:27:02<37:02, 370.42s/it]

Accuracy of the network on the 5000 validation images: 81.68
Epoch [15/20], Loss: 0.2281


 75%|███████▌  | 15/20 [1:33:13<30:52, 370.41s/it]

Accuracy of the network on the 5000 validation images: 81.08
Epoch [16/20], Loss: 0.3019


 80%|████████  | 16/20 [1:39:23<24:40, 370.23s/it]

Accuracy of the network on the 5000 validation images: 83.82
Epoch [17/20], Loss: 0.2785


 85%|████████▌ | 17/20 [1:45:32<18:30, 370.12s/it]

Accuracy of the network on the 5000 validation images: 81.74
Epoch [18/20], Loss: 0.2078


 90%|█████████ | 18/20 [1:51:42<12:20, 370.06s/it]

Accuracy of the network on the 5000 validation images: 80.68
Epoch [19/20], Loss: 0.1852


 95%|█████████▌| 19/20 [1:57:52<06:09, 369.94s/it]

Accuracy of the network on the 5000 validation images: 83.92
Epoch [20/20], Loss: 0.5332


100%|██████████| 20/20 [2:04:04<00:00, 372.23s/it]

Accuracy of the network on the 5000 validation images: 83.74





In [9]:
model.eval()
with torch.inference_mode():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        del images, labels, outputs

    print('Accuracy of the network on the {} test images: {} %'.format(10000, 100 * correct / total))   

Accuracy of the network on the 10000 test images: 82.4 %
