In [None]:
import os
SEED = 45
# Set environment variables for reproducibility
os.environ['PYTHONHASHSEED'] = 'SEED'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import random
import numpy as np
import torch

# Set seeds for reproducibility
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


import pandas as pd
from PIL import Image
import cv2
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, recall_score, f1_score, confusion_matrix as cm
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.model_selection import StratifiedGroupKFold
import torch.nn as nn
from torchvision import transforms
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision.models import MobileNetV2, MobileNet_V2_Weights
from torchvision.models.inception import InceptionOutputs
from torchvision.models import resnet50,ResNet50_Weights, inception_v3,Inception_V3_Weights,mobilenet_v2,MobileNet_V2_Weights




class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.hook_handles = []

        # Register hooks to capture gradients and activations
        self.hook_handles.append(target_layer.register_forward_hook(self.save_activation))
        self.hook_handles.append(target_layer.register_full_backward_hook(self.save_gradient))

    def save_activation(self, module, input, output):
        self.activations = output

    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]

    def __call__(self, input_tensor, target_category):
        self.model.eval()  # Ensure the model is in evaluation mode
        self.model.zero_grad()

        # Ensure input_tensor requires gradients
        input_tensor.requires_grad = True
        
        # Forward pass
        output = self.model(input_tensor)

        # Target for backprop
        target = output[:, target_category].requires_grad_(True)

        # Backward pass
        self.model.zero_grad()
        target.backward(retain_graph=True)

        gradients = self.gradients.cpu().data.numpy()[0]
        activations = self.activations.cpu().data.numpy()[0]

        weights = np.mean(gradients, axis=(1, 2))
        cam = np.zeros(activations.shape[1:], dtype=np.float32)

        for i, w in enumerate(weights):
            cam += w * activations[i, :, :]

        cam = np.maximum(cam, 0)
        cam = cam / np.max(cam)
        cam = np.uint8(255 * cam)
        cam = cv2.resize(cam, (input_tensor.shape[2], input_tensor.shape[3]))
        cam = cv2.applyColorMap(cam, cv2.COLORMAP_JET)

        return cam

    def remove_hooks(self):
        for handle in self.hook_handles:
            handle.remove()

            
# def visualize_images(grad_cam,image_paths, predictions, data_name,transform, title, heatmap_path):
def visualize_images(grad_cam,image_paths, predictions,transform, which_grouping,title,which_model):
    

    if which_model == 'MobileNetV2':
        root_path = 'Update path to save images'
    elif which_model == 'InceptionNetV3':
        root_path = 'Update path  to save images'
  
    elif which_model == 'ResNet50':
        root_path = 'Update path to save images'
    else:
        raise ValueError("Unsupported model type. Please use 'MobileNetV2', 'InceptionV3', or 'ResNet50'.")

  
    misclassified_path = os.path.join(root_path, "Misclassified_images")
    if not os.path.exists(misclassified_path):
        os.makedirs(misclassified_path)

    correct_path = os.path.join(root_path, "Correctly_classifed_images")
    if not os.path.exists(correct_path):
        os.makedirs(correct_path)


    for idx, image_path in enumerate(image_paths):
        # split image path for making heatmap
        heatmap_name = image_path.split('/')[-1]
        file_split = image_path.split('/')
        folder = file_split[-2]
        file_name = file_split[-1]

        actual_name = image_path.split("/")[-1]

     
        
        # Load and preprocess the image
        image = Image.open(image_path).convert('RGB')
        image_tensor = transform(image).unsqueeze(0).to('cpu')
        image_path_display = os.path.join(*image_path.split(os.sep)[-3:])

        # Generate GradCAM heatmap
        cam = grad_cam(image_tensor, predictions[idx])
        # Save the heatmap for verification
       # save_to_heatmap_path = os.path.join(heatmap_path, f'{heatmap_name}')
        # plt.imshow(cam, cmap='jet')
        # plt.axis('off')
        # plt.savefig(save_to_heatmap_path, bbox_inches='tight')# uncomment to save heatmap
        # plt.close()
       

        # Convert the original image to a format suitable for OpenCV
        # input_image = image_tensor.squeeze(0).permute(1, 2, 0).numpy()
        input_image = image_tensor.squeeze(0).permute(1, 2, 0).detach().numpy()
        input_image = (input_image - input_image.min()) / (input_image.max() - input_image.min())
        input_image = np.uint8(255 * input_image)

        # Overlay the heatmap on the original image
        overlay = cv2.addWeighted(input_image, 0.5, cam, 0.5, 0)

        # Visualize the original image, heatmap, and overlay
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(input_image)
        # plt.title('Original Image')

        # plt.subplot(1, 3, 2)
        # plt.imshow(cam)
        # plt.title('GradCAM Heatmap')

        plt.subplot(1, 2, 2)
        plt.imshow(overlay)

        plt.suptitle(actual_name)
        
        # put images in different folders based on its title
        if title == 'M':
           
            # separate erythema and MST groupings into different folders
            if 'Erythema' in which_grouping:
                # Creaete a subfolder for Erythema grouping
                first_subfolder_path = os.path.join(misclassified_path, 'Erythema_groups')
                if not os.path.exists(first_subfolder_path):
                    os.makedirs(first_subfolder_path)
                # Create the subfolder for optical or thermal images within the Erythema grouping
                subfolder_path = os.path.join(first_subfolder_path, folder) # folder represent optical/thermal 
                if not os.path.exists(subfolder_path):
                    os.makedirs(subfolder_path)
            
            else:
                # Create a subfolder for MST grouping
                first_subfolder_path = os.path.join(misclassified_path, 'MST_groups')
                if not os.path.exists(first_subfolder_path):
                    os.makedirs(first_subfolder_path)
                # Create the subfolder for optical or thermal images within the MST grouping
                subfolder_path = os.path.join(first_subfolder_path, folder)
                if not os.path.exists(subfolder_path):
                    os.makedirs(subfolder_path)

            save_to_file_path = os.path.join( subfolder_path, file_name)  
            plt.savefig(save_to_file_path,bbox_inches='tight')

        # create a folder for correctly classified images
        else:
          
            # separate erythema and MST groupings into different folders
            if 'Erythema' in which_grouping:
                # Creaete a subfolder for Erythema grouping
                group_subfolder_path = os.path.join(correct_path, 'Erythema_groups')
                if not os.path.exists(group_subfolder_path):
                    os.makedirs(group_subfolder_path)
                # Create the subfolder for optical or thermal images within the Erythema grouping
                subfolder_path = os.path.join(group_subfolder_path, folder)
                if not os.path.exists(subfolder_path):
                    os.makedirs(subfolder_path)
            else:
                # Create a subfolder for MST grouping
                group_subfolder_path = os.path.join(correct_path, 'MST_groups')
                if not os.path.exists(group_subfolder_path):
                    os.makedirs(group_subfolder_path)
                # Create the subfolder for optical or thermal images within the MST grouping
                subfolder_path = os.path.join(group_subfolder_path, folder)
                if not os.path.exists(subfolder_path):
                    os.makedirs(subfolder_path)
            
            save_to_file_path = os.path.join( subfolder_path, file_name)  
            plt.savefig(save_to_file_path,bbox_inches='tight')
        
        plt.close()
        
        # plt.show()

    


    
# Visualizing misclassified Monk Skin tones

def statistics_by_skin_tone(correctly_classified,correct_preds, misclassified, misclassified_preds,model_name,target_layer, val_transform,modality_name,which_model):
     excel_path = r'Update path to file'
     subject_skin_tone_file= pd.read_excel(excel_path)

     if which_model == 'MobileNetV2':
          root_path = '/update path to save images'
     elif which_model == 'InceptionNetV3':
          root_path = 'Update path to save images'
     elif which_model == 'ResNet50':
          root_path = 'Update path to save images'
     else:
          raise ValueError("Unsupported model type. Please use 'MobileNetV2', 'InceptionNetV3', or 'ResNet50'.")

     # categorize subjects by monk skin tone number
     subject_skin_tone = subject_skin_tone_file[['Subj_ID','Monk_Group']].iloc[:35,:]# replace with test images(extract the hip of test images)
     print(len(subject_skin_tone))

     # Extract subject_ids
     def extract_subject_ids(data): 
          subj_ids_list = []
          for f in data:
               file_parts = f.split('/')[-1]
               subj_id = file_parts.split('_')[:2]
               actual_subj_id = '_'.join(subj_id)
               subj_ids_list.append(actual_subj_id)
          return subj_ids_list

     # test_monk_skin_tone = {}
     # misclassified_monk_skin_tone ={}
     # group_with_percentage ={}
     
     def get_skin_tone(data, model_predictions):

          skin_tone_cat_dict = {'Light': 0, 'Medium': 0, 'Dark': 0,'Erythema<=5': 0,'Erythema>5': 0}
          images_per_cat_dict = { 'Light': [], 'Medium': [], 'Dark': [],'Erythema<=5': [],'Erythema>5': []}
          model_pred_per_image_dict = {'Light': [], 'Medium': [], 'Dark': [],'Erythema<=5': [],'Erythema>5': []}

          
          monk_skin_tone_cat = {}
          actual_corre_file = {}
          model_pred_per_image = {}


          subj_ids = extract_subject_ids(data)
          
          # Get the Monk Skin tone per subject id
          for id, actual_file, model_preds in zip(subj_ids,data, model_predictions):
               skintone = subject_skin_tone_file[subject_skin_tone_file['Subj_ID']==id]['Monk_Group'].iloc[0]
               if skintone not in monk_skin_tone_cat: 
                    monk_skin_tone_cat[skintone] = 1
                    actual_corre_file [skintone]= []
                    actual_corre_file [skintone].append(actual_file)
                    model_pred_per_image [skintone]= []
                    model_pred_per_image [skintone].append(model_preds)

               else:
                    monk_skin_tone_cat[skintone] += 1
                    actual_corre_file [skintone].append(actual_file)
                    model_pred_per_image [skintone].append(model_preds)


          #  Categorize  into light, medium and dark
          for k, v in monk_skin_tone_cat.items():
               # Convert k to integer if it's not already
               k_val = int(k) if not isinstance(k, int) else k
               
               # Check if k is within the ranges, not equal to the range object
               if 1 <= k_val <= 3:
                    skin_tone_cat_dict['Light'] += v
                    images_per_cat_dict['Light'].extend(actual_corre_file[k])
                    model_pred_per_image_dict ['Light'].extend(model_pred_per_image[k])

               elif 4 <= k_val <= 6:
                    skin_tone_cat_dict['Medium'] += v
                    images_per_cat_dict['Medium'].extend(actual_corre_file[k])
                    model_pred_per_image_dict['Medium'].extend(model_pred_per_image[k])
               else:
                    skin_tone_cat_dict['Dark'] += v
                    images_per_cat_dict['Dark'].extend(actual_corre_file[k])
                    model_pred_per_image_dict['Dark'].extend(model_pred_per_image[k])

               # Categorize based on Erythema groupings (<=5,>=6)
               if k <= 5:
                   skin_tone_cat_dict['Erythema<=5'] += v
                   images_per_cat_dict['Erythema<=5'].extend(actual_corre_file[k])
                   model_pred_per_image_dict['Erythema<=5'].extend(model_pred_per_image[k])

               else:
                   skin_tone_cat_dict['Erythema>5'] += v
                   images_per_cat_dict['Erythema>5'].extend(actual_corre_file[k])
                   model_pred_per_image_dict['Erythema>5'].extend(model_pred_per_image[k])

                 
          return skin_tone_cat_dict, images_per_cat_dict, model_pred_per_image_dict
     
     misclassified_skin_cat, misclass_images_per_cat,misclassified_pred_per_cat = get_skin_tone(misclassified,misclassified_preds)
     correctly_classified_skin_cat,correct_images_per_cat, correct_pred_per_cat  = get_skin_tone(correctly_classified,correct_preds)
     
     ## Initialize empty dictionaries for the modified data
     misclass_images_per_cat_modified_dict = {}
     correct_images_per_cat_modified_dict = {}

     # Process misclassified images
     print(f"Misclassified skin tone categories: {misclassified_skin_cat}")
     print('%' * 50)
     for k, v in misclass_images_per_cat.items():
     # Create a new list for each category
          misclass_images_per_cat_modified_list = []
          print(f"Misclassified category: {k}")
          for img in v:
               # Get the file name
               name = img.split('/')[-1]
               misclass_images_per_cat_modified_list.append(name)
               print(name)
          # Assign the category-specific list to the dictionary
          misclass_images_per_cat_modified_dict[k] = misclass_images_per_cat_modified_list

     # Process correctly classified images
     print('%' * 50)
     print(f"Correctly classified skin tone categories: {correctly_classified_skin_cat}")
     for k, v in correct_images_per_cat.items():
          # Create a new list for each category
          correct_images_per_cat_modified_list = []
          print(f"Correctly classified category: {k}")
          for img in v:
               # Get the file name
               name = img.split('/')[-1]
               correct_images_per_cat_modified_list.append(name)
               print(name)
          # Assign the category-specific list to the dictionary
          correct_images_per_cat_modified_dict[k] = correct_images_per_cat_modified_list
     
     
     # Process misclassified images and save each category to its own CSV
     for category, image_list in misclass_images_per_cat_modified_dict.items():
          # Create a DataFrame with just this category's images
          df = pd.DataFrame({
               f'{category}': image_list
          })
          
          # Create a filename for this category
          save_path = os.path.join(root_path, f"{modality_name}_misclass_{category}_images.csv")
          
          # Save to CSV
          df.to_csv(save_path, index=False)
          print(f"Saved {category} misclassified images to: {save_path}")

     # Process correctly classified images and save each category to its own CSV
     for category, image_list in correct_images_per_cat_modified_dict.items():
          # Create a DataFrame with just this category's images
          df = pd.DataFrame({
               f'{category}': image_list
          })
          
          # Create a filename for this category
          save_path = os.path.join(root_path, f"{modality_name}_correct_{category}_images.csv")
          
          # Save to CSV
          df.to_csv(save_path, index=False)
          print(f"Saved {category} correctly classified images to: {save_path}")

    # Visualize  misclassified image  and correctly_classified image in each category using GRADCAM


     grad_cam = GradCAM(model_name, target_layer)

     for k,v in correct_images_per_cat.items():
   
          correct_images = v
          correct_pred = correct_pred_per_cat[k]

          visualize_images(grad_cam, correct_images, correct_pred,val_transform,k, title=f"C", which_model= which_model)

     for k,v in misclass_images_per_cat.items():

          misclass_images = v
          misclassified_pred = misclassified_pred_per_cat[k] 

          visualize_images(grad_cam, misclass_images, misclassified_pred,val_transform,k, title=f"M", which_model= which_model)
     
     

global erythema_optical_path
global erythema_thermal_bw_path
global erythema_thermal_color_path
global monk_filepath

#get erythema images
excel_path = r'Update path to excel file'
erythema_file= pd.read_excel(excel_path)
erythema_path = r'Update path to erythema images'

erythema_optical_path = f'{erythema_path}only_cupping_images_optical'
erythema_thermal_bw_path = f'{erythema_path}only_cupping_images'
erythema_thermal_color_path = f'{erythema_path}only_cupping_images_color'

# Monk skin tone groupings
monk_filepath =r'Update path to file'



#  Load images for each modality
def get_valid_images(image_dir):
    train_images = []
    train_labels = []
   
    image_names = erythema_file['Img Name'].values
    labels = erythema_file['Label'].values
    for img_name, label in zip(image_names, labels):
        img_name = img_name.strip()  # Remove any leading/trailing whitespace
        image_path = os.path.join(image_dir, img_name)
        file_parts = img_name.split('_')
      
        if os.path.exists(image_path):
            train_images.append(image_path)
            train_labels.append(label)
        else:
            print(f"Warning: Image not found: {image_path}")
    print(f'Total number of train images and train labels are:', len(train_images),len(train_labels))

    return  train_images,train_labels



def get_fold_train_test(data,labels,test_subjects):

    # initialize list for optical train and test folds
    train_fold= []
    test_fold= []
    train_fold_labels = []
    test_fold_labels = []


    data_label_dict = dict(zip(data, labels))
    #  get optical train and test paths 

    for file in data:
        
        file_parts = file.split('/')[-1]
        subj_id = '_'.join(file_parts.split('_')[:2])
       
        if subj_id not in test_subjects:
            train_fold.append(file)
            train_fold_labels.append( data_label_dict[file])
        else:
            test_fold.append(file)
            test_fold_labels.append( data_label_dict[file])
    
    return train_fold,train_fold_labels,test_fold,test_fold_labels

# Stratified k-fold cv for folds
def stratified_kfold():  
  
    optical_ery_folds = []
    thermal_bw_ery_folds = []
    thermal_color_ery_folds = []

    # Get the images and labels per modality
    optical_ery,optical_ery_labels= get_valid_images(erythema_optical_path)
    # get data for thermal bw images
    thermal_bw_ery,thermal_bw_ery_labels= get_valid_images(erythema_thermal_bw_path)
    # get data for thermal color images
    thermal_color_ery,thermal_color_ery_labels = get_valid_images(erythema_thermal_color_path)


    # Get stratified k-fold using the subject id for grouping
    sub_monk_skin_tone = pd.read_excel(monk_filepath).iloc[0:35,:]

    subject_ids = sub_monk_skin_tone['Subj_ID']
    monk_scores = sub_monk_skin_tone['Monk_Group']

    # Initialize StratifiedGroupKFold
    sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=SEED)

    for fold_idx, (_, test_idx) in enumerate(sgkf.split(subject_ids, 
                                                            monk_scores, 
                                                            groups=subject_ids)):
      
        test_subjects = subject_ids[test_idx].to_list()
        print(test_subjects )
       

        train_fold_optical_ery,train_fold_optical_ery_labels,test_fold_optical_ery,test_fold_optical_ery_labels= get_fold_train_test(optical_ery,optical_ery_labels,test_subjects)
        train_fold_thermal_bw_ery,train_fold_thermal_bw_ery_labels,test_fold_thermal_bw_ery,test_fold_thermal_bw_ery_labels= get_fold_train_test(thermal_bw_ery,thermal_bw_ery_labels,test_subjects)
        train_fold_thermal_color_ery,train_fold_thermal_color_ery_labels,test_fold_thermal_color_ery,test_fold_thermal_color_ery_labels= get_fold_train_test(thermal_color_ery,thermal_color_ery_labels,test_subjects)
        
        optical_ery_folds.append({'train_images': train_fold_optical_ery, 
                                'train_labels':train_fold_optical_ery_labels, 
                                'test_images': test_fold_optical_ery, 
                                'test_labels':test_fold_optical_ery_labels})
        
        thermal_bw_ery_folds.append({'train_images': train_fold_thermal_bw_ery, 
                                'train_labels':train_fold_thermal_bw_ery_labels,
                                    'test_images': test_fold_thermal_bw_ery, 
                                    'test_labels':test_fold_thermal_bw_ery_labels})
        
        thermal_color_ery_folds.append({'train_images': train_fold_thermal_color_ery, 
                                    'train_labels': train_fold_thermal_color_ery_labels,
                                    'test_images': test_fold_thermal_color_ery, 
                                    'test_labels':test_fold_thermal_color_ery_labels})
        
        print(f'Fold {fold_idx+1} extracted')
    return optical_ery_folds,thermal_bw_ery_folds, thermal_color_ery_folds




# def model_train_evaluation(data_in_folds,name):


def model_train_evaluation(data_in_folds,name,model_config,img_size,which_model):
    print(f'Image size is {img_size}')


    global data_name
    batch_size = 32
    image_size =img_size
   
    # model = MobileNetV2(weights=MobileNet_V2_Weights.DEFAULT)
    # model= mobilenet_v2(weights = MobileNet_V2_Weights.DEFAULT)
    # which_model = 'MobileNetV2'
    model =  model_config

    data_name = name    
    model_accuracy = []
    model_auc = []
    model_specificity = []
    model_sensitivity = []
    model_f1 = []
    total_predicted_labels = []
    total_true_labels = []
    # best_overall_acc = 0.0
    all_predictions = []
    all_labels = []
    best_models = {}

    class CustomDataset(Dataset):
        def __init__(self, image_paths, labels, transform=None):
            self.image_paths = image_paths
            self.labels = labels
            self.transform = transform
            self.label_map = {label: idx for idx, label in enumerate(set(labels))}

        def __len__(self):
            return len(self.image_paths)

        def __getitem__(self, idx):
            image_path = self.image_paths[idx]
            image = Image.open(image_path).convert('RGB')
            label = self.labels[idx]
            
            if self.transform:
                image = self.transform(image)
            
            label_idx = self.label_map[label]
            return image, torch.tensor(label_idx, dtype=torch.long)

    train_transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.RandomRotation(20),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.PILToTensor(),
        transforms.ConvertImageDtype(torch.float32),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    val_transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.PILToTensor(),
        transforms.ConvertImageDtype(torch.float32),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    def visualize_augmentations(dataset, num_samples=3, img_indices=None):
    # def visualize_individual_augmentations(dataset, img_indices=None, num_samples=5):
        """
        Visualize augmentations applied to training images, with each image displayed individually
        
        Parameters:
            dataset: Your CustomDataset instance
            img_indices: Specific image indices to visualize (optional)
            num_samples: Number of random samples to visualize if img_indices not provided
        """
        
        
        # If no specific indices provided, select random ones
        if img_indices is None:
            img_indices = np.random.choice(len(dataset), num_samples, replace=False)
        
        # Original transform components (for individual augmentations)
        resize = transforms.Resize(image_size)
        rotate = transforms.RandomRotation(20)
        h_flip = transforms.RandomHorizontalFlip(p=1.0)  # Always flip
        v_flip = transforms.RandomVerticalFlip(p=1.0)    # Always flip
        
        figures = []
        
        # Process each selected image
        for idx in img_indices:
            # Get the image path and label
            image_path = dataset.image_paths[idx]
            label = dataset.labels[idx]
            
            # Create a figure for this specific image
            fig, axes = plt.subplots(1, 4, figsize=(16, 4))
            
            # Load original image
            original_img = Image.open(image_path).convert('RGB')
            
            # Apply individual transformations
            resized_img = resize(original_img)
            rotated_img = rotate(resized_img.copy())
            h_flipped_img = h_flip(resized_img.copy())
            v_flipped_img = v_flip(resized_img.copy())
            
            # Display images
            axes[0].imshow(resized_img)
            axes[0].set_title(f'Original Image')
            
            axes[1].imshow(rotated_img)
            axes[1].set_title('Rotation at 20 degrees')
            
            axes[2].imshow(h_flipped_img)
            axes[2].set_title('Horizontal Flip')
            
            axes[3].imshow(v_flipped_img)
            axes[3].set_title('Vertical Flip')
            
            # Remove axis ticks
            for ax in axes:
                ax.axis('off')
            
            plt.tight_layout()
            # fig.suptitle(f'Augmentations for Image {idx}', fontsize=16, y=1.05)
            plt.show()
            
            figures.append(fig)
        
        return figures

    # Run model on each fold
    print(f"Model training and evaluation on {name} modality")

    total_correctly_classified =[]
    total_correct_labels =[]
    total_correct_preds = []
    total_misclassified = []
    total_misclass_labels = []
    total_misclass_pred = []

    all_true_labels_across_folds = []
    all_pred_labels_across_folds = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for fold_idx, fold in enumerate(data_in_folds):
        print("="*50)
        
        print(f'Starting Fold {fold_idx+1}')
        print("="*50)
        print(fold.keys())

        fold_train_paths = fold['train_images']
        fold_train_labels = fold['train_labels']
        fold_val_paths = fold['test_images']
        fold_val_labels = fold['test_labels']

        # Print the total training and evaluation images
        print(f'Total training images and labels are:', len(fold_train_paths), len(fold_train_labels))
        print(f'Total evaluation images and labels are:', len(fold_val_paths), len(fold_val_labels))
        
        # Create datasets
        train_dataset = CustomDataset(fold_train_paths, fold_train_labels, transform=train_transform)
        val_dataset = CustomDataset(fold_val_paths, fold_val_labels, transform=val_transform)

        # Define seed worker function (moved outside the loop but called here)
        def seed_worker(worker_id):
            worker_seed = SEED
            np.random.seed(worker_seed)
            random.seed(worker_seed)

        # Create a generator with fixed seed
        g = torch.Generator()
        g.manual_seed(SEED)

        # Create DataLoaders
        train_loader = DataLoader(
            train_dataset, 
            batch_size=batch_size, 
            shuffle=True,
            worker_init_fn=seed_worker,
            generator=g
        )

        val_loader = DataLoader(
            val_dataset, 
            batch_size=batch_size, 
            shuffle=False,
            worker_init_fn=seed_worker,
            generator=g
        )

        # Initialize a completely new model for each fold
        if which_model == 'MobileNetV2':
            # Create a fresh instance of the model
            if 'mobilenet_v2' in globals():
                # If function is available directly
                model_name = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
            else:
                # Fallback to class-based initialization
                model_name = MobileNetV2(weights=MobileNet_V2_Weights.DEFAULT)
            
            # Reset seed before modifying the model
            torch.manual_seed(SEED)
            model_name.classifier[1] = torch.nn.Linear(model_name.classifier[1].in_features, len(set(fold_train_labels)))
            
            # Define initialization function
            def init_weights(m):
                if isinstance(m, nn.Linear):
                    torch.nn.init.xavier_uniform_(m.weight, gain=1.0)
                    torch.nn.init.zeros_(m.bias)
            
            # Apply initialization
            model_name.classifier[1].apply(init_weights)
            print(f'Test labels are ', len(set(fold_train_labels)))
        
        elif which_model == 'InceptionNetV3':
            # Create a fresh instance of the model
            model_name = inception_v3(weights=Inception_V3_Weights.DEFAULT)
            
            # Reset seed before modifying the model
            torch.manual_seed(SEED)
            model_name.fc = nn.Linear(model_name.fc.in_features, len(set(fold_train_labels)))
            
            # Define initialization function (if not defined outside)
            def init_weights(m):
                if isinstance(m, nn.Linear):
                    torch.nn.init.xavier_uniform_(m.weight, gain=1.0)
                    torch.nn.init.zeros_(m.bias)
            
            # Apply initialization
            model_name.fc.apply(init_weights)
        
        else:  # ResNet50
            # Create a fresh instance of the model
            model_name = resnet50(weights=ResNet50_Weights.DEFAULT)
            
            # Reset seed before modifying the model
            torch.manual_seed(SEED)
            model_name.fc = nn.Linear(model_name.fc.in_features, len(set(fold_train_labels)))
            
            # Define initialization function (if not defined outside)
            def init_weights(m):
                if isinstance(m, nn.Linear):
                    torch.nn.init.xavier_uniform_(m.weight, gain=1.0)
                    torch.nn.init.zeros_(m.bias)
            
            # Apply initialization
            model_name.fc.apply(init_weights)
        
        # Move model to device
        model_name.to(device)

        # Define criterion and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model_name.parameters(), lr=0.001, weight_decay=0.0005)
        
        best_val_auc = 0.0
        patience = 20
        epochs_no_improve = 0
        num_epochs = 100
        # Lists to track training and validation losses
        train_losses = []
        val_losses = []

        # Rest of your training and validation code continues...
        #     # Initialize the model
        #     if which_model =='MobileNetV2':
        #         model_name = model
        #         model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, len(set(fold_train_labels)))
                
        #         model_name.to(device)
        #         print(f'Test labels are ', len(set(fold_train_labels)))

        #     else:
        #         model_name = model
        #         model_name.fc = nn.Linear(model_name.fc.in_features, len(set(fold_train_labels)))
        #         model_name.to(device)
        #         print(f'Test labels are ', len(set(fold_train_labels)))

        print('Starting training loop')
        for epoch in range(1, num_epochs + 1):
            # Training phase
            model_name.train()
            running_loss = 0.0
            
            for data in train_loader:
                inputs, targets = data
                inputs, targets = inputs.to(device), targets.to(device)
                
                optimizer.zero_grad()
                outputs = model_name(inputs)
                
                if isinstance(outputs, InceptionOutputs):
                    logits = outputs.logits
                    aux_logits = outputs.aux_logits
                    loss = criterion(logits, targets) + 0.4 * criterion(aux_logits, targets)  # Add auxiliary loss
                else:
                    logits = outputs
                    loss = criterion(logits, targets)
                
                loss.backward()
                optimizer.step()
                running_loss += loss.item() * inputs.size(0)
            
            epoch_train_loss = running_loss / len(train_loader.dataset)
            train_losses.append(epoch_train_loss)
            
            # Validation phase
            model_name.eval()
            val_corrects = 0
            val_total = 0
            val_all_labels = []
            val_all_probs = []
            val_all_preds = []
            val_running_loss = 0.0

            correctly_classified = []
            correct_labels = []
            correct_preds = []
            misclassified = []
            misclass_labels = []
            misclass_preds = []

            with torch.no_grad():
                for i, (inputs, targets) in enumerate(val_loader):
                    inputs, targets = inputs.to(device), targets.to(device)
                    outputs = model_name(inputs)
                    
                    # Handle output for InceptionNetV3 and ResNet50
                    if which_model == 'InceptionNetV3' and isinstance(outputs, InceptionOutputs):
                        logits = outputs.logits  # Use logits for classification
                    else:
                        logits = outputs  # For ResNet50 and MobileNet, output is already logits
                    
                    # Calculate validation loss
                    val_loss = criterion(logits, targets)
                    val_running_loss += val_loss.item() * inputs.size(0)
                    
                    _, preds = torch.max(logits, 1)
                    
                    # Track image paths for analyzing errors
                    start_idx = i * val_loader.batch_size
                    for j in range(len(preds)):
                        if start_idx + j < len(fold_val_paths):  # Ensure index is in range
                            image_path = fold_val_paths[start_idx + j]
                            label = targets[j].item()
                            pred = preds[j].item()
                            if pred == label:
                                correctly_classified.append(image_path)
                                correct_labels.append(label)
                                correct_preds.append(pred)
                            else:
                                misclassified.append(image_path)
                                misclass_labels.append(label)
                                misclass_preds.append(pred)
                    
                    # Update metrics
                    val_total += targets.size(0)
                    val_corrects += (preds == targets).sum().item()
                    val_all_labels.extend(targets.cpu().numpy())
                    val_all_probs.extend(torch.softmax(logits, dim=1)[:, 1].cpu().numpy())
                    val_all_preds.extend(preds.cpu().numpy())
            
            # Calculate epoch validation loss
            epoch_val_loss = val_running_loss / len(val_loader.dataset)
            val_losses.append(epoch_val_loss)
            # scheduler.step(epoch_val_loss)
            
            # Compute metrics
            val_acc = val_corrects / val_total
            val_auc = roc_auc_score(val_all_labels, val_all_probs)
            sensitivity = recall_score(val_all_labels, val_all_preds, pos_label=1)
            tn,fp,fn,tp = cm(val_all_labels, val_all_preds, labels= [0,1]).ravel()
            specificity = tn / (tn + fp)
            f1 = f1_score(val_all_labels, val_all_preds,pos_label=1)


            print(
                f"Epoch {epoch}/{num_epochs}, "
                f"Train Loss: {epoch_train_loss:.4f}, "
                f"Val Loss: {epoch_val_loss:.4f}, "
                f"Val Acc: {val_acc:.3f}, "
                f"AUC: {val_auc:.3f}, "
                f"Sens: {sensitivity:.3f}, "
                f"Spec: {specificity:.3f}, "
                f"F1: {f1:.3f}"
            )


            

            # Update all_predictions and all_labels for saving later
            all_predictions.extend(val_all_preds)
            all_labels.extend(val_all_labels)

            # modification done on Februray 26th,2025
    
            # Save the best model for this fold
                # Check and update the best model for this fold
            if val_auc > best_val_auc:
                best_val_auc= val_auc
                epochs_no_improve = 0
                best_model_state = model_name.state_dict()  # Save the state dict of the best model for this fold
                best_val_metrics = {
                        'val_acc': val_acc,
                        'val_auc': val_auc,
                        'sensitivity': sensitivity,
                        'specificity': specificity,
                        'f1': f1
                    } 
                temp_correctly_classified = correctly_classified
                temp_correct_labels = correct_labels
                temp_correct_preds = correct_preds
                temp_misclassified = misclassified
                temp_misclass_labels = misclass_labels
                temp_misclass_pred = misclass_preds
                temp_val_all_labels = val_all_labels
                temp_val_all_preds = val_all_preds
                
            else:
                epochs_no_improve += 1
                print(f'No performance gains at:  {epochs_no_improve}')
            # Early stopping
            if epochs_no_improve >= patience:
                print(f'Early stopping at epoch {epoch} for fold {fold_idx + 1}')
                break
        

        # Plot training and validation loss
        # plt.figure(figsize=(10, 5))
        # plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')
        # plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss')
        # plt.xlabel('Epochs')
        # plt.ylabel('Loss')
        # plt.title(f'Training and Validation Loss - {name} Fold {fold_idx + 1}')
        # plt.legend()
        # plt.grid(True)
        
        # Define the save directory based on model type
        if which_model == 'MobileNetV2':
            save_dir = 'Update path'
        elif which_model == 'InceptionNetV3':
            save_dir = 'Update path'
        else:  # ResNet50
            save_dir = 'Update path'
            
        # Ensure directory exists
        os.makedirs(save_dir, exist_ok=True)
        
        # Save loss plot
        # plt.savefig(f'{save_dir}{name}_{which_model}_fold_{fold_idx+1}_loss_plot.png')
        # plt.show()

        # Update model_accuracy, model_auc, model_specificity, model_sensitivity, and model_f1
        model_accuracy.append(best_val_metrics['val_acc'])
        model_auc.append(best_val_metrics['val_auc'])
        model_specificity.append(best_val_metrics['specificity'])
        model_sensitivity.append(best_val_metrics['sensitivity'])
        model_f1.append(best_val_metrics['f1'])


            # best_auc_across_folds.append(f'{model_auc:.3f}')

        total_correctly_classified.extend(temp_correctly_classified)
        total_misclassified.extend(temp_misclassified)

        total_correct_preds.extend(temp_correct_preds)
        total_misclass_pred.extend(temp_misclass_pred)

        total_correct_labels.extend(temp_correct_labels)
        total_misclass_labels.extend(temp_misclass_labels)

        all_true_labels_across_folds.extend(temp_val_all_labels)
        all_pred_labels_across_folds.extend(temp_val_all_preds)




    # Get the mean and standard deviation of evaluation metrics
    model_accuracy_mean = np.mean(model_accuracy)
    model_accuracy_std = np.std(model_accuracy)
    model_auc_mean = np.mean(model_auc)
    model_auc_std = np.std(model_auc)
    model_specificity_mean = np.mean(model_specificity)
    model_specificity_std = np.std(model_specificity)
    model_sensitivity_mean = np.mean(model_sensitivity)
    model_sensitivity_std = np.std(model_sensitivity)
    model_f1_mean = np.mean(model_f1)
    model_f1_std = np.std(model_f1)

    #print metrics: mean and std
    print(f'Mean and standard deviation of evaluation metrics for {which_model}{name}:')
    print(f'Accuracy: {model_accuracy_mean:.3f}± {model_accuracy_std:.3f}')
    print(f'AUC: {model_auc_mean:.3f} ± {model_auc_std:.3f}')
    print(f'Sensitivity: {model_sensitivity_mean:.3f} ± {model_sensitivity_std:.3f}')
    print(f'Specificity: {model_specificity_mean:.3f} ± {model_specificity_std:.3f}')
    print(f'F1: {model_f1_mean:.3f} ± {model_f1_std:.3f}')

   


    # print(f'Correctly classified labels: {best_classification_outcomes["correctly_classified_labels"]}')
    # print(f'Misclassified labels: {best_classification_outcomes["misclassified_labels"]}')


    # Show the list of top aucs per fold
    print(f'Best AUCs across all folds for {name}: {model_auc}')
    # Confusion matrix across 5-folds 

    # all_val_true_labels =  classification_outcomes["true_labels"]
    # all_val_predicted_labels = classification_outcomes["predicted_labels"]
    # all_val_true_labels = all_labels
    # all_val_predicted_labels = all_predictions
    # Create confusion matrix display
    disp = ConfusionMatrixDisplay.from_predictions(
    all_true_labels_across_folds, 
    all_pred_labels_across_folds, 
    display_labels=['No Erythema','Erythema'],
    cmap='YlGnBu'
)

   # Access the axis object and set font sizes
    ax = disp.ax_
    plt.setp(ax.get_yticklabels(), fontsize=12)  # For y-axis labels
    plt.setp(ax.get_xticklabels(), fontsize=12)  # For x-axis labels
    # modify the fontsize of the x and y axes
    ax.set_xlabel(ax.get_xlabel(), fontsize=15)
    ax.set_ylabel(ax.get_ylabel(), fontsize=15)

    # To adjust the numbers inside the cells
    for im in ax.images:
        im.colorbar.ax.tick_params(labelsize=16)  # For colorbar text
        
        # Adjust only the numbers inside the cells
        for text in disp.ax_.texts:
            text.set_fontsize(18)  

    # Define the save directory based on model type
    if which_model == 'MobileNetV2':
        save_dir = 'update path'
    elif which_model == 'InceptionNetV3':
        save_dir = '/update path'
    else:  # ResNet50
        save_dir = ''
        
    # Ensure directory exists
    os.makedirs(save_dir, exist_ok=True)
    
    plt.title(f'{name} {which_model} Confusion Matrix')
    plt.savefig(f'{save_dir}{name}_{which_model}_confusion_matrix.png')
    plt.show()

     # Save the metrics to file
    output_path = os.path.join(save_dir, f'{name}_{which_model}_metrics.txt')
    with open(output_path, 'w') as f:
        f.write(f'Mean and standard deviation of evaluation metrics for {which_model}{name}:\n')
        f.write(f'Accuracy: {model_accuracy_mean:.3f}± {model_accuracy_std:.3f}\n')
        f.write(f'AUC: {model_auc_mean:.3f} ± {model_auc_std:.3f}\n')
        f.write(f'Sensitivity: {model_sensitivity_mean:.3f} ± {model_sensitivity_std:.3f}\n')
        f.write(f'Specificity: {model_specificity_mean:.3f} ± {model_specificity_std:.3f}\n')
        f.write(f'F1: {model_f1_mean:.3f} ± {model_f1_std:.3f}\n')


    # GRADCAM analysis for all models
    # Define model-specific target layers for GradCAM
    if which_model == 'MobileNetV2':
        target_layer = model_name.features[-1]
    elif which_model == 'InceptionNetV3':
        target_layer = model_name.Mixed_7c
    else:  # ResNet50
        target_layer = model_name.layer4[-1]
        
    validation_transform = val_transform
    
   
    # Make sure to pass which_model to the statistics_by_skin_tone function
    statistics_by_skin_tone(
        total_correctly_classified, 
        total_correct_preds, 
        total_misclassified, 
        total_misclass_pred, 
        model_name, 
        target_layer, 
        validation_transform, 
        name,
        which_model=which_model  # Pass the model type to save in correct directory
    )
    # except Exception as e:
    #     print(f"Error saving predictions: {e}")


#######################################################################################


# Calling model training and evaluation

optical_ery_folds,thermal_bw_ery_folds, thermal_color_ery_folds= stratified_kfold()
# Define your models and datasets
models = ['MobileNetV2', 'InceptionNetV3', 'ResNet50']
# models = ['InceptionNetV3', 'ResNet50']
# models = ['MobileNetV2']

datasets = [
    {'train_images': optical_ery_folds, 'name': 'Optical Erythema'},
    {'train_images': thermal_color_ery_folds, 'name': 'Thermal Color Erythema'}
]

# Get the k-fold data
optical_ery_folds, thermal_bw_ery_folds, thermal_color_ery_folds = stratified_kfold()

# Run each model-modality combination with fresh random states
for model_a in models:
    for dataset in datasets:
        print("="*80)
        print("Training with patience of 20 + weight decay of 0.0005".upper())
        print(f"\nSTARTING NEW RUN: {model_a} on {dataset['name']}")

        print("="*80)
        
        # Reset ALL random states before each model-dataset combination
        random.seed(SEED)
        np.random.seed(SEED)
        torch.manual_seed(SEED)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(SEED)
            torch.cuda.manual_seed_all(SEED)
        
        # Load a fresh model for this run
        if model_a == 'MobileNetV2':
            model_loaded = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
            image_size = [224, 224]
        elif model_a == 'InceptionNetV3':
            model_loaded = inception_v3(weights=Inception_V3_Weights.DEFAULT)
            image_size = [299, 299]
        else:  # ResNet50
            model_loaded = resnet50(weights=ResNet50_Weights.DEFAULT)
            image_size = [224, 224]
        
        # Call model training with all the necessary parameters
        model_train_evaluation(
            data_in_folds=dataset['train_images'],
            name=dataset['name'],
            model_config=model_loaded,
            img_size=image_size,
            which_model=model_a
        )
        
        print(f"COMPLETED: {model_a} on {dataset['name']}")
        print("-"*80)


