In [None]:
from glob import glob
from transformers import AutoModelForImageClassification
import torch
import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
import random
import json

from Code_py import CustomMatrixDataset,CustomMatrixDataset_augmentation, BalancedBatchSampler
from owkin_code import TorchTrainer, Chowder, auc2, slide_level_train_step, slide_level_val_step, get_cv_metrics

### Import images/patches

In [None]:
res="5X"
path_to_images=f'E:\\01_DVLPMT_DATASET\\02_TILES_224\\**\\{res}\\*HES'
patch_list=glob(path_to_images+'\\*')

len(patch_list)

In [None]:
# This function's aim is to extract the patient identifier (or slide identifier) from the path that leads to the image
# It works by cutting the path, which is a chain of characters, at relevant symbols
def get_patient_id_from_path(path):
    patient=path.split('\\')[-1].split('[')[0][:-5]
    return(str(patient))

In [None]:
#This function should load your image. It is paired with a function that is supposed to show this image, if it isn't working
#there is probably an error with the 'load_image' function

def load_image(path):
    image=cv2.imread(path)
    return(image)

def show_image(image):
    fig,ax=plt.subplots()
    if image.shape[0] == 3:
        image=image.permute({1,2,0})
    ax.imshow(image)
    return(fig)
#show_image(load(image('path_to_one_of_your_image.tiff')))

In [None]:
path=patch_list[20]
print(get_patient_id_from_path(path))
img=load_image(path)
fig=show_image(img)

In [None]:
#We'll then create a dictionnary that will store every tile available for each patient
whole_patient_dict={}
for patch in patch_list: #Look through all the paths
    patient=get_patient_id_from_path(patch)
    if patient not in whole_patient_dict.keys(): #If it is the first time the patient pops up, create a new key that will be its list
        whole_patient_dict[patient]=[]
    whole_patient_dict[patient].append(patch)

In [None]:
len(whole_patient_dict)

#### Encoding those patches

In [None]:
## Load model pretrained by Owkin
encoder = AutoModelForImageClassification.from_pretrained(
    "owkin/phikon",
    ignore_mismatched_sizes=False,
)
## Remove last layer, which is the classification layer for ImageNet (1000 features), we thus obtain 768 features
encoder.classifier=torch.nn.Sequential()

#Fix the GPU to work one, if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#Set the model in evaluation mode, otherwise it will be trained and encoding will vary every time they are encoded
encoder=encoder.eval()

In [None]:
# Every patch must be normalized. It is first reduced in the range [0,1], by dividing by 255 (highest possible value)
# Then, as it is common in litterature, we normalize with the ImageNet values

def normalize_image(tensor,mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]):
    if tensor.shape[0] != 3:
        tensor=tensor.permute(2,0,1)
    norm_tens=torch.div(tensor, 255.0) #norm between 0 and 1
    for k in range(3): #ImageNet norm
        norm_tens[k,:,:]=torch.sub(norm_tens[k,:,:],mean[k])
        norm_tens[k,:,:]=torch.div(norm_tens[k,:,:],std[k])
    return(norm_tens)

In [None]:
# To avoid loading every patch in memory, we'll iterate through the paths and encode patchs bit by bit
# If you want to save the encodings to avoid re-encoding (and if you have room), indicate the save_path in the following function
import os

def encode_patch(path, encoder, save_path=None):
    img=load_image(path)
    encoded_patchs=[]
    encoder=encoder.to(device)
    tensor=normalize_image(torch.from_numpy(img).to(device))
    tensor=tensor.unsqueeze(0) # Add a dimension to mimic a batch (but contains only one image)
    enc=encoder(tensor)[0][0].clone().detach()
        
    if save_path != None:
        os.makedirs(save_path, exist_ok=True)
        patch_name=path.split('\\')[-1].split('.')[0]
        torch.save(enc,save_path+'\\'+patch_name+'_enc.pt')
    return(enc)

In [None]:
label_file="E:\\01_DVLPMT_DATASET\\labels_402.csv"
labels=pd.read_csv(label_file, delimiter=";")
labels['patient_id'] = labels['patient_id'].astype(str)

nb_classes=len(set(list(labels['label'])))
print('Classification for',nb_classes,'classes')

In [None]:
## To add ID and FOLD column
labels=labels.sort_values('label').reset_index().drop(['index'],axis=1)
nb_0, nb_1=labels.groupby('label').count().values
list_0, list_1= [], []
for i in range(int(nb_0)):
    list_0.append(i%5)
for j in range(int(nb_1)):
    list_1.append(j%5)
new_col={'fold':list_0+list_1, 'ID':np.arange(len(labels))}
new_labels=pd.concat([labels,pd.DataFrame.from_dict(new_col)],axis=1)
new_labels

In [None]:
len_list=[]
for patient in whole_patient_dict.keys():
    len_list.append(len(whole_patient_dict[patient]))
mini=min(len_list)

print('Your minimum number of patches is',mini)

In [None]:
nb_patch=mini-1
patient_dict={}
removed=0
for patient in whole_patient_dict.keys():
    if len(whole_patient_dict[patient])>nb_patch:
        patient_dict[patient]=whole_patient_dict[patient]
    else:
        removed+=1
print('You removed', removed, 'patients')

In [None]:
#Creation of a new dictionnary, for each patient, instead of being the list of paths, it will be the list of encoded tiles
def create_patient_encoded_dict(patient_dict, encoder, patch_tiling_infos=None, save_path=None, is_saved=None):
    if is_saved !=None:
        print('Loading patient dict')
        patient_encoded=torch.load(is_saved)
    else:
        patient_encoded={}
        progress_bar = tqdm.tqdm(total=np.sum([len(patient_dict[patient]) for patient in list(patient_dict.keys())]), position=0,
                                 desc="Processing")
        for i,patient in enumerate(list(patient_dict.keys())):
            progress_bar.set_description("Processing patient {}/{}".format(i+1,len(patient_dict)))
            tile_list=patient_dict[patient]
            patient_encoded[patient]=[]
            
            for tile_path in tile_list:
                patient_encoded[patient].append(encode_patch(tile_path,encoder,save_path=save_path))
                                
            progress_bar.update(1)
        
    return(patient_encoded)

patient_encoded=create_patient_encoded_dict(patient_dict, encoder, 
                patch_tiling_infos=None, save_path="D:\\02_TILES\\004_encoded\\5X", is_saved=None)

patient_list=list(patient_encoded.keys())
random.shuffle(patient_list) # C'est ici qu'il faut juste appeler la fonction shuffle, sans l'attribuer à une variable avec un =
patient_encoded_shuffled={}
for patient in patient_list:
    patient_encoded_shuffled[patient]=patient_encoded[patient]
del patient_encoded

In [None]:
#torch.save(patient_encoded_shuffled, f"E:\\02_TILES\\004_encoded\\{res}_patient_dict_shuffle.pt")

In [None]:
# For a loading of encoded dict
patient_encoded=create_patient_encoded_dict(patient_dict, encoder, save_path=None, is_saved=f"E:\\01_DVLPMT_DATASET\\03_TILES_ENC_PHIKON\\{res}_patient_dict_shuffle_402.pt")

In [None]:
from torch.utils.data import Dataset

class CustomMatrixDataset(Dataset):
    def __init__(self, patient_dict, labels, nb_patch, fixed=False):
        self.mat_dict = patient_dict
        self.nb_patch=nb_patch
        self.labels=labels
        self.fixed=fixed
        if self.fixed:
            self.patch_per_patient={}
    
    def random_patch_selection(self, patch_list, slide):
        if len(patch_list) >= self.nb_patch:
            patch_sel=random.sample(patch_list, self.nb_patch)
        else:
            patch_sel=patch_list
            n=self.nb_patch-len(patch_list)
            while n>len(patch_list):
                patch_sel.extend(patch_list)
                n-=len(patch_list)
            last_patch_sel=random.sample(patch_list, n)
            patch_sel.extend(last_patch_sel)
        if self.fixed:
            self.patch_per_patient[slide]=patch_sel
        return(patch_sel)
    

    def __len__(self):
        return len(self.mat_dict)
    
    def __getitem__(self, idx):
        slide=self.labels.loc[self.labels['ID']==idx]['patient_id'].iloc[0]
        patch_list = self.mat_dict[slide]
        label = torch.tensor(self.labels.loc[self.labels['patient_id'] == slide]['label'].iloc[0]).to(torch.int64)
        
        if self.fixed:
            if slide in self.patch_per_patient.keys():
                patch_sel=patch_per_patient[slide]
            else:
                patch_sel=self.random_patch_selection(patch_list, slide)
        else:
            patch_sel=self.random_patch_selection(patch_list, slide)
        
        return torch.stack(patch_sel), label            

In [None]:
whole_ds=CustomMatrixDataset(patient_encoded, new_labels, nb_patch=mini, fixed=False)

In [None]:
# We define the loss function, optimizer and metrics for the training
#if nb_classes==2:
#    criterion = torch.nn.BCEWithLogitsLoss()  # Binary Cross-Entropy Loss
#else:
criterion = torch.nn.CrossEntropyLoss()  #Cross-Entropy Loss
    
optimizer = torch.optim.Adam              # Adam optimizer
metrics = {"auc": auc2}                    # AUC will be the tracking metric

In [None]:
import warnings
from copy import deepcopy
from datetime import datetime

from IPython.display import clear_output

from sklearn.model_selection import StratifiedKFold
#from trainer import TorchTrainer, slide_level_train_step, slide_level_val_step

# We run a 5-fold cross-validation with 1 repeat (you can tweak these parameters)
n_outer_cv = 5
n_inner_cv = 4
train_metrics, val_metrics = [], []
train_losses, val_losses = [], []
test_logits = []

cv_start_time = datetime.now()

chowder = Chowder(
    in_features=768,                     # output dimension of Phikon
    out_features=nb_classes,                      # dimension of predictions (a probability for class "1")
    n_top=5,                             # number of top scores in Chowder 
    n_bottom=5,                          # number of bottom scores in Chowder
    mlp_hidden=[200, 100],               # MLP hidden layers after the max-min layer
    mlp_activation=torch.nn.Sigmoid(),   # MLP activation
    bias=True                            # bias for first 1D convolution which computes scores
)

best_model_inner_cv={}

for test_fold in range(n_outer_cv):
    best_model_inner_cv[test_fold]={}
    print(f"Running cross-validation #{test_fold+1}")
    # We stratify with respect to the training labels
    test_idx=list(new_labels[new_labels['fold']==test_fold]['ID'])
    test_set = torch.utils.data.Subset(whole_ds, test_idx)

    test_lab=[]
    for i in range(len(test_set)):
        test_lab.append(test_set[i][1])
    
    cv_splits=[]
    for j in range(n_inner_cv):
        val_fold=j if j<test_fold else j+1
        val_idx=list(new_labels[new_labels['fold']==val_fold]['ID'])
        train_idx=list(new_labels[(new_labels['fold']!=test_fold) & (new_labels['fold']!=val_fold)]['ID'])
        cv_splits.append((train_idx,val_idx))

    for i, (train_indices, val_indices) in enumerate(cv_splits):
        fold_start_time = datetime.now()
        trainer_2 = TorchTrainer(
            model=deepcopy(chowder),
            criterion=criterion,
            metrics=metrics,
            batch_size=64,                           # you can tweak this
            num_epochs=100,                           # you can tweak this
            learning_rate=1e-3,                      # you can tweak this
            weight_decay=0.0,                        # you can tweak this
            device=device,
            balanced=True,                           #you can tweak this
            num_workers=0, 
            optimizer=deepcopy(optimizer),
            train_step=slide_level_train_step,
            val_step=slide_level_val_step,
            nb_classes=nb_classes
        )

        print(f"Running cross-validation on split #{i+1}")
        train_train_dataset = torch.utils.data.Subset(
            whole_ds, indices=train_indices
        )
        train_val_dataset = torch.utils.data.Subset(
            whole_ds, indices=val_indices
        )

        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=UserWarning)
            # Training step for the given number of epochs
            local_train_metrics, local_val_metrics, local_train_losses, local_val_losses = trainer_2.train(
                train_train_dataset, train_val_dataset
            )
            # Predictions on test (logits, sigmoid(logits) = probability)
            local_test_logits = trainer_2.predict(test_set)[1]

        train_metrics.append(local_train_metrics)
        val_metrics.append(local_val_metrics)
        train_losses.append(local_train_losses)
        val_losses.append(local_val_losses)
        test_logits.append(local_test_logits)
        fold_end_time = datetime.now()
        fold_running_time = fold_end_time - fold_start_time
        best_model_inner_cv[test_fold][i]=trainer_2.best_sd
        print("\n-----------------------------Finished in {}---------------------------------------\n".format(fold_running_time))
    clear_output()
cv_end_time = datetime.now()

cv_running_time = cv_end_time - cv_start_time
print("\nFinished cross-validation in {}".format(cv_running_time))

In [None]:
#save your model
#torch.save(trainer_2.model, 'E:\\04_RETURN_TO_TILES\\MODEL_TEST')

In [None]:
#plot your losses
figure,axs=plt.subplots(ncols=n_outer_cv,nrows=n_inner_cv, figsize=(10*n_outer_cv,5*n_inner_cv))
X=np.arange(trainer_2.num_epochs)
for j in range(n_outer_cv):
    for i in range(n_inner_cv):
        if n_outer_cv==1:
            ax=axs[i]
        else:
            ax=axs[i,j]
        
        Y_train=train_losses[j*n_inner_cv+i]
        Y_val=val_losses[j*n_inner_cv+i]
        ax.plot(X,Y_train, label='train_loss')
        ax.plot(X,Y_val, label='val_loss')
        if i==0:
            ax.title.set_text('Repeat '+str(j+1))
        if j==0:
            ax.set_ylabel('Fold '+ str(i+1), rotation=0, labelpad=20)
        ax.legend(loc='lower left')
        fig.tight_layout()

In [None]:
#evaluate your model
from sklearn.metrics import f1_score, confusion_matrix, ConfusionMatrixDisplay, cohen_kappa_score, roc_auc_score
from torch.utils.data import DataLoader

# Create subplots for confusion matrices
fig, axs_cm = plt.subplots(1, n_outer_cv, figsize=(15, 10))

# To store true labels and predicted probabilities for ROC curve and Se Sp
roc_data = {'true_labels': [], 'probas': []}
sensitivity_scores = []
specificity_scores = []
auc_scores = []

best_model_outer_cv={}
for test_fold in tqdm.tqdm(range(n_outer_cv)):
    test_idx=list(new_labels[new_labels['fold']==test_fold]['ID'])
    test_set = torch.utils.data.Subset(whole_ds, test_idx)
    test_dl = DataLoader(test_set)
    
    best_f1=0
    
    for val_fold in range(n_inner_cv):
        model_test=deepcopy(trainer_2.model)
        model_test.load_state_dict(best_model_inner_cv[test_fold][val_fold])
        true_lab=[]
        logits=[]
        with torch.no_grad():
            model_test=model_test.eval()
            model_test=model_test.to(device)
            for batch in test_dl:
                matrix, lab = batch
                matrix = matrix.to(device)
                true_lab.append(lab)
                logits.append(model_test(matrix).squeeze(1))

            probas=torch.nn.functional.softmax(torch.stack(logits).squeeze(1), dim=1).to('cpu')
            preds=torch.argmax(probas,1)

            test_f1=f1_score(np.array(preds), np.array(torch.stack(true_lab)), average='macro')
            if test_f1 > best_f1:
                best_f1=test_f1
                best_model_outer_cv[test_fold]=(best_model_inner_cv[test_fold][val_fold], preds)

    # Concatenate all true labels and predictions
    true_lab = torch.cat(true_lab).numpy()  # Flatten to a 1D array
    preds = best_model_outer_cv[test_fold][1].numpy()

    if len(true_lab) == len(preds):  # Ensure lengths match before computing the confusion matrix
        # Compute confusion matrix
        test_cm = confusion_matrix(true_lab, preds)
        
        # Extract TP, TN, FP, FN from confusion matrix
        TN, FP, FN, TP = test_cm.ravel()

        # Compute Sensitivity (Recall) and Specificity
        sensitivity = TP / (TP + FN)  # True Positive Rate (Recall)
        specificity = TN / (TN + FP)  # True Negative Rate (Specificity)

        # Store Sensitivity and Specificity for this fold
        sensitivity_scores.append(sensitivity)
        specificity_scores.append(specificity)

        # Plot confusion matrix on the first row
        disp = ConfusionMatrixDisplay(test_cm)
        disp.plot(ax=axs_cm[test_fold], colorbar=False)
        axs_cm[test_fold].set_title(f'F1: {best_f1:.3f}')

        # Store data for ROC curve plotting later
        roc_data['true_labels'].append(true_lab)
        roc_data['probas'].append(probas.numpy())  # Store probabilities for AUC calculation

    else:
        print(f"Warning: True labels and predictions length mismatch for fold {test_fold}")

# Display the confusion matrices plot
plt.tight_layout()
plt.show()

#fig.savefig("confusion_matrices.png", dpi=300)  # Save the confusion matrix with 300 DPI to the working directory, default is desktop

In [None]:
# Print the results for Sensitivity and Specificity
print("\nSensitivity and Specificity for each fold:")
for i in range(n_outer_cv):
    print(f"Fold {i}: Sensitivity = {sensitivity_scores[i]:.3f}, Specificity = {specificity_scores[i]:.3f}")

# Calculate and print the mean and standard deviation of Sensitivity and Specificity
mean_sensitivity = np.mean(sensitivity_scores)
std_sensitivity = np.std(sensitivity_scores)
mean_specificity = np.mean(specificity_scores)
std_specificity = np.std(specificity_scores)

print(f"\nSummary on internal dataset:")
print(f"Mean Sensitivity: {mean_sensitivity:.3f} ± {std_sensitivity:.3f}")
print(f"Mean Specificity: {mean_specificity:.3f} ± {std_specificity:.3f}")

#calculate meand+-sd F1 Score
f1_scores = [0.831, 0.899, 0.897, 0.883, 0.909]
mean_f1 = np.mean(f1_scores)
std_f1 = np.std(f1_scores)

print(f"Mean F1 Score: {mean_f1:.3f} ± {std_f1:.3f}")

In [None]:
from sklearn.metrics import roc_auc_score, auc, roc_curve, RocCurveDisplay
from sklearn.model_selection import StratifiedKFold

# List to store AUC for each fold
auc_scores = []
tprs = []
mean_fpr = np.linspace(0, 1, 100)

# Create a figure for all ROC curves on the same plot
fig, ax = plt.subplots(figsize=(8, 6))

# Loop over stored ROC data and plot all ROC curves on the same plot
for test_fold in range(n_outer_cv):
    true_labels = roc_data['true_labels'][test_fold]
    probas = roc_data['probas'][test_fold]
    
    # Calculate ROC curve for this fold
    fpr, tpr, _ = roc_curve(true_labels, probas[:, 1])  # Get ROC curve values for the positive class (class 1)
    
    # Interpolate the TPR (True Positive Rate) to a common set of FPR values (mean_fpr)
    interp_tpr = np.interp(mean_fpr, fpr, tpr)
    interp_tpr[0] = 0.0  # Ensure the curve starts at (0,0)
    tprs.append(interp_tpr)
    
    # Calculate AUC for this fold and store it
    auc_score = roc_auc_score(true_labels, probas[:, 1])  # AUC for binary classification
    auc_scores.append(auc_score)
    
    # Use RocCurveDisplay to plot ROC curve for this fold
    RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=auc_score).plot(ax=ax, alpha=0.3, lw=1, label=f'Fold {test_fold} (AUC: {auc_score:.3f})')

# Plot the diagonal (chance level)
ax.plot([0, 1], [0, 1], linestyle="--", lw=2, color="grey", label="Chance", alpha=0.8)

# Calculate the mean ROC curve
mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0  # Ensure the curve ends at (1, 1)

# Calculate the mean AUC
mean_auc = auc(mean_fpr, mean_tpr)

# Calculate the standard deviation of the AUC
std_auc = np.std(auc_scores)

# Plot the mean ROC curve
ax.plot(mean_fpr, mean_tpr, color="b", label=f'Mean ROC (AUC = {mean_auc:.2f} ± {std_auc:.2f})', lw=2, alpha=0.8)

# Calculate the standard deviation of the TPRs (True Positive Rates)
std_tpr = np.std(tprs, axis=0)

# Create shaded area for the std deviation
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
ax.fill_between(mean_fpr, tprs_lower, tprs_upper, color="grey", alpha=0.2, label=r"$\pm$ 1 std. dev.")

# Customize the plot
ax.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], title="Receiver Operating Characteristic (ROC) Curve")
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.legend(loc="lower right")
ax.grid(False)

# Display the plot
plt.tight_layout()
plt.show()

# Print the mean and standard deviation of the AUC scores
print(f"Mean AUC: {mean_auc:.3f} ± {std_auc:.3f}")

#fig.savefig("roc_curves_all.png", dpi=300)  # Save the confusion matrix with 300 DPI to the working directory, default is desktop

In [None]:
#load the external validation dataset
external_encoded=torch.load("E:\\03_PROSPECTIV2025_DATASET\\5X_patient_dict_shuffle_prospectiv2025.pt")
external_label=pd.read_csv("E:\\03_PROSPECTIV2025_DATASET\\labels_prospectiv2025.csv", delimiter=';')

In [None]:
not_found=external_label[~external_label['patient_id'].isin(list(external_encoded.keys()))]
external_label=external_label[external_label['patient_id'].isin(list(external_encoded.keys()))]

external_label=external_label.sort_values('label').reset_index().drop(['index'],axis=1)
nb_0, nb_1=external_label.groupby('label').count().values
list_0, list_1= [], []
for i in range(int(nb_0)):
    list_0.append(i%5)
for j in range(int(nb_1)):
    list_1.append(j%5)
new_col={'fold':list_0+list_1, 'ID':np.arange(len(external_label))}
new_external_label=pd.concat([external_label,pd.DataFrame.from_dict(new_col)],axis=1)
new_external_label

In [None]:
class CustomMatrixDataset(Dataset):
    def __init__(self, patient_dict, labels, nb_patch, fixed=False, all_patchs=True):
        self.mat_dict = patient_dict
        self.nb_patch=nb_patch
        self.labels=labels
        self.fixed=fixed
        self.all_patchs=all_patchs
        if self.fixed:
            self.patch_per_patient={}
    
    def random_patch_selection(self, patch_list, slide):
        if len(patch_list) >= seff.nb_patch:
            patch_sel=random.sample(patch_list, self.nb_patch)
        else:
            patch_sel=patch_list
            n=self.nb_patch-len(patch_list)
            while n>len(patch_list):
                patch_sel.extend(patch_list)
                n-=len(patch_list)
            last_patch_sel=random.sample(patch_list, n)
            patch_sel.extend(last_patch_sel)
        if fixed:
            self.patch_per_patient[slide]=patch_sel
        return(patch_sel)
    

    def __len__(self):
        return len(self.mat_dict)
    
    def __getitem__(self, idx):
        slide=self.labels.loc[self.labels['ID']==idx]['patient_id'].iloc[0]
        patch_list = self.mat_dict[slide]
        label = torch.tensor(self.labels.loc[self.labels['patient_id'] == slide]['label'].iloc[0]).to(torch.int64)
        
        if self.all_patchs:
            return torch.stack(patch_list), label
        
        if self.fixed:
            if slide in self.patch_per_patient.keys():
                patch_sel=patch_per_patient[slide]
            else:
                patch_sel=random_patch_selection(patch_list, slide)
        else:
            patch_sel=random_patch_selection(patch_list, slide)
        
        return torch.stack(patch_sel), label

In [None]:
from collections import Counter

fig, axs = plt.subplots(1, n_outer_cv, figsize=(15, 10))

# Prepare external dataset
external_ds = CustomMatrixDataset(external_encoded, new_external_label, nb_patch=20, fixed=False, all_patchs=True)
external_dl = DataLoader(external_ds)

all_preds = {}
all_probas = {}
true_labels = [] 
auc_scores = {}  

for test_fold in best_model_outer_cv.keys():
    # Load the best model for this fold
    model_best = deepcopy(trainer_2.model)
    model_best.load_state_dict(best_model_outer_cv[test_fold][0])
    model_best = model_best.eval().to(device)

    logits, probas_list, true_lab = [], [], []
    
    with torch.no_grad():
        for batch in external_dl:
            matrix, lab = batch
            matrix = matrix.to(device)
            logits.append(model_best(matrix).squeeze(1))  # Store model outputs
            true_lab.append(lab)  # Store true labels

    # Convert lists to tensors and move to CPU
    probas = torch.nn.functional.softmax(torch.cat(logits), dim=1).cpu().numpy()  # Get softmax probabilities
    preds = np.argmax(probas, axis=1)  # Get predicted labels
    true_lab = torch.cat(true_lab).cpu().numpy()  # Flatten true labels

    # Store predictions and probabilities
    all_preds[test_fold] = preds
    all_probas[test_fold] = probas  # Store softmax scores

    # Compute F1 score per fold
    ext_f1 = f1_score(true_lab, preds, average='macro')
    
    # Compute AUC per fold
    auc_scores[test_fold] = roc_auc_score(true_lab, probas[:, 1])  # AUC for class 1

    # Compute and plot confusion matrix for each fold model
    ext_cm = confusion_matrix(true_lab, preds)
    disp_ext = ConfusionMatrixDisplay(ext_cm)
    disp_ext.plot(ax=axs[test_fold], colorbar=False)
    axs[test_fold].set_title(f'Fold {test_fold} F1 = {ext_f1:.3f}')

    # Store true labels
    if len(true_labels) == 0:
        true_labels = true_lab

# Majority voting across folds for final prediction
nb_pat = len(next(iter(all_preds.values())))  # Number of samples
pred_voting = []

for pat in range(nb_pat):
    pat_votes = [all_preds[fold][pat] for fold in all_preds.keys()]
    majority_vote = Counter(pat_votes).most_common(1)[0][0]  # Get majority class
    pred_voting.append(majority_vote)

# Compute final voting F1 score
voting_f1 = f1_score(true_labels, np.array(pred_voting), average='macro')
print('Voting F1 =', voting_f1)

# Compute final AUC for voting model (average softmax probabilities)
avg_probas = np.mean(np.array([all_probas[fold] for fold in all_probas]), axis=0)  # Averaging across folds
voting_auc = roc_auc_score(true_labels, avg_probas[:, 1])  # AUC for class 1

print('Voting AUC =', voting_auc)      

In [None]:
# Compute confusion matrix for the final voting predictions
voting_cm = confusion_matrix(true_labels, np.array(pred_voting))

# Plot confusion matrix
fig, ax = plt.subplots(figsize=(8, 6)) 
disp = ConfusionMatrixDisplay(confusion_matrix=voting_cm, display_labels=[0, 1])
disp.plot(ax=ax, cmap='Blues', colorbar=True)

# Extract TP, TN, FP, FN
TN, FP, FN, TP = voting_cm.ravel()

# Compute Sensitivity (Se) and Specificity (Sp)
Se = TP / (TP + FN)  # Sensitivity / Recall
Sp = TN / (TN + FP)  # Specificity

# Set title and labels
ax.set_title(f'Voting Confusion Matrix (F1 = {voting_f1:.3f})')
plt.show()

print(f'Sensitivity (Se): {Se:.3f}')
print(f'Specificity (Sp): {Sp:.3f}')

#fig.savefig("voting_confusion_matrices.png", dpi=300)  # Save the confusion matrix with 300 DPI to the working directory, default is desktop

In [None]:
plt.figure(figsize=(10, 8))

# Define a set of distinct but subdued colors
fold_colors = ['lightblue', 'lightcoral', 'lightgreen', 'lightskyblue', 'lightsalmon']

# Plot ROC curve for each fold model with distinct subdued colors
for i, test_fold in enumerate(best_model_outer_cv.keys()):
    fpr, tpr, _ = roc_curve(true_labels, all_probas[test_fold][:, 1])  # FPR, TPR for fold model
    roc_auc = auc(fpr, tpr)  # Compute AUC
    plt.plot(fpr, tpr, color=fold_colors[i], alpha=0.8, label=f'Fold {test_fold} (AUC = {roc_auc:.3f})')

# Compute ROC curve for the final voting model
avg_probas = np.mean(np.array([all_probas[fold] for fold in all_probas]), axis=0)  # Average softmax scores
fpr_voting, tpr_voting, _ = roc_curve(true_labels, avg_probas[:, 1])  # FPR, TPR for voting model
roc_auc_voting = auc(fpr_voting, tpr_voting)

# Plot ROC curve for the voting model (bold black)
plt.plot(fpr_voting, tpr_voting, color='black', linewidth=3, label=f'Voting Model (AUC = {roc_auc_voting:.3f})')

# Plot diagonal reference line (random classifier)
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')

# Customize the plot
plt.xlabel('False Positive Rate (FPR)')
plt.ylabel('True Positive Rate (TPR)')
plt.title('ROC Curves for Each Fold and Voting Model')
plt.legend(loc='lower right')

# Remove grid background
plt.grid(False)

#plt.savefig("roc_curve.png", dpi=300, bbox_inches='tight') 

# Show the plot
plt.show()

In [None]:
# Retrieve patient IDs and ground truth
patient_ids = new_external_label['patient_id'].values  
final_true = np.array(true_labels)

# Collect fold predictions into a DataFrame
fold_preds = {
    f"fold_{fold}_pred": all_preds[fold] for fold in all_preds.keys()
}
fold_preds_df = pd.DataFrame(fold_preds)

# Build final results DataFrame
results_df = pd.DataFrame({
    "patient_id": patient_ids,
    "ground_truth": final_true
})

# Concatenate fold predictions
results_df = pd.concat([results_df, fold_preds_df], axis=1)

# Add majority voting predictions
results_df["voting_pred"] = np.array(pred_voting)

# Add averaged probability for class 1
results_df["avg_proba_class1"] = avg_probas[:, 1]

# Show full table
pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)
print(results_df)

# Optionally save to CSV for external inspection
#results_df.to_csv("external_validation_predictions_with_folds.csv", index=False)

In [None]:
plt.figure(figsize=(8,6))
results_df.boxplot(column="avg_proba_class1", by="ground_truth")
plt.axhline(0.5, color="red", linestyle="--")
plt.xlabel("Ground Truth Class")
plt.ylabel("Average probability for class 1")
plt.title("Distribution of avg_proba_class1 by True Class")
plt.suptitle("")  # remove default pandas title

#plt.savefig("BoxplotDistributionperClassEVC.png", dpi=300, bbox_inches='tight') 

plt.show()

In [None]:
plt.figure(figsize=(8,6))
plt.hist(results_df["avg_proba_class1"], bins=20, edgecolor="black", alpha=0.7)
plt.axvline(0.5, color="red", linestyle="--", label="Decision threshold (0.5)")
plt.xlabel("Average probability for class 1")
plt.ylabel("Number of patients")
plt.title("Distribution of avg_proba_class1 across patients")
plt.legend()

#plt.savefig("HistogramDistributionperAvgProbEVC.png", dpi=300, bbox_inches='tight') 

plt.show()

### Retour aux tuiles

In [None]:
import os
import json
import torch
import tqdm
import matplotlib.pyplot as plt

# Load model architecture and weights
model_path = 'E:/04_RETURN_TO_TILES/MODEL_TEST_5X'
chowder_model = torch.load(model_path)

# Recreate the model architecture
chowder_tiles = Chowder(
    in_features=768,
    out_features=chowder_model.mlp[2].out_features,
    n_top=chowder_model.extreme_layer.n_top,
    n_bottom=chowder_model.extreme_layer.n_bottom,
    return_indices=True,
    mlp_hidden=[
        chowder_model.mlp[0][0].out_features,
        chowder_model.mlp[1][0].out_features
    ],
    mlp_activation=torch.nn.Sigmoid(),
    bias=True
)

# Load trained weights from CV
test_fold = 4
chowder_tiles.load_state_dict(best_model_outer_cv[test_fold][0])
chowder_tiles.eval()
chowder_tiles.to(device)

In [None]:
test_patients = list(new_labels[new_labels['fold'] == test_fold]['patient_id'])
patient_pred = {}

with torch.no_grad():
    for patient_id in tqdm.tqdm(test_patients):
        tensor = torch.stack(patient_encoded[patient_id]).to(device)
        x = tensor.unsqueeze(0)  # Shape: (1, n_tiles, 768)
        
        pred_logits, indices = chowder_tiles(x)
        pred = pred_logits[0][0].cpu()
        indices = indices.squeeze().cpu()

        proba = torch.nn.functional.softmax(pred, dim=-1)

        patient_pred[patient_id] = {
            'proba': proba,
            'pred': int(torch.argmax(proba)),
            'indices': indices
        }

In [None]:
patient_patches = {
    pid: [patient_dict[pid][idx] for idx in patient_pred[pid]['indices']]
    for pid in patient_pred
}

In [None]:
def new_feature(path, res, clss):
    color = [0, 0, 255] if clss == 'top' else [128, 0, 0]
    x = int(path.split('x=')[-1].split(',')[0])
    y = int(path.split('y=')[-1].split(',')[0])

    scale = {'20X': 2, '10X': 4, '5X': 8}[res]
    sq_size = 224 * scale

    coords = [[[x, y], [x + sq_size, y], [x + sq_size, y + sq_size], [x, y + sq_size], [x, y]]]

    return {
        "type": "Feature",
        "geometry": {"type": "Polygon", "coordinates": coords},
        "properties": {
            "objectType": "annotation",
            "classification": {"name": clss, "color": color}
        }
    }

def export_geojson(patient_patches, res, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    
    for patient_id, paths in patient_patches.items():
        features = [
            new_feature(p, res, 'top' if i < chowder_model.extreme_layer.n_top else 'bottom')
            for i, p in enumerate(paths)
        ]

        geo_json = {"type": "FeatureCollection", "features": features}
        out_path = os.path.join(output_dir, f"{patient_id}_annot.json")
        
        with open(out_path, 'w') as f:
            json.dump(geo_json, f, indent=2)

# Usage
res = "5X"
geojson_output_dir = f"E:/04_RETURN_TO_TILES/predictive_patches_coords_{res}/"
export_geojson(patient_patches, res, geojson_output_dir)

In [None]:
n_col = model_test.extreme_layer.n_top + model_test.extreme_layer.n_bottom
fig, axs = plt.subplots(nrows=len(patient_patches), ncols=n_col + 1, figsize=(n_col + 1, (n_col + 1)*5))
plt.subplots_adjust(left=0.05, bottom=0.1, right=0.98, top=0.975, wspace=0.05, hspace=0.3)

for i, key in enumerate(patient_patches.keys()):
    patch_list = patient_patches[key]
    true_label = labels[labels["patient_id"] == key]["label"].iloc[0]
    pred_label = patient_pred[key]["pred"]
    confidence = patient_pred[key].get("confidence", None)

    correct = (true_label == pred_label)
    color = "green" if correct else "red"

    label_str = f'Patient {key}\nTrue {true_label} | Pred {pred_label}'
    if confidence is not None:
        label_str += f'\nConf: {confidence:.2f}'

    # First column for text label only
    axs[i, 0].axis("off")  # Turn off the label cell's axis
    axs[i, 0].text(-0.05, 0.5, label_str, va='center', ha='right', fontsize=10, color=color, transform=axs[i, 0].transAxes)

    # Remaining columns: display tiles
    for j in range(n_col):
        ax = axs[i, j + 1]
        ax.imshow(load_image(patch_list[j]))
        ax.tick_params(left=False, right=False, labelleft=False,
                       labelbottom=False, bottom=False)

# Add vertical dashed line (now shifted one column right)
split = model_test.extreme_layer.n_top
line_x = (split + 1.2) / (n_col + 1)  # Account for added column
line = plt.Line2D((line_x, line_x), (.1, .975), linestyle='--', color='red', transform=fig.transFigure)
fig.add_artist(line)

fig.suptitle('Patches selected to predict, top 5 on the left, bottom 5 on the right');

In [None]:
fig.savefig('predicted_tiles_fold_4', dpi=300, bbox_inches='tight')