In [None]:
from sklearn.metrics import recall_score, precision_score, roc_auc_score, f1_score, roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from PIL import Image
import os
import random
import torch
import torchvision
import torchvision.transforms.v2 as v2
from torch.utils.data import DataLoader
from torchvision import utils
from torch.utils.data import random_split
import pytorch_grad_cam
import torch.hub as hub
from torchvision.transforms.v2 import functional as F
from sklearn.model_selection import StratifiedKFold

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

seed = 42 # set by user
np.random.seed(seed)
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

channels = 3
batch_size = 32 # set by user
baseline = False # whether to use baseline model: VGG-16, DenseNet-121,ResNet-50, efficientnet_v2_small, convnext_base
baseline_model = ['densenet121', 'efficientnet_v2_s', 'resnet50', 'vgg16', 'convnext_base']
early_stop_mode = 'accuracy'  # choose 'loss' mode, 'accuracy' mode, 'loss or accuracy' mode or 'loss and accuracy' modes
# Number of variations to generate per image
num_variations_per_image_0 = 4
num_variations_per_image_1 = 4
test_percent = 0.15  # choose the proportion size of test set
validation_percent = 0.1  # choose the proportion size of validation set if close the cross validation
cross_validation = True  # choose open or close Cross-Validation
fold_num = 5  # choose the number of fold if open the cross-validation
run_name = 'cell_classification_with_nucleus' # log name
if not cross_validation:
    fold_num = 1
finetune = True # choose to whether fine-tuning the features from the backbone
main_structure = 'dino' # choose 'dino', 'vit', 'dinov2', 'dinov3'

grad_cam_base_path = "D:\cell_image_XAI\\cell_40x"

dataset_autoseg_path = "D:\\cell_40x"
train_autoseg_path = "D:\\cell_autoseg_train\split"
train_autoseg_cancer_path = 'D:\\cell_autoseg_train\split\\cancer\\'
train_autoseg_normal_path = 'D:\\cell_autoseg_train\split\\normal\\'
train_autoseg_cv_path = 'D:\\cell_autoseg_train\\cross validation\\'
test_autoseg_path = 'D:\\cell_autoseg_test\split'
test_whole_autoseg_path = 'D:\\cell_autoseg_test\whole\\'
test_autoseg_cancer_path = "D:\\cell_autoseg_test\split\\cancer\\"
test_autoseg_normal_path = 'D:\\cell_autoseg_test\split\\normal\\'
validation_autoseg_path = 'D:\\cell_autoseg_validation\split'
validation_whole_autoseg_path = 'D:\\cell_autoseg_validation\whole\\'
validation_autoseg_cancer_path = 'D:\\cell_autoseg_validation\split\\cancer\\'
validation_autoseg_normal_path = 'D:\\cell_autoseg_validation\split\\normal\\'
validation_autoseg_cv_path = 'D:\\cell_autoseg_validation\\cross validation\\'
REPO_DINOV3= "D:\cell_classification_pythonprojects\dinov3\dinov3"
WEIGHTS_DINOV3 = "D:\cell_classification_pythonprojects\My_model\dinov3_vitb16_pretrain_lvd1689m-73cec8be.pth"

In [None]:
import logging
from pathlib import Path
from datetime import datetime

SCRIPT_DIR = Path(__file__).resolve().parent if "__file__" in globals() else Path.cwd()
log_file = SCRIPT_DIR / f"{run_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
logger = logging.getLogger(run_name)

logger.setLevel(logging.INFO)
logger.propagate = False
for h in logger.handlers[:]:
    try:
        h.flush()
    except Exception:
        pass
    h.close()
    logger.removeHandler(h)

fmt = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s")

sh = logging.StreamHandler()
sh.setFormatter(fmt)
logger.addHandler(sh)

fh = logging.FileHandler(log_file, mode='a', encoding="utf-8")
fh.setFormatter(fmt)
logger.addHandler(fh)

In [None]:
from training_toolbox import ResizeWithPadding

transform = v2.Compose([ResizeWithPadding((224, 224)), v2.ToTensor()])
# resize and transfer to tensor
dataset = torchvision.datasets.ImageFolder(dataset_autoseg_path, transform=transform)  # read data
# Assuming images are organized in subdirectories where each subdirectory name is the class label
# 0 is cancer, 1 is normal

transform_augmented = v2.Compose([v2.RandomHorizontalFlip(),
                                  v2.RandomVerticalFlip(),
                                  v2.RandomRotation(degrees=40),
                                  v2.RandomAffine(degrees=40, translate=(0.1, 0.1), shear=(-8, 8, -8, 8),
                                                  scale=(0.9, 1.1)),
                                  ])  # Image Augmented Transformation

In [None]:
from sklearn.model_selection import train_test_split

image_paths = []
image_labels = []
for i in dataset.samples:
    image_paths.append(i[0])
    image_labels.append(i[1])

# split dataset into test set and (train set + validation set)
train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
    image_paths, image_labels, test_size=test_percent, stratify=image_labels, random_state=seed)

logger.info(
    f'we have {train_val_labels.count(0)} cancer cells and {train_val_labels.count(1)} normal cells, {len(train_val_labels)} cells in total, for training and validating')
logger.info(
    f'we have {test_labels.count(0)} cancer cells and {test_labels.count(1)} normal cells, {len(test_paths)} cells in total, for testing')

# save the test image to target folders
for i in range(len(test_paths)):
    for j in range(len(dataset.samples)):
        if dataset.samples[j][0] == test_paths[i] and dataset.samples[j][1] == 0:
            utils.save_image(dataset[j][0],
                             test_autoseg_cancer_path + test_paths[i].split('\\')[-1].split('.')[0] + "_test.png")
            break
        elif dataset.samples[j][0] == test_paths[i] and dataset.samples[j][1] == 1:
            utils.save_image(dataset[j][0],
                             test_autoseg_normal_path + test_paths[i].split('\\')[-1].split(".")[0] + "_test.png")
            break

In [None]:
from training_toolbox import compute_mean_std

# train and validation dataset for calculating the image mean and std for normalization
train_val_dataset = []
for i in range(len(train_val_paths)):
    for j in range(len(dataset.samples)):
        if dataset.samples[j][0] == train_val_paths[i]:
            train_val_dataset.append(dataset[j][0])
            break
if finetune:
    images_mean, images_std = compute_mean_std(train_val_dataset, channels)
else:
    images_mean = torch.tensor([0.485, 0.456, 0.406])
    images_std = torch.tensor([0.229, 0.224, 0.225])
logger.info(f"Mean: {images_mean}")
logger.info(f"Std: {images_std}")

# inverse_transform, to restore the images when saving them to whole folder and displaying them in XAI
transform_inverse = v2.Compose([v2.Normalize(
    mean=[-images_mean[0] / images_std[0], -images_mean[1] / images_std[1], -images_mean[2] / images_std[2]],
    std=[1 / images_std[0], 1 / images_std[1], 1 / images_std[2]])])  # when mean = images_mean and std = images_std