# Transfer Learning of ResNet50

The purpose of this notebook is to reimplement the results of the following paper:

* https://link.springer.com/article/10.1007%2Fs00521-020-05437-x 



Data source: 

* https://www.kaggle.com/mloey1/covid19-chest-ct-image-augmentation-gan-dataset 

In [38]:
import pandas as pd
import matplotlib.pyplot as plt
import sys
sys.path.append("../")
from covidct.dataset import *

path = '../data/'

train_data = CovidCTDataset(path, with_aug=False, with_cgan=False, split = 'train')
val_data = CovidCTDataset(path, with_aug=False, with_cgan=False, split = 'val')
test_data = CovidCTDataset(path, with_aug=False, with_cgan=False, split = 'test')

print(len(train_data))
print(len(val_data))
print(len(test_data))

425
118
199


# Loading ResNet50 pretrained model

https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html  

In [50]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

resnet = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)

num_ftrs = resnet.fc.in_features

resnet.fc = torch.nn.Linear(2048, 2)

resnet

Using cache found in C:\Users\Alex/.cache\torch\hub\pytorch_vision_v0.10.0


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

# Transfer learning with COVID-CT data to ResNet50 model

## define model training

In [52]:
from tqdm import tqdm
import copy

def train_resnet(model, data, criterion, optimizer, n_epochs):
    '''f'''

    val_acc_history = []
    best_acc = 0.0

    # iterating epochs
    for epoch in tqdm(range(n_epochs)):
        for phase in ['train', 'val']:
            if phase == 'train':
                resnet.train()
            else:
                resnet.eval()

            running_loss = 0.0
            running_corrects = 0

            # iterating data
            for x, y in data[phase]:
                inputs = x.to(device)
                labels = y.to(device)

                optimizer.zero_grad()

                # forward
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    _, predictions = torch.max(outputs, 1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # update running values
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(predictions == labels.data)

            epoch_loss = running_loss / len(data[phase].dataset)
            epoch_acc = running_corrects.double() / len(data[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

            if phase == 'val':
                val_acc_history.append(epoch_acc)
        print()

    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history


def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_gradd = False

## creating optimizer

In [54]:
# model to device
resnet = resnet.to(device)

feature_extract = True

# Gather the parameters to be optimized/updated in this run. If we are
#  finetuning we will be updating all parameters. However, if we are
#  doing feature extract method, we will only update the parameters
#  that we have just initialized, i.e. the parameters with requires_grad
#  is True.

params_to_update = resnet.parameters()
print('Parameters to learn:')

if feature_extract:
    params_to_update = []
    for name, param in resnet.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print('\t', name)
else:
    for name, param in resnet.named_parameters():
        if param.requires_grad == True:
            print('\t', name)

optimizer_ft = torch.optim.Adam(params_to_update, lr = 0.001)

Parameters to learn:
	 conv1.weight
	 bn1.weight
	 bn1.bias
	 layer1.0.conv1.weight
	 layer1.0.bn1.weight
	 layer1.0.bn1.bias
	 layer1.0.conv2.weight
	 layer1.0.bn2.weight
	 layer1.0.bn2.bias
	 layer1.0.conv3.weight
	 layer1.0.bn3.weight
	 layer1.0.bn3.bias
	 layer1.0.downsample.0.weight
	 layer1.0.downsample.1.weight
	 layer1.0.downsample.1.bias
	 layer1.1.conv1.weight
	 layer1.1.bn1.weight
	 layer1.1.bn1.bias
	 layer1.1.conv2.weight
	 layer1.1.bn2.weight
	 layer1.1.bn2.bias
	 layer1.1.conv3.weight
	 layer1.1.bn3.weight
	 layer1.1.bn3.bias
	 layer1.2.conv1.weight
	 layer1.2.bn1.weight
	 layer1.2.bn1.bias
	 layer1.2.conv2.weight
	 layer1.2.bn2.weight
	 layer1.2.bn2.bias
	 layer1.2.conv3.weight
	 layer1.2.bn3.weight
	 layer1.2.bn3.bias
	 layer2.0.conv1.weight
	 layer2.0.bn1.weight
	 layer2.0.bn1.bias
	 layer2.0.conv2.weight
	 layer2.0.bn2.weight
	 layer2.0.bn2.bias
	 layer2.0.conv3.weight
	 layer2.0.bn3.weight
	 layer2.0.bn3.bias
	 layer2.0.downsample.0.weight
	 layer2.0.downsample.1.we

In [55]:
# hyperparameters
n_epochs = 50
batch_size = 32

data = {'train': torch.utils.data.DataLoader(train_data, shuffle = True, batch_size = batch_size)
        ,'val': torch.utils.data.DataLoader(val_data, shuffle = True, batch_size = batch_size)
        ,'test': torch.utils.data.DataLoader(test_data, shuffle = True, batch_size = batch_size)}

In [56]:
criterion = torch.nn.CrossEntropyLoss()

resnet, hist = train_resnet(resnet, data, criterion, optimizer_ft, n_epochs)

  0%|          | 0/50 [00:00<?, ?it/s]

train Loss: 0.6852 Acc: 0.7012


  2%|▏         | 1/50 [03:15<2:39:18, 195.07s/it]

val Loss: 27.2574 Acc: 0.6186

train Loss: 0.5757 Acc: 0.7341


  4%|▍         | 2/50 [06:32<2:37:15, 196.57s/it]

val Loss: 12.6519 Acc: 0.4661

train Loss: 0.4505 Acc: 0.7929


  6%|▌         | 3/50 [09:51<2:34:38, 197.41s/it]

val Loss: 2.3632 Acc: 0.6017

train Loss: 0.4130 Acc: 0.8000


  8%|▊         | 4/50 [13:09<2:31:41, 197.86s/it]

val Loss: 1.6840 Acc: 0.4661

train Loss: 0.3331 Acc: 0.8682


 10%|█         | 5/50 [16:21<2:26:48, 195.75s/it]

val Loss: 2.0972 Acc: 0.6441

train Loss: 0.3394 Acc: 0.8541


 12%|█▏        | 6/50 [19:29<2:21:33, 193.03s/it]

val Loss: 2.5832 Acc: 0.6102

train Loss: 0.3178 Acc: 0.8612


 14%|█▍        | 7/50 [22:58<2:22:11, 198.41s/it]

val Loss: 1.1041 Acc: 0.5932

train Loss: 0.2431 Acc: 0.9176


 16%|█▌        | 8/50 [26:21<2:19:51, 199.80s/it]

val Loss: 1.2090 Acc: 0.6186

train Loss: 0.2130 Acc: 0.9388


 18%|█▊        | 9/50 [29:42<2:16:50, 200.26s/it]

val Loss: 1.3465 Acc: 0.6780

train Loss: 0.1046 Acc: 0.9765


 20%|██        | 10/50 [33:53<2:23:45, 215.64s/it]

val Loss: 2.1087 Acc: 0.5254

train Loss: 0.1860 Acc: 0.9224


 22%|██▏       | 11/50 [38:25<2:31:30, 233.08s/it]

val Loss: 1.1300 Acc: 0.6695

train Loss: 0.2203 Acc: 0.9153


 24%|██▍       | 12/50 [42:49<2:33:37, 242.56s/it]

val Loss: 1.4058 Acc: 0.5424

train Loss: 0.1076 Acc: 0.9553


 26%|██▌       | 13/50 [47:06<2:32:09, 246.73s/it]

val Loss: 0.8408 Acc: 0.7119

train Loss: 0.0461 Acc: 0.9882


 28%|██▊       | 14/50 [50:49<2:23:46, 239.62s/it]

val Loss: 3.2942 Acc: 0.5593

train Loss: 0.0308 Acc: 0.9929


 30%|███       | 15/50 [55:24<2:26:01, 250.33s/it]

val Loss: 1.4154 Acc: 0.6949

train Loss: 0.0964 Acc: 0.9671


 32%|███▏      | 16/50 [1:00:39<2:32:55, 269.86s/it]

val Loss: 1.8803 Acc: 0.7034

train Loss: 0.1051 Acc: 0.9647


 34%|███▍      | 17/50 [1:05:33<2:32:24, 277.12s/it]

val Loss: 2.4093 Acc: 0.6271

train Loss: 0.0720 Acc: 0.9718


 36%|███▌      | 18/50 [1:10:19<2:29:15, 279.85s/it]

val Loss: 1.1889 Acc: 0.7119

train Loss: 0.0483 Acc: 0.9882


 38%|███▊      | 19/50 [1:15:26<2:28:44, 287.87s/it]

val Loss: 2.9179 Acc: 0.6356

train Loss: 0.1828 Acc: 0.9412


 40%|████      | 20/50 [1:20:14<2:23:55, 287.85s/it]

val Loss: 5.0796 Acc: 0.4661

train Loss: 0.1999 Acc: 0.9365


 42%|████▏     | 21/50 [1:24:55<2:18:07, 285.77s/it]

val Loss: 1.4606 Acc: 0.7119

train Loss: 0.1107 Acc: 0.9600


 44%|████▍     | 22/50 [1:29:25<2:11:13, 281.18s/it]

val Loss: 4.0123 Acc: 0.5085

train Loss: 0.0511 Acc: 0.9835


 46%|████▌     | 23/50 [1:33:45<2:03:40, 274.84s/it]

val Loss: 0.5406 Acc: 0.7627

train Loss: 0.0305 Acc: 0.9906


 48%|████▊     | 24/50 [1:38:24<1:59:39, 276.12s/it]

val Loss: 1.1049 Acc: 0.7034

train Loss: 0.0178 Acc: 0.9953


 50%|█████     | 25/50 [1:42:58<1:54:46, 275.47s/it]

val Loss: 1.2885 Acc: 0.7542

train Loss: 0.0174 Acc: 0.9953


 52%|█████▏    | 26/50 [1:47:45<1:51:32, 278.86s/it]

val Loss: 0.9243 Acc: 0.7797

train Loss: 0.0142 Acc: 0.9953


 54%|█████▍    | 27/50 [1:52:18<1:46:10, 276.97s/it]

val Loss: 1.4615 Acc: 0.7542

train Loss: 0.0080 Acc: 0.9976


 56%|█████▌    | 28/50 [1:57:16<1:43:56, 283.47s/it]

val Loss: 1.3260 Acc: 0.7288

train Loss: 0.0252 Acc: 0.9929


 58%|█████▊    | 29/50 [2:02:32<1:42:35, 293.14s/it]

val Loss: 1.0864 Acc: 0.7458

train Loss: 0.0145 Acc: 0.9953


 60%|██████    | 30/50 [2:07:19<1:37:05, 291.29s/it]

val Loss: 2.4515 Acc: 0.6949

train Loss: 0.0243 Acc: 0.9859


 62%|██████▏   | 31/50 [2:12:03<1:31:30, 288.96s/it]

val Loss: 3.5543 Acc: 0.6017

train Loss: 0.1440 Acc: 0.9529


 64%|██████▍   | 32/50 [2:16:34<1:25:06, 283.72s/it]

val Loss: 2.5268 Acc: 0.6610

train Loss: 0.2529 Acc: 0.9200


 66%|██████▌   | 33/50 [2:21:11<1:19:47, 281.61s/it]

val Loss: 1.6882 Acc: 0.6864

train Loss: 0.1176 Acc: 0.9718


 68%|██████▊   | 34/50 [2:25:33<1:13:32, 275.80s/it]

val Loss: 1.0465 Acc: 0.6780

train Loss: 0.1958 Acc: 0.9412


 70%|███████   | 35/50 [2:30:08<1:08:52, 275.50s/it]

val Loss: 1.4789 Acc: 0.6271

train Loss: 0.1412 Acc: 0.9365


 72%|███████▏  | 36/50 [2:34:13<1:02:11, 266.54s/it]

val Loss: 0.8499 Acc: 0.6441

train Loss: 0.0810 Acc: 0.9694


 74%|███████▍  | 37/50 [2:38:40<57:45, 266.60s/it]  

val Loss: 2.2004 Acc: 0.6271

train Loss: 0.0248 Acc: 0.9953


 76%|███████▌  | 38/50 [2:42:44<51:58, 259.91s/it]

val Loss: 1.5279 Acc: 0.7203

train Loss: 0.0106 Acc: 0.9976


 78%|███████▊  | 39/50 [2:46:57<47:15, 257.78s/it]

val Loss: 1.2390 Acc: 0.6949

train Loss: 0.0107 Acc: 0.9976


 80%|████████  | 40/50 [2:50:54<41:54, 251.42s/it]

val Loss: 1.2000 Acc: 0.7627

train Loss: 0.0081 Acc: 0.9976


 82%|████████▏ | 41/50 [2:54:14<35:23, 235.91s/it]

val Loss: 1.3925 Acc: 0.7458

train Loss: 0.0056 Acc: 0.9976


 84%|████████▍ | 42/50 [2:57:20<29:28, 221.05s/it]

val Loss: 1.2750 Acc: 0.7373

train Loss: 0.0016 Acc: 1.0000


 86%|████████▌ | 43/50 [3:00:27<24:35, 210.73s/it]

val Loss: 1.4624 Acc: 0.6864

train Loss: 0.0024 Acc: 1.0000


 88%|████████▊ | 44/50 [3:03:32<20:19, 203.25s/it]

val Loss: 1.2732 Acc: 0.7373

train Loss: 0.0127 Acc: 0.9953


 90%|█████████ | 45/50 [3:06:35<16:25, 197.05s/it]

val Loss: 4.6056 Acc: 0.6102

train Loss: 0.1665 Acc: 0.9388


 92%|█████████▏| 46/50 [3:09:38<12:51, 192.97s/it]

val Loss: 12.5771 Acc: 0.5085

train Loss: 0.2976 Acc: 0.9200


 94%|█████████▍| 47/50 [3:12:45<09:32, 190.93s/it]

val Loss: 1.7582 Acc: 0.5508

train Loss: 0.1528 Acc: 0.9482


 96%|█████████▌| 48/50 [3:15:48<06:17, 188.64s/it]

val Loss: 1.1489 Acc: 0.6864

train Loss: 0.0677 Acc: 0.9788


 98%|█████████▊| 49/50 [3:18:54<03:07, 187.91s/it]

val Loss: 0.7642 Acc: 0.6864

train Loss: 0.0250 Acc: 0.9953


100%|██████████| 50/50 [3:22:11<00:00, 242.62s/it]

val Loss: 0.9440 Acc: 0.7458

Best val Acc: 0.779661





## Testing model on test dataset

In [91]:
with torch.no_grad():
    resnet.eval()
    y_pred = []
    y_true = []
    
    for img, lab in data['test']:
        batch_pred = resnet(img)
        for i, (y0, y1) in enumerate(batch_pred):
            if y0 > y1:
                pred = 0
            else:
                pred = 1
            y_pred.append(pred)
            y_true.append(lab[i].item())

print(y_pred)
print(y_true)

[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0]
[0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1,

In [109]:
import sklearn.metrics as metrics

print('true class')
print(metrics.confusion_matrix(y_true, y_pred))
print('--------')
print(f'accuracy: {metrics.accuracy_score(y_true, y_pred)}')
print(f'precision: {metrics.precision_score(y_true, y_pred)}')
print(f'recall: {metrics.recall_score(y_true, y_pred)}')
print(f'f1 score: {metrics.f1_score(y_true, y_pred)}')

true class
[[83 22]
 [40 54]]
--------
accuracy: 0.6884422110552764
precision: 0.7105263157894737
recall: 0.574468085106383
f1 score: 0.6352941176470589


## Saving model to repo

In [63]:
torch.save(resnet.state_dict(), '../models/ResNet50_raw/model.wts')