# Transfer learning Fine-tuning on the `train_dataset`

This notebook demonstrates how a pre-trained CNN can be fine-tuned using randomly selected train splits. Here, we'll demonstrate fine-tuning a pre-trained CNN on the medical image classification task. In this example, fine-tuning serves to update the CNN to new measurement parameters. This code illustrates the procedure described in the `3.Materials and methods`. 

Copyright (C) 2023, Zhao Bingqiang, All Rights Reserved

Email: zbqherb@163.com

2023-07-02

## Import Libs

In [None]:
import time
import os
import wandb
import random

import pandas as pd
import numpy as np
from tqdm import tqdm
from datetime import datetime
from matplotlib import colors as mcolors
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader

from torchvision import models
import torch.optim as optim
from torch.optim import lr_scheduler

from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score

import warnings
warnings.filterwarnings("ignore")

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

## Image data preprocessing

In [None]:
# COVID-19 CT
train_transform = transforms.Compose([transforms.RandomResizedCrop(512),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,), (0.5,))])

test_transform = transforms.Compose([transforms.Resize(1000),
                                     transforms.CenterCrop(512),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5,), (0.5,))])

## Load image dataset `data_split`

In [None]:
dataset_dir = 'data_split'
train_path = os.path.join(dataset_dir, 'train')
test_path = os.path.join(dataset_dir, 'val')
print('train_dataset Path', train_path)
print('test_dataset Path', test_path)

# Load train_dataset
train_dataset = datasets.ImageFolder(train_path, train_transform)
# Load test_dataset
test_dataset = datasets.ImageFolder(test_path, test_transform)

print('train_dataset number', len(train_dataset))
print('train_dataset class number', len(train_dataset.classes))
print('train_dataset class name', train_dataset.classes)

print('test_dataset number', len(test_dataset))
print('test_dataset class number', len(test_dataset.classes))
print('test_dataset class name', test_dataset.classes)

## Class and Index mapping

In [None]:
# Mapping：Class to Index
class_to_idx = train_dataset.class_to_idx
# Mapping：Index to Class 
idx_to_labels = {y:x for x,y in train_dataset.class_to_idx.items()}
# Save mapping files
np.save('table/idx_to_labels.npy', idx_to_labels)
np.save('table/labels_to_idx.npy', class_to_idx)

print(class_to_idx)
print(idx_to_labels)

## Define DataLoader

In [None]:
BATCH_SIZE = 32

# train_dataset DataLoader
train_loader = DataLoader(train_dataset,
                          batch_size = BATCH_SIZE,
                          shuffle = True,
                          num_workers = 4
                         )

# test_dataset DataLoader
test_loader = DataLoader(test_dataset,
                         batch_size = BATCH_SIZE,
                         shuffle = False,
                         num_workers = 4
                        )

## Visualize the image and annotation of a batch

In [None]:
images, labels = next(iter(train_loader))
print(images.shape)
print(labels)

# Tensor to np.array
images = images.numpy()
plt.figure(figsize = (15,10))
plt.hist(images[10].flatten(), bins = 100)
plt.tick_params(labelsize = 25)
plt.savefig('figure/hist.tif', dpi = 300, bbox_inches = 'tight')

In [None]:
# Preprocessed images in the batch 
idx = 1
label = labels[idx].item()
# Preprocessed image
plt.subplot(121)
plt.imshow(images[idx].transpose((1,2,0)))
plt.axis('off')
plt.title('Preprocessed:'+ idx_to_labels[label], fontsize = 10)

# Original image
plt.subplot(122)
mean = np.array((0.5,))
std = np.array((0.5,))
# mean = np.array([0.485, 0.456, 0.406])
# std = np.array([0.229, 0.224, 0.225])

plt.imshow(np.clip(images[idx].transpose((1,2,0)) * std + mean, 0, 1))
plt.axis('off')
plt.title('Original:'+ idx_to_labels[label], fontsize = 10)
plt.savefig('figure/image visualization2.tif', dpi = 300, bbox_inches = 'tight')

## Transfer learning fine-tuning options

https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

Now we set up a ResNet CNN and load weights that previously trained for the `ImageNet`.

In [None]:
n_class = len(train_dataset.classes)

### 1. Fine-tuning fully connected layer only

In [None]:
# # Load pre_trained image classification model
# model = models.resnet18(pretrained = True) 
# # New layer default (requires_grad = True)
# model.fc = nn.Linear(model.fc.in_features, n_class)
# optimizer = optim.Adam(model.fc.parameters())

# print(model.fc)

### 2. Fine-tunning all layers

In [None]:
model = models.resnet18(pretrained = True)
model.fc = nn.Linear(model.fc.in_features, n_class)
optimizer = optim.Adam(model.parameters())

print(model.fc)

### 3. Initialize all model weights randomly and train all layers from scratch

In [None]:
# # Only the model structure is loaded, not the pre-training weight parameters
# model = models.resnet18(pretrained=False) 
# model.fc = nn.Linear(model.fc.in_features, n_class)
# optimizer = optim.Adam(model.parameters())

## Model training Parameters

In [None]:
model = model.to(device)
criterion = nn.CrossEntropyLoss() 

EPOCHS = 100

lr_scheduler = lr_scheduler.StepLR(optimizer, step_size = 50, gamma = 0.1)

## Function: Train on the `train_dataset`

In [None]:
def train_one_batch(images, labels):
    
    '''
    Train a batch, returns the Training Log of the current batch
    '''
    
    images = images.to(device)
    labels = labels.to(device)
    
    # Forward propagation
    outputs = model(images) 
    # Calculate the average cross-entropy loss function value 
    # of each sample in the current batch
    loss = criterion(outputs, labels) 
    
    # Back propagation, optimize and update the weight
    optimizer.zero_grad() # gradient to zero
    loss.backward() # Back propagation
    optimizer.step() # Update parameters
    
    # Gets the label ID and predicted ID for the current batch
    # Gets the predicted ID for all images in the current batch
    _, preds = torch.max(outputs, 1) 
    preds = preds.cpu().numpy()
    loss = loss.detach().cpu().numpy()
    outputs = outputs.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()
    
    log_train = {}
    log_train['epoch'] = epoch
    log_train['batch'] = batch_idx
    # Classification evaluation index on train_dataset
    log_train['train_loss'] = loss
    log_train['train_accuracy'] = accuracy_score(labels, preds)
    log_train['train_precision'] = precision_score(labels, preds, average = 'macro')
    log_train['train_recall'] = recall_score(labels, preds, average = 'macro')
    log_train['train_f1-score'] = f1_score(labels, preds, average = 'macro')
    
    return log_train

## Function: Evaluate on the `test_dataset`

In [None]:
def evaluate_testset():
    '''
    Evaluate the test_dataset, returns the Test log of current epoch
    '''

    loss_list = []
    labels_list = []
    preds_list = []
    
    with torch.no_grad():
        for images, labels in test_loader: # Generate a batch of data and annotations
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images) # forward prediction

            # Obtain the label and predicted ID for the test_dataset
            # Gets the predicted ID for all images in the current batch
            _, preds = torch.max(outputs, 1) 
            preds = preds.cpu().numpy()
            # Calculate the average cross-entropy loss function value 
            # of each sample in the current batch via logit
            loss = criterion(outputs, labels) 
            loss = loss.detach().cpu().numpy()
            outputs = outputs.detach().cpu().numpy()
            labels = labels.detach().cpu().numpy()

            loss_list.append(loss)
            labels_list.extend(labels)
            preds_list.extend(preds)
        
    log_test = {}
    log_test['epoch'] = epoch
    
    # Classification evaluation index on test_dataset
    log_test['test_loss'] = np.mean(loss_list)
    log_test['test_accuracy'] = accuracy_score(labels_list, preds_list)
    log_test['test_precision'] = precision_score(labels_list, preds_list, average = 'macro')
    log_test['test_recall'] = recall_score(labels_list, preds_list, average = 'macro')
    log_test['test_f1-score'] = f1_score(labels_list, preds_list, average = 'macro')
    
    return log_test

## Log recoed before training starts

In [None]:
epoch = 0
batch_idx = 0
best_test_accuracy = 0

# log_train - train_dataset
df_train_log = pd.DataFrame()
log_train = {}
log_train['epoch'] = 0
log_train['batch'] = 0
images, labels = next(iter(train_loader))
log_train.update(train_one_batch(images, labels))
df_train_log = df_train_log._append(log_train, ignore_index = True)

df_train_log

In [None]:
# log_train - test_dataset
df_test_log = pd.DataFrame()
log_test = {}
log_test['epoch'] = 0
log_test.update(evaluate_testset())
df_test_log = df_test_log._append(log_test, ignore_index = True)

df_test_log

## wandb visualization

In [None]:
# wandb.init(project = 'COVID', name = time.strftime('%m%d%H%M%S'))

## Train start

In [None]:
for epoch in range(1, EPOCHS+1):
    
    print(f'{datetime.now()}, Epoch {epoch}/{EPOCHS}')
    
    ## Train phase
    model.train()
    for images, labels in tqdm(train_loader): 
        batch_idx += 1
        log_train = train_one_batch(images, labels)
        df_train_log = df_train_log._append(log_train, ignore_index = True)
#         wandb.log(log_train)
        
    lr_scheduler.step()

    ## Test phase
    model.eval()
    log_test = evaluate_testset()
    df_test_log = df_test_log._append(log_test, ignore_index = True)
#     wandb.log(log_test)
    
    # Save the latest best model file
    if log_test['test_accuracy'] > best_test_accuracy: 
        # Delete old best model files (if any)
        old_best_checkpoint_path = 'checkpoint/best-{:.3f}.pth'.format(best_test_accuracy)
        if os.path.exists(old_best_checkpoint_path):
            os.remove(old_best_checkpoint_path)
        # Save the new best model file
        best_test_accuracy = log_test['test_accuracy']
        new_best_checkpoint_path = 'checkpoint/best-{:.3f}.pth'.format(log_test['test_accuracy'])
        torch.save(model, new_best_checkpoint_path)
        print('save best model', 'checkpoint/best-{:.3f}.pth'.format(best_test_accuracy))
            

df_train_log.to_csv('table/Train Log-train_dataset.csv', index = False)
df_test_log.to_csv('table/Train Log-test_dataset.csv', index = False)

## Evaluate on `test_dataset`

In [None]:
model = torch.load('checkpoint/best-{:.3f}.pth'.format(best_test_accuracy))

model.eval()
print(evaluate_testset())

## Visualize Train Log

In [None]:
df_train = pd.read_csv('table/Train Log-train_dataset.csv')
df_test = pd.read_csv('table/Train Log-test_dataset.csv')

In [None]:
df_train

In [None]:
df_test

### `train_dataset` loss function

In [None]:
plt.figure(figsize = (15, 10))

x = df_train['batch']
y = df_train['train_loss']

plt.plot(x, y, linewidth = 2)

plt.tick_params(labelsize = 25)
plt.xlabel('Batch', fontsize = 25)
plt.ylabel('Train Loss', fontsize = 25)
# plt.title('train_dataset Loss', fontsize=25)
plt.savefig('figure/train_dataset Loss.tif', dpi = 300, bbox_inches = 'tight')

### `train_dataset` accuracy

In [None]:
plt.figure(figsize = (15, 10))

x = df_train['batch']
y = df_train['train_accuracy']

plt.plot(x, y, linewidth = 2)

plt.tick_params(labelsize = 25)
plt.xlabel('Batch', fontsize = 25)
plt.ylabel('Train Accuracy', fontsize = 25)
# plt.title('train_dataset Accuracy', fontsize=25)
plt.savefig('figure/train_dataset Accuracy.tif', dpi = 300, bbox_inches = 'tight')

### `test_dataset` loss function

In [None]:
plt.figure(figsize = (15, 10))

x = df_test['epoch']
y = df_test['test_loss']

plt.plot(x, y, linewidth = 2.5)

plt.tick_params(labelsize = 25)
plt.xlabel('Epoch', fontsize = 25)
plt.ylabel('Test Loss', fontsize = 25)
# plt.title('test_dataset Loss', fontsize=25)
plt.savefig('figure/test_dataset Loss.tif', dpi = 300, bbox_inches = 'tight')

### `test_dataset` Metrics

In [None]:
random.seed(222)
colors = ['tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']
linestyle = ['--', '-.', '-']

def get_line_arg():
    line_arg = {}
    line_arg['color'] = random.choice(colors)
    line_arg['linestyle'] = random.choice(linestyle)
    line_arg['linewidth'] = 2
    return line_arg

metrics = ['test_accuracy', 'test_precision', 'test_recall', 'test_f1-score']

plt.figure(figsize = (15, 10))

x = df_test['epoch']
for y in metrics:
    plt.plot(x, df_test[y], label = y, **get_line_arg())

plt.tick_params(labelsize = 25)

plt.ylim([0, 1.05])
plt.xlabel('Epoch', fontsize = 25)
plt.ylabel('Test_Metrics', fontsize = 25)
plt.legend(loc = 4, fontsize = 20, frameon = False)
plt.savefig('figure/test_dataset performance.tif', dpi = 300, bbox_inches = 'tight')