In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
import os
import sys
import time
import math
import torch.nn.init as init

In [3]:
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse

In [4]:
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [5]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential(
            nn.Conv2d(in_planes, self.expansion*planes,
                      kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(self.expansion*planes))


    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [6]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [7]:
"""
block: Block type, we have two types of base block.
       BasicBlock which is designed in ResNet-18.
       Bottleneck which is designed in ResNet-50.
num_blocks: the number of each layer in the block
channels: The size of each filter.
          channels[0]: Convl_CH
          channels[1]: Layer1_CH
          channels[2]: Layer2_CH
          channels[3]: Layer3_CH
"""
class ResNet_three_layers(nn.Module):
    def __init__(self, block, num_blocks, channels, num_classes=10):
        super(ResNet_three_layers, self).__init__()
        self.in_planes = channels[0]

        self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels[0])
        self.layer1 = self._make_layer(block, channels[1], num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, channels[2], num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, channels[3], num_blocks[2], stride=2)
        self.linear = nn.Linear(channels[-1]*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, 8)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [8]:
"""
block: Block type, we have two types of base block.
       BasicBlock which is designed in ResNet-18.
       Bottleneck which is designed in ResNet-50.
num_blocks: the number of each layer in the block
channels: The size of each filter.
          channels[0]: Convl_CH
          channels[1]: Layer1_CH
          channels[2]: Layer2_CH
          channels[3]: Layer3_CH
          channels[4]: Layer4_CH
"""
class ResNet_four_layers(nn.Module):
    def __init__(self, block, num_blocks, channels, num_classes=10):
        super(ResNet_four_layers, self).__init__()
        self.in_planes = channels[0]

        self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels[0])
        self.layer1 = self._make_layer(block, channels[1], num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, channels[2], num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, channels[3], num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, channels[4], num_blocks[3], stride=2)
        self.linear = nn.Linear(channels[-1]*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [9]:
def ResNet_Custom():
    """
    block:{BasicBlock, Bottleneck}
    num_blocks: int list with length 3
    we have tried [2, 6 ,2], [2, 8, 2], [3, 3, 3], [5, 5, 5], [7, 7, 7]
    channels: int list with length 4
    we have tried [16, 16, 32, 64], [32, 32, 64, 128], [64, 64, 128, 128]
    """
    # return ResNet_three_layers(BasicBlock, [7, 7, 7], [16, 16, 32, 64])

    """
    block:{BasicBlock, Bottleneck}
    num_blocks: int list with length 4
    we have tried [2, 2, 2, 2], [2, 2, 4, 4], [4, 4, 2, 2], [3, 4, 6, 3], [3, 4, 5, 3]
    channels: int list with length 5
    we have tried [16, 32, 64, 128, 256], [32, 32, 64, 128, 256], [16, 16, 32, 64, 128], [32, 128, 128, 128, 256], [32, 64, 128, 128, 256]
    """
    return ResNet_four_layers(BasicBlock, [2, 2, 2, 2], [16, 32, 64, 128, 256])

In [10]:
print('==> Building model..')
net = ResNet_Custom()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

==> Building model..


In [11]:
pip install torch-summary

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [12]:
from torchsummary import summary

In [13]:
summary(net, (3, 32, 32))

Layer (type:depth-idx)                   Output Shape              Param #
├─ResNet_four_layers: 1-1                [-1, 10]                  --
|    └─Conv2d: 2-1                       [-1, 16, 32, 32]          432
|    └─BatchNorm2d: 2-2                  [-1, 16, 32, 32]          32
|    └─Sequential: 2-3                   [-1, 32, 32, 32]          --
|    |    └─BasicBlock: 3-1              [-1, 32, 32, 32]          14,528
|    |    └─BasicBlock: 3-2              [-1, 32, 32, 32]          19,648
|    └─Sequential: 2-4                   [-1, 64, 16, 16]          --
|    |    └─BasicBlock: 3-3              [-1, 64, 16, 16]          57,728
|    |    └─BasicBlock: 3-4              [-1, 64, 16, 16]          78,208
|    └─Sequential: 2-5                   [-1, 128, 8, 8]           --
|    |    └─BasicBlock: 3-5              [-1, 128, 8, 8]           230,144
|    |    └─BasicBlock: 3-6              [-1, 128, 8, 8]           312,064
|    └─Sequential: 2-6                   [-1, 256, 4, 4]  

Layer (type:depth-idx)                   Output Shape              Param #
├─ResNet_four_layers: 1-1                [-1, 10]                  --
|    └─Conv2d: 2-1                       [-1, 16, 32, 32]          432
|    └─BatchNorm2d: 2-2                  [-1, 16, 32, 32]          32
|    └─Sequential: 2-3                   [-1, 32, 32, 32]          --
|    |    └─BasicBlock: 3-1              [-1, 32, 32, 32]          14,528
|    |    └─BasicBlock: 3-2              [-1, 32, 32, 32]          19,648
|    └─Sequential: 2-4                   [-1, 64, 16, 16]          --
|    |    └─BasicBlock: 3-3              [-1, 64, 16, 16]          57,728
|    |    └─BasicBlock: 3-4              [-1, 64, 16, 16]          78,208
|    └─Sequential: 2-5                   [-1, 128, 8, 8]           --
|    |    └─BasicBlock: 3-5              [-1, 128, 8, 8]           230,144
|    |    └─BasicBlock: 3-6              [-1, 128, 8, 8]           312,064
|    └─Sequential: 2-6                   [-1, 256, 4, 4]  

In [14]:
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

In [15]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

In [16]:
# Training
def train(epoch):
    # print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        '''
        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
        '''

In [17]:
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            '''
            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
            '''

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        # print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt1.pth')
        best_acc = acc
        print(best_acc)

In [18]:
from tqdm import tqdm

In [19]:
for epoch in tqdm(range(start_epoch, start_epoch+200)):
    train(epoch)
    test(epoch)
    scheduler.step()

  0%|          | 1/200 [00:19<1:04:48, 19.54s/it]

48.76


  1%|          | 2/200 [00:38<1:02:28, 18.93s/it]

57.44


  2%|▏         | 3/200 [00:56<1:01:47, 18.82s/it]

58.68


  2%|▏         | 4/200 [01:15<1:02:00, 18.98s/it]

72.17


  2%|▎         | 5/200 [01:34<1:01:43, 18.99s/it]

75.89


  4%|▎         | 7/200 [02:12<1:00:55, 18.94s/it]

78.08


  5%|▌         | 10/200 [03:09<1:00:01, 18.95s/it]

79.54


  6%|▌         | 12/200 [03:47<59:43, 19.06s/it]

79.76


  7%|▋         | 14/200 [04:26<59:18, 19.13s/it]

80.17


  8%|▊         | 15/200 [04:45<59:05, 19.17s/it]

80.26


  8%|▊         | 16/200 [05:04<58:48, 19.18s/it]

81.66


 10%|█         | 21/200 [06:40<57:05, 19.14s/it]

82.35


 12%|█▎        | 25/200 [07:56<55:52, 19.16s/it]

83.55


 16%|█▌        | 32/200 [10:11<53:51, 19.23s/it]

84.34


 21%|██        | 42/200 [13:22<50:28, 19.17s/it]

85.42


 30%|███       | 60/200 [19:08<44:43, 19.16s/it]

85.57


 32%|███▏      | 63/200 [20:05<43:50, 19.20s/it]

85.61


 34%|███▎      | 67/200 [21:22<42:26, 19.15s/it]

86.38


 34%|███▍      | 69/200 [22:00<41:47, 19.14s/it]

86.55


 40%|████      | 81/200 [25:50<38:06, 19.21s/it]

87.16


 41%|████      | 82/200 [26:09<37:44, 19.19s/it]

87.53


 43%|████▎     | 86/200 [27:26<36:26, 19.18s/it]

87.82


 49%|████▉     | 98/200 [31:16<32:30, 19.13s/it]

88.14


 50%|█████     | 101/200 [32:13<31:34, 19.13s/it]

88.94


 56%|█████▌    | 112/200 [35:43<28:03, 19.13s/it]

89.27


 60%|██████    | 121/200 [38:35<25:09, 19.10s/it]

89.83


 62%|██████▏   | 123/200 [39:13<24:31, 19.11s/it]

89.92


 62%|██████▏   | 124/200 [39:32<24:15, 19.14s/it]

90.43


 66%|██████▋   | 133/200 [42:25<21:23, 19.15s/it]

90.96


 70%|███████   | 140/200 [44:38<19:06, 19.11s/it]

91.38


 73%|███████▎  | 146/200 [46:33<17:17, 19.21s/it]

91.42


 74%|███████▍  | 149/200 [47:31<16:20, 19.23s/it]

91.72


 76%|███████▌  | 151/200 [48:10<15:42, 19.24s/it]

92.42


 76%|███████▌  | 152/200 [48:29<15:23, 19.24s/it]

92.51


 77%|███████▋  | 154/200 [49:08<14:47, 19.29s/it]

92.66


 79%|███████▉  | 158/200 [50:25<13:27, 19.24s/it]

92.84


 82%|████████▏ | 163/200 [52:01<11:50, 19.20s/it]

93.01


 83%|████████▎ | 166/200 [52:58<10:52, 19.19s/it]

93.03


 84%|████████▎ | 167/200 [53:17<10:32, 19.17s/it]

93.26


 84%|████████▍ | 168/200 [53:37<10:16, 19.27s/it]

93.82


 84%|████████▍ | 169/200 [53:56<09:56, 19.24s/it]

93.84


 86%|████████▌ | 172/200 [54:53<08:57, 19.21s/it]

93.95


 86%|████████▋ | 173/200 [55:13<08:38, 19.19s/it]

94.02


 88%|████████▊ | 175/200 [55:51<08:00, 19.21s/it]

94.07


 88%|████████▊ | 176/200 [56:10<07:40, 19.18s/it]

94.09


 88%|████████▊ | 177/200 [56:29<07:22, 19.22s/it]

94.23


 90%|████████▉ | 179/200 [57:08<06:43, 19.20s/it]

94.38


 90%|█████████ | 180/200 [57:27<06:25, 19.25s/it]

94.4


 91%|█████████ | 182/200 [58:05<05:44, 19.16s/it]

94.56


 92%|█████████▏| 184/200 [58:44<05:07, 19.22s/it]

94.57


 92%|█████████▎| 185/200 [59:03<04:47, 19.18s/it]

94.69


 93%|█████████▎| 186/200 [59:22<04:28, 19.17s/it]

94.75


 96%|█████████▌| 191/200 [1:00:58<02:52, 19.14s/it]

94.8


 96%|█████████▌| 192/200 [1:01:17<02:33, 19.17s/it]

94.84


 96%|█████████▋| 193/200 [1:01:36<02:14, 19.15s/it]

94.9


 97%|█████████▋| 194/200 [1:01:55<01:54, 19.15s/it]

94.93


100%|██████████| 200/200 [1:03:50<00:00, 19.15s/it]
