In [None]:
from sklearn.metrics import recall_score, precision_score, roc_auc_score, f1_score, roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from PIL import Image
import os
import random
import torch
import torchvision
import torchvision.transforms.v2 as v2
from torch.utils.data import DataLoader
from torchvision import utils
from torch.utils.data import random_split
import pytorch_grad_cam
import torch.hub as hub
from torchvision.transforms.v2 import functional as F
from sklearn.model_selection import StratifiedKFold

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

seed = 42 # set by user
np.random.seed(seed)
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

channels = 2 # without blue channel
batch_size = 32 # set by user
baseline = False # whether to use baseline model: VGG-16, DenseNet-121,ResNet-50, efficientnet_v2_small, convnext_base
baseline_model = ['densenet121'] # baseline model can input 'densenet121', 'efficientnet_v2_s', 'resnet50', 'vgg16', 'convnext_base'
early_stop_mode = 'accuracy'  # choose 'loss' mode, 'accuracy' mode, 'loss or accuracy' mode or 'loss and accuracy' modes
# Number of variations to generate per image
num_variations_per_image_0 = 4
num_variations_per_image_1 = 4
test_percent = 0.15  # choose the proportion size of test set
validation_percent = 0.1  # choose the proportion size of validation set if close the cross validation
cross_validation = True  # choose open or close Cross-Validation
fold_num = 5  # choose the number of fold if open the cross-validation
run_name = 'cell_classification_without_nucleus' # log name
if not cross_validation:
    fold_num = 1
finetune = True # choose to whether fine-tuning the features from the backbone
main_structure = 'dino' # choose 'dino', 'vit', 'dinov2', 'dinov3'

dataset_autoseg_noblue_path = "D:\\cell_40x_noblue"
train_autoseg_noblue_path = "D:\\cell_autoseg_train_noblue\split"
train_autoseg_noblue_cancer_path = 'D:\\cell_autoseg_train_noblue\split\\cancer\\'
train_autoseg_noblue_normal_path = 'D:\\cell_autoseg_train_noblue\split\\normal\\'
train_autoseg_noblue_cv_path = 'D:\\cell_autoseg_train_noblue\\cross validation\\'
test_autoseg_noblue_path = 'D:\\cell_autoseg_test_noblue\split'
test_whole_autoseg_noblue_path = 'D:\\cell_autoseg_test_noblue\whole\\'
test_autoseg_noblue_cancer_path = "D:\\cell_autoseg_test_noblue\split\\cancer\\"
test_autoseg_noblue_normal_path = 'D:\\cell_autoseg_test_noblue\split\\normal\\'
validation_autoseg_noblue_path = 'D:\\cell_autoseg_validation_noblue\split'
validation_whole_autoseg_noblue_path = 'D:\\cell_autoseg_validation_noblue\whole\\'
validation_autoseg_noblue_cancer_path = 'D:\\cell_autoseg_validation_noblue\split\\cancer\\'
validation_autoseg_noblue_normal_path = 'D:\\cell_autoseg_validation_noblue\split\\normal\\'
validation_autoseg_noblue_cv_path = 'D:\\cell_autoseg_validation_noblue\\cross validation\\'
REPO_DINOV3= "D:\cell_classification_pythonprojects\dinov3\dinov3"
WEIGHTS_DINOV3 = "D:\cell_classification_pythonprojects\My_model\dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth"
grad_cam_base_path = "D:\cell_image_XAI\\cell_40x"

In [None]:
import logging
from pathlib import Path
from datetime import datetime

SCRIPT_DIR = Path(__file__).resolve().parent if "__file__" in globals() else Path.cwd()
log_file = SCRIPT_DIR / f"{run_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
logger = logging.getLogger(run_name)

logger.setLevel(logging.INFO)
logger.propagate = False
for h in logger.handlers[:]:
    try:
        h.flush()
    except Exception:
        pass
    h.close()
    logger.removeHandler(h)

fmt = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s")

sh = logging.StreamHandler()
sh.setFormatter(fmt)
logger.addHandler(sh)

fh = logging.FileHandler(log_file, mode='a', encoding="utf-8")
fh.setFormatter(fmt)
logger.addHandler(fh)

# image process

In [None]:
from training_toolbox import ResizeWithPadding


transform = v2.Compose([ResizeWithPadding((224,224)), v2.ToTensor()])
# resize and transfer to tensor
dataset = torchvision.datasets.ImageFolder(dataset_autoseg_noblue_path, transform=transform)  # read data
# Assuming images are organized in subdirectories where each subdirectory name is the class label
# 0 is cancer, 1 is normal

transform_augmented = v2.Compose([v2.RandomHorizontalFlip(),
                                  v2.RandomVerticalFlip(),
                                  v2.RandomRotation(degrees=40),
                                  v2.RandomAffine(degrees=40, translate=(0.1, 0.1), shear=(-8,8,-8,8), scale=(0.9, 1.1)),
                                  ])  # Image Augmented Transformation

In [None]:
from sklearn.model_selection import train_test_split

image_paths = []
image_labels = []
for i in dataset.samples:
    image_paths.append(i[0])
    image_labels.append(i[1])

# split dataset into test set and (train set + validation set)
train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
    image_paths, image_labels, test_size=test_percent, stratify=image_labels, random_state=seed)

logger.info(
    f'we have {train_val_labels.count(0)} cancer cells and {train_val_labels.count(1)} normal cells, {len(train_val_labels)} cells in total, for training and validating')
logger.info(
    f'we have {test_labels.count(0)} cancer cells and {test_labels.count(1)} normal cells, {len(test_paths)} cells in total, for testing')

# save the test image to target folders
for i in range(len(test_paths)):
    for j in range(len(dataset.samples)):
        if dataset.samples[j][0] == test_paths[i] and dataset.samples[j][1] == 0:
            utils.save_image(dataset[j][0],
                             test_autoseg_noblue_cancer_path + test_paths[i].split('\\')[-1].split('.')[0] + "_test.png")
            break
        elif dataset.samples[j][0] == test_paths[i] and dataset.samples[j][1] == 1:
            utils.save_image(dataset[j][0],
                             test_autoseg_noblue_normal_path + test_paths[i].split('\\')[-1].split(".")[0] + "_test.png")
            break

In [None]:
from training_toolbox import compute_mean_std_noblue

# train and validation dataset for calculating the image mean and std for normalization
train_val_dataset = []
for i in range(len(train_val_paths)):
    for j in range(len(dataset.samples)):
        if dataset.samples[j][0] == train_val_paths[i]:
            train_val_dataset.append(dataset[j][0])
            break

if finetune:
    images_mean, images_std = compute_mean_std_noblue(train_val_dataset, channels)
else:
    images_mean = torch.tensor([0.485, 0.456, 0.406])
    images_std = torch.tensor([0.229, 0.224, 0.225])
logger.info(f"Mean: {images_mean}")
logger.info(f"Std: {images_std}")

# inverse_transform, to restore the images when saving them to whole folder and displaying them in XAI
transform_inverse = v2.Compose([v2.Normalize(
    mean=[-images_mean[0] / images_std[0], -images_mean[1] / images_std[1], -images_mean[2] / images_std[2]],
    std=[1 / images_std[0], 1 / images_std[1], 1 / images_std[2]])])  # when mean = images_mean and std = images_std

In [None]:
from training_toolbox import save_validation_images, generate_save_train_images, read_data, save_whole_image

# This transform: to tensor and normalization is used for all images
transform_whole_dataset = v2.Compose([v2.ToTensor(), v2.Normalize(mean=images_mean, std=images_std)])

if not cross_validation:
    # split train_val_dataset into train set and validation set
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        train_val_paths, train_val_labels, test_size=validation_percent / (1 - test_percent), stratify=train_val_labels,
        random_state=seed) # here is a shuffle

    logger.info(f'The cross validation openness is closed')
    logger.info(
        f'we have {train_labels.count(0)} cancer cells and {train_labels.count(1)} normal cells, {len(train_paths)} cells in total, for training')
    logger.info(
        f'we have {val_labels.count(0)} cancer cells and {val_labels.count(1)} normal cells, {len(val_paths)} cells in total, for validating')

    save_validation_images(val_paths, dataset, validation_autoseg_noblue_cancer_path, validation_autoseg_noblue_normal_path)
    generate_save_train_images(train_paths, dataset, train_autoseg_noblue_cancer_path, train_autoseg_noblue_normal_path, num_variations_per_image_0, num_variations_per_image_1, transform_augmented, logger)

    # read test, train, validate data from target folders
    test_dataset = []
    val_dataset = []
    train_dataset = []
    test, val, train = read_data(test_autoseg_noblue_path, validation_autoseg_noblue_path, train_autoseg_noblue_path, transform_whole_dataset)
    test_dataset.append(test)
    val_dataset.append(val)
    train_dataset.append(train)

    save_whole_image(val, test, validation_whole_autoseg_noblue_path, test_whole_autoseg_noblue_path, transform_inverse)
else:
    logger.info(f'The cross validation openness is opened')
    test_dataset = []
    val_dataset = []
    train_dataset = []
    skf = StratifiedKFold(n_splits=fold_num, shuffle=True, random_state=seed)
    for fold, (train_idx, val_idx) in enumerate(skf.split(train_val_paths, train_val_labels)):
        # use Stratified K-fold cross validation to create e.g. 10 folds, in each fold, the ratio between normal and cancer approximately keeping the same as the ratio in original image dataset.
        train_paths = [train_val_paths[i] for i in train_idx]
        train_labels = [train_val_labels[i] for i in train_idx]

        val_paths = [train_val_paths[i] for i in val_idx]
        val_labels = [train_val_labels[i] for i in val_idx]

        logger.info(f"In fold {fold}")
        logger.info(
        f'we have {train_labels.count(0)} cancer cells and {train_labels.count(1)} normal cells, {len(train_paths)} cells in total, for training')
        logger.info(
        f'we have {val_labels.count(0)} cancer cells and {val_labels.count(1)} normal cells, {len(val_paths)} cells in total, for validating')

        # new folders to save validation images for each fold.
        validation_cv_fold_path = validation_autoseg_noblue_cv_path + "cross validation " + str(fold)
        validation_cv_fold_split_path = validation_autoseg_noblue_cv_path + "cross validation " + str(fold) + '\\split'
        validation_cv_fold_whole_path = validation_autoseg_noblue_cv_path + "cross validation " + str(fold) + '\\whole'
        validation_cv_fold_split_cancer_path = validation_autoseg_noblue_cv_path + "cross validation " + str(fold) + '\\split\\cancer'
        validation_cv_fold_split_normal_path = validation_autoseg_noblue_cv_path + "cross validation " + str(fold) + '\\split\\normal'
        os.makedirs(validation_cv_fold_path, exist_ok=True)
        os.makedirs(validation_cv_fold_split_path, exist_ok=True)
        os.makedirs(validation_cv_fold_whole_path, exist_ok=True)
        os.makedirs(validation_cv_fold_split_cancer_path, exist_ok=True)
        os.makedirs(validation_cv_fold_split_normal_path, exist_ok=True)

        save_validation_images(val_paths, dataset, validation_cv_fold_split_cancer_path + '\\', validation_cv_fold_split_normal_path + '\\')

        # new folders to save train images for each fold.
        train_cv_fold_path = train_autoseg_noblue_cv_path + "cross validation " + str(fold)
        train_cv_fold_split_path = train_autoseg_noblue_cv_path + "cross validation " + str(fold) + '\\split'
        train_cv_fold_whole_path = train_autoseg_noblue_cv_path + "cross validation " + str(fold) + '\\whole'
        train_cv_fold_split_cancer_path = train_autoseg_noblue_cv_path + "cross validation " + str(fold) + '\\split\\cancer'
        train_cv_fold_split_normal_path = train_autoseg_noblue_cv_path + "cross validation " + str(fold) + '\\split\\normal'
        os.makedirs(train_cv_fold_path, exist_ok=True)
        os.makedirs(train_cv_fold_split_path, exist_ok=True)
        os.makedirs(train_cv_fold_whole_path, exist_ok=True)
        os.makedirs(train_cv_fold_split_cancer_path, exist_ok=True)
        os.makedirs(train_cv_fold_split_normal_path, exist_ok=True)

        generate_save_train_images(train_paths, dataset, train_cv_fold_split_cancer_path + '\\', train_cv_fold_split_normal_path + '\\', num_variations_per_image_0, num_variations_per_image_1, transform_augmented, logger)

        # read test, train, validate data from target folders
        test, val, train = read_data(test_autoseg_noblue_path, validation_cv_fold_split_path, train_cv_fold_split_path, transform_whole_dataset)
        test_dataset.append(test)
        val_dataset.append(val)
        train_dataset.append(train)

        save_whole_image(val, test, validation_cv_fold_whole_path + '\\', test_whole_autoseg_noblue_path, transform_inverse)

# Model training

In [None]:
from torch.optim.lr_scheduler import StepLR
from torch import nn, optim
from transformers import get_cosine_schedule_with_warmup
from sklearn.model_selection import ParameterGrid
from sklearn.metrics import roc_curve, roc_auc_score, confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
from torchvision.models import ViT_B_16_Weights
from training_toolbox import replace_classifier, get_cosine_with_warmup_tail, plot_loss_curves, inspect_model_and_optimizer

if not baseline:
    logger.info("#####################################################################")
    logger.info(f'Using model: {main_structure}')
    logger.info("#####################################################################")
    if finetune:
        param_grid = {
        'lr_backbone': [0.00002, 0.0001],
        'lr_head': [0.0002, 0.001],
        'weight_decay_backbone': [0.01, 0.001],
        'weight_decay_head': [0.000],
        'dropout_p': [0.3],
        'warmup_epoch': [5],
        'lr_decay_epoch': [40],
        'unfrozen_blocks': [[0,1,2,3,4,5,6,7,8,9,10,11]],
        'grad_clip': [False],
        'label_smoothing': [0.0]
        }
    else:
        param_grid = {
        'lr_backbone': [0.00002],
        'lr_head': [0.0002, 0.001],
        'weight_decay_backbone': [0.01],
        'weight_decay_head': [0.000, 0.001],
        'dropout_p': [0.3, 0.0],
        'warmup_epoch': [5],
        'lr_decay_epoch': [40],
        'unfrozen_blocks': [[0,1,2,3,4,5,6,7,8,9,10,11]],
        'grad_clip': [False],
        'label_smoothing': [0.0]
        }
    grid = list(ParameterGrid(param_grid)) # generate all the hyper-parameter combination
    gird_search_result = [] # store the cross-validation performance of each hyper-parameter set
    for set_num, hyper_params in enumerate(grid):

        config = {
        'lr_backbone': hyper_params['lr_backbone'],
        'lr_head': hyper_params['lr_head'],
        'weight_decay_backbone': hyper_params['weight_decay_backbone'],
        'weight_decay_head': hyper_params['weight_decay_head'],
        'dropout_p': hyper_params['dropout_p'],
        'num_epochs': 60,
        'warmup_epoch': hyper_params['warmup_epoch'],
        'lr_decay_epoch': hyper_params['lr_decay_epoch'],
        'unfrozen_blocks': hyper_params['unfrozen_blocks'],
        'grad_clip': hyper_params['grad_clip'],
        'label_smoothing': hyper_params['label_smoothing']
        } # configure all the hyper-parameters, including not for grid search

        logger.info(f"Running config {set_num + 1}/{len(grid)}: {config}")

        for fold in range(fold_num):
            logger.info('###############################################')
            logger.info(f'This is the fold: {fold}')
            logger.info('###############################################')

            g = torch.Generator()
            g.manual_seed(42 + fold)

            # Load the train, validate, and test dataset, ensure the sequences of training samples in the same fold of different hyperparameter sets to keep same.
            train_loader = DataLoader(train_dataset[fold], batch_size=batch_size, shuffle=True, generator=g, pin_memory=True)
            val_loader = DataLoader(val_dataset[fold], batch_size=batch_size, shuffle=False, pin_memory=True)
            test_loader = DataLoader(test_dataset[fold], batch_size=batch_size, shuffle=False, pin_memory=True)


            # Set gpu/cpu
            logger.info(f"Use device: {device}")

            if main_structure.lower() == 'dino':
                model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16', pretrained=True)
                num_features = model.embed_dim
            elif main_structure.lower() == 'dinov2':
                model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14', pretrained=True)
                num_features = model.embed_dim
            elif main_structure.lower() == 'dinov3':
                model = torch.hub.load(REPO_DINOV3, 'dinov3_vitb16', source='local', weights=WEIGHTS_DINOV3)
                num_features = model.embed_dim
            elif main_structure.lower() == 'vit':
                model = torch.hub.load('pytorch/vision', 'vit_b_16', pretrained=True)
                model, num_features = replace_classifier(model)
            model = model.to(device)


            class ModifiedViT(nn.Module):
                def __init__(self, base_model, in_features):
                    super(ModifiedViT, self).__init__()
                    self.base_model = base_model
                    self.head = nn.Sequential(nn.Dropout(config['dropout_p']),
                                              nn.Linear(in_features, 2))  # modify the head from Identify to 2-class classification (Linear Layer) and add drop out (included in grid search)

                def forward(self, x):
                    # the features from ViT backbone (batch size, 768 (dim of class token))
                    features = self.base_model(x)
                    # pass the classification head
                    return self.head(features)


            model = ModifiedViT(model, num_features).to(device)

            if finetune:
                if main_structure.lower() == 'vit':
                    for param in model.base_model.encoder.layers.parameters():
                        param.requires_grad = False
                    for layer_index in config['unfrozen_blocks']:
                        layer = model.base_model.encoder.layers[layer_index]
                        # only un-froze the last few layers (included in grid search)
                        for param in layer.parameters():
                            param.requires_grad = True
                else:
                    for param in model.base_model.blocks.parameters():
                        param.requires_grad = False
                    for layer_index in config['unfrozen_blocks']:
                        layer = model.base_model.blocks[layer_index]
                        # only un-froze the last few layers (included in grid search)
                        for param in layer.parameters():
                            param.requires_grad = True
                for param in model.head.parameters(): # un froze the classification head
                    param.requires_grad = True
            else:
                for param in model.parameters():
                    param.requires_grad = False
                for param in model.head.parameters(): # un froze the classification head
                    param.requires_grad = True


            class EarlyStopping:
                def __init__(self, patience=5, mode='loss', min_delta=0.0, fold=0, epoch_number = 10, name = None, hyper_index = 0):
                    self.patience = patience
                    self.mode = mode
                    self.fold = fold
                    self.min_delta = min_delta
                    self.counter = 0
                    self.epoch_number = epoch_number
                    self.name = name
                    self.hyper_index = hyper_index
                    self.best = None
                    self.early_stop = False

                def __call__(self, val_loss, accuracy, epoch):
                    logger.info("the mode of early stopping is " + self.mode)
                    if self.mode == 'loss':
                        if self.best is None:
                            logger.info(f"This is the first epoch {epoch + 1}!")
                            self.best = val_loss
                            torch.save(model.state_dict(), self.name + "_best_model_parameter_" + str(self.hyper_index) + "_" + str(self.fold) + ".pth")
                            logger.info(f"The best val_loss is: {self.best}")
                        elif val_loss < self.best - self.min_delta:
                            logger.info(f"Epoch {epoch + 1} Early Stop Check Pass!")
                            self.best = val_loss
                            self.counter = 0
                            torch.save(model.state_dict(), self.name + "_best_model_parameter_" + str(self.hyper_index) + "_" + str(self.fold) + ".pth")
                            logger.info(f"The best val_loss is: {self.best}")
                        else:
                            self.counter += 1
                            logger.info(
                                f"Epoch {epoch + 1} not pass the Early Stopping! EarlyStopping counter: {self.counter} / {self.patience}")
                            logger.info(f"The best val_loss is still: {self.best}")
                            if self.counter >= self.patience or (epoch + 1) == self.epoch_number:
                                self.early_stop = True
                    elif self.mode == 'accuracy':
                        if self.best is None:
                            logger.info(f"This is the first epoch {epoch + 1}!")
                            self.best = [val_loss, accuracy]
                            torch.save(model.state_dict(), self.name + "_best_model_parameter_" + str(self.hyper_index) + "_" + str(self.fold) + ".pth")
                            logger.info(f"The best accuracy / weighted f1 score is: {self.best[1]}")
                        elif accuracy > self.best[1]:
                            logger.info(f"Epoch {epoch + 1} Early Stop Check Pass!")
                            self.best = [val_loss, accuracy]
                            self.counter = 0
                            torch.save(model.state_dict(), self.name + "_best_model_parameter_" + str(self.hyper_index) + "_" + str(self.fold) + ".pth")
                            logger.info(f"The best accuracy / weighted f1 score is: {self.best[1]}")
                        elif accuracy == self.best[1] and val_loss < self.best[0]:
                            logger.info(f"Epoch {epoch + 1} Early Stop Check Pass! val_loss is smaller although the accuracy / weighted f1 score keeps the same.")
                            self.best = [val_loss, accuracy]
                            self.counter = 0
                            torch.save(model.state_dict(), self.name + "_best_model_parameter_" + str(self.hyper_index) + "_" + str(self.fold) + ".pth")
                            logger.info(f"The best accuracy / weighted f1 score is: {self.best[1]}")
                        else:
                            self.counter += 1
                            logger.info(
                                f"Epoch {epoch + 1} not pass the Early Stopping! EarlyStopping counter: {self.counter} / {self.patience}")
                            logger.info(f"The best accuracy / weighted f1 score is still: {self.best[1]}")
                            if self.counter >= self.patience or (epoch + 1) == self.epoch_number:
                                self.early_stop = True
                    elif self.mode == 'loss and accuracy':
                        if self.best is None:
                            logger.info(f"This is the first epoch {epoch + 1}!")
                            self.best = [val_loss, accuracy]
                            torch.save(model.state_dict(), self.name + "_best_model_parameter_" + str(self.hyper_index) + "_" + str(self.fold) + ".pth")
                            logger.info(f"The best val_loss and accuracy / weighted f1 score are: {self.best[0]}, {self.best[1]}")
                        elif val_loss < self.best[0] - self.min_delta and accuracy > self.best[1]:
                            logger.info(f"Epoch {epoch + 1} Early Stop Check Pass!")
                            self.best = [val_loss, accuracy]
                            self.counter = 0
                            torch.save(model.state_dict(), self.name + "_best_model_parameter_" + str(self.hyper_index) + "_" + str(self.fold) + ".pth")
                            logger.info(f"The best val_loss and accuracy / weighted f1 score are: {self.best[0]}, {self.best[1]}")
                        else:
                            self.counter += 1
                            logger.info(
                                f"Epoch {epoch + 1} not pass the Early Stopping! EarlyStopping counter: {self.counter} / {self.patience}")
                            logger.info(f"The best val_loss and accuracy / weighted f1 score are still: {self.best[0]}, {self.best[1]}")
                            if self.counter >= self.patience or (epoch + 1) == self.epoch_number:
                                self.early_stop = True
                    elif self.mode == 'loss or accuracy':
                        if self.best is None:
                            logger.info(f"This is the first epoch {epoch + 1}!")
                            self.best = [val_loss, accuracy]
                            torch.save(model.state_dict(), self.name + "_best_model_parameter_" + str(self.hyper_index) + "_" + str(self.fold) + ".pth")
                            logger.info(f"The best val_loss and accuracy / weighted f1 score are: {self.best[0]}, {self.best[1]}")
                        elif val_loss < self.best[0] - self.min_delta or accuracy > self.best[1]:
                            logger.info(f"Epoch {epoch + 1} Early Stop Check Pass!")
                            self.best = [val_loss, accuracy]
                            self.counter = 0
                            torch.save(model.state_dict(), self.name + "_best_model_parameter_" + str(self.hyper_index) + "_" + str(self.fold) + ".pth")
                            logger.info(f"The best val_loss and accuracy / weighted f1 score are: {self.best[0]}, {self.best[1]}")
                        else:
                            self.counter += 1
                            logger.info(
                                f"Epoch {epoch + 1} not pass the Early Stopping! EarlyStopping counter: {self.counter} / {self.patience}")
                            logger.info(f"The best val_loss and accuracy / weighted f1 score are still: {self.best[0]}, {self.best[1]}")
                            if self.counter >= self.patience or (epoch + 1) == self.epoch_number:
                                self.early_stop = True

            num_epochs = config['num_epochs']
            warm_up_epochs = config['warmup_epoch']
            total_steps = (config['warmup_epoch'] + config['lr_decay_epoch']) * len(train_loader)
            warm_up_steps = warm_up_epochs * len(train_loader)
            # Set L.F., optimizer
            train_criterion = nn.CrossEntropyLoss(label_smoothing=config['label_smoothing'])
            val_criterion = nn.CrossEntropyLoss(reduction="sum")
            optimizer = optim.AdamW([{'params': filter(lambda p: p.requires_grad, model.base_model.parameters()), 'weight_decay': config['weight_decay_backbone'], 'lr': config['lr_backbone']}, {'params': model.head.parameters(), 'weight_decay': config['weight_decay_head'], 'lr': config['lr_head']}])
            scheduler = get_cosine_with_warmup_tail(optimizer, num_warmup_steps=warm_up_steps, num_training_steps=total_steps, min_lr_factor=0.1)  # the learning rate will warm-up firstly, then cosine decay to 0.1 of the initial lr


            # train func
            def train(model, loader, criterion, optimizer, device, epoch, train_dataset):
                model.train()  # set model to train mode
                running_loss = 0.0
                for i, data in enumerate(loader, 0):  # each batch input
                    inputs, targets = data
                    if i == 0:
                        logger.info(f"First train batch labels: {targets}")
                    inputs, targets = inputs.to(device), targets.to(device)
                    optimizer.zero_grad()
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    loss.backward()  # back propagation
                    if config['grad_clip']:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # gradients clip
                    optimizer.step()  # update parameter
                    running_loss += loss.item() * targets.size(0)
                    logger.info("In epoch " + str(epoch + 1) + ", batch: " + str(i + 1) + ", average loss per image: " + str(
                        loss.item()))

                    correct_train = 0
                    _, predicted_train = outputs.max(1)  # model predicted class
                    correct_train += (predicted_train == targets).sum().item()
                    logger.info("Accuracy of the network on the train set: " + str(correct_train / targets.size(0)))

                    scheduler.step()  # regularize the learning rate

                return running_loss / len(train_dataset)


            # validate func
            def validate(model, loader, criterion, device, val_dataset):
                model.eval()  # set model to validate mode
                correct = 0
                total = 0
                false = []
                running_loss = 0.0
                y_true = []
                y_pred = []
                with torch.no_grad():
                    for i, data in enumerate(loader, 0):
                        inputs, targets = data
                        inputs, targets = inputs.to(device), targets.to(device)
                        outputs = model(inputs)
                        _, predicted = outputs.max(1)  # model predicted class
                        correct += (predicted == targets).sum().item()
                        total += targets.size(0)
                        loss = criterion(outputs, targets)
                        running_loss += loss.item() # the sum of val_loss in one batch

                        for target in targets:
                            y_true.append(target.item())
                        for predict in predicted:
                            y_pred.append(predict.item())

                        for result in range(len(predicted)): # collect wrong samples
                            if predicted[result] != targets[result]:
                                false.append(i * batch_size + result)

                    y_true = np.array(y_true)
                    y_pred = np.array(y_pred)

                return correct / total, false, running_loss / total, f1_score(y_true, y_pred, average='weighted')


            # train and validate
            early_stopping = EarlyStopping(patience=10, mode=early_stop_mode, fold=fold, epoch_number = num_epochs, name = main_structure, hyper_index = set_num)
            wrong_number = []

            training_loss_list = []
            validating_loss_list = []
            for epoch in range(num_epochs):
                logger.info(f"The learning rate of backbone is: {optimizer.param_groups[0]['lr']}, of head is {optimizer.param_groups[1]['lr']}")
                train_loss = train(model, train_loader, train_criterion, optimizer, device, epoch, train_dataset[fold])
                accuracy, wrong_predicted, val_loss, weighted_f1 = validate(model, val_loader, val_criterion, device, val_dataset[fold])
                wrong_number.append(wrong_predicted)
                logger.info("====================" + str(epoch + 1) + "====================")
                logger.info(f'Epoch {epoch + 1}/{num_epochs}')
                logger.info(f'Average train loss per image: {train_loss:.7f}')
                logger.info(f'Average validate loss per image: {val_loss:.7f}')
                logger.info(f'Validate accuracy: {accuracy:.4f}')
                logger.info("====================" + str(epoch + 1) + "====================")

                training_loss_list.append(train_loss)
                validating_loss_list.append(val_loss)
                early_stopping(val_loss, weighted_f1, epoch)

                if early_stopping.early_stop:
                    logger.info(" ðŸ”¥ Early stopping, Stop Training")
                    logger.info(f"select the epoch: {epoch - early_stopping.counter + 1}")
                    for wrong_result in wrong_number[epoch - early_stopping.counter]:
                        logger.info("The number " + str(wrong_result) + " is wrong!")
                    break

                if epoch == num_epochs - 1:
                    logger.info("train until the last epoch!")
                    for wrong_result in wrong_predicted:
                        logger.info("The number " + str(wrong_result) + " is wrong!")

            plot_loss_curves(training_loss_list, validating_loss_list, fold_number = fold, hyper_setnum=set_num, model_type = main_structure)

        inspect_model_and_optimizer(model, optimizer, logger)

        logger.info("#####################################################################")
        logger.info(f'Using model: {main_structure}')
        logger.info("#####################################################################")

        model_accuracy = []
        model_recall = []
        model_precision = []
        model_specificity = []
        model_auc_roc_macro = []
        model_auc_roc_micro = []
        model_auc_roc_weighted = []
        model_f1_macro = []
        model_f1_micro = []
        model_f1_weighted = []


        for fold in range(fold_num):
            model.load_state_dict(torch.load(main_structure + "_best_model_parameter_" + str(set_num) + "_" + str(fold) + ".pth"))
            model.eval()
            y_true = []
            y_pred = []
            y_prob = []

            for data in range(len(val_dataset[fold])):
                y_true.append(val_dataset[fold][data][1])

                inputs = val_dataset[fold][data][0]
                inputs = inputs.to(device)
                input_tensor = inputs.unsqueeze(0).to(device)
                outputs = model(input_tensor)
                predicted_class = outputs.argmax(dim=1).item()
                y_pred.append(predicted_class)

                prob = torch.softmax(outputs, dim=1)[:, 1].item()
                y_prob.append(prob)

            y_true = np.array(y_true)
            y_pred = np.array(y_pred)
            y_prob = np.array(y_prob)

            fpr, tpr, thresholds = roc_curve(y_true, y_prob)
            roc_auc = auc(fpr, tpr)

            plt.figure()
            plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
            plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
            plt.xlim([0.0, 1.0])
            plt.ylim([0.0, 1.05])
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive Rate')
            plt.title('Receiver Operating Characteristic')
            plt.legend(loc="lower right")
            plt.show()

            recall = recall_score(y_true, y_pred)
            precision = precision_score(y_true, y_pred)
            auc_roc_macro = roc_auc_score(y_true, y_prob)
            auc_roc_micro = roc_auc_score(y_true, y_prob)
            auc_roc_weighted = roc_auc_score(y_true, y_prob)
            f1_macro = f1_score(y_true, y_pred, average='macro')
            f1_micro = f1_score(y_true, y_pred, average='micro')
            f1_weighted = f1_score(y_true, y_pred, average='weighted')

            accuracy = 0.0
            for i in range(len(y_pred)):
                if y_pred[i] == y_true[i]:
                    accuracy += 1.0
            tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
            specificity = tn / (tn + fp)
            model_accuracy.append(accuracy/len(y_pred))
            model_recall.append(recall)
            model_precision.append(precision)
            model_specificity.append(specificity)
            model_auc_roc_macro.append(auc_roc_macro)
            model_auc_roc_micro.append(auc_roc_micro)
            model_auc_roc_weighted.append(auc_roc_weighted)
            model_f1_macro.append(f1_macro)
            model_f1_micro.append(f1_micro)
            model_f1_weighted.append(f1_weighted)

            logger.info(f"Accuracy: {accuracy/len(y_pred):.4f}")
            logger.info(f"Recall: {recall:.4f}")
            logger.info(f"Precision: {precision:.4f}")
            logger.info(f"Specificity: {specificity:.4f}")
            logger.info(f"AUC-ROC Macro: {auc_roc_macro:.4f}")
            logger.info(f"AUC-ROC Micro: {auc_roc_micro:.4f}")
            logger.info(f"AUC-ROC Weighted: {auc_roc_weighted:.4f}")
            logger.info(f"F1 Macro: {f1_macro:.4f}")
            logger.info(f"F1 Micro: {f1_micro:.4f}")
            logger.info(f"F1 Weighted: {f1_weighted:.4f}")

        logger.info(f"The mean and std of accuracy are: {np.array(model_accuracy).mean()}, and {np.array(model_accuracy).std()}")
        logger.info(f"The mean and std of recall are: {np.array(model_recall).mean()}, and {np.array(model_recall).std()}")
        logger.info(f"The mean and std of precision are: {np.array(model_precision).mean()}, and {np.array(model_precision).std()}")
        logger.info(f"The mean and std of specificity are: {np.array(model_specificity).mean()}, and {np.array(model_specificity).std()}")
        logger.info(f"The mean and std of auc_roc_macro are: {np.array(model_auc_roc_macro).mean()}, and {np.array(model_auc_roc_macro).std()}")
        logger.info(f"The mean and std of auc_roc_micro are: {np.array(model_auc_roc_micro).mean()}, and {np.array(model_auc_roc_micro).std()}")
        logger.info(f"The mean and std of auc_roc_weighted are: {np.array(model_auc_roc_weighted).mean()}, and {np.array(model_auc_roc_weighted).std()}")
        logger.info(f"The mean and std of f1_macro are: {np.array(model_f1_macro).mean()}, and {np.array(model_f1_macro).std()}")
        logger.info(f"The mean and std of f1_micro are: {np.array(model_f1_micro).mean()}, and {np.array(model_f1_micro).std()}")
        logger.info(f"The mean and std of f1_weighted are: {np.array(model_f1_weighted).mean()}, and {np.array(model_f1_weighted).std()}")

        gird_search_result.append({'model name': main_structure,
                                   'hyperparams': hyper_params,
                                   'accuracy mean':np.array(model_accuracy).mean(),
                                   'accuracy std': np.array(model_accuracy).std(),
                                   'recall mean': np.array(model_recall).mean(),
                                   'recall std': np.array(model_recall).std(),
                                   'precision mean': np.array(model_precision).mean(),
                                   'precision std': np.array(model_precision).std(),
                                   "specificity mean": np.array(model_specificity).mean(),
                                   "specificity std": np.array(model_specificity).std(),
                                   'auc_roc_macro mean': np.array(model_auc_roc_macro).mean(),
                                   'auc_roc_macro std': np.array(model_auc_roc_macro).std(),
                                   'auc_roc_micro mean': np.array(model_auc_roc_micro).mean(),
                                   'auc_roc_micro std': np.array(model_auc_roc_micro).std(),
                                   'auc_roc_weighted mean': np.array(model_auc_roc_weighted).mean(),
                                   'auc_roc_weighted std': np.array(model_auc_roc_weighted).std(),
                                   'f1_macro mean': np.array(model_f1_macro).mean(),
                                   'f1_macro std': np.array(model_f1_macro).std(),
                                   'f1_micro mean': np.array(model_f1_micro).mean(),
                                   'f1_micro std': np.array(model_f1_micro).std(),
                                   'f1_weighted mean': np.array(model_f1_weighted).mean(),
                                   'f1_weighted std': np.array(model_f1_weighted).std()})

else:
    if finetune:
        param_grid = {
        'lr_backbone': [0.00002, 0.0001],
        'lr_head': [0.0002, 0.001],
        'weight_decay_backbone': [0.01, 0.001],
        'weight_decay_head': [0.000],
        'dropout_p': [0.3],
        'warmup_epoch': [5],
        'lr_decay_epoch': [40],
        'unfrozen_blocks': [[0,1,2,3,4,5,6,7,8,9,10,11]],
        'grad_clip': [False],
        'label_smoothing': [0.0]
        }
    else:
        param_grid = {
        'lr_backbone': [0.00002],
        'lr_head': [0.0002, 0.001],
        'weight_decay_backbone': [0.01],
        'weight_decay_head': [0.000, 0.001],
        'dropout_p': [0.3, 0.0],
        'warmup_epoch': [5],
        'lr_decay_epoch': [40],
        'unfrozen_blocks': [[0,1,2,3,4,5,6,7,8,9,10,11]],
        'grad_clip': [False],
        'label_smoothing': [0.0]
        }
    grid = list(ParameterGrid(param_grid)) # generate all the hyper-parameter combination
    gird_search_result = [] # store the cross-validation performance of each hyper-parameter set
    for set_num, hyper_params in enumerate(grid):

        config = {
        'lr_backbone': hyper_params['lr_backbone'],
        'lr_head': hyper_params['lr_head'],
        'weight_decay_backbone': hyper_params['weight_decay_backbone'],
        'weight_decay_head': hyper_params['weight_decay_head'],
        'dropout_p': hyper_params['dropout_p'],
        'num_epochs': 60,
        'warmup_epoch': hyper_params['warmup_epoch'],
        'lr_decay_epoch': hyper_params['lr_decay_epoch'],
        'unfrozen_blocks': hyper_params['unfrozen_blocks'],
        'grad_clip': hyper_params['grad_clip'],
        'label_smoothing': hyper_params['label_smoothing']
        } # configure all the hyper-parameters, including not for grid search

        logger.info(f"Running config {set_num + 1}/{len(grid)}: {config}")

        for model_name in baseline_model:
            logger.info("#####################################################################")
            logger.info(f'Using model: {model_name}')
            logger.info("#####################################################################")

            for fold in range(fold_num):
                logger.info('###############################################')
                logger.info(f'This is the fold: {fold}')
                logger.info('###############################################')

                g = torch.Generator()
                g.manual_seed(42 + fold)

                # Load the train, validate, and test dataset, ensure the sequences of training samples in the same fold of different hyperparameter sets to keep same.
                train_loader = DataLoader(train_dataset[fold], batch_size=batch_size, shuffle=True, generator=g, pin_memory=True)
                val_loader = DataLoader(val_dataset[fold], batch_size=batch_size, shuffle=False, pin_memory=True)
                test_loader = DataLoader(test_dataset[fold], batch_size=batch_size, shuffle=False, pin_memory=True)


                # Set gpu/cpu
                logger.info(f"Use device: {device}")

                model = torch.hub.load('pytorch/vision', model_name, pretrained=True)

                model, num_features = replace_classifier(model)
                model = model.to(device)

                class ModifiedCNN(nn.Module):
                    def __init__(self, base_model, in_features):
                        super(ModifiedCNN, self).__init__()
                        self.base_model = base_model
                        self.head = nn.Sequential(nn.Dropout(config['dropout_p']),
                                                  nn.Linear(in_features, 2))  # modify the head from Identify to 2-class classification (Linear Layer) and add drop out (included in grid search)

                    def forward(self, x):
                        # the features from ViT backbone (batch size, 768 (dim of class token))
                        features = self.base_model(x)
                        # pass the classification head
                        return self.head(features)

                model = ModifiedCNN(model, num_features).to(device)

                if finetune:
                    for param in model.parameters():
                        param.requires_grad = True
                else:
                    for param in model.parameters():
                        param.requires_grad = False
                    for param in model.head.parameters(): # un froze the classification head
                        param.requires_grad = True
                    model.base_model.eval()

                class EarlyStopping:
                    def __init__(self, patience=5, mode='loss', min_delta=0.0, fold=0, epoch_number = 10, name = None, hyper_index = 0):
                        self.patience = patience
                        self.mode = mode
                        self.fold = fold
                        self.min_delta = min_delta
                        self.counter = 0
                        self.epoch_number = epoch_number
                        self.name = name
                        self.hyper_index = hyper_index
                        self.best = None
                        self.early_stop = False

                    def __call__(self, val_loss, accuracy, epoch):
                        logger.info("the mode of early stopping is " + self.mode)
                        if self.mode == 'loss':
                            if self.best is None:
                                logger.info(f"This is the first epoch {epoch + 1}!")
                                self.best = val_loss
                                torch.save(model.state_dict(), self.name + "_best_model_parameter_" + str(self.hyper_index) + "_" + str(self.fold) + ".pth")
                                logger.info(f"The best val_loss is: {self.best}")
                            elif val_loss < self.best - self.min_delta:
                                logger.info(f"Epoch {epoch + 1} Early Stop Check Pass!")
                                self.best = val_loss
                                self.counter = 0
                                torch.save(model.state_dict(), self.name + "_best_model_parameter_" + str(self.hyper_index) + "_" + str(self.fold) + ".pth")
                                logger.info(f"The best val_loss is: {self.best}")
                            else:
                                self.counter += 1
                                logger.info(
                                    f"Epoch {epoch + 1} not pass the Early Stopping! EarlyStopping counter: {self.counter} / {self.patience}")
                                logger.info(f"The best val_loss is still: {self.best}")
                                if self.counter >= self.patience or (epoch + 1) == self.epoch_number:
                                    self.early_stop = True
                        elif self.mode == 'accuracy':
                            if self.best is None:
                                logger.info(f"This is the first epoch {epoch + 1}!")
                                self.best = [val_loss, accuracy]
                                torch.save(model.state_dict(), self.name + "_best_model_parameter_" + str(self.hyper_index) + "_" + str(self.fold) + ".pth")
                                logger.info(f"The best accuracy / weighted f1 score is: {self.best[1]}")
                            elif accuracy > self.best[1]:
                                logger.info(f"Epoch {epoch + 1} Early Stop Check Pass!")
                                self.best = [val_loss, accuracy]
                                self.counter = 0
                                torch.save(model.state_dict(), self.name + "_best_model_parameter_" + str(self.hyper_index) + "_" + str(self.fold) + ".pth")
                                logger.info(f"The best accuracy / weighted f1 score is: {self.best[1]}")
                            elif accuracy == self.best[1] and val_loss < self.best[0]:
                                logger.info(f"Epoch {epoch + 1} Early Stop Check Pass! val_loss is smaller although the accuracy / weighted f1 score keeps the same.")
                                self.best = [val_loss, accuracy]
                                self.counter = 0
                                torch.save(model.state_dict(), self.name + "_best_model_parameter_" + str(self.hyper_index) + "_" + str(self.fold) + ".pth")
                                logger.info(f"The best accuracy / weighted f1 score is: {self.best[1]}")
                            else:
                                self.counter += 1
                                logger.info(
                                    f"Epoch {epoch + 1} not pass the Early Stopping! EarlyStopping counter: {self.counter} / {self.patience}")
                                logger.info(f"The best accuracy / weighted f1 score is still: {self.best[1]}")
                                if self.counter >= self.patience or (epoch + 1) == self.epoch_number:
                                    self.early_stop = True
                        elif self.mode == 'loss and accuracy':
                            if self.best is None:
                                logger.info(f"This is the first epoch {epoch + 1}!")
                                self.best = [val_loss, accuracy]
                                torch.save(model.state_dict(), self.name + "_best_model_parameter_" + str(self.hyper_index) + "_" + str(self.fold) + ".pth")
                                logger.info(f"The best val_loss and accuracy / weighted f1 score are: {self.best[0]}, {self.best[1]}")
                            elif val_loss < self.best[0] - self.min_delta and accuracy > self.best[1]:
                                logger.info(f"Epoch {epoch + 1} Early Stop Check Pass!")
                                self.best = [val_loss, accuracy]
                                self.counter = 0
                                torch.save(model.state_dict(), self.name + "_best_model_parameter_" + str(self.hyper_index) + "_" + str(self.fold) + ".pth")
                                logger.info(f"The best val_loss and accuracy / weighted f1 score are: {self.best[0]}, {self.best[1]}")
                            else:
                                self.counter += 1
                                logger.info(
                                    f"Epoch {epoch + 1} not pass the Early Stopping! EarlyStopping counter: {self.counter} / {self.patience}")
                                logger.info(f"The best val_loss and accuracy / weighted f1 score are still: {self.best[0]}, {self.best[1]}")
                                if self.counter >= self.patience or (epoch + 1) == self.epoch_number:
                                    self.early_stop = True
                        elif self.mode == 'loss or accuracy':
                            if self.best is None:
                                logger.info(f"This is the first epoch {epoch + 1}!")
                                self.best = [val_loss, accuracy]
                                torch.save(model.state_dict(), self.name + "_best_model_parameter_" + str(self.hyper_index) + "_" + str(self.fold) + ".pth")
                                logger.info(f"The best val_loss and accuracy / weighted f1 score are: {self.best[0]}, {self.best[1]}")
                            elif val_loss < self.best[0] - self.min_delta or accuracy > self.best[1]:
                                logger.info(f"Epoch {epoch + 1} Early Stop Check Pass!")
                                self.best = [val_loss, accuracy]
                                self.counter = 0
                                torch.save(model.state_dict(), self.name + "_best_model_parameter_" + str(self.hyper_index) + "_" + str(self.fold) + ".pth")
                                logger.info(f"The best val_loss and accuracy / weighted f1 score are: {self.best[0]}, {self.best[1]}")
                            else:
                                self.counter += 1
                                logger.info(
                                    f"Epoch {epoch + 1} not pass the Early Stopping! EarlyStopping counter: {self.counter} / {self.patience}")
                                logger.info(f"The best val_loss and accuracy / weighted f1 score are still: {self.best[0]}, {self.best[1]}")
                                if self.counter >= self.patience or (epoch + 1) == self.epoch_number:
                                    self.early_stop = True

                num_epochs = config['num_epochs']
                warm_up_epochs = config['warmup_epoch']
                total_steps = (config['warmup_epoch'] + config['lr_decay_epoch']) * len(train_loader)
                warm_up_steps = warm_up_epochs * len(train_loader)
                # Set L.F., optimizer
                train_criterion = nn.CrossEntropyLoss(label_smoothing=config['label_smoothing'])
                val_criterion = nn.CrossEntropyLoss(reduction="sum")
                optimizer = optim.AdamW([{'params': filter(lambda p: p.requires_grad, model.base_model.parameters()), 'weight_decay': config['weight_decay_backbone'], 'lr': config['lr_backbone']}, {'params': model.head.parameters(), 'weight_decay': config['weight_decay_head'], 'lr': config['lr_head']}])
                scheduler = get_cosine_with_warmup_tail(optimizer, num_warmup_steps=warm_up_steps, num_training_steps=total_steps, min_lr_factor=0.1)  # the learning rate will warm-up firstly, then cosine decay to 0.1 of the initial lr


                # train func
                def train(model, loader, criterion, optimizer, device, epoch, train_dataset):
                    model.train()  # set model to train mode
                    if not finetune:
                        model.base_model.eval()
                    running_loss = 0.0
                    for i, data in enumerate(loader, 0):  # each batch input
                        inputs, targets = data
                        if i == 0:
                            logger.info(f"First train batch labels: {targets}")
                        inputs, targets = inputs.to(device), targets.to(device)
                        optimizer.zero_grad()
                        outputs = model(inputs)
                        loss = criterion(outputs, targets)
                        loss.backward()  # back propagation
                        if config['grad_clip']:
                            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # gradients clip
                        optimizer.step()  # update parameter
                        running_loss += loss.item() * targets.size(0)
                        logger.info("In epoch " + str(epoch + 1) + ", batch: " + str(i + 1) + ", average loss per image: " + str(
                            loss.item()))

                        correct_train = 0
                        _, predicted_train = outputs.max(1)  # model predicted class
                        correct_train += (predicted_train == targets).sum().item()
                        logger.info("Accuracy of the network on the train set: " + str(correct_train / targets.size(0)))

                        scheduler.step()  # regularize the learning rate

                    return running_loss / len(train_dataset)


                # validate func
                def validate(model, loader, criterion, device, val_dataset):
                    model.eval()  # set model to validate mode
                    correct = 0
                    total = 0
                    false = []
                    running_loss = 0.0
                    y_true = []
                    y_pred = []
                    with torch.no_grad():
                        for i, data in enumerate(loader, 0):
                            inputs, targets = data
                            inputs, targets = inputs.to(device), targets.to(device)
                            outputs = model(inputs)
                            _, predicted = outputs.max(1)  # model predicted class
                            correct += (predicted == targets).sum().item()
                            total += targets.size(0)
                            loss = criterion(outputs, targets)
                            running_loss += loss.item() # the sum of val_loss in one batch

                            for target in targets:
                                y_true.append(target.item())
                            for predict in predicted:
                                y_pred.append(predict.item())

                            for result in range(len(predicted)): # collect wrong samples
                                if predicted[result] != targets[result]:
                                    false.append(i * batch_size + result)

                        y_true = np.array(y_true)
                        y_pred = np.array(y_pred)

                    return correct / total, false, running_loss / total, f1_score(y_true, y_pred, average='weighted')


                # train and validate
                early_stopping = EarlyStopping(patience=10, mode=early_stop_mode, fold=fold, epoch_number = num_epochs, name = model_name, hyper_index = set_num)
                wrong_number = []

                training_loss_list = []
                validating_loss_list = []
                for epoch in range(num_epochs):
                    logger.info(f"The learning rate of backbone is: {optimizer.param_groups[0]['lr']}, of head is {optimizer.param_groups[1]['lr']}")
                    train_loss = train(model, train_loader, train_criterion, optimizer, device, epoch, train_dataset[fold])
                    accuracy, wrong_predicted, val_loss, weighted_f1 = validate(model, val_loader, val_criterion, device, val_dataset[fold])
                    wrong_number.append(wrong_predicted)
                    logger.info("====================" + str(epoch + 1) + "====================")
                    logger.info(f'Epoch {epoch + 1}/{num_epochs}')
                    logger.info(f'Average train loss per image: {train_loss:.7f}')
                    logger.info(f'Average validate loss per image: {val_loss:.7f}')
                    logger.info(f'Validate accuracy: {accuracy:.4f}')
                    logger.info("====================" + str(epoch + 1) + "====================")

                    training_loss_list.append(train_loss)
                    validating_loss_list.append(val_loss)

                    early_stopping(val_loss, weighted_f1, epoch)

                    if early_stopping.early_stop:
                        logger.info(" ðŸ”¥ Early stopping, Stop Training")
                        logger.info(f"select the epoch: {epoch - early_stopping.counter + 1}")
                        for wrong_result in wrong_number[epoch - early_stopping.counter]:
                            logger.info("The number " + str(wrong_result) + " is wrong!")
                        break

                    if epoch == num_epochs - 1:
                        logger.info("train until the last epoch!")
                        for wrong_result in wrong_predicted:
                            logger.info("The number " + str(wrong_result) + " is wrong!")

                plot_loss_curves(training_loss_list, validating_loss_list, fold_number = fold, hyper_setnum=set_num, model_type = model_name)

            inspect_model_and_optimizer(model, optimizer, logger)

        for model_name in baseline_model:
            logger.info("#####################################################################")
            logger.info(f'Using model: {model_name}')
            logger.info("#####################################################################")
            model_accuracy = []
            model_recall = []
            model_precision = []
            model_specificity = []
            model_auc_roc_macro = []
            model_auc_roc_micro = []
            model_auc_roc_weighted = []
            model_f1_macro = []
            model_f1_micro = []
            model_f1_weighted = []

            model = torch.hub.load('pytorch/vision', model_name, pretrained=True)

            model, num_features = replace_classifier(model)
            model = model.to(device)

            class ModifiedCnnVal(nn.Module):
                def __init__(self, base_model, in_features):
                    super(ModifiedCnnVal, self).__init__()
                    self.base_model = base_model
                    self.head = nn.Sequential(nn.Dropout(config['dropout_p']),
                                              nn.Linear(in_features, 2))  # modify the head from Identify to 2-class classification (Linear Layer) and add drop out (included in grid search)

                def forward(self, x):
                    # the features from ViT backbone (batch size, 768 (dim of class token))
                    features = self.base_model(x)
                    # pass the classification head
                    return self.head(features)

            model = ModifiedCnnVal(model, num_features).to(device)


            for fold in range(fold_num):
                model.load_state_dict(torch.load(model_name + "_best_model_parameter_" + str(set_num) + "_" + str(fold) + ".pth"))
                model.eval()
                y_true = []
                y_pred = []
                y_prob = []

                for data in range(len(val_dataset[fold])):
                    y_true.append(val_dataset[fold][data][1])

                    inputs = val_dataset[fold][data][0]
                    inputs = inputs.to(device)
                    input_tensor = inputs.unsqueeze(0).to(device)
                    outputs = model(input_tensor)
                    predicted_class = outputs.argmax(dim=1).item()
                    y_pred.append(predicted_class)

                    prob = torch.softmax(outputs, dim=1)[:, 1].item()
                    y_prob.append(prob)

                y_true = np.array(y_true)
                y_pred = np.array(y_pred)
                y_prob = np.array(y_prob)

                fpr, tpr, thresholds = roc_curve(y_true, y_prob)
                roc_auc = auc(fpr, tpr)

                plt.figure()
                plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
                plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
                plt.xlim([0.0, 1.0])
                plt.ylim([0.0, 1.05])
                plt.xlabel('False Positive Rate')
                plt.ylabel('True Positive Rate')
                plt.title('Receiver Operating Characteristic')
                plt.legend(loc="lower right")
                plt.show()

                recall = recall_score(y_true, y_pred)
                precision = precision_score(y_true, y_pred)
                auc_roc_macro = roc_auc_score(y_true, y_prob)
                auc_roc_micro = roc_auc_score(y_true, y_prob)
                auc_roc_weighted = roc_auc_score(y_true, y_prob)
                f1_macro = f1_score(y_true, y_pred, average='macro')
                f1_micro = f1_score(y_true, y_pred, average='micro')
                f1_weighted = f1_score(y_true, y_pred, average='weighted')

                accuracy = 0.0
                for i in range(len(y_pred)):
                    if y_pred[i] == y_true[i]:
                        accuracy += 1.0
                tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
                specificity = tn / (tn + fp)
                model_accuracy.append(accuracy/len(y_pred))
                model_recall.append(recall)
                model_precision.append(precision)
                model_specificity.append(specificity)
                model_auc_roc_macro.append(auc_roc_macro)
                model_auc_roc_micro.append(auc_roc_micro)
                model_auc_roc_weighted.append(auc_roc_weighted)
                model_f1_macro.append(f1_macro)
                model_f1_micro.append(f1_micro)
                model_f1_weighted.append(f1_weighted)

                logger.info(f"Accuracy: {accuracy/len(y_pred):.4f}")
                logger.info(f"Recall: {recall:.4f}")
                logger.info(f"Precision: {precision:.4f}")
                logger.info(f"Specificity: {specificity:.4f}")
                logger.info(f"AUC-ROC Macro: {auc_roc_macro:.4f}")
                logger.info(f"AUC-ROC Micro: {auc_roc_micro:.4f}")
                logger.info(f"AUC-ROC Weighted: {auc_roc_weighted:.4f}")
                logger.info(f"F1 Macro: {f1_macro:.4f}")
                logger.info(f"F1 Micro: {f1_micro:.4f}")
                logger.info(f"F1 Weighted: {f1_weighted:.4f}")

            logger.info(f"The mean and std of accuracy are: {np.array(model_accuracy).mean()}, and {np.array(model_accuracy).std()}")
            logger.info(f"The mean and std of recall are: {np.array(model_recall).mean()}, and {np.array(model_recall).std()}")
            logger.info(f"The mean and std of precision are: {np.array(model_precision).mean()}, and {np.array(model_precision).std()}")
            logger.info(f"The mean and std of specificity are: {np.array(model_specificity).mean()}, and {np.array(model_specificity).std()}")
            logger.info(f"The mean and std of auc_roc_macro are: {np.array(model_auc_roc_macro).mean()}, and {np.array(model_auc_roc_macro).std()}")
            logger.info(f"The mean and std of auc_roc_micro are: {np.array(model_auc_roc_micro).mean()}, and {np.array(model_auc_roc_micro).std()}")
            logger.info(f"The mean and std of auc_roc_weighted are: {np.array(model_auc_roc_weighted).mean()}, and {np.array(model_auc_roc_weighted).std()}")
            logger.info(f"The mean and std of f1_macro are: {np.array(model_f1_macro).mean()}, and {np.array(model_f1_macro).std()}")
            logger.info(f"The mean and std of f1_micro are: {np.array(model_f1_micro).mean()}, and {np.array(model_f1_micro).std()}")
            logger.info(f"The mean and std of f1_weighted are: {np.array(model_f1_weighted).mean()}, and {np.array(model_f1_weighted).std()}")

            gird_search_result.append({'model name': model_name,
                                   'hyperparams': hyper_params,
                                   'accuracy mean':np.array(model_accuracy).mean(),
                                   'accuracy std': np.array(model_accuracy).std(),
                                   'recall mean': np.array(model_recall).mean(),
                                   'recall std': np.array(model_recall).std(),
                                   'precision mean': np.array(model_precision).mean(),
                                   'precision std': np.array(model_precision).std(),
                                   "specificity mean": np.array(model_specificity).mean(),
                                   "specificity std": np.array(model_specificity).std(),
                                   'auc_roc_macro mean': np.array(model_auc_roc_macro).mean(),
                                   'auc_roc_macro std': np.array(model_auc_roc_macro).std(),
                                   'auc_roc_micro mean': np.array(model_auc_roc_micro).mean(),
                                   'auc_roc_micro std': np.array(model_auc_roc_micro).std(),
                                   'auc_roc_weighted mean': np.array(model_auc_roc_weighted).mean(),
                                   'auc_roc_weighted std': np.array(model_auc_roc_weighted).std(),
                                   'f1_macro mean': np.array(model_f1_macro).mean(),
                                   'f1_macro std': np.array(model_f1_macro).std(),
                                   'f1_micro mean': np.array(model_f1_micro).mean(),
                                   'f1_micro std': np.array(model_f1_micro).std(),
                                   'f1_weighted mean': np.array(model_f1_weighted).mean(),
                                   'f1_weighted std': np.array(model_f1_weighted).std()})

# Model print

In [None]:
logger.info(model)
model.eval()

# Print the validation performance of each hyper-parameter set

In [None]:
for result in gird_search_result:
    for key, value in result.items():
        logger.info(f"{key}: {value}")

# After deciding the best 5-fold models (based on the validation set performance and grid search), we use these 5-fold models to predict the test dataset, and calculate the mean and std of the performance metrics like accuracy, recall, precision, etc.

In [None]:
from sklearn.metrics import roc_curve, roc_auc_score, confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
model.eval()

if not baseline:

    logger.info("#####################################################################")
    logger.info(f'Using model: {main_structure}')
    logger.info("#####################################################################")

    model_accuracy = []
    model_recall = []
    model_precision = []
    model_specificity = []
    model_auc_roc_macro = []
    model_auc_roc_micro = []
    model_auc_roc_weighted = []
    model_f1_macro = []
    model_f1_micro = []
    model_f1_weighted = []

    set_num = 0 # users can choose the best hyper-parameter set based on the grid search
    dropout = 0.3 # in all-layer finetune, this can use the default value 0.3, but in linear probe, the dropout value of the best hyperparameter set should be entered by users. 0.3 or 0.0
    logger.info("#####################################################################")
    logger.info(f'Using hyper-parameter set: {set_num}, dropout: {dropout}')
    logger.info("#####################################################################")

    if main_structure.lower() == 'dino':
        model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16', pretrained=True)
        num_features = model.embed_dim
    elif main_structure.lower() == 'dinov2':
        model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14', pretrained=True)
        num_features = model.embed_dim
    elif main_structure.lower() == 'dinov3':
        model = torch.hub.load(REPO_DINOV3, 'dinov3_vitb16', source='local', weights=WEIGHTS_DINOV3)
        num_features = model.embed_dim
    elif main_structure.lower() == 'vit':
        model = torch.hub.load('pytorch/vision', 'vit_b_16', pretrained=True)
        model, num_features = replace_classifier(model)
    model = model.to(device)


    class ModifiedTestViT(nn.Module):
        def __init__(self, base_model, in_features):
            super(ModifiedTestViT, self).__init__()
            self.base_model = base_model
            self.head = nn.Sequential(nn.Dropout(dropout),
                                      nn.Linear(in_features, 2))  # modify the head from Identify to 2-class classification (Linear Layer) and add drop out (included in grid search)

        def forward(self, x):
            # the features from ViT backbone (batch size, 768 (dim of class token))
            features = self.base_model(x)
            # pass the classification head
            return self.head(features)


    model = ModifiedTestViT(model, num_features).to(device)

    for fold in range(fold_num):
        model.load_state_dict(torch.load(main_structure + "_best_model_parameter_" + str(set_num) + "_" + str(fold) + ".pth"))
        model.eval()
        y_true = []
        y_pred = []
        y_prob = []

        for data in range(len(test_dataset[fold])):
            y_true.append(test_dataset[fold][data][1])

            inputs = test_dataset[fold][data][0]
            inputs = inputs.to(device)
            input_tensor = inputs.unsqueeze(0).to(device)
            outputs = model(input_tensor)
            predicted_class = outputs.argmax(dim=1).item()
            y_pred.append(predicted_class)

            prob = torch.softmax(outputs, dim=1)[:, 1].item()
            y_prob.append(prob)

        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        y_prob = np.array(y_prob)

        fpr, tpr, thresholds = roc_curve(y_true, y_prob)
        roc_auc = auc(fpr, tpr)

        plt.figure()
        plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic')
        plt.legend(loc="lower right")
        plt.show()

        recall = recall_score(y_true, y_pred)
        precision = precision_score(y_true, y_pred)
        auc_roc_macro = roc_auc_score(y_true, y_prob)
        auc_roc_micro = roc_auc_score(y_true, y_prob)
        auc_roc_weighted = roc_auc_score(y_true, y_prob)
        f1_macro = f1_score(y_true, y_pred, average='macro')
        f1_micro = f1_score(y_true, y_pred, average='micro')
        f1_weighted = f1_score(y_true, y_pred, average='weighted')

        accuracy = 0.0
        for i in range(len(y_pred)):
            if y_pred[i] == y_true[i]:
                accuracy += 1.0
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
        specificity = tn / (tn + fp)
        model_accuracy.append(accuracy/len(y_pred))
        model_recall.append(recall)
        model_precision.append(precision)
        model_specificity.append(specificity)
        model_auc_roc_macro.append(auc_roc_macro)
        model_auc_roc_micro.append(auc_roc_micro)
        model_auc_roc_weighted.append(auc_roc_weighted)
        model_f1_macro.append(f1_macro)
        model_f1_micro.append(f1_micro)
        model_f1_weighted.append(f1_weighted)

        logger.info(f"Accuracy: {accuracy/len(y_pred):.4f}")
        logger.info(f"Recall: {recall:.4f}")
        logger.info(f"Precision: {precision:.4f}")
        logger.info(f"Specificity: {specificity:.4f}")
        logger.info(f"AUC-ROC Macro: {auc_roc_macro:.4f}")
        logger.info(f"AUC-ROC Micro: {auc_roc_micro:.4f}")
        logger.info(f"AUC-ROC Weighted: {auc_roc_weighted:.4f}")
        logger.info(f"F1 Macro: {f1_macro:.4f}")
        logger.info(f"F1 Micro: {f1_micro:.4f}")
        logger.info(f"F1 Weighted: {f1_weighted:.4f}")

    logger.info(f"The mean and std of accuracy are: {np.array(model_accuracy).mean()}, and {np.array(model_accuracy).std()}")
    logger.info(f"The mean and std of recall are: {np.array(model_recall).mean()}, and {np.array(model_recall).std()}")
    logger.info(f"The mean and std of precision are: {np.array(model_precision).mean()}, and {np.array(model_precision).std()}")
    logger.info(f"The mean and std of specificity are: {np.array(model_specificity).mean()}, and {np.array(model_specificity).std()}")
    logger.info(f"The mean and std of auc_roc_macro are: {np.array(model_auc_roc_macro).mean()}, and {np.array(model_auc_roc_macro).std()}")
    logger.info(f"The mean and std of auc_roc_micro are: {np.array(model_auc_roc_micro).mean()}, and {np.array(model_auc_roc_micro).std()}")
    logger.info(f"The mean and std of auc_roc_weighted are: {np.array(model_auc_roc_weighted).mean()}, and {np.array(model_auc_roc_weighted).std()}")
    logger.info(f"The mean and std of f1_macro are: {np.array(model_f1_macro).mean()}, and {np.array(model_f1_macro).std()}")
    logger.info(f"The mean and std of f1_micro are: {np.array(model_f1_micro).mean()}, and {np.array(model_f1_micro).std()}")
    logger.info(f"The mean and std of f1_weighted are: {np.array(model_f1_weighted).mean()}, and {np.array(model_f1_weighted).std()}")
else:
    set_num = 0 # users can choose the best hyper-parameter set based on the grid search
    dropout = 0.3 # in all-layer finetune, this can use the default value 0.3, but in linear probe, the dropout value of the best hyperparameter set should be entered by users. 0.3 or 0.0
    logger.info("#####################################################################")
    logger.info(f'Using hyper-parameter set: {set_num}, dropout: {dropout}')
    logger.info("#####################################################################")

    for model_name in baseline_model:
        logger.info("#####################################################################")
        logger.info(f'Using model: {model_name}')
        logger.info("#####################################################################")

        model_accuracy = []
        model_recall = []
        model_precision = []
        model_specificity = []
        model_auc_roc_macro = []
        model_auc_roc_micro = []
        model_auc_roc_weighted = []
        model_f1_macro = []
        model_f1_micro = []
        model_f1_weighted = []

        model = torch.hub.load('pytorch/vision', model_name, pretrained=True)

        model, num_features = replace_classifier(model)
        model = model.to(device)

        class ModifiedCnnTest(nn.Module):
            def __init__(self, base_model, in_features):
                super(ModifiedCnnTest, self).__init__()
                self.base_model = base_model
                self.head = nn.Sequential(nn.Dropout(dropout),
                                          nn.Linear(in_features, 2))  # modify the head from Identify to 2-class classification (Linear Layer) and add drop out (included in grid search)

            def forward(self, x):
                # the features from ViT backbone (batch size, 768 (dim of class token))
                features = self.base_model(x)
                # pass the classification head
                return self.head(features)

        model = ModifiedCnnTest(model, num_features).to(device)

        for fold in range(fold_num):
            model.load_state_dict(torch.load(model_name + "_best_model_parameter_" + str(set_num) + "_" + str(fold) + ".pth"))
            model.eval()
            y_true = []
            y_pred = []
            y_prob = []

            for data in range(len(test_dataset[fold])):
                y_true.append(test_dataset[fold][data][1])

                inputs = test_dataset[fold][data][0]
                inputs = inputs.to(device)
                input_tensor = inputs.unsqueeze(0).to(device)
                outputs = model(input_tensor)
                predicted_class = outputs.argmax(dim=1).item()
                y_pred.append(predicted_class)

                prob = torch.softmax(outputs, dim=1)[:, 1].item()
                y_prob.append(prob)

            y_true = np.array(y_true)
            y_pred = np.array(y_pred)
            y_prob = np.array(y_prob)

            fpr, tpr, thresholds = roc_curve(y_true, y_prob)
            roc_auc = auc(fpr, tpr)

            plt.figure()
            plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
            plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
            plt.xlim([0.0, 1.0])
            plt.ylim([0.0, 1.05])
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive Rate')
            plt.title('Receiver Operating Characteristic')
            plt.legend(loc="lower right")
            plt.show()

            recall = recall_score(y_true, y_pred)
            precision = precision_score(y_true, y_pred)
            auc_roc_macro = roc_auc_score(y_true, y_prob)
            auc_roc_micro = roc_auc_score(y_true, y_prob)
            auc_roc_weighted = roc_auc_score(y_true, y_prob)
            f1_macro = f1_score(y_true, y_pred, average='macro')
            f1_micro = f1_score(y_true, y_pred, average='micro')
            f1_weighted = f1_score(y_true, y_pred, average='weighted')

            accuracy = 0.0
            for i in range(len(y_pred)):
                if y_pred[i] == y_true[i]:
                    accuracy += 1.0
            tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
            specificity = tn / (tn + fp)
            model_accuracy.append(accuracy/len(y_pred))
            model_recall.append(recall)
            model_precision.append(precision)
            model_specificity.append(specificity)
            model_auc_roc_macro.append(auc_roc_macro)
            model_auc_roc_micro.append(auc_roc_micro)
            model_auc_roc_weighted.append(auc_roc_weighted)
            model_f1_macro.append(f1_macro)
            model_f1_micro.append(f1_micro)
            model_f1_weighted.append(f1_weighted)

            logger.info(f"Accuracy: {accuracy/len(y_pred):.4f}")
            logger.info(f"Recall: {recall:.4f}")
            logger.info(f"Precision: {precision:.4f}")
            logger.info(f"Specificity: {specificity:.4f}")
            logger.info(f"AUC-ROC Macro: {auc_roc_macro:.4f}")
            logger.info(f"AUC-ROC Micro: {auc_roc_micro:.4f}")
            logger.info(f"AUC-ROC Weighted: {auc_roc_weighted:.4f}")
            logger.info(f"F1 Macro: {f1_macro:.4f}")
            logger.info(f"F1 Micro: {f1_micro:.4f}")
            logger.info(f"F1 Weighted: {f1_weighted:.4f}")

        logger.info(f"The mean and std of accuracy are: {np.array(model_accuracy).mean()}, and {np.array(model_accuracy).std()}")
        logger.info(f"The mean and std of recall are: {np.array(model_recall).mean()}, and {np.array(model_recall).std()}")
        logger.info(f"The mean and std of precision are: {np.array(model_precision).mean()}, and {np.array(model_precision).std()}")
        logger.info(f"The mean and std of specificity are: {np.array(model_specificity).mean()}, and {np.array(model_specificity).std()}")
        logger.info(f"The mean and std of auc_roc_macro are: {np.array(model_auc_roc_macro).mean()}, and {np.array(model_auc_roc_macro).std()}")
        logger.info(f"The mean and std of auc_roc_micro are: {np.array(model_auc_roc_micro).mean()}, and {np.array(model_auc_roc_micro).std()}")
        logger.info(f"The mean and std of auc_roc_weighted are: {np.array(model_auc_roc_weighted).mean()}, and {np.array(model_auc_roc_weighted).std()}")
        logger.info(f"The mean and std of f1_macro are: {np.array(model_f1_macro).mean()}, and {np.array(model_f1_macro).std()}")
        logger.info(f"The mean and std of f1_micro are: {np.array(model_f1_micro).mean()}, and {np.array(model_f1_micro).std()}")
        logger.info(f"The mean and std of f1_weighted are: {np.array(model_f1_weighted).mean()}, and {np.array(model_f1_weighted).std()}")

# Model explanation

In [None]:
from torchcam.utils import overlay_mask
from torchvision.transforms.v2.functional import to_pil_image
import matplotlib.pyplot as plt
from pytorch_grad_cam import GradCAM, HiResCAM, GradCAMPlusPlus, XGradCAM, EigenGradCAM, GradCAMElementWise
import math
from torch import nn
import cv2
from pytorch_grad_cam.utils.image import show_cam_on_image, \
    preprocess_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# only do XAI for DINO model
logger.info("#####################################################################")
logger.info(f'Using model: {main_structure}')
logger.info("#####################################################################")

xai_method = "HiResCAM" # user input: the XAI method
set_num = 3 # users can choose the best hyper-parameter set based on the grid search
dropout = 0.3 # in all-layer finetune, this can use the default value 0.3, but in linear probe, the dropout value of the best hyperparameter set should be entered by users. 0.3 or 0.0
fold_idx = 1 # users can choose the best fold index within the best hyper-parameter set based on the performance of test set.
logger.info("#####################################################################")
logger.info(f'Using hyper-parameter set: {set_num}, fold: {fold_idx}, dropout: {dropout}, XAI method: {xai_method}')
logger.info("#####################################################################")


if main_structure.lower() == 'dino':
    model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16', pretrained=True)
    num_features = model.embed_dim
elif main_structure.lower() == 'dinov2':
    model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14', pretrained=True)
    num_features = model.embed_dim
elif main_structure.lower() == 'dinov3':
    model = torch.hub.load(REPO_DINOV3, 'dinov3_vitb16', source='local', weights=WEIGHTS_DINOV3)
    num_features = model.embed_dim
elif main_structure.lower() == 'vit':
    model = torch.hub.load('pytorch/vision', 'vit_b_16', pretrained=True)
    model, num_features = replace_classifier(model)
model = model.to(device)


class ModifiedTestViT(nn.Module):
    def __init__(self, base_model, in_features):
        super(ModifiedTestViT, self).__init__()
        self.base_model = base_model
        self.head = nn.Sequential(nn.Dropout(dropout),
                                  nn.Linear(in_features, 2))  # modify the head from Identify to 2-class classification (Linear Layer) and add drop out (included in grid search)

    def forward(self, x):
        # the features from ViT backbone (batch size, 768 (dim of class token))
        features = self.base_model(x)
        # pass the classification head
        return self.head(features)


model = ModifiedTestViT(model, num_features).to(device)
if finetune:
    logger.info("here is the XAI for all-layer finetune, without nucleus") # change
    model.load_state_dict(torch.load(main_structure + "_all_layer_finetune_without_blue_best_model_parameter_" + str(set_num) + "_" + str(fold_idx) + ".pth")) # change
    save_path = os.path.join(grad_cam_base_path, "without_nucleus_all_layer_finetune_" + main_structure + "_" + str(set_num) + "_" + str(fold_idx) + "_" + xai_method) # change
else:
    logger.info("here is the XAI for linear probe, without nucleus") # change
    model.load_state_dict(torch.load(main_structure + "_linear_probe_without_blue_best_model_parameter_" + str(set_num) + "_" + str(fold_idx) + ".pth")) # change
    save_path = os.path.join(grad_cam_base_path, "without_nucleus_linear_probe_" + main_structure + "_" + str(set_num) + "_" + str(fold_idx) + "_" + xai_method) # change
model.eval()

os.makedirs(save_path, exist_ok=True)

last_block = model.base_model.blocks[-1]
penultimate_block = model.base_model.blocks[-2]
# target_layers = [penultimate_block.norm1, penultimate_block.norm2, last_block.norm1]
target_layers = [last_block.norm1]

def reshape_transform(tensor):
    # get rid of cls token
    if (tensor.size(1) - 1) != 196 and (tensor.size(1) - 1) != 256:
        logger.info("here is using DINOv3, remove the register tokens!")
        height = int(math.sqrt(tensor.size(1) - 5)) # remove 1 [CLS] token and 4 register tokens
        width = int(math.sqrt(tensor.size(1) - 5))
        assert height * width == (tensor.size(1) - 5), "the square root of input tensor is not integer!"
        result = tensor[:, 5:, :].reshape(tensor.size(0),
                                      height, width, tensor.size(2))
    else:
        height = int(math.sqrt(tensor.size(1) - 1))
        width = int(math.sqrt(tensor.size(1) - 1))
        assert height * width == (tensor.size(1) - 1), "the square root of input tensor is not integer!"
        result = tensor[:, 1:, :].reshape(tensor.size(0),
                                      height, width, tensor.size(2))
    # (1,14,14,768) or (1,16,16,768)
    # put the channel in the second dimension
    result = result.transpose(2, 3).transpose(1, 2)
    return result

cam = HiResCAM(model=model, target_layers=target_layers, reshape_transform=reshape_transform)

for data in range(len(test_dataset[0])):
    inputs = test_dataset[0][data][0]
    inputs = inputs.to(device)
    target_class = test_dataset[0][data][1]

    input_tensor = inputs.unsqueeze(0).to(device)

    input_image = to_pil_image(transform_inverse(input_tensor.squeeze(0)))

    cam.batch_size = 32
    activation_map = cam(input_tensor=input_tensor, targets=None, eigen_smooth = False) # use sum to calculate final cam

    activation_map = activation_map[0, :]

    overlay = overlay_mask(input_image, to_pil_image(activation_map, mode='F'), alpha=0.5)

    outputs = cam.outputs
    predicted_class = outputs.argmax(dim=1).item()
    logger.info(f"the image index is: {data}, the output is: {outputs}, the predicted and ground truth class are: {predicted_class} and {target_class}")

    # plot original and traditional Grad-CAM
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.title('Original Image')
    plt.imshow(input_image)
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.title(xai_method)
    plt.imshow(overlay)
    plt.axis('off')

    image_save_path = os.path.join(save_path, str(data) + "_" + str(test_dataset[0][data][1]) + "_XAI.png")
    plt.savefig(image_save_path, bbox_inches='tight', pad_inches=0)
    plt.close()

    plt.figure(figsize=(5, 5))
    plt.title(xai_method)
    plt.imshow(overlay)
    plt.axis('off')

    image_save_path = os.path.join(save_path, str(data) + "_" + str(test_dataset[0][data][1]) + "_" + xai_method + ".png")
    plt.savefig(image_save_path, bbox_inches='tight', pad_inches=0)
    plt.close()

In [None]:
logging.shutdown()