In [28]:
from sklearn.metrics import recall_score, precision_score, roc_auc_score, f1_score, roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from PIL import Image
import os
import random
import torch
import torchvision
import torchvision.transforms.v2 as v2
from torch.utils.data import DataLoader
from torchvision import utils
from torch.utils.data import random_split
import pytorch_grad_cam
import torch.hub as hub
from torchvision.transforms.v2 import functional as F
from sklearn.model_selection import StratifiedKFold
from src.backbones.vit.chada_vit import ChAdaViT
import hashlib

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

seed = 42
np.random.seed(seed)
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

channels = 3
batch_size = 32
baseline = False # whether to use baseline model: VGG-16, DenseNet-121,ResNet-50, efficientnet_v2_small, convnext_base
baseline_model = ['densenet121', 'efficientnet_v2_s', 'resnet50', 'vgg16', 'convnext_base']
early_stop_mode = 'loss or accuracy'  # choose 'loss' mode, 'accuracy' mode, 'loss or accuracy' mode or 'loss and accuracy' mode, here using weighted F1-score to replace accuracy, especially in imbalanced dataset
# Number of variations to generate per image
num_variations_per_image_0 = 1
num_variations_per_image_1 = 8
test_percent = 0.15  # choose the proportion size of test set
validation_percent = 0.1  # choose the proportion size of validation set if close the cross validation
cross_validation = True  # choose open or close Cross-Validation
fold_num = 5  # choose the number of fold if open the cross-validation
if not cross_validation:
    fold_num = 1
CKPT_PATH = "weights.ckpt"
mixed_channels = False

dataset_path = "D:\cell_images_manual_extracted"
whole_dataset_path = "D:\cell_image_augmented_manual_extracted"
train_images_path = "D:\cell_image_train_manual_extracted\\"
test_images_path = "D:\cell_image_test_manual_extracted\\"
grad_cam_base_path = "D:\cell_image_XAI\Grad-CAM\\"
augmented_cancer_path = 'D:\cell_image_augmented\cancer\\'
augmented_normal_path = 'D:\cell_image_augmented\\normal\\'

dataset_autoseg_path = "D:\\cell_autoseg"
train_autoseg_path = "D:\\cell_autoseg_train\split"
train_autoseg_cancer_path = 'D:\\cell_autoseg_train\split\\cancer\\'
train_autoseg_normal_path = 'D:\\cell_autoseg_train\split\\normal\\'
train_autoseg_cv_path = 'D:\\cell_autoseg_train\\cross validation\\'
test_autoseg_path = 'D:\\cell_autoseg_test\split'
test_whole_autoseg_path = 'D:\\cell_autoseg_test\whole\\'
test_autoseg_cancer_path = "D:\\cell_autoseg_test\split\\cancer\\"
test_autoseg_normal_path = 'D:\\cell_autoseg_test\split\\normal\\'
validation_autoseg_path = 'D:\\cell_autoseg_validation\split'
validation_whole_autoseg_path = 'D:\\cell_autoseg_validation\whole\\'
validation_autoseg_cancer_path = 'D:\\cell_autoseg_validation\split\\cancer\\'
validation_autoseg_normal_path = 'D:\\cell_autoseg_validation\split\\normal\\'
validation_autoseg_cv_path = 'D:\\cell_autoseg_validation\\cross validation\\'

In [29]:
def check_hash(file_path, expected_hash):
    md5 = hashlib.md5()
    with open(file_path, "rb") as f:
        while chunk := f.read(4096):
            md5.update(chunk)
    return md5.hexdigest() == expected_hash

In [30]:
check_hash(CKPT_PATH, "e8a24ac58b8e34bdce10e0024d507f2e")

True

In [48]:
def collate_images(batch: list):
    """
    Collate a batch of images into a list of channels and a mapping of the number of channels per image.
    
    Args:
        batch (list): A batch of images Tensor(B,C,H,W)

    Return:
        channels_list (torch.Tensor): A tensor of shape (X*num_channels, 1, height, width)
        num_channels_list (list): A list of the number of channels per image
    """
    num_channels_list = []
    channels_list = []

    # Iterate over the list of images and extract the channels
    for image in batch: 
        num_channels = image.shape[0]
        num_channels_list.append(num_channels) 

        for channel in range(num_channels):
            channel_image = image[channel, :, :].unsqueeze(0) 
            channels_list.append(channel_image)

    channels_list = torch.cat(channels_list, dim=0).unsqueeze(
        1
    )  # Shape: (X*num_channels, 1, height, width)

    return channels_list, num_channels_list

In [31]:
def compute_mean_std(ds, channels):  # calculate the mean and std for each channel in the dataset
    mean = torch.zeros(channels)
    std = torch.zeros(channels)
    for image in ds:
        for channel in range(channels):
            mean[channel] += image[channel, :, :].mean()
            std[channel] += image[channel, :, :].std()
    mean = mean / len(ds)
    std = std / len(ds)  #TODO: is this a correct way to calculate STD?
    return mean, std

In [32]:
class ResizeWithPadding:
    def __init__(self, size, fill=0):
        """
        size: tuple (width, height) target size
        fill: pixel value to fill
        """
        self.target_width, self.target_height = size
        self.fill = fill

    def __call__(self, img):
        # obtain the size of original image
        orig_width, orig_height = img.size

        # use the smaller ratio to scale the width and height equally.
        width_ratio = self.target_width / orig_width
        height_ratio = self.target_height / orig_height
        if width_ratio <= height_ratio:
            new_width = int(orig_width * width_ratio + 0.1)  # plus 0.1 to prevent the float error
            new_height = int(orig_height * width_ratio)
        else:
            new_width = int(orig_width * height_ratio)
            new_height = int(orig_height * height_ratio + 0.1)

        # resize
        img = F.resize(img, [new_height, new_width])  # the resize in F needs the format of input as (height, width)

        # calculate padding size
        pad_left = (self.target_width - new_width) // 2
        pad_top = (self.target_height - new_height) // 2
        pad_right = self.target_width - new_width - pad_left
        pad_bottom = self.target_height - new_height - pad_top

        # 添加 padding
        img = F.pad(img, [pad_left, pad_top, pad_right, pad_bottom], fill=self.fill)
        
        assert img.size[0] == self.target_width and img.size[1] == self.target_height, 'Output Image size is incorrect!'
        
        return img

In [33]:
# Define the path to the directory containing images
transform = v2.Compose([ResizeWithPadding((224, 224)), v2.ToTensor()])
# resize and transfer to tensor
dataset = torchvision.datasets.ImageFolder(dataset_autoseg_path, transform=transform)  # read data
# Assuming images are organized in subdirectories where each subdirectory name is the class label
# 0 is cancer, 1 is normal

transform_augmented = v2.Compose([v2.RandomHorizontalFlip(),
                                  v2.RandomVerticalFlip(),
                                  v2.RandomRotation(degrees=40),
                                  v2.RandomAffine(degrees=40, translate=(0.1, 0.1), shear=(-8, 8, -8, 8),
                                                  scale=(0.9, 1.1)),
                                  #v2.RandomPerspective(distortion_scale=0.1),
                                  #v2.GaussianBlur(kernel_size=2)
                                  #v2.ColorJitter(brightness=0.2,contrast=0.2,saturation=0.2,hue=0)
                                  # 0.2 means the brightness / contrast / saturation alternation range from [0.8~1.2] of original.
                                  ])  # Image Augmented Transformation
# here is a shuffle
# TODO: try other augmentation method



In [34]:
from sklearn.model_selection import train_test_split

image_paths = []
image_labels = []
for i in dataset.samples:
    image_paths.append(i[0])
    image_labels.append(i[1])

# split dataset into test set and (train set + validation set)
train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
    image_paths, image_labels, test_size=test_percent, stratify=image_labels, random_state=seed)
# here is a shuffle

print(
    f'we have {train_val_labels.count(0)} cancer cells and {train_val_labels.count(1)} normal cells, {len(train_val_labels)} cells in total, for training and validating')
print(
    f'we have {test_labels.count(0)} cancer cells and {test_labels.count(1)} normal cells, {len(test_paths)} cells in total, for testing')

# save the test image to target folders
for i in range(len(test_paths)):
    for j in range(len(dataset.samples)):
        if dataset.samples[j][0] == test_paths[i] and dataset.samples[j][1] == 0:
            utils.save_image(dataset[j][0],
                             test_autoseg_cancer_path + test_paths[i].split('\\')[-1].split('.')[0] + "_test.png")
            break
        elif dataset.samples[j][0] == test_paths[i] and dataset.samples[j][1] == 1:
            utils.save_image(dataset[j][0],
                             test_autoseg_normal_path + test_paths[i].split('\\')[-1].split(".")[0] + "_test.png")
            break

we have 892 cancer cells and 167 normal cells, 1059 cells in total, for training and validating
we have 158 cancer cells and 30 normal cells, 188 cells in total, for testing


In [35]:
# train and validation dataset for calculating the image mean and std for normalization
train_val_dataset = []
for i in range(len(train_val_paths)):
    for j in range(len(dataset.samples)):
        if dataset.samples[j][0] == train_val_paths[i]:
            train_val_dataset.append(dataset[j][0])
            break

images_mean, images_std = compute_mean_std(train_val_dataset, channels)
print("Mean:", images_mean)
print("Std:", images_std)

# inverse_transform, to restore the images when saving them to whole folder and displaying them in XAI
transform_inverse = v2.Compose([v2.Normalize(
    mean=[-images_mean[0] / images_std[0], -images_mean[1] / images_std[1], -images_mean[2] / images_std[2]],
    std=[1 / images_std[0], 1 / images_std[1], 1 / images_std[2]])])  # when mean = images_mean and std = images_std

Mean: tensor([0.0157, 0.0309, 0.0269])
Std: tensor([0.0377, 0.0673, 0.0844])


In [36]:
# save the validate image to target folders
def save_validation_images(val_paths, dataset, cancer_path, normal_path):
    for i in range(len(val_paths)):
        for j in range(len(dataset.samples)):
            if dataset.samples[j][0] == val_paths[i] and dataset.samples[j][1] == 0:
                utils.save_image(dataset[j][0],
                                 cancer_path + val_paths[i].split('\\')[-1].split('.')[0] + "_val.png")
                break
            elif dataset.samples[j][0] == val_paths[i] and dataset.samples[j][1] == 1:
                utils.save_image(dataset[j][0],
                                 normal_path + val_paths[i].split('\\')[-1].split(".")[0] + "_val.png")
                break

In [37]:
def generate_save_train_images(train_paths, dataset, cancer_path, normal_path):
    # original training images
    train_origin_cancer_dataset = []
    train_origin_normal_dataset = []
    for i in range(len(train_paths)):
        for j in range(len(dataset.samples)):
            if dataset.samples[j][0] == train_paths[i] and dataset.samples[j][1] == 0:
                train_origin_cancer_dataset.append(
                    {'image': dataset[j][0], 'filename': train_paths[i].split('\\')[-1].split('.')[0]})
                break
            elif dataset.samples[j][0] == train_paths[i] and dataset.samples[j][1] == 1:
                train_origin_normal_dataset.append(
                    {'image': dataset[j][0], 'filename': train_paths[i].split('\\')[-1].split('.')[0]})
                break
    
    # Initialize an empty list to store the augmented images
    augmented_images_class_0 = []
    augmented_images_class_1 = []
    
    # Image augmentation
    for image in train_origin_cancer_dataset:
        for i in range(num_variations_per_image_0):
            augmented_images_class_0.append({'image': transform_augmented(image['image']), 'filename': image['filename']})
    
    for image in train_origin_normal_dataset:
        for i in range(num_variations_per_image_1):
            augmented_images_class_1.append({'image': transform_augmented(image['image']), 'filename': image['filename']})
    
    # save training dataset (original + augmentation)
    if num_variations_per_image_0 > 0:
        for i in range(len(augmented_images_class_0)):
            utils.save_image(augmented_images_class_0[i]['image'],
                             cancer_path + str(int(i / num_variations_per_image_0)) + "_" + str(
                                 i % num_variations_per_image_0) + "_" + augmented_images_class_0[i][
                                 'filename'] + "_aug.png")
    
    if num_variations_per_image_1 > 0:
        for i in range(len(augmented_images_class_1)):
            utils.save_image(augmented_images_class_1[i]['image'],
                             normal_path + str(int(i / num_variations_per_image_1)) + "_" + str(
                                 i % num_variations_per_image_1) + "_" + augmented_images_class_1[i][
                                 'filename'] + "_aug.png")
    
    # save original images (after resize)
    for i in range(len(train_origin_cancer_dataset)):
        utils.save_image(train_origin_cancer_dataset[i]['image'],
                         cancer_path + str(i) + "_" + train_origin_cancer_dataset[i][
                             'filename'] + "_original.png")
    
    for i in range(len(train_origin_normal_dataset)):
        utils.save_image(train_origin_normal_dataset[i]['image'],
                         normal_path + str(i) + "_" + train_origin_normal_dataset[i][
                             'filename'] + "_original.png")
    
    print(
        f"we have generate {len(augmented_images_class_0)} augmented cancer cell images and {len(augmented_images_class_1)} augmented normal cell images.")
    print(
        f'Totally we have {len(train_origin_cancer_dataset) + len(augmented_images_class_0)} cancer cell images and {len(train_origin_normal_dataset) + len(augmented_images_class_1)} normal cell images for training')

In [38]:
def read_data(test_path, validation_path, train_path):
    
    # read test data
    test_dataset = torchvision.datasets.ImageFolder(test_path, transform=transform_whole_dataset)
    # read validation data
    val_dataset = torchvision.datasets.ImageFolder(validation_path, transform=transform_whole_dataset)
    # read train data
    train_dataset = torchvision.datasets.ImageFolder(train_path, transform=transform_whole_dataset)
    
    return  test_dataset, val_dataset, train_dataset

In [39]:
def save_whole_image(val_dataset, test_dataset, validation_whole_path, test_whole_path): 
    # save test and validate image (un-shuffle, easy to find which one is misclassified)
    for i in range(len(val_dataset)):
        utils.save_image(transform_inverse(val_dataset[i][0]),
                         validation_whole_path + str(i) + "_" + str(val_dataset[i][1]) + "_" +
                         val_dataset.samples[i][0].split('\\')[-1].split('.')[0] + ".png")
    for i in range(len(test_dataset)):
        utils.save_image(transform_inverse(test_dataset[i][0]),
                         test_whole_path + str(i) + "_" + str(test_dataset[i][1]) + "_" +
                         test_dataset.samples[i][0].split('\\')[-1].split('.')[0] + ".png")

In [40]:
# This transform: to tensor and normalization is used for all images
transform_whole_dataset = v2.Compose([v2.ToTensor(), v2.Normalize(mean=images_mean, std=images_std)])

if not cross_validation:
    # split train_val_dataset into train set and validation set
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        train_val_paths, train_val_labels, test_size=validation_percent / (1 - test_percent), stratify=train_val_labels,
        random_state=seed) # here is a shuffle
    
    print(f'The cross validation openness is closed')
    print(
        f'we have {train_labels.count(0)} cancer cells and {train_labels.count(1)} normal cells, {len(train_paths)} cells in total, for training')
    print(
        f'we have {val_labels.count(0)} cancer cells and {val_labels.count(1)} normal cells, {len(val_paths)} cells in total, for validating')
    
    save_validation_images(val_paths, dataset, validation_autoseg_cancer_path, validation_autoseg_normal_path)
    generate_save_train_images(train_paths, dataset, train_autoseg_cancer_path, train_autoseg_normal_path)
    
    # read test, train, validate data from target folders
    test_dataset = []
    val_dataset = []
    train_dataset = []
    test, val, train = read_data(test_autoseg_path, validation_autoseg_path, train_autoseg_path)
    test_dataset.append(test)
    val_dataset.append(val)
    train_dataset.append(train)
    
    save_whole_image(val, test, validation_whole_autoseg_path, test_whole_autoseg_path)
else:
    print(f'The cross validation openness is opened')
    test_dataset = []
    val_dataset = []
    train_dataset = []
    skf = StratifiedKFold(n_splits=fold_num, shuffle=True, random_state=seed) # here is a shuffle
    for fold, (train_idx, val_idx) in enumerate(skf.split(train_val_paths, train_val_labels)):
        # use Stratified K-fold cross validation to create e.g. 10 folds, in each fold, the ratio between normal and cancer approximately keeping the same as the ratio in original image dataset.
        train_paths = [train_val_paths[i] for i in train_idx]
        train_labels = [train_val_labels[i] for i in train_idx]
        
        val_paths = [train_val_paths[i] for i in val_idx]
        val_labels = [train_val_labels[i] for i in val_idx]
        
        print(f"In fold {fold}")
        print(
        f'we have {train_labels.count(0)} cancer cells and {train_labels.count(1)} normal cells, {len(train_paths)} cells in total, for training')
        print(
        f'we have {val_labels.count(0)} cancer cells and {val_labels.count(1)} normal cells, {len(val_paths)} cells in total, for validating')
        
        # new folders to save validation images for each fold.
        validation_cv_fold_path = validation_autoseg_cv_path + "cross validation " + str(fold)
        validation_cv_fold_split_path = validation_autoseg_cv_path + "cross validation " + str(fold) + '\\split'
        validation_cv_fold_whole_path = validation_autoseg_cv_path + "cross validation " + str(fold) + '\\whole'
        validation_cv_fold_split_cancer_path = validation_autoseg_cv_path + "cross validation " + str(fold) + '\\split\\cancer'
        validation_cv_fold_split_normal_path = validation_autoseg_cv_path + "cross validation " + str(fold) + '\\split\\normal'
        os.makedirs(validation_cv_fold_path, exist_ok=True)
        os.makedirs(validation_cv_fold_split_path, exist_ok=True)
        os.makedirs(validation_cv_fold_whole_path, exist_ok=True)
        os.makedirs(validation_cv_fold_split_cancer_path, exist_ok=True)
        os.makedirs(validation_cv_fold_split_normal_path, exist_ok=True)
        
        save_validation_images(val_paths, dataset, validation_cv_fold_split_cancer_path + '\\', validation_cv_fold_split_normal_path + '\\')
        
        # new folders to save train images for each fold.
        train_cv_fold_path = train_autoseg_cv_path + "cross validation " + str(fold)
        train_cv_fold_split_path = train_autoseg_cv_path + "cross validation " + str(fold) + '\\split'
        train_cv_fold_whole_path = train_autoseg_cv_path + "cross validation " + str(fold) + '\\whole'
        train_cv_fold_split_cancer_path = train_autoseg_cv_path + "cross validation " + str(fold) + '\\split\\cancer'
        train_cv_fold_split_normal_path = train_autoseg_cv_path + "cross validation " + str(fold) + '\\split\\normal'
        os.makedirs(train_cv_fold_path, exist_ok=True)
        os.makedirs(train_cv_fold_split_path, exist_ok=True)
        os.makedirs(train_cv_fold_whole_path, exist_ok=True)
        os.makedirs(train_cv_fold_split_cancer_path, exist_ok=True)
        os.makedirs(train_cv_fold_split_normal_path, exist_ok=True)
        
        generate_save_train_images(train_paths, dataset, train_cv_fold_split_cancer_path + '\\', train_cv_fold_split_normal_path + '\\')
        
        # read test, train, validate data from target folders
        test, val, train = read_data(test_autoseg_path, validation_cv_fold_split_path, train_cv_fold_split_path)
        test_dataset.append(test)
        val_dataset.append(val)
        train_dataset.append(train)
        
        save_whole_image(val, test, validation_cv_fold_whole_path + '\\', test_whole_autoseg_path)



The cross validation openness is opened
In fold 0
we have 713 cancer cells and 134 normal cells, 847 cells in total, for training
we have 179 cancer cells and 33 normal cells, 212 cells in total, for validating
we have generate 713 augmented cancer cell images and 1072 augmented normal cell images.
Totally we have 1426 cancer cell images and 1206 normal cell images for training
In fold 1
we have 713 cancer cells and 134 normal cells, 847 cells in total, for training
we have 179 cancer cells and 33 normal cells, 212 cells in total, for validating
we have generate 713 augmented cancer cell images and 1072 augmented normal cell images.
Totally we have 1426 cancer cell images and 1206 normal cell images for training
In fold 2
we have 714 cancer cells and 133 normal cells, 847 cells in total, for training
we have 178 cancer cells and 34 normal cells, 212 cells in total, for validating
we have generate 714 augmented cancer cell images and 1064 augmented normal cell images.
Totally we have 14

# Basic Model

In [53]:
from torch.optim.lr_scheduler import StepLR
from torch import nn, optim
from transformers import get_cosine_schedule_with_warmup
from sklearn.model_selection import ParameterGrid
from sklearn.metrics import roc_curve, roc_auc_score
import matplotlib.pyplot as plt
import numpy as np

'''
# hyper-parameters set for training the model
# users can make the length of each list = 1 to close grid search
param_grid = {
'lr_backbone': [0.00001, 0.00005],
'lr_head': [0.00005, 0.0001],
'weight_decay_backbone': [0.001],
'weight_decay_head': [0.0001],
'dropout_p': [0.0, 0.3, 0.5],
'warmup_epoch': [3, 6],
'lr_decay_epoch': [12, 15],
'unfrozen_blocks': [[7,8,9,10,11], [0,1,2,3,4,5,6,7,8,9,10,11]]
}
'''
'''
param_grid = {
'lr_backbone': [0.00001, 0.00005],
'lr_head': [0.00005],
'weight_decay_backbone': [0.001],
'weight_decay_head': [0.0001],
'dropout_p': [0.3],
'warmup_epoch': [3],
'lr_decay_epoch': [12, 15],
'unfrozen_blocks': [[7,8,9,10,11]]
}   
'''
param_grid = {
'lr_backbone': [0.00001],
'lr_head': [0.00005],
'weight_decay_backbone': [0.001],
'weight_decay_head': [0.0001],
'dropout_p': [0.3],
'warmup_epoch': [3],
'lr_decay_epoch': [15],
'unfrozen_blocks': [[7,8,9,10,11]]
}    

grid = list(ParameterGrid(param_grid)) # generate all the hyper-parameter combination
gird_search_result = [] # store the cross-validation performance of each hyper-parameter set
for set_num, hyper_params in enumerate(grid):
    print(f"Running config {set_num + 1}/{len(grid)}: {hyper_params}")
    
    config = {
    'lr_backbone': hyper_params['lr_backbone'],
    'lr_head': hyper_params['lr_head'],
    'weight_decay_backbone': hyper_params['weight_decay_backbone'],
    'weight_decay_head': hyper_params['weight_decay_head'],
    'dropout_p': hyper_params['dropout_p'],
    'num_epochs': 30,
    'warmup_epoch': hyper_params['warmup_epoch'],
    'lr_decay_epoch': hyper_params['lr_decay_epoch'],
    'unfrozen_blocks': hyper_params['unfrozen_blocks']
    } # configure all the hyper-parameters, including not for grid search

    for fold in range(fold_num):
        print('###############################################')
        print(f'This is the fold: {fold}')
        print('###############################################')
        # Load the train, validate, and test dataset
        #torch.manual_seed(seed) # if the shuffle is displayed different each time, use torch.manual_seed(seed)
        train_loader = DataLoader(train_dataset[fold], batch_size=batch_size, shuffle=True)  # here is a shuffle
        val_loader = DataLoader(val_dataset[fold], batch_size=batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset[fold], batch_size=batch_size, shuffle=False)
        
        for images, labels in train_loader:
            print(f"Train batch images shape: {images.shape}")
            print(f"Train batch labels: {labels}")
            break
        for images, labels in val_loader:
            print(f"Validate batch images shape: {images.shape}")
            print(f"Validate batch labels: {labels}")
            break
        for images, labels in test_loader:
            print(f"Test batch images shape: {images.shape}")
            print(f"Test batch labels: {labels}")
            break
        
        
        # Set gpu/cpu
        print(f"Use device: {device}")
        
        # Model Params
        PATCH_SIZE = 16
        EMBED_DIM = 192
        RETURN_ALL_TOKENS = False
        MAX_NUMBER_CHANNELS = 10
        
        # use chadavit model
        model = ChAdaViT(
            patch_size=PATCH_SIZE,
            embed_dim=EMBED_DIM,
            return_all_tokens=RETURN_ALL_TOKENS,
            max_number_channels=MAX_NUMBER_CHANNELS,
        )
        
        assert (
            CKPT_PATH.endswith(".ckpt")
            or CKPT_PATH.endswith(".pth")
            or CKPT_PATH.endswith(".pt")
        ) # ensure the CKPT_PATH ends correctly
        state = torch.load(CKPT_PATH, map_location="cpu")["state_dict"]
        for k in list(state.keys()):
            if "encoder" in k:
                state[k.replace("encoder", "backbone")] = state[k]
            if "backbone" in k:
                state[k.replace("backbone.", "")] = state[k]
            del state[k]
        model.load_state_dict(state, strict=False) # load the pre-trained parameter

        model = model.to(device)
        
        class ModifiedChadaViT(nn.Module):
            def __init__(self, base_model):
                super(ModifiedChadaViT, self).__init__()
                self.base_model = base_model
                self.head = nn.Sequential(nn.Linear(base_model.embed_dim, 768),
                                          nn.ReLU(),
                                          nn.Dropout(config['dropout_p']),
                                          nn.Linear(768, 2))  # modify the head from Identify to 2-class classification (Linear Layer) and add drop out (included in grid search)
        
            def forward(self, x, list_num_channels, index=0):
                # the features from ViT backbone (batch size, 768 (dim of class token))
                features = self.base_model(x=x, index=index, list_num_channels=[list_num_channels])
                # pass the classification head
                return self.head(features)
        
        
        model = ModifiedChadaViT(model).to(device)  # TODO: use the basic ViT-B-16 model, see the most below
        model.mixed_channels = mixed_channels # all of the inputs share the same channel number.
        
        for param in model.base_model.blocks.parameters():
            param.requires_grad = False
        for layer_index in config['unfrozen_blocks']:
            layer = model.base_model.blocks[layer_index]  
            # only un-froze the last few layers (included in grid search)
            for param in layer.parameters():
                param.requires_grad = True
        for param in model.head.parameters(): # un froze the classification head
            param.requires_grad = True
        
        class EarlyStopping:
            def __init__(self, patience=5, mode='loss', min_delta=0.0, fold=0, epoch_number = 10):
                # min_delta used to measure the extension of loss decrease, only the loss decrease > min_delta, we can say early stop check pass.
                self.patience = patience
                self.mode = mode
                self.fold = fold
                self.min_delta = min_delta
                self.counter = 0
                self.epoch_number = epoch_number
                self.best = None
                self.early_stop = False
        
            def __call__(self, val_loss, accuracy, epoch):
                print("the mode of early stopping is " + self.mode)
                if self.mode == 'loss':
                    if self.best is None:
                        print(f"This is the first epoch {epoch + 1}!")
                        self.best = val_loss
                        torch.save(model.state_dict(), "best_model_parameter_" + str(self.fold) + ".pth")
                        print(f"The best val_loss is: {self.best}")
                    elif val_loss < self.best - self.min_delta:
                        print(f"Epoch {epoch + 1} Early Stop Check Pass!")
                        self.best = val_loss
                        self.counter = 0
                        torch.save(model.state_dict(), "best_model_parameter_" + str(self.fold) + ".pth")
                        print(f"The best val_loss is: {self.best}")
                    else:
                        self.counter += 1
                        print(
                            f"Epoch {epoch + 1} not pass the Early Stopping! EarlyStopping counter: {self.counter} / {self.patience}")
                        print(f"The best val_loss is still: {self.best}")
                        if self.counter >= self.patience or (epoch + 1) == self.epoch_number:
                            self.early_stop = True
                elif self.mode == 'accuracy':
                    if self.best is None:
                        print(f"This is the first epoch {epoch + 1}!")
                        self.best = accuracy
                        torch.save(model.state_dict(), "best_model_parameter_" + str(self.fold) + ".pth")
                        print(f"The best accuracy / weighted f1 score is: {self.best}")
                    elif accuracy > self.best:
                        print(f"Epoch {epoch + 1} Early Stop Check Pass!")
                        self.best = accuracy
                        self.counter = 0
                        torch.save(model.state_dict(), "best_model_parameter_" + str(self.fold) + ".pth")
                        print(f"The best accuracy / weighted f1 score is: {self.best}")
                    else:
                        self.counter += 1
                        print(
                            f"Epoch {epoch + 1} not pass the Early Stopping! EarlyStopping counter: {self.counter} / {self.patience}")
                        print(f"The best accuracy / weighted f1 score is still: {self.best}")
                        if self.counter >= self.patience or (epoch + 1) == self.epoch_number:
                            self.early_stop = True
                elif self.mode == 'loss and accuracy':
                    if self.best is None:
                        print(f"This is the first epoch {epoch + 1}!")
                        self.best = [val_loss, accuracy]
                        torch.save(model.state_dict(), "best_model_parameter_" + str(self.fold) + ".pth")
                        print(f"The best val_loss and accuracy / weighted f1 score are: {self.best[0]}, {self.best[1]}")
                    elif val_loss < self.best[0] - self.min_delta and accuracy > self.best[1]:
                        print(f"Epoch {epoch + 1} Early Stop Check Pass!")
                        self.best = [val_loss, accuracy]
                        self.counter = 0
                        torch.save(model.state_dict(), "best_model_parameter_" + str(self.fold) + ".pth")
                        print(f"The best val_loss and accuracy / weighted f1 score are: {self.best[0]}, {self.best[1]}")
                    else:
                        self.counter += 1
                        print(
                            f"Epoch {epoch + 1} not pass the Early Stopping! EarlyStopping counter: {self.counter} / {self.patience}")
                        print(f"The best val_loss and accuracy / weighted f1 score are still: {self.best[0]}, {self.best[1]}")
                        if self.counter >= self.patience or (epoch + 1) == self.epoch_number:
                            self.early_stop = True
                elif self.mode == 'loss or accuracy':
                    if self.best is None:
                        print(f"This is the first epoch {epoch + 1}!")
                        self.best = [val_loss, accuracy]
                        torch.save(model.state_dict(), "best_model_parameter_" + str(self.fold) + ".pth")
                        print(f"The best val_loss and accuracy / weighted f1 score are: {self.best[0]}, {self.best[1]}")
                    elif val_loss < self.best[0] - self.min_delta or accuracy > self.best[1]:
                        print(f"Epoch {epoch + 1} Early Stop Check Pass!")
                        self.best = [val_loss, accuracy]
                        self.counter = 0
                        torch.save(model.state_dict(), "best_model_parameter_" + str(self.fold) + ".pth")
                        print(f"The best val_loss and accuracy / weighted f1 score are: {self.best[0]}, {self.best[1]}")
                    else:
                        self.counter += 1
                        print(
                            f"Epoch {epoch + 1} not pass the Early Stopping! EarlyStopping counter: {self.counter} / {self.patience}")
                        print(f"The best val_loss and accuracy / weighted f1 score are still: {self.best[0]}, {self.best[1]}")
                        if self.counter >= self.patience or (epoch + 1) == self.epoch_number:
                            self.early_stop = True
        
        num_epochs = config['num_epochs']
        warm_up_epochs = config['warmup_epoch']
        total_steps = (config['warmup_epoch'] + config['lr_decay_epoch']) * len(train_loader)
        warm_up_steps = warm_up_epochs * len(train_loader)
        # Set L.F., optimizer
        #weights = torch.tensor([1.0, 1.5]).to(device) # set higher weights for positive class in CEL
        #criterion = nn.CrossEntropyLoss(weight=weights)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW([{'params': filter(lambda p: p.requires_grad, model.base_model.parameters()), 'weight_decay': config['weight_decay_backbone'], 'lr': config['lr_backbone']}, {'params': model.head.parameters(), 'weight_decay': config['weight_decay_head'], 'lr': config['lr_head']}])  # maybe can use AdamW for better combination with weight decay
        # optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
        # weight_decay is not always good, see notebook 55 for detail
        scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warm_up_steps, num_training_steps=total_steps)  # the learning rate will warm-up firstly, then cosine decay to 0
        
        
        # TODO: the normalization need to be mean = [0.485, 0.456, 0.406] std  = [0.229, 0.224, 0.225]. ------ not good, use the image_mean and image_std will let the model converge better.
        # TODO: Set an Early stop and try epoch = 10 or more. ------ Done
        # TODO: Try another kind of lr-scheduler, lr value or let the gamma smaller, maybe 0.25
        # TODO: In optimizer, use momentum? ------ Adam and AdamW don't need momentum, they are carried with momentum. Only SGD needs momentum, but prefer to using AdamW in ViT training
        # TODO: Use another optimizer? AdamW, SGD, etc. ------ Done
        # TODO: Use weight decay? dropout in the classifier ? ------ Done
        # TODO: Use warm-up + learning rate scheduler combination? ------ Done
        # TODO: Change the size of Batch?
        # TODO: Use hyperparameter search algorithm, like random search, grid search, etc. ------ Done
        # TODO: Use TensorBoard to monitor the training process
        # TODO: Use other kinds of augmented methods
        # TODO: add label smoothing for small dataset and focal loss for hard to classify sample
        # TODO: Freeze fewer blocks, and un-freeze position embedding, layernorm ------ Done
        # TODO: Try hybrid ViT, Swin-ViT (both have CNN’s properties) or Dei-T (better on small dataset)
        # TODO: 再看几篇paper的experiment，看看别人是怎么做的
        # TODO: ViT should be put into comparison model set
        
        # train func
        def train(model, loader, criterion, optimizer, device, epoch, train_dataset):
            model.train()  # set model to train mode
            running_loss = 0.0
            for i, data in enumerate(loader, 0):  # each batch input
                inputs, targets = data
                targets = targets.to(device)
                inputs = collate_images(inputs) # collate the input image batch to a sequence of channel
                X, list_num_channels = inputs
                X = X.to(device)
                optimizer.zero_grad()  # reset the grads
                outputs = model(X, list_num_channels, index=0)
                loss = criterion(outputs, targets)
                loss.backward()  # back propagation
                optimizer.step()  # update parameter
                running_loss += loss.item() * targets.size(0)
                print("In epoch " + str(epoch + 1) + ", batch: " + str(i + 1) + ", average loss per image: " + str(
                    loss.item()))
        
                correct_train = 0
                _, predicted_train = outputs.max(1)  # model predicted class
                correct_train += (predicted_train == targets).sum().item()
                print("Accuracy of the network on the train set: " + str(correct_train / targets.size(0)))
                
                scheduler.step()  # regularize the learning rate
        
            return running_loss / len(train_dataset)
        
        
        # validate func
        def validate(model, loader, device, val_dataset):
            model.eval()  # set model to validate mode
            correct = 0
            total = 0
            false = []
            running_loss = 0.0
            y_true = []
            y_pred = []
            with torch.no_grad():
                for i, data in enumerate(loader, 0):
                    inputs, targets = data
                    targets = targets.to(device)
                    inputs = collate_images(inputs) # collate the input image batch to a sequence of channel
                    X, list_num_channels = inputs
                    X = X.to(device)
                    outputs = model(X, list_num_channels, index=0)
                    _, predicted = outputs.max(1)  # model predicted class
                    correct += (predicted == targets).sum().item()
                    total += targets.size(0)
                    loss = criterion(outputs, targets)
                    running_loss += loss.item() * targets.size(0) # the sum of val_loss in one batch            
                    
                    for target in targets:
                        y_true.append(target.item())
                    for predict in predicted:
                        y_pred.append(predict.item())
                    
                    for result in range(len(predicted)): # collect wrong samples
                        if predicted[result] != targets[result]:
                            false.append(i * batch_size + result)
                    
                y_true = np.array(y_true)
                y_pred = np.array(y_pred)
                
            return correct / total, false, running_loss / len(val_dataset), f1_score(y_true, y_pred, average='weighted')
        
        
        # train and validate
        early_stopping = EarlyStopping(patience=5, mode=early_stop_mode, fold=fold, epoch_number = num_epochs)
        wrong_number = []
        
        for epoch in range(num_epochs):
            print(f"The learning rate of backbone is: {optimizer.param_groups[0]['lr']}, of head is {optimizer.param_groups[1]['lr']}")
            train_loss = train(model, train_loader, criterion, optimizer, device, epoch, train_dataset[fold])
            accuracy, wrong_predicted, val_loss, weighted_f1 = validate(model, val_loader, device, val_dataset[fold])
            wrong_number.append(wrong_predicted)
            print("====================" + str(epoch + 1) + "====================")
            print(f'Epoch {epoch + 1}/{num_epochs}')
            print(f'Average train loss per image: {train_loss:.7f}')
            print(f'Average validate loss per image: {val_loss:.7f}')
            print(f'Validate accuracy: {accuracy:.4f}')
            print("====================" + str(epoch + 1) + "====================")
        
            early_stopping(val_loss, weighted_f1, epoch)
        
            if early_stopping.early_stop:
                print(" 🔥 Early stopping, Stop Training")
                print(f"select the epoch: {epoch - early_stopping.counter + 1}")
                for wrong_result in wrong_number[epoch - early_stopping.counter]:
                    print("The number " + str(wrong_result) + " is wrong!")
                break
    
            if epoch == num_epochs - 1:
                print("train until the last epoch!")
                for wrong_result in wrong_predicted:
                    print("The number " + str(wrong_result) + " is wrong!")
                    
    model_accuracy = []
    model_recall = []
    model_precision = []
    model_auc_roc_macro = []
    model_auc_roc_micro = []
    model_auc_roc_weighted = []
    model_f1_macro = []
    model_f1_micro = []
    model_f1_weighted = []
    
    
    for fold in range(fold_num):
        model.load_state_dict(torch.load("best_model_parameter_" + str(fold) + ".pth"))
        model.eval()
        y_true = []
        y_pred = []
        y_prob = []
        
        for data in range(len(val_dataset[fold])):
            y_true.append(val_dataset[fold][data][1])
        
            inputs = val_dataset[fold][data][0]
            inputs = inputs.to(device)
            input_tensor = collate_images(inputs.unsqueeze(0))
            X, list_num_channels = input_tensor
            X = X.to(device)
            outputs = model(X, list_num_channels, index=0)
            predicted_class = outputs.argmax(dim=1).item()
            y_pred.append(predicted_class)
        
            prob = torch.softmax(outputs, dim=1)[:, 1].item()
            y_prob.append(prob)
        
        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        y_prob = np.array(y_prob)
        
        fpr, tpr, thresholds = roc_curve(y_true, y_prob)
        roc_auc = auc(fpr, tpr)
        
        plt.figure()
        plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic')
        plt.legend(loc="lower right")
        plt.show()
        
        recall = recall_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred)
        auc_roc_macro = roc_auc_score(y_true, y_prob)
        auc_roc_micro = roc_auc_score(y_true, y_prob)
        auc_roc_weighted = roc_auc_score(y_true, y_prob)
        f1_macro = f1_score(y_true, y_pred, average='macro')
        f1_micro = f1_score(y_true, y_pred, average='micro')
        f1_weighted = f1_score(y_true, y_pred, average='weighted')
        
        accuracy = 0.0
        for i in range(len(y_pred)):
            if y_pred[i] == y_true[i]:
                accuracy += 1.0
        model_accuracy.append(accuracy/len(y_pred))
        model_recall.append(recall)
        model_precision.append(precision)
        model_auc_roc_macro.append(auc_roc_macro)
        model_auc_roc_micro.append(auc_roc_micro)
        model_auc_roc_weighted.append(auc_roc_weighted)
        model_f1_macro.append(f1_macro)
        model_f1_micro.append(f1_micro)
        model_f1_weighted.append(f1_weighted)
        
        print(f"Accuracy: {accuracy/len(y_pred):.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"AUC-ROC Macro: {auc_roc_macro:.4f}")
        print(f"AUC-ROC Micro: {auc_roc_micro:.4f}")
        print(f"AUC-ROC Weighted: {auc_roc_weighted:.4f}")
        print(f"F1 Macro: {f1_macro:.4f}")
        print(f"F1 Micro: {f1_micro:.4f}")
        print(f"F1 Weighted: {f1_weighted:.4f}")
        
    print(f"The mean and std of accuracy are: {np.array(model_accuracy).mean()}, and {np.array(model_accuracy).std()}")
    print(f"The mean and std of recall are: {np.array(model_recall).mean()}, and {np.array(model_recall).std()}")
    print(f"The mean and std of precision are: {np.array(model_precision).mean()}, and {np.array(model_precision).std()}")
    print(f"The mean and std of auc_roc_macro are: {np.array(model_auc_roc_macro).mean()}, and {np.array(model_auc_roc_macro).std()}")
    print(f"The mean and std of auc_roc_micro are: {np.array(model_auc_roc_micro).mean()}, and {np.array(model_auc_roc_micro).std()}")
    print(f"The mean and std of auc_roc_weighted are: {np.array(model_auc_roc_weighted).mean()}, and {np.array(model_auc_roc_weighted).std()}")
    print(f"The mean and std of f1_macro are: {np.array(model_f1_macro).mean()}, and {np.array(model_f1_macro).std()}")
    print(f"The mean and std of f1_micro are: {np.array(model_f1_micro).mean()}, and {np.array(model_f1_micro).std()}")
    print(f"The mean and std of f1_weighted are: {np.array(model_f1_weighted).mean()}, and {np.array(model_f1_weighted).std()}")
    
    gird_search_result.append({'hyperparams': hyper_params,
                               'accuracy mean':np.array(model_accuracy).mean(),
                               'accuracy std': np.array(model_accuracy).std(),
                               'recall mean': np.array(model_recall).mean(),
                               'recall std': np.array(model_recall).std(),
                               'precision mean': np.array(model_precision).mean(),
                               'precision std': np.array(model_precision).std(),
                               'auc_roc_macro mean': np.array(model_auc_roc_macro).mean(),
                               'auc_roc_macro std': np.array(model_auc_roc_macro).std(),
                               'auc_roc_micro mean': np.array(model_auc_roc_micro).mean(),
                               'auc_roc_micro std': np.array(model_auc_roc_micro).std(),
                               'auc_roc_weighted mean': np.array(model_auc_roc_weighted).mean(),
                               'auc_roc_weighted std': np.array(model_auc_roc_weighted).std(),
                               'f1_macro mean': np.array(model_f1_macro).mean(),
                               'f1_macro std': np.array(model_f1_macro).std(),
                               'f1_micro mean': np.array(model_f1_micro).mean(),
                               'f1_micro std': np.array(model_f1_micro).std(),
                               'f1_weighted mean': np.array(model_f1_weighted).mean(),
                               'f1_weighted std': np.array(model_f1_weighted).std()})


Running config 1/1: {'dropout_p': 0.3, 'lr_backbone': 1e-05, 'lr_decay_epoch': 15, 'lr_head': 5e-05, 'unfrozen_blocks': [7, 8, 9, 10, 11], 'warmup_epoch': 3, 'weight_decay_backbone': 0.001, 'weight_decay_head': 0.0001}
###############################################
This is the fold: 0
###############################################
Train batch images shape: torch.Size([32, 3, 224, 224])
Train batch labels: tensor([0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1,
        1, 1, 1, 1, 0, 0, 0, 1])
Validate batch images shape: torch.Size([32, 3, 224, 224])
Validate batch labels: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
Test batch images shape: torch.Size([32, 3, 224, 224])
Test batch labels: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
Use device: cpu
#trainable params: 71
   base_model.cls_token
   base_model.channel_token
   b

RuntimeError: [enforce fail at alloc_cpu.cpp:114] data. DefaultCPUAllocator: not enough memory: you tried to allocate 5906720256 bytes.