<a href="https://colab.research.google.com/github/Abby-xu/PaperReading/blob/master/ResNet18.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Set Up
- Import libraries
- Set device
- Mount drive

In [1]:
from google.colab import drive
import torch
import torch.nn as nn
from torch.hub import load_state_dict_from_url
import numpy as np
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler
from tqdm import tqdm
import torch.optim as optim

In [2]:
# set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
drive.mount('/content/drive')

### ResNet 18

In [3]:
# ResNet Child Model: Residual Block
class BasicBlock(nn.Module):
    def __init__(self,inplanes: int,planes: int,stride: int = 1,downsample = None) -> None:
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes,planes, 3, stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, 3, stride=1,padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride


    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out

In [4]:
class ResNet(nn.Module):

    def __init__(self,layers: list, num_classes: int = 1000,zero_init_residual: bool = False) -> None:
        super(ResNet, self).__init__()
        self.inplanes = 64
        self.dilation = 1
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)

        self.layer1 = self._make_layer(64, layers[0], stride=1) # 'stride = 1' since output won't change (keep as 56x56 after maxpool?)
        self.layer2 = self._make_layer(128, layers[1], stride=2)
        self.layer3 = self._make_layer(256, layers[2], stride=2)
        self.layer4 = self._make_layer(512, layers[3], stride=2)


        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        #self.do = nn.Dropout(0.2)
        self.fc = nn.Linear(512, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)


    def _make_layer(self, planes: int, blocks: int,stride: int = 1) -> nn.Sequential:
        downsample = None
        if stride != 1 or self.inplanes != planes:
            downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes, 1, stride=stride, bias=False),
                                       nn.BatchNorm2d(planes))
        layers = []
        layers.append(BasicBlock(self.inplanes, planes, stride, downsample))
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(BasicBlock(self.inplanes, planes))
        return nn.Sequential(*layers)

    def forward(self, x):
        '''
        Decide the forward pass of the model

        1. input: conv1, bn1, relu, maxpool
        2. conv: layer1, layer2, layer3, layer4 (or stage)
        3. output: avgpool & fc
        '''

        # input
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        # Conv
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # Output
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        #x = self.do(x)
        x = self.fc(x)

        return x


In [5]:
def resnet18(num_classes: int = 1000, pretrained: bool = False, path= None, progress: bool = True) -> ResNet:
    '''
    if the pretrained is not defined, the model will be trained from official resnet18-f37072fd.pth
    progress shows the download progress bar if True
    '''

    # [2, 2, 2, 2] for ReNet18 layers, [3, 4, 6, 3] for ResNet34 layers
    model=ResNet([2, 2, 2, 2], num_classes, zero_init_residual= True)
    if pretrained:
        if not path:
            pretrained_state_dict = load_state_dict_from_url('https://download.pytorch.org/models/resnet18-f37072fd.pth',progress=progress)
            print('load pretrained model')
        else:
            pretrained_state_dict = torch.load(path)
        state_dict=model.state_dict()
        for k in pretrained_state_dict:
            if k in state_dict and k not in ['fc.weight', 'fc.bias']:
                state_dict[k] = pretrained_state_dict[k].data
        model.load_state_dict(state_dict)
    return model

#### Load Model

In [6]:
x = torch.randn(8, 3, 80, 80)
model = resnet18(2,pretrained=True)
# model = resnet18(2)
print(model)
a = model.state_dict()
o=model(x)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 181MB/s]


load pretrained model
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu):

### Training

#### Pre-train - data loading

In [7]:
def read_dataset(batch_size=16,valid_size=0.2,num_workers=0,pic_path='dataset'):
    """
    batch_size: Number of loaded drawings per batch
    valid_size: Percentage of training set to use as validation
    num_workers: Number of subprocesses to use for data loading
    pic_path: The path of the pictrues
    """
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
    ])

    # transfer data to torch.FloatTensor and standardized
    train_data = datasets.CIFAR10(pic_path, train=True, download=True, transform=transform_train)
    valid_data = datasets.CIFAR10(pic_path, train=True, download=True, transform=transform_test)
    test_data = datasets.CIFAR10(pic_path, train=False, download=True, transform=transform_test)

    # obtain training indices that will be used for validation
    num_train = len(train_data)
    indices = list(range(num_train))
    # random indices
    np.random.shuffle(indices)
    # the ratio of split
    split = int(np.floor(valid_size * num_train))
    # divide data to radin_data and valid_data
    train_idx, valid_idx = indices[split:], indices[:split]

    # define samplers for obtaining training and validation batches
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    # prepare data loaders (combine dataset and sampler)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
        sampler=train_sampler, num_workers=num_workers)
    valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size,
        sampler=valid_sampler, num_workers=num_workers)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size,
        num_workers=num_workers)

    return train_loader,valid_loader,test_loader

In [11]:
batch_size = 128
train_loader,valid_loader,test_loader = read_dataset(batch_size=batch_size,pic_path='dataset')

# load model - non-pretrained
n_class = 10
model = resnet18()
model = model.to(device)
# using cross entropy loss func
criterion = nn.CrossEntropyLoss().to(device)

# start training
n_epochs = 50
valid_loss_min = np.Inf # track change in validation loss
accuracy = []
lr = 0.1
counter = 0
for epoch in tqdm(range(1, n_epochs+1)):

    # keep track of training and validation loss
    train_loss = 0.0
    valid_loss = 0.0
    total_sample = 0
    right_sample = 0

    if counter/10 ==1:
        counter = 0
        lr = lr*0.5
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)

    ##### model of training set #####
    model.train() # using batch normalization and drop out
    for data, target in train_loader:
        data = data.to(device)
        target = target.to(device)
        # clear the gradients of all optimized variables
        optimizer.zero_grad()
        # forward pass: compute predicted outputs by passing inputs to the model
        output = model(data).to(device)  # == 'output = model.forward(data).to(device)'
        # calculate the batch loss
        loss = criterion(output, target)
        # backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # perform a single optimization step (parameter update)
        optimizer.step()
        # update training loss
        train_loss += loss.item()*data.size(0)

    ##### model of validation set #####
    model.eval()
    for data, target in valid_loader:
        data = data.to(device)
        target = target.to(device)
        # forward pass: compute predicted outputs by passing inputs to the model
        output = model(data).to(device)
        # calculate the batch loss
        loss = criterion(output, target)
        # update average validation loss
        valid_loss += loss.item()*data.size(0)
        # convert output probabilities to predicted class
        _, pred = torch.max(output, 1)
        # compare predictions to true label
        correct_tensor = pred.eq(target.data.view_as(pred))
        # correct = np.squeeze(correct_tensor.to(device).numpy())
        total_sample += batch_size
        for i in correct_tensor:
            if i:
                right_sample += 1
    print("\nAccuracy:",100*right_sample/total_sample,"%")
    accuracy.append(right_sample/total_sample)

    # cal avg loss
    train_loss = train_loss/len(train_loader.sampler)
    valid_loss = valid_loss/len(valid_loader.sampler)

    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
        epoch, train_loss, valid_loss))

    # save model
    if valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).'.format(valid_loss_min,valid_loss))
        # print('Saving model ...')
        # torch.save(model.state_dict(), 'checkpoint/resnet18_cifar10.pt')
        valid_loss_min = valid_loss
        counter = 0
    else:
        counter += 1
    print('------------------------------')

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


  2%|▏         | 1/50 [00:24<19:49, 24.28s/it]


Accuracy: 52.35363924050633 %
Epoch: 1 	Training Loss: 1.595356 	Validation Loss: 1.319011
Validation loss decreased (inf --> 1.319011).
------------------------------


  4%|▍         | 2/50 [00:48<19:26, 24.30s/it]


Accuracy: 60.68037974683544 %
Epoch: 2 	Training Loss: 1.079899 	Validation Loss: 1.111023
Validation loss decreased (1.319011 --> 1.111023).
------------------------------


  6%|▌         | 3/50 [01:12<19:02, 24.30s/it]


Accuracy: 65.92167721518987 %
Epoch: 3 	Training Loss: 0.864617 	Validation Loss: 0.958296
Validation loss decreased (1.111023 --> 0.958296).
------------------------------


  8%|▊         | 4/50 [01:37<18:41, 24.37s/it]


Accuracy: 68.40387658227849 %
Epoch: 4 	Training Loss: 0.747217 	Validation Loss: 0.885348
Validation loss decreased (0.958296 --> 0.885348).
------------------------------


 10%|█         | 5/50 [02:01<18:15, 24.35s/it]


Accuracy: 69.48180379746836 %
Epoch: 5 	Training Loss: 0.654395 	Validation Loss: 0.866201
Validation loss decreased (0.885348 --> 0.866201).
------------------------------


 12%|█▏        | 6/50 [02:25<17:49, 24.32s/it]


Accuracy: 70.44106012658227 %
Epoch: 6 	Training Loss: 0.582278 	Validation Loss: 0.865008
Validation loss decreased (0.866201 --> 0.865008).
------------------------------


 14%|█▍        | 7/50 [02:50<17:26, 24.33s/it]


Accuracy: 69.10601265822785 %
Epoch: 7 	Training Loss: 0.528030 	Validation Loss: 0.911217
------------------------------


 16%|█▌        | 8/50 [03:14<17:01, 24.33s/it]


Accuracy: 71.67721518987342 %
Epoch: 8 	Training Loss: 0.478582 	Validation Loss: 0.836266
Validation loss decreased (0.865008 --> 0.836266).
------------------------------


 18%|█▊        | 9/50 [03:39<16:47, 24.56s/it]


Accuracy: 70.18393987341773 %
Epoch: 9 	Training Loss: 0.435760 	Validation Loss: 0.929269
------------------------------


 20%|██        | 10/50 [04:04<16:22, 24.56s/it]


Accuracy: 72.37935126582279 %
Epoch: 10 	Training Loss: 0.392269 	Validation Loss: 0.831641
Validation loss decreased (0.836266 --> 0.831641).
------------------------------


 22%|██▏       | 11/50 [04:29<16:02, 24.69s/it]


Accuracy: 69.61036392405063 %
Epoch: 11 	Training Loss: 0.366163 	Validation Loss: 0.988677
------------------------------


 24%|██▍       | 12/50 [04:53<15:34, 24.60s/it]


Accuracy: 69.61036392405063 %
Epoch: 12 	Training Loss: 0.342061 	Validation Loss: 1.011468
------------------------------


 26%|██▌       | 13/50 [05:18<15:07, 24.54s/it]


Accuracy: 71.88488924050633 %
Epoch: 13 	Training Loss: 0.320673 	Validation Loss: 0.897379
------------------------------


 28%|██▊       | 14/50 [05:42<14:40, 24.45s/it]


Accuracy: 70.74762658227849 %
Epoch: 14 	Training Loss: 0.297359 	Validation Loss: 0.937570
------------------------------


 30%|███       | 15/50 [06:06<14:17, 24.50s/it]


Accuracy: 71.74643987341773 %
Epoch: 15 	Training Loss: 0.293721 	Validation Loss: 0.951103
------------------------------


 32%|███▏      | 16/50 [06:31<13:50, 24.42s/it]


Accuracy: 72.8935917721519 %
Epoch: 16 	Training Loss: 0.272644 	Validation Loss: 0.900356
------------------------------


 34%|███▍      | 17/50 [06:55<13:23, 24.35s/it]


Accuracy: 72.01344936708861 %
Epoch: 17 	Training Loss: 0.275835 	Validation Loss: 0.955970
------------------------------


 36%|███▌      | 18/50 [07:19<13:00, 24.39s/it]


Accuracy: 71.41020569620254 %
Epoch: 18 	Training Loss: 0.265940 	Validation Loss: 0.931372
------------------------------


 38%|███▊      | 19/50 [07:44<12:36, 24.39s/it]


Accuracy: 71.4003164556962 %
Epoch: 19 	Training Loss: 0.253923 	Validation Loss: 0.949594
------------------------------


 40%|████      | 20/50 [08:08<12:10, 24.35s/it]


Accuracy: 71.20253164556962 %
Epoch: 20 	Training Loss: 0.253392 	Validation Loss: 1.020884
------------------------------


 42%|████▏     | 21/50 [08:32<11:45, 24.32s/it]


Accuracy: 76.78995253164557 %
Epoch: 21 	Training Loss: 0.069627 	Validation Loss: 0.905729
------------------------------


 44%|████▍     | 22/50 [08:57<11:23, 24.40s/it]


Accuracy: 77.6503164556962 %
Epoch: 22 	Training Loss: 0.023470 	Validation Loss: 0.910688
------------------------------


 46%|████▌     | 23/50 [09:21<10:57, 24.34s/it]


Accuracy: 75.77136075949367 %
Epoch: 23 	Training Loss: 0.018261 	Validation Loss: 1.076827
------------------------------


 48%|████▊     | 24/50 [09:45<10:34, 24.39s/it]


Accuracy: 72.30023734177215 %
Epoch: 24 	Training Loss: 0.074131 	Validation Loss: 1.115674
------------------------------


 50%|█████     | 25/50 [10:10<10:07, 24.31s/it]


Accuracy: 73.05181962025317 %
Epoch: 25 	Training Loss: 0.112248 	Validation Loss: 1.038243
------------------------------


 52%|█████▏    | 26/50 [10:34<09:43, 24.32s/it]


Accuracy: 74.53520569620254 %
Epoch: 26 	Training Loss: 0.123569 	Validation Loss: 0.940431
------------------------------


 54%|█████▍    | 27/50 [10:58<09:19, 24.33s/it]


Accuracy: 74.1198575949367 %
Epoch: 27 	Training Loss: 0.113089 	Validation Loss: 0.975857
------------------------------


 56%|█████▌    | 28/50 [11:23<08:57, 24.41s/it]


Accuracy: 74.42642405063292 %
Epoch: 28 	Training Loss: 0.121116 	Validation Loss: 0.970545
------------------------------


 58%|█████▊    | 29/50 [11:47<08:32, 24.39s/it]


Accuracy: 74.66376582278481 %
Epoch: 29 	Training Loss: 0.124784 	Validation Loss: 0.993191
------------------------------


 60%|██████    | 30/50 [12:12<08:07, 24.36s/it]


Accuracy: 73.15071202531645 %
Epoch: 30 	Training Loss: 0.106638 	Validation Loss: 1.060766
------------------------------


 62%|██████▏   | 31/50 [12:36<07:42, 24.34s/it]


Accuracy: 78.74802215189874 %
Epoch: 31 	Training Loss: 0.024569 	Validation Loss: 0.851826
------------------------------


 64%|██████▍   | 32/50 [13:01<07:20, 24.46s/it]


Accuracy: 80.00395569620254 %
Epoch: 32 	Training Loss: 0.003937 	Validation Loss: 0.811809
Validation loss decreased (0.831641 --> 0.811809).
------------------------------


 66%|██████▌   | 33/50 [13:26<06:59, 24.66s/it]


Accuracy: 80.14240506329114 %
Epoch: 33 	Training Loss: 0.001773 	Validation Loss: 0.794922
Validation loss decreased (0.811809 --> 0.794922).
------------------------------


 68%|██████▊   | 34/50 [13:51<06:37, 24.86s/it]


Accuracy: 80.20174050632912 %
Epoch: 34 	Training Loss: 0.001421 	Validation Loss: 0.773721
Validation loss decreased (0.794922 --> 0.773721).
------------------------------


 70%|███████   | 35/50 [14:16<06:12, 24.81s/it]


Accuracy: 80.32041139240506 %
Epoch: 35 	Training Loss: 0.001402 	Validation Loss: 0.767851
Validation loss decreased (0.773721 --> 0.767851).
------------------------------


 72%|███████▏  | 36/50 [14:40<05:45, 24.71s/it]


Accuracy: 80.24129746835443 %
Epoch: 36 	Training Loss: 0.001322 	Validation Loss: 0.760757
Validation loss decreased (0.767851 --> 0.760757).
------------------------------


 74%|███████▍  | 37/50 [15:05<05:20, 24.63s/it]


Accuracy: 80.32041139240506 %
Epoch: 37 	Training Loss: 0.001426 	Validation Loss: 0.755303
Validation loss decreased (0.760757 --> 0.755303).
------------------------------


 76%|███████▌  | 38/50 [15:29<04:54, 24.57s/it]


Accuracy: 80.32041139240506 %
Epoch: 38 	Training Loss: 0.001445 	Validation Loss: 0.752183
Validation loss decreased (0.755303 --> 0.752183).
------------------------------


 78%|███████▊  | 39/50 [15:54<04:29, 24.53s/it]


Accuracy: 80.32041139240506 %
Epoch: 39 	Training Loss: 0.001386 	Validation Loss: 0.748416
Validation loss decreased (0.752183 --> 0.748416).
------------------------------


 80%|████████  | 40/50 [16:18<04:05, 24.56s/it]


Accuracy: 80.46875 %
Epoch: 40 	Training Loss: 0.001378 	Validation Loss: 0.747692
Validation loss decreased (0.748416 --> 0.747692).
------------------------------


 82%|████████▏ | 41/50 [16:43<03:40, 24.54s/it]


Accuracy: 80.37974683544304 %
Epoch: 41 	Training Loss: 0.001355 	Validation Loss: 0.752182
------------------------------


 84%|████████▍ | 42/50 [17:07<03:16, 24.57s/it]


Accuracy: 80.35996835443038 %
Epoch: 42 	Training Loss: 0.001337 	Validation Loss: 0.750868
------------------------------


 86%|████████▌ | 43/50 [17:32<02:52, 24.69s/it]


Accuracy: 80.28085443037975 %
Epoch: 43 	Training Loss: 0.001323 	Validation Loss: 0.754184
------------------------------


 88%|████████▊ | 44/50 [17:57<02:28, 24.76s/it]


Accuracy: 80.46875 %
Epoch: 44 	Training Loss: 0.001293 	Validation Loss: 0.748960
------------------------------


 90%|█████████ | 45/50 [18:23<02:04, 24.94s/it]


Accuracy: 80.55775316455696 %
Epoch: 45 	Training Loss: 0.001265 	Validation Loss: 0.750366
------------------------------


 92%|█████████▏| 46/50 [18:47<01:39, 24.88s/it]


Accuracy: 80.47863924050633 %
Epoch: 46 	Training Loss: 0.001259 	Validation Loss: 0.757070
------------------------------


 94%|█████████▍| 47/50 [19:12<01:14, 24.83s/it]


Accuracy: 80.55775316455696 %
Epoch: 47 	Training Loss: 0.001217 	Validation Loss: 0.756305
------------------------------


 96%|█████████▌| 48/50 [19:37<00:49, 24.79s/it]


Accuracy: 80.61708860759494 %
Epoch: 48 	Training Loss: 0.001215 	Validation Loss: 0.760608
------------------------------


 98%|█████████▊| 49/50 [20:01<00:24, 24.79s/it]


Accuracy: 80.43908227848101 %
Epoch: 49 	Training Loss: 0.001200 	Validation Loss: 0.761368
------------------------------


100%|██████████| 50/50 [20:26<00:00, 24.53s/it]


Accuracy: 80.48852848101266 %
Epoch: 50 	Training Loss: 0.001170 	Validation Loss: 0.767953
------------------------------



