In [1]:
import torch
import torchvision
from torchvision.datasets import CIFAR10
import torch.nn as nn
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm
from torchsummary import summary
import copy
from tensorboardX import SummaryWriter

In [2]:
class IdentityShortcut(nn.Module):
    """
        A class to make Identity Mapping Shortcut

        Attributes:
            pooling : A max pooling layer 
            extra_cahnnel : The difference between input an output channels
    """
    def __init__(self, in_channels, out_channels, stride):
        """
            Initialize the Identity Shortcut Class

            Args:
                in_channels: number of input channels
                out_channels: number of output channels
                stride: size of stride

            Returns:
                None
        """
        super().__init__()
        self.pooling = nn.MaxPool2d(1, stride=stride)
        self.extra_channel = out_channels - in_channels
    
    def forward(self, x):
        x = F.pad(x, (0, 0, 0, 0, 0, self.extra_channel))
        x = self.pooling(x)
        return x


In [3]:
class ResidualBlock(nn.Module):
    """
        A class about residual block

        When stride == 1, just add input value to output of residual block
        When stride == 2, process the input value according to shortcut type and add it to output of residual block
        Attributes:
            residual_blcok: A sequential container to sequentially configure the layers for residual learning
            shortcut: A shortcut connection of residual block
            relu: A relu activation layer
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, shortcut_type = None):
        """
            Initialize the residual block

            Args:
                in_channels: The number of input channels
                out_channels: The number of output channels
                kernel_size: The number of kernnel size in the convolution layers
                stride: The size of stride
                shortcut_type: The type of shortcut connection when stride is not 1
            Returns:
                None
        """
        super().__init__()
        self.residual_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size, padding = 1, bias = False),
            nn.BatchNorm2d(out_channels)
        )

        self.shortcut = nn.Sequential()
        self.relu = nn.ReLU()
        
        if stride != 1:
            if shortcut_type == 'A':
                self.shortcut = IdentityShortcut(in_channels, out_channels, stride)
            elif shortcut_type == 'B':
                self.shortcut = nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, 1, stride),
                    nn.BatchNorm2d(out_channels)
                )
        
    
    def forward(self, x):
        x = self.residual_block(x) + self.shortcut(x)
        x = self.relu(x)
        return x

In [4]:
class ResNet(nn.Module):
    """
        A class about residual network

        Attributes:
            conv1_x: First set of layers
            conv2_x: Second set of residual blocks
            conv3_x: Third set of residual blocks
            conv4_x: Fourth set of residual blocks
            avg_pool: A average pooling layer of residual network
            flatten: A flatten layer for the fully connected layer
            fc: A fully connected layer to get probability of each class about given image
    """
    def __init__(self, conv_num, num_classes):
        """
            Initialize the residual network

            Args:
                conv_num: The number of residual blocks in each set of residual block
                num_classes: The number of classes to predict
            Returns:
                None
        """
        super().__init__()
        self.conv1_x = nn.Sequential(
                                    nn.Conv2d(3, 16, kernel_size = 3, stride = 1, padding=1, bias=False),
                                    nn.BatchNorm2d(16),
                                    nn.ReLU(),
                                    #nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1)
        )
        
        self.conv2_x = nn.Sequential(*[ResidualBlock(16, 16, 3, 1) for _ in range(conv_num[0])])
        self.conv3_x = nn.Sequential(ResidualBlock(16, 32, 3, 2, 'A'), *[ResidualBlock(32, 32, 3, 1) for _ in range(conv_num[1] - 1)])
        self.conv4_x = nn.Sequential(ResidualBlock(32, 64, 3, 2, 'A'), *[ResidualBlock(64, 64, 3, 1) for _ in range(conv_num[2] - 1)])
        #self.conv5_x = nn.Sequential(ResidualBlock(256, 128, 3, 2), *[ResidualBlock(512, 512, 3, 1) for _ in range(conv_num[3])])

        #self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.avg_pool = nn.AvgPool2d(8, stride=1)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(in_features=64,out_features=num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity = 'relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv1_x(x)
        x = self.conv2_x(x)
        x = self.conv3_x(x)
        x = self.conv4_x(x)
        #x = self.conv5_x(x)
        x = self.avg_pool(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

In [5]:
#model = ResNet([3, 3, 3], 10).to('cuda') # resnet20
model = ResNet([5, 5, 5], 10).to('cuda') # resnet32
#resnet_44 = ResNet([7, 7, 7], 10) # resnet44
#resnet_56 = ResNet([9, 9, 9], 10) # resnet56
#resnet_110 = ResNet([18, 18, 18], 10) # resnet110

In [6]:
summary(model, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             432
       BatchNorm2d-2           [-1, 16, 32, 32]              32
              ReLU-3           [-1, 16, 32, 32]               0
            Conv2d-4           [-1, 16, 32, 32]           2,304
       BatchNorm2d-5           [-1, 16, 32, 32]              32
              ReLU-6           [-1, 16, 32, 32]               0
            Conv2d-7           [-1, 16, 32, 32]           2,304
       BatchNorm2d-8           [-1, 16, 32, 32]              32
              ReLU-9           [-1, 16, 32, 32]               0
    ResidualBlock-10           [-1, 16, 32, 32]               0
           Conv2d-11           [-1, 16, 32, 32]           2,304
      BatchNorm2d-12           [-1, 16, 32, 32]              32
             ReLU-13           [-1, 16, 32, 32]               0
           Conv2d-14           [-1, 16,

In [7]:
train_cifar10 = CIFAR10(root = 'datasets/cifar10', train=True, download=True)
val_cifar10 = CIFAR10(root = 'datasets/cifar10', train=False, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
train_mean = train_cifar10.data.mean(axis=(0,1,2)) / 255
train_std = train_cifar10.data.std(axis=(0,1,2)) / 255

In [9]:
train_transforms = transforms.Compose([
                                        #transforms.Resize(224),
                                        transforms.RandomCrop(32, padding=4),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ToTensor(),
                                        transforms.Normalize(train_mean, train_std),                                      
                                        ])

val_transforms = transforms.Compose([
                                        transforms.ToTensor(),
                                        #transforms.Resize(224),
                                        transforms.Normalize(train_mean, train_std)
                                        ])

train_cifar10.transform = train_transforms
val_cifar10.transform = train_transforms

train_dl = DataLoader(train_cifar10, batch_size=256, shuffle=True, num_workers=4)
val_dl = DataLoader(val_cifar10, batch_size=128, shuffle=True, num_workers=4)





In [10]:
def count_correct(output, target):
    pred = output.argmax(1, keepdim=True)
    corrects = pred.eq(target.view_as(pred)).sum().item()
    return corrects

In [11]:
train_losses = []
val_losses = []
train_accs = []
val_accs = []

In [12]:
#writer = SummaryWriter('resnet_logs')
optimizer = torch.optim.SGD(model.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=1e-4)
#optimizer = torch.optim.Adam(model.parameters())
patience = 0
device = 'cuda'
model.to(device)
epochs = 100
loss_func = nn.CrossEntropyLoss()
decay_epoch = [32000, 48000]
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decay_epoch, gamma=0.1)
#lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor = 0.1, patience=10)
#accumulation_steps = 2
best_model = None
best_accs = -1
for _ in tqdm(range(epochs)):
    global_loss = 0
    corrects = 0
    model.train()
    for batch_idx, (data, target) in enumerate(train_dl):
        data, target = data.to(device), target.to(device)
        
        output = model(data)
        loss = loss_func(output, target)
        global_loss = global_loss + loss.item()
        optimizer.zero_grad()
        loss.backward()
        #if (batch_idx + 1) % accumulation_steps == 0:
        optimizer.step()
        lr_scheduler.step()
        corrects += count_correct(output, target)

    train_losses.append(global_loss / (batch_idx + 1))
    train_accs.append(corrects / len(train_cifar10) * 100)
    
    model.eval()
    corrects = 0
    global_loss = 0
    for batch_idx, (data, target) in enumerate(val_dl):
        data, target = data.to(device), target.to(device)
        with torch.no_grad():
            output = model(data)
            loss = loss_func(output, target)
            global_loss = global_loss + loss.item()
            corrects += count_correct(output, target)

    val_losses.append(global_loss / (batch_idx + 1))
    val_accs.append(corrects / len(val_cifar10) * 100)
    
    
    if best_accs > val_accs[-1]:
        patience = patience + 1
        if patience == 3:
            val_accs[-1] = best_accs
            model = copy.deepcopy(best_model)
            patience = 0
    else:
        best_accs = val_accs[-1]
        best_model = copy.deepcopy(model)
    
    
    #writer.add_scalar('resnet_log/train_error', 100 - train_accs[-1], _ + 1)
    #writer.add_scalar('resnet_log/validation_error', 100 - val_accs[-1], _ + 1)
    
    if (_ + 1) % 10 == 0:
        print("Epoch %d | train_loss = %.2f |  train_acc = %.2f | val_loss = %.2f | val_acc = %.2f" % (_ + 1, train_losses[-1], train_accs[-1], val_losses[-1], val_accs[-1]))

 10%|█         | 10/100 [02:59<26:45, 17.84s/it]

Epoch 10 | train_loss = 0.49 |  train_acc = 83.18 | val_loss = 0.65 | val_acc = 78.16


 20%|██        | 20/100 [05:58<23:54, 17.94s/it]

Epoch 20 | train_loss = 0.44 |  train_acc = 84.63 | val_loss = 0.52 | val_acc = 82.22


 30%|███       | 30/100 [08:57<20:48, 17.83s/it]

Epoch 30 | train_loss = 0.45 |  train_acc = 84.43 | val_loss = 0.51 | val_acc = 82.58


 40%|████      | 40/100 [11:57<17:53, 17.89s/it]

Epoch 40 | train_loss = 0.45 |  train_acc = 84.51 | val_loss = 0.52 | val_acc = 82.72


 50%|█████     | 50/100 [14:56<14:53, 17.88s/it]

Epoch 50 | train_loss = 0.44 |  train_acc = 84.60 | val_loss = 0.51 | val_acc = 82.36


 55%|█████▌    | 55/100 [16:26<13:28, 17.97s/it]