# **MONKEYPOX DETECTION USING DEEP LEARNING MODELS AND EXPLAINABLE AI WITH FEDERATED LEARNING**



---



RATIONELE FOR PROJECT

Mpox, formerly known as monkeypox, is a rare but potentially serious viral disease caused by the monkeypox virus, a species of the genus Orthopoxvirus (World Health Organization, 2024). The disease has two distinct clades: clade I (with subclades Ia and Ib) and clade II (with subclades IIa and IIb), with clade II being the most recent mpox that has caused a global outbreak in 2024 (CDC, 2024). The focus of this project is to examine how Machine Learning models can be utilized to accurately diagnose specific disease through classification and detection techniques while considering racial bias. The ongoing global outbreak of clade II has caused more than 100 00 cases in 122 countries, including 115 countries where Mpox was not reported previously (CDC, 2024).

The fatality rate of Mpox can reach up to 11% depending on the strain and the health condition of the affected individual, with the Clade I having a mortality rate of around 3.6% (Mpox Virus: Clade I and Clade II, n.d). There is no specific treatment for Mpox, antiviral medications like tecovirimat (TPOXX) can help mitigate symptoms, but these treatments are not accessible in underprivileged areas (CDC, 2024). With the ongoing violence in the Democratic Republic of Congo, efforts to control Mpox have been severely hindered by ongoing violence, which results in higher transmission rates and delayed responses.

The traditional diagnosis method is PCR testing, and Serological testing, where PCR is the primary method that uses samples from skin lesions and is the most preferred because of its high sensitivity and specificity, while Serological tests detect antibodies, but this test is less reliable due to its cross-reactivity with other orthopoxviruses (Khehra, Padda and Swift, 2023). Both these methods have their limitations, and these include the need for specialized equipment and trained personnel and the potential for false negatives due to viral mutations. Enhancing diagnostic accuracy and accessibility, especially in resource-limited environments, requires the development of a reliable computer-based framework for the detection of Mpox disease. The goal is to provide an accessible, reliable tool that aids healthcare professionals in conflict zones and underprivileged areas

AIMS AND OBJECTIVES

AIM

To implement a robust, privacy-preserving, racially fair, and explainable Deep Learning framework integrated with Federated Learning for early and accurate Mpox detection.

OBJECTIVES

1.    Implement a Deep Learning model using Transfer Learning and Federated Learning on diverse skin lesion datasets to accurately classify Mpox.

2.    Assess and mitigate the impact of skin tone variations on model performance to ensure racial fairness in diagnosis.

3.    Integrate and evaluate Explainable AI techniques such as Grad-CAM and LIME to improve transparency and trustworthiness of the model‚Äôs predictions.

4.    Benchmark the proposed framework against traditional diagnostic methods and recent Deep Learning approaches using robust cross-validation techniques.





---



## 1. IMPORT IMPORTANT LIBRARIES

In [1]:
pip install imgaug opencv-python

Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install "numpy<2"

Note: you may need to restart the kernel to use updated packages.


In [3]:
!pip install tensorflow



In [4]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
#import tensorflow as tf
import seaborn as sns
from PIL import Image
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')
from sklearn.model_selection import train_test_split
import os
import shutil
import zipfile
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import resample
#import imgaug.augmenters as iaa
from skimage import io
#from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import DenseNet201
from tensorflow.keras.layers import Input, GlobalAveragePooling2D, Dense, BatchNormalization, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adamax
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
import os
import numpy as np
import pandas as pd
from skimage import io

2025-08-24 11:45:43.911326: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1756035943.933795      99 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1756035943.940610      99 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [5]:
import seaborn as sns
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score
from sklearn.metrics import mean_squared_error
import numpy as np
from sklearn.metrics import classification_report, accuracy_score, precision_recall_fscore_support
from tensorflow.keras.applications import ResNet152V2
from keras.applications import ResNet152V2
from tensorflow.keras.layers import Input, GlobalAveragePooling2D, Dense, BatchNormalization, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adamax
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.models import load_model
from skimage import io
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
import cv2

## 2. IMPORT DATASETS AND CREATION OF DATASETS

In [9]:
# Replace with your image path
image_path = '/kaggle/input/mpox-skin-lesion-dataset-version-20-msld-v20/Keywords.jpg'

# Open and display the image
image = Image.open(image_path)
plt.imshow(image)
plt.axis('off')  # Hide axis
plt.show()

FileNotFoundError: [Errno 2] No such file or directory: '/kaggle/input/mpox-skin-lesion-dataset-version-20-msld-v20/Keywords.jpg'

In [8]:
log=pd.read_excel('/kaggle/input/mpox-skin-lesion-dataset-version-20-msld-v20/datalog.xlsx')
print(log)

FileNotFoundError: [Errno 2] No such file or directory: '/kaggle/input/mpox-skin-lesion-dataset-version-20-msld-v20/datalog.xlsx'

In [None]:
# Define paths (these are already directories, not zip files)
aug_folder = "/kaggle/input/mpox-skin-lesion-dataset-version-20-msld-v20/Augmented Images"
orig_folder = "/kaggle/input/mpox-skin-lesion-dataset-version-20-msld-v20/Original Images"

# Example: List some image files from the augmented folder
for root, dirs, files in os.walk(aug_folder):
    for file in files[:5]:  # Just print first 5 files
        print(os.path.join(root, file))
    break  # Remove this if you want to walk through all subdirectories

Paths to ZIP files
aug_zip = "/content/drive/MyDrive/PROJECT DATASETS/Augmented Images.zip"
orig_zip = "/content/drive/MyDrive/PROJECT DATASETS/Original Images.zip"

Unzip both folders
with zipfile.ZipFile(aug_zip, 'r') as zip_ref:
    zip_ref.extractall("Augmented_Images")

with zipfile.ZipFile(orig_zip, 'r') as zip_ref:
    zip_ref.extractall("Original_Images")

### Augmented Images

In [None]:
# Classes (folder names)
#diagnosis_classes = ['Monkeypox', 'Chickenpox', 'Measles', 'Cowpox', 'Healthy', 'HFMD']

In [None]:
import os
import pandas as pd

# Adjust this to your actual root path
base_dir = "/kaggle/input/mpox-skin-lesion-dataset-version-20-msld-v20/Augmented Images/Augmented Images/FOLDS_AUG"

# Define folds and splits
folds = ['fold1_AUG', 'fold2_AUG', 'fold3_AUG', 'fold4_AUG', 'fold5_AUG']
splits = ['Train']

# Initialize records list
records = []

for fold in folds:
    for split in splits:
        split_path = os.path.join(base_dir, fold, split)
        if not os.path.exists(split_path):
            continue

        # Dynamically detect class folders under each split
        for diagnosis in os.listdir(split_path):
            class_path = os.path.join(split_path, diagnosis)
            if not os.path.isdir(class_path):
                continue  # Skip files

            # Iterate over image files
            for img_file in os.listdir(class_path):
                if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                    image_id = os.path.splitext(img_file)[0]
                    lesion_id = "_".join(image_id.split('_')[:3])
                    image_path = os.path.join(class_path, img_file)

                    records.append({
                        "lesion_Id": lesion_id,
                        "image_Id": image_id,
                        "diagnosis": diagnosis,
                        "image_Path": image_path,
                        "fold": fold,
                        "split": split
                    })

# Convert to DataFrame
df = pd.DataFrame(records)

# Save to CSV
df.to_csv("all_folds_dataset.csv", index=False)
print(f"‚úÖ Saved dataset with {len(df)} entries to 'all_folds_dataset.csv'")

In [None]:
df

### Original Datasets

In [None]:
import os
import pandas as pd

# Adjust this to your actual root path
base_dir = "/kaggle/input/mpox-skin-lesion-dataset-version-20-msld-v20/Original Images/Original Images/FOLDS"

# Define folds and splits
folds = ['fold1', 'fold2', 'fold3', 'fold4', 'fold5']
splits = ['Train', 'Test', 'Valid']

# Initialize records list
records = []

for fold in folds:
    for split in splits:
        split_path = os.path.join(base_dir, fold, split)
        if not os.path.exists(split_path):
            continue

        # Dynamically detect class folders under each split
        for diagnosis in os.listdir(split_path):
            class_path = os.path.join(split_path, diagnosis)
            if not os.path.isdir(class_path):
                continue  # Skip files

            # Iterate over image files
            for img_file in os.listdir(class_path):
                if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                    image_id = os.path.splitext(img_file)[0]
                    lesion_id = "_".join(image_id.split('_')[:3])
                    image_path = os.path.join(class_path, img_file)

                    records.append({
                        "lesion_Id": lesion_id,
                        "image_Id": image_id,
                        "diagnosis": diagnosis,
                        "image_Path": image_path,
                        "fold": fold,
                        "split": split
                    })

# Convert to DataFrame
df_all = pd.DataFrame(records)

# Save to CSV
df_all.to_csv("all_folds_dataset.csv", index=False)
print(f"‚úÖ Saved dataset with {len(df_all)} entries to 'all_folds_dataset.csv'")

In [None]:
df_all

In [None]:
df_all['diagnosis'].nunique()

In [None]:
# Concatenate all dataframes
combined_df = pd.concat([df,df_all], ignore_index=True)

# Optional: Shuffle the combined dataset
combined_df = combined_df.sample(frac=1).reset_index(drop=True)

# Save to CSV
combined_df.to_csv("combined_dataset.csv", index=False)

print("‚úÖ Concatenated dataset saved as 'combined_dataset.csv'")

In [None]:
combined_df

In [None]:
# Load the dataset
df = pd.read_csv("combined_dataset.csv")

## 2.1 EXPLORATORY DATA ANALYSIS

In [None]:
print("üìä Basic Info:")
print(df.info())

This dataset is a DataFrame containing 40,819 entries and 6 columns, all of which are of type object. Each row represents a skin image sample with the following details:



1.   Lesion_Id: Unique identifier for each lesion.
2.   Image_Id: Unique identifier for each image.
3. Diagnosis: The skin disease label (e.g., Mpox, Chickenpox).
4. Image_Path: The file path to the image.
5. Fold: Indicates the fold assignment (e.g., for cross-validation).
6. Split: Indicates the data split (e.g., Train, Test, Validation).

There are no missing values in the dataset.

In [None]:
print("\nüîç First few rows:")
print(df.head())

In [None]:
unique_counts = df.nunique()

print("\nUnique entries in each column:")
print(unique_counts)

Here‚Äôs a summary of the unique entries in each column of the dataset:

*   Lesion_Id (755 unique): Many images belong to the same lesion, indicating multiple views per lesion.
*   Image_Id (11,325 unique): Multiple image records may share the same image ID, possibly due to data augmentation or replication.

*   Diagnosis (6 unique): There are six distinct skin disease classes (e.g., Mpox, Chickenpox, Measles, etc.).
*   Image_Path (40,819 unique): Each entry refers to a unique image file path.

*   Fold (10 unique): The dataset is divided into 10 folds, likely for cross-validation purposes.
*   Split (3 unique): Images are categorized into Train, Validation, and Test sets.







This structure supports tasks like classification, lesion-level grouping, and model evaluation using cross-validation.

In [None]:
# Replace with your image path
image_path = '/kaggle/input/mpox-skin-lesion-dataset-version-20-msld-v20/Keywords.jpg'

# Open and display the image
image = Image.open(image_path)
plt.imshow(image)
plt.axis('off')  # Hide axis
plt.show()

In [None]:
df.describe(include='all')

Here is a brief summary of the dataset based on the provided statistics:

* Total Records: 40,819 images

* Lesion_Id: 755 unique lesions, with the most frequent lesion (MKP_78_05) appearing 75 times, suggesting multiple images per lesion.

* Image_Id: 11,325 unique images, with the most common (HEALTHY_104_01) appearing 5 times, indicating some images may be reused or augmented.

* Diagnosis: 6 skin disease categories, with Monkeypox being the most common diagnosis (15,490 images).

* Image_Path: All 40,819 entries have unique paths, confirming each path refers to a distinct image.

* Fold: 10 unique folds (e.g., fold5_AUG) for cross-validation, with fold5_AUG being the most populated (7,532 images).

* Split: 3 data splits ‚Äî Train, Validation, Test ‚Äî with the Train split containing the majority (39,690 images).

This dataset is structured for lesion-level classification with support for robust training and evaluation through folds and splits.

In [None]:
print(df.isnull().sum())

This output shows that there are no missing values in the dataset ‚Äî all 40,819 rows have complete entries across all six columns.

‚úÖ Data quality is high, making the dataset suitable for model training and evaluation without requiring imputation or removal of incomplete records.

In [None]:
# Set plot style
sns.set(style="whitegrid")
plt.rcParams["figure.figsize"] = (10, 6)

In [None]:
categorical_columns = ['lesion_id', 'image_id', 'diagnosis']

In [None]:
for column in categorical_columns:
    df['diagnosis'].value_counts()

df['diagnosis'].value_counts()

In [None]:
plt.figure()
sns.countplot(data=df, x='diagnosis', order=df['diagnosis'].value_counts().index)
plt.title('Diagnosis Distribution')
plt.xticks(rotation=45)
plt.show()

In [None]:
plt.figure()
sns.countplot(x='split', data=df)
plt.title("Image Count per Data Split (Train/Test/Validation)")
plt.show()

In [None]:
plt.figure()
sns.countplot(x='fold', data=df)
plt.title("Image Count per Fold")
plt.show()

In [None]:
plt.figure()
sns.countplot(data=df, x='diagnosis', hue='split', order=df['diagnosis'].value_counts().index)
plt.title("Diagnosis Count per Split")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

In [None]:
plt.figure()
sns.countplot(data=df, x='diagnosis', hue='fold', order=df['diagnosis'].value_counts().index)
plt.title("Diagnosis Count per Fold")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

In [None]:
lesion_counts = df.groupby("diagnosis")["lesion_Id"].nunique().sort_values(ascending=False)
print("\nü¶† Unique Lesion IDs per Diagnosis:")
print(lesion_counts)

lesion_counts.plot(kind='bar', title="Unique Lesions per Diagnosis", ylabel="Lesion Count")
plt.xticks(rotation=45)
plt.show()

## 3. DATA PREPROCESSING

### 3.1 Rename Feature Names

In [None]:
df = df.rename(columns={
    "lesion_Id": 'Lesion_Id',
    "image_Id": 'Image_Id',
    "diagnosis": 'Diagnosis',
    "image_Path": 'Image_Path',
    "fold": 'Fold',
    "split": 'Split'
})

In [None]:
df.info()

In [None]:
# Replace with your image path
image_path = '/kaggle/input/mpox-skin-lesion-dataset-version-20-msld-v20/Keywords.jpg'

# Open and display the image
image = Image.open(image_path)
plt.imshow(image)
plt.axis('off')  # Hide axis
plt.show()

### 3.2 Map Classes To Disease Code

In [None]:
diagnosis_mapping = {
    'Monkeypox': 'MKP',
    'Chickenpox': 'CHP',
    'Measles': 'MSL',
    'Cowpox' : 'CWP',
    'Healthy': 'HEALTHY',
    'HFMD': 'HFMD'
}

df['Updated_Diagnosis'] = df['Diagnosis'].map(diagnosis_mapping)

In [None]:
df.head()

In [None]:
df.info()

### 3.3 Label Encoder

In [None]:
# Use correct column name
label_encoder = LabelEncoder()
df['Updated_Diagnosis_Label'] = label_encoder.fit_transform(df['Diagnosis'])
# Get mapping from label to encoded value
label_mapping = dict(zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_)))
# Display result
print("‚úÖ Label Encoding Mapping:")
print(label_mapping)

In [None]:
df.head()

### 3.4 View Samples of The Diseases Classes

In [None]:
for i in range(len(df)):
    if not os.path.isfile(df['Image_Path'].iloc[i]):
        raise FileNotFoundError(f"Image file not found: {df['Image_Path'].iloc[i]}")

class_images = df.drop_duplicates(subset='Diagnosis')

def plot_class_images(class_images):
    plt.figure(figsize=(15, 5))
    num_classes = len(class_images)
    for i in range(num_classes):
        plt.subplot(1, num_classes, i + 1)
        img = plt.imread(class_images['Image_Path'].iloc[i])
        plt.imshow(img)
        plt.axis('off')
        plt.title(class_images['Diagnosis'].iloc[i], fontsize=12)
    plt.tight_layout()
    plt.show()
plot_class_images(class_images)

# IMAGE PROCESSING TO ELIMINATE SKIN TONE VARIATION

In [None]:
pip install pandas opencv-python tqdm

In [None]:
# Assuming df is already defined as a DataFrame
df = pd.DataFrame(df)

# Save to CSV in specified folder
df.to_csv(r'/kaggle/working/Datasets/output.csv', index=False)

print("CSV saved successfully at C:\\Datasets\\output.csv")

In [None]:
# ==== SETTINGS ====
csv_path = r'/kaggle/working/output.csv'              # Path to your input CSV file
image_column = 'Image_Path'                     # Column name for image file paths
image_id_column = 'Image_Id'                    # Column name for Image IDs (e.g., MKP_124_02_ORIGINAL)
output_dir = r'/kaggle/working/'        # Directory to save augmented images
output_csv = r'augmented_images.csv'            # Output CSV to save new augmented paths and IDs
num_augments = 180                              # Number of augmentations per image
np.random.seed(42)                              # Set seed for reproducibility

# Create output directory if not exists
os.makedirs(output_dir, exist_ok=True)

# ==== HSV Color Space Augmentation Function ====
def hsv_color_space_augmentation(image, num_augments=180):
    augmented_images = []
    for i in range(num_augments):
        hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV).astype(np.float32)
        
        # Shift hue (0 to 179 incrementally)
        hue_shift = i % 180
        hsv[..., 0] = (hsv[..., 0] + hue_shift) % 180

        # Slight random variation for Saturation and Value
        sat_mult = np.random.uniform(0.9, 1.1)
        val_mult = np.random.uniform(0.9, 1.1)
        hsv[..., 1] *= sat_mult
        hsv[..., 2] *= val_mult

        # Clip values and convert back to BGR
        hsv = np.clip(hsv, 0, 255).astype(np.uint8)
        aug_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)

        augmented_images.append(aug_img)
    return augmented_images

# ==== LOAD CSV AND FILTER FOR ORIGINAL IMAGES ONLY ====
df = pd.read_csv(csv_path)
df_original = df[df[image_id_column].str.contains("ORIGINAL", case=False, na=False)]

# ==== AUGMENT IMAGES AND COLLECT DATA ====
augmented_records = []

for _, row in df_original.iterrows():
    img_path = row[image_column]
    img_id = row[image_id_column]

    if not os.path.isfile(img_path):
        print(f"Missing file: {img_path}")
        continue

    image = cv2.imread(img_path)
    if image is None:
        print(f"Failed to load: {img_path}")
        continue

    aug_images = hsv_color_space_augmentation(image, num_augments=num_augments)
    base_name = os.path.splitext(os.path.basename(img_path))[0]

    for i, aug_img in enumerate(aug_images):
        aug_name = f"{base_name}_hsvaug_{i+1:03d}.jpg"
        save_path = os.path.join(output_dir, aug_name)
        cv2.imwrite(save_path, aug_img)

        # Append data for CSV
        augmented_records.append({
            'augmented_image_path': save_path,
            'original_Image_Id': img_id
        })

# ==== SAVE NEW CSV FILE ====
aug_df = pd.DataFrame(augmented_records)
aug_df.to_csv(output_csv, index=False)

print(f"All ORIGINAL images augmented and saved to {output_dir}")
print(f"New CSV with augmented image paths saved as {output_csv}")

In [None]:
new=pd.read_csv('/kaggle/working/augmented_images.csv')
new

In [None]:
# Assume df_original is already filtered to contain ORIGINAL images
# Example: df_original = df[df['Image_Id'].str.contains('ORIGINAL', case=False, na=False)]

# Settings
num_augments = 180
output_dir = r'/kaggle/working'  # Replace with your actual output directory

augmented_records = []

for _, row in df_original.iterrows():
    lesion_id = row['Lesion_Id']
    original_image_id = row['Image_Id']
    diagnosis = row['Diagnosis']
    fold = row['Fold']
    split = row['Split']
    base_name = os.path.splitext(os.path.basename(row['Image_Path']))[0]

    for i in range(num_augments):
        aug_image_id = f"{original_image_id}_hsvaug_{i+1:03d}"
        aug_filename = f"{base_name}_hsvaug_{i+1:03d}.jpg"
        aug_image_path = os.path.join(output_dir, aug_filename)

        augmented_records.append({
            'lesion_Id': lesion_id,
            'image_Id': aug_image_id,
            'diagnosis': diagnosis,
            'image_Path': aug_image_path,
            'fold': fold,
            'split': split
        })

# Create DataFrame from records
augmented_df = pd.DataFrame(augmented_records)

# Example: Save to CSV if needed
# augmented_df.to_csv('augmented_metadata.csv', index=False)

print(augmented_df.head())

In [None]:
augmented_df

In [None]:
augmented_df = augmented_df.rename(columns={
    "lesion_Id": 'Lesion_Id',
    "image_Id": 'Image_Id',
    "diagnosis": 'Diagnosis',
    "image_Path": 'Image_Path',
    "fold": 'Fold',
    "split": 'Split'
})
augmented_df

In [None]:
augmented_df['Diagnosis'].value_counts()

In [None]:
for i in range(len(df)):
    if not os.path.isfile(augmented_df['Image_Path'].iloc[i]):
        raise FileNotFoundError(f"Image file not found: {augmented_df['Image_Path'].iloc[i]}")

class_images = df.drop_duplicates(subset='Diagnosis')

def plot_class_images(class_images):
    plt.figure(figsize=(15, 5))
    num_classes = len(class_images)
    for i in range(num_classes):
        plt.subplot(1, num_classes, i + 1)
        img = plt.imread(class_images['Image_Path'].iloc[i])
        plt.imshow(img)
        plt.axis('off')
        plt.title(class_images['Diagnosis'].iloc[i], fontsize=12)
    plt.tight_layout()
    plt.show()
plot_class_images(class_images)

In [None]:
plt.figure()
sns.countplot(data=augmented_df, x='Diagnosis', order=df['Diagnosis'].value_counts().index)
plt.title('Diagnosis Distribution')
plt.xticks(rotation=45)
plt.show()

# 4. AUGMENENT AND BALANCE THE DATASETS

In [None]:
pip install imgaug opencv-python

In [None]:
# ==== Load Data ====
df = augmented_df  # Your input DataFrame with 'Diagnosis' and 'Image_Path'

# ==== Define Target Range ====
min_target = 101340
max_target = 180900

balanced_samples = []

for label, current_count in df['Diagnosis'].value_counts().items():
    class_subset = df[df['Diagnosis'] == label]

    if current_count < min_target:
        # Calculate how many images to augment to reach at least min_target
        n_to_augment = min_target - current_count
        print(f"Augmenting '{label}' from {current_count} to at least {min_target} samples.")

        # Simple augmentation: randomly sample with replacement
        augmented_subset = class_subset.sample(n=n_to_augment, replace=True, random_state=42)
        class_subset = pd.concat([class_subset, augmented_subset]).reset_index(drop=True)

        # Ensure we don't exceed max_target
        if len(class_subset) > max_target:
            print(f"After augmentation, '{label}' exceeds max target. Reducing to {max_target}.")
            class_subset = class_subset.sample(n=max_target, random_state=42)

    elif current_count > max_target:
        # Reduce samples if above max_target
        print(f"Reducing '{label}' from {current_count} to {max_target} samples.")
        class_subset = class_subset.sample(n=max_target, random_state=42)
    else:
        print(f"'{label}' is within the target range [{min_target}, {max_target}]. Using all {current_count} samples.")

    balanced_samples.append(class_subset)

# ==== Final Balanced Dataset ====
balanced_dataset = pd.concat(balanced_samples).reset_index(drop=True)

# Save to CSV
balanced_dataset.to_csv(r'/kaggle/working/Datasets/balanced_dataset.csv', index=False)
print("Balanced dataset saved at /kaggle/working\\Datasets\\balanced_dataset.csv")

In [None]:
# Rename the DataFrame to meta_data
df = balanced_dataset
df

In [None]:
plt.figure(figsize=(10,4))
sns.countplot(data=df, x='Diagnosis', order=df['Diagnosis'].value_counts().index)
plt.title('Diagnosis Distribution')
plt.xticks(rotation=45)
plt.show()

In [None]:
# Use correct column name
label_encoder = LabelEncoder()
df['Updated_Diagnosis_Label'] = label_encoder.fit_transform(df['Diagnosis'])
# Get mapping from label to encoded value
label_mapping = dict(zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_)))
# Display result
print("‚úÖ Label Encoding Mapping:")
print(label_mapping)

In [None]:
# Calculate class weights
class_counts = df['Updated_Diagnosis_Label'].value_counts().to_dict()
total_samples = sum(class_counts.values())
class_weights = {i: total_samples/count for i, count in enumerate(class_counts.values())}

print("Class weights:", class_weights)