In [None]:
# Necessary Imports
import os
import numpy as np # linear algebra
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (f1_score, roc_auc_score, accuracy_score, balanced_accuracy_score,
                            precision_score, recall_score, roc_curve, auc, confusion_matrix, classification_report)
from pathlib import Path
import cv2
import optuna
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras.utils import to_categorical
import keras
import pickle
import gc
from tensorflow.keras import mixed_precision
import multiprocessing
import subprocess
import shutil
import time
from tensorflow.keras.models import Model
import json
from optuna.samplers import NSGAIISampler

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')[0]
tf.config.experimental.set_memory_growth(gpus, True)

In [None]:
class ModelConfig:
    """Configuration class for model hyperparameters"""
    def __init__(self, **kwargs):
        self.img_height = kwargs.get("img_height", 512)
        self.img_width = kwargs.get("img_width", 512)
        self.channels = kwargs.get("channels", 1)  # Grayscale
        self.num_classes = kwargs.get("num_classes", 4)
        self.batch_size = kwargs.get("batch_size", 16)
        self.epochs = kwargs.get("epochs", 50)

In [None]:
class DataProcessor:
    """Handle data processing and augmentation for MRI images."""
    def __init__(self, config: ModelConfig):
        self.config = config

    def create_kfold_data(self, base_dir: str, n_splits: int = 5, seed: int = 42):
        """Yields tf.data.Dataset objects for training and validation for each fold."""
        df, self.class_indices = self._load_image_paths_and_labels(base_dir)
        skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
        class_names = list(self.class_indices.keys())
        num_classes = len(class_names)

        def augment_image(image, label):
            image = tf.image.random_flip_left_right(image, seed=seed)
            angle = tf.random.uniform([], -15 * np.pi / 180, 15 * np.pi / 180, dtype=tf.float32)
            image = tf.image.rot90(image, k=tf.cast(angle * 4 / (2 * np.pi), tf.int32))
            image = tf.image.random_brightness(image, max_delta=0.1, seed=seed)
            image = tf.image.random_contrast(image, lower=0.9, upper=1.1, seed=seed)
            scale = tf.random.uniform([], 0.9, 1.1, dtype=tf.float32)
            new_height = tf.cast(tf.cast(self.config.img_height, tf.float32) * scale, tf.int32)
            new_width = tf.cast(tf.cast(self.config.img_width, tf.float32) * scale, tf.int32)
            image = tf.image.resize(image, [new_height, new_width])
            image = tf.image.resize_with_crop_or_pad(image, self.config.img_height, self.config.img_width)
            # Shear transformation
            shear = tf.random.uniform([], -0.1, 0.1, dtype=tf.float32)
            # Construct a 3x3 projective transform matrix for shear
            shear_matrix = tf.stack([
                tf.constant(1.0), shear, tf.constant(0.0),  # [1, s, 0]
                tf.constant(0.0), tf.constant(1.0), tf.constant(0.0),  # [0, 1, 0]
                tf.constant(0.0), tf.constant(0.0), tf.constant(1.0)   # [0, 0, 1]
            ])
            # Flatten to [9], then take first 8 elements for projective transform
            shear_matrix = tf.reshape(shear_matrix, [9])[:8]  # [1, s, 0, 0, 1, 0, 0, 0]
            shear_matrix = tf.expand_dims(shear_matrix, 0)  # Shape: [1, 8]
            image = tf.raw_ops.ImageProjectiveTransformV3(
                images=tf.expand_dims(image, 0),
                transforms=shear_matrix,
                output_shape=[self.config.img_height, self.config.img_width],
                fill_value=0.0,
                interpolation='BILINEAR'
            )[0]
            noise = tf.random.normal(shape=tf.shape(image), mean=0.0, stddev=0.02, dtype=tf.float32)
            image = image + noise
            shift_fraction = 0.05
            crop_height = int(self.config.img_height * (1 - 2 * shift_fraction))
            crop_width = int(self.config.img_width * (1 - 2 * shift_fraction))
            image = tf.image.random_crop(
                image,
                size=[crop_height, crop_width, self.config.channels],
                seed=seed
            )
            max_offset_height = int(self.config.img_height * shift_fraction)
            max_offset_width = int(self.config.img_width * shift_fraction)
            offset_height = tf.random.uniform([], 0, max_offset_height, dtype=tf.int32)
            offset_width = tf.random.uniform([], 0, max_offset_width, dtype=tf.int32)
            image = tf.image.pad_to_bounding_box(
                image,
                offset_height=offset_height,
                offset_width=offset_width,
                target_height=self.config.img_height,
                target_width=self.config.img_width
            )
            image = tf.clip_by_value(image, 0.0, 1.0)
            return image, label

        def load_and_preprocess(path, label):
            image = tf.io.read_file(path)
            image = tf.image.decode_jpeg(image, channels=self.config.channels)
            image = tf.image.resize(image, [self.config.img_height, self.config.img_width])
            image = tf.py_function(self._normalize_image, [image], tf.float32)
            image.set_shape([self.config.img_height, self.config.img_width, self.config.channels])
            return image, tf.one_hot(label, num_classes)

        for fold, (train_idx, val_idx) in enumerate(skf.split(df.filepath, df.label)):
            print(f"\nðŸ“‚ Fold {fold + 1}/{n_splits}")
            train_paths = df.filepath.iloc[train_idx].values
            train_labels = df.label.iloc[train_idx].values
            val_paths = df.filepath.iloc[val_idx].values
            val_labels = df.label.iloc[val_idx].values
            train_ds = tf.data.Dataset.from_tensor_slices((train_paths, train_labels))
            train_ds = train_ds.map(load_and_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
            train_ds = train_ds.map(augment_image, num_parallel_calls=tf.data.AUTOTUNE)
            train_ds = train_ds.shuffle(1024).batch(self.config.batch_size).prefetch(tf.data.AUTOTUNE)
            val_ds = tf.data.Dataset.from_tensor_slices((val_paths, val_labels))
            val_ds = val_ds.map(load_and_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
            val_ds = val_ds.batch(self.config.batch_size).prefetch(tf.data.AUTOTUNE)
            yield fold, train_ds, val_ds

    def create_full_dataset(self, train_dir: str, validation_dir: str):
        """Create tf.data.Dataset for full training and validation sets."""
        train_df, class_indices = self._load_image_paths_and_labels(train_dir)
        val_df, _ = self._load_image_paths_and_labels(validation_dir)
        num_classes = len(class_indices)

        def augment_image(image, label):
            image = tf.image.random_flip_left_right(image, seed=42)
            angle = tf.random.uniform([], -15 * np.pi / 180, 15 * np.pi / 180, dtype=tf.float32)
            image = tf.image.rot90(image, k=tf.cast(angle * 4 / (2 * np.pi), tf.int32))
            image = tf.image.random_brightness(image, max_delta=0.1, seed=42)
            image = tf.image.random_contrast(image, lower=0.9, upper=1.1, seed=42)
            scale = tf.random.uniform([], 0.9, 1.1, dtype=tf.float32)
            new_height = tf.cast(tf.cast(self.config.img_height, tf.float32) * scale, tf.int32)
            new_width = tf.cast(tf.cast(self.config.img_width, tf.float32) * scale, tf.int32)
            image = tf.image.resize(image, [new_height, new_width])
            image = tf.image.resize_with_crop_or_pad(image, self.config.img_height, self.config.img_width)
            shear = tf.random.uniform([], -0.1, 0.1, dtype=tf.float32)
            shear_matrix = tf.stack([
                tf.constant(1.0), shear, tf.constant(0.0),
                tf.constant(0.0), tf.constant(1.0), tf.constant(0.0),
                tf.constant(0.0), tf.constant(0.0), tf.constant(1.0)
            ])
            shear_matrix = tf.reshape(shear_matrix, [9])[:8]
            shear_matrix = tf.expand_dims(shear_matrix, 0)
            image = tf.raw_ops.ImageProjectiveTransformV3(
                images=tf.expand_dims(image, 0),
                transforms=shear_matrix,
                output_shape=[self.config.img_height, self.config.img_width],
                fill_value=0.0,
                interpolation='BILINEAR'
            )[0]
            noise = tf.random.normal(shape=tf.shape(image), mean=0.0, stddev=0.02, dtype=tf.float32)
            image = image + noise
            shift_fraction = 0.05
            crop_height = int(self.config.img_height * (1 - 2 * shift_fraction))
            crop_width = int(self.config.img_width * (1 - 2 * shift_fraction))
            image = tf.image.random_crop(
                image,
                size=[crop_height, crop_width, self.config.channels],
                seed=42
            )
            max_offset_height = int(self.config.img_height * shift_fraction)
            max_offset_width = int(self.config.img_width * shift_fraction)
            offset_height = tf.random.uniform([], 0, max_offset_height, dtype=tf.int32)
            offset_width = tf.random.uniform([], 0, max_offset_width, dtype=tf.int32)
            image = tf.image.pad_to_bounding_box(
                image,
                offset_height=offset_height,
                offset_width=offset_width,
                target_height=self.config.img_height,
                target_width=self.config.img_width
            )
            image = tf.clip_by_value(image, 0.0, 1.0)
            return image, label

        def load_and_preprocess(path, label):
            image = tf.io.read_file(path)
            image = tf.image.decode_jpeg(image, channels=self.config.channels)
            image = tf.image.resize(image, [self.config.img_height, self.config.img_width])
            image = tf.py_function(self._normalize_image, [image], tf.float32)
            image.set_shape([self.config.img_height, self.config.img_width, self.config.channels])
            return image, tf.one_hot(label, num_classes)

        train_ds = tf.data.Dataset.from_tensor_slices((train_df.filepath, train_df.label))
        train_ds = train_ds.map(load_and_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
        train_ds = train_ds.map(augment_image, num_parallel_calls=tf.data.AUTOTUNE)
        train_ds = train_ds.shuffle(1024).batch(self.config.batch_size).prefetch(tf.data.AUTOTUNE)
        val_ds = tf.data.Dataset.from_tensor_slices((val_df.filepath, val_df.label))
        val_ds = val_ds.map(load_and_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
        val_ds = val_ds.batch(self.config.batch_size).prefetch(tf.data.AUTOTUNE)
        return train_ds, val_ds

    def create_hardware_validation_dataset(self, base_dir: str):
        """Create tf.data.Dataset for validation on hardware Gabor-filtered dataset."""
        data = []
        class_names = sorted(os.listdir(base_dir))
        label_map = {'glioma': 0, 'meningioma': 1, 'notumor': 2}
        for class_name in class_names:
            class_path = os.path.join(base_dir, class_name)
            if not os.path.isdir(class_path):
                continue
            for run_name in os.listdir(class_path):
                run_path = os.path.join(class_path, run_name)
                if not os.path.isdir(run_path):
                    continue
                run_images = sorted([f for f in os.listdir(run_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
                if len(run_images) != 4:
                    print(f"Warning: Run {run_path} has {len(run_images)} images, expected 4. Skipping.")
                    continue
                run_image_paths = [os.path.join(run_path, img) for img in run_images]
                data.append((run_image_paths, class_name))
        
        if not data:
            raise ValueError("No valid runs found in hardware dataset directory")

        df = pd.DataFrame(data, columns=["image_paths", "class_name"])
        df["label"] = df.class_name.map(label_map)
        
        def load_and_stack_images(image_paths, label):
            images = []
            for path in image_paths:
                image = tf.io.read_file(path)
                image = tf.image.decode_jpeg(image, channels=1)
                image = tf.image.resize(image, [self.config.img_height, self.config.img_width])
                image = tf.py_function(self._normalize_image, [image], tf.float32)
                image.set_shape([self.config.img_height, self.config.img_width, 1])
                images.append(image)
            stacked_image = tf.concat(images, axis=-1)
            return stacked_image, tf.one_hot(label, depth=4)

        dataset = tf.data.Dataset.from_tensor_slices((df.image_paths, df.label))
        dataset = dataset.map(load_and_stack_images, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.batch(self.config.batch_size).prefetch(tf.data.AUTOTUNE)
        return dataset, class_names

    def compute_class_weights_from_dataset(self, dataset):
        """Calculate class weights to handle imbalance from tf.data.Dataset."""
        labels = []
        for _, label in dataset.unbatch():
            labels.append(np.argmax(label.numpy()))
        class_counts = np.bincount(labels)
        total_samples = len(labels)
        class_weights = {}
        for cls, count in enumerate(class_counts):
            weight = total_samples / (len(class_counts) * count)
            if cls == 0:
                weight *= 2.0
            class_weights[cls] = weight
        return class_weights

    def _load_image_paths_and_labels(self, base_dir: str):
        """Return a DataFrame with filepaths and class labels."""
        data = []
        for class_name in sorted(os.listdir(base_dir)):
            class_path = os.path.join(base_dir, class_name)
            if not os.path.isdir(class_path):
                continue
            for fname in os.listdir(class_path):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                    data.append((os.path.join(class_path, fname), class_name))
        df = pd.DataFrame(data, columns=["filepath", "class_name"])
        label_map = {name: idx for idx, name in enumerate(sorted(df.class_name.unique()))}
        df["label"] = df.class_name.map(label_map)
        return df, label_map

    @staticmethod
    def _normalize_image(x: np.ndarray) -> np.ndarray:
        """Normalize image using z-score normalization."""
        return (x - np.mean(x)) / (np.maximum(np.std(x), 1e-7))

In [None]:
class Evaluator:
    """Handles model evaluation and visualization."""
    def __init__(self, model, num_classes):
        self.model = model
        self.num_classes = num_classes

    def get_predictions_and_labels(self, data, class_indices=None):
        """Get predictions and labels from tf.data.Dataset, handling class subset if needed."""
        if isinstance(data, tf.data.Dataset):
            images, labels = [], []
            for batch_images, batch_labels in data.unbatch().batch(1):
                images.append(batch_images.numpy())
                labels.append(batch_labels.numpy())
            images = np.concatenate(images, axis=0)
            labels = np.concatenate(labels, axis=0)
            predictions = self.model.predict(images)
            y_true = np.argmax(labels, axis=1) if labels.ndim > 1 else labels
            y_true_onehot = labels if labels.ndim > 1 else to_categorical(y_true, num_classes=self.num_classes)
            # If class_indices is provided (e.g., [0, 1, 3]), adjust predictions and labels
            if class_indices is not None:
                # Extract predictions for the relevant classes
                predictions = predictions[:, class_indices]
                # Recompute y_true for the subset of classes
                y_true_mapped = []
                for label in y_true:
                    # Map the original label to the new index in class_indices
                    idx = class_indices.index(label) if label in class_indices else -1
                    y_true_mapped.append(idx)
                y_true_mapped = np.array(y_true_mapped)
                # Filter out samples with unmapped labels (e.g., notumor)
                valid_indices = y_true_mapped != -1
                y_true_mapped = y_true_mapped[valid_indices]
                predictions = predictions[valid_indices]
                y_true_onehot = y_true_onehot[valid_indices][:, class_indices]
            else:
                y_true_mapped = y_true
        else:
            raise ValueError("Expected tf.data.Dataset")
        return predictions, y_true_mapped, y_true_onehot

    def evaluate(self, data, class_indices=None):
        """Evaluate the model using key metrics, handling class subset if needed."""
        predictions, y_true, y_true_onehot = self.get_predictions_and_labels(data, class_indices)
        # Skip evaluation if no valid samples after filtering
        if len(y_true) == 0:
            print("No valid samples to evaluate after filtering classes.")
            return {}
        results = {
            'accuracy': accuracy_score(y_true, np.argmax(predictions, axis=1)),
            'balanced_accuracy': balanced_accuracy_score(y_true, np.argmax(predictions, axis=1)),
            'macro_precision': precision_score(y_true, np.argmax(predictions, axis=1), average='macro'),
            'macro_recall': recall_score(y_true, np.argmax(predictions, axis=1), average='macro'),
            'macro_f1': f1_score(y_true, np.argmax(predictions, axis=1), average='macro'),
            'macro_roc_auc': roc_auc_score(y_true_onehot, predictions, average='macro', multi_class='ovr')
        }
        num_classes = len(class_indices) if class_indices is not None else self.num_classes
        for i in range(num_classes):
            results[f'class_{i}_roc_auc'] = roc_auc_score(y_true_onehot[:, i], predictions[:, i])
        return results

    def plot_confusion_matrix(self, data, class_names=None, class_indices=None):
        """Plot normalized confusion matrix and classification report."""
        predictions, y_true, _ = self.get_predictions_and_labels(data, class_indices)
        if len(y_true) == 0:
            print("No valid samples to plot confusion matrix.")
            return
        y_pred = np.argmax(predictions, axis=1)
        cm = confusion_matrix(y_true, y_pred)
        cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        plt.figure(figsize=(12, 8))
        sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues',
                    xticklabels=class_names if class_names else 'auto',
                    yticklabels=class_names if class_names else 'auto')
        plt.title('Normalized Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.show()
        print("\nClassification Report:")
        print(classification_report(y_true, y_pred, target_names=class_names if class_names else None))

    def plot_roc_curves(self, data, class_names=None, class_indices=None):
        """Plot ROC curves for each class."""
        predictions, _, y_true_onehot = self.get_predictions_and_labels(data, class_indices)
        if len(y_true_onehot) == 0:
            print("No valid samples to plot ROC curves.")
            return
        plt.figure(figsize=(12, 8))
        plt.title('Receiver Operating Characteristic (ROC) Curves')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.grid(True)
        num_classes = len(class_indices) if class_indices is not None else self.num_classes
        cmap = plt.cm.get_cmap('nipy_spectral', num_classes)
        colors = cmap(np.linspace(0, 1, num_classes))
        for i in range(num_classes):
            fpr, tpr, _ = roc_curve(y_true_onehot[:, i], predictions[:, i])
            roc_auc = auc(fpr, tpr)
            name = class_names[i] if class_names else f'Class {i}'
            plt.plot(fpr, tpr, color=colors[i], lw=2, label=f'{name} (AUC = {roc_auc:.2f})')
        fpr_micro, tpr_micro, _ = roc_curve(y_true_onehot.ravel(), predictions.ravel())
        roc_auc_micro = auc(fpr_micro, tpr_micro)
        plt.plot(fpr_micro, tpr_micro, label=f'Micro-average (AUC = {roc_auc_micro:.2f})',
                 color='deeppink', linestyle=':', linewidth=4)
        all_fpr = np.unique(np.concatenate([
            roc_curve(y_true_onehot[:, i], predictions[:, i])[0] for i in range(num_classes)
        ]))
        mean_tpr = np.zeros_like(all_fpr)
        for i in range(num_classes):
            fpr, tpr, _ = roc_curve(y_true_onehot[:, i], predictions[:, i])
            sorted_idx = np.argsort(fpr)
            fpr = fpr[sorted_idx]
            tpr = tpr[sorted_idx]
            mean_tpr += np.interp(all_fpr, fpr, tpr)
        mean_tpr /= num_classes
        roc_auc_macro = auc(all_fpr, mean_tpr)
        plt.plot(all_fpr, mean_tpr, label=f'Macro-average (AUC = {roc_auc_macro:.2f})',
                 color='navy', linestyle='--', linewidth=4)
        plt.plot([0, 1], [0, 1], 'k--', label='Random chance')
        plt.legend(loc='center left', bbox_to_anchor=(1.05, 0.5), fontsize='small')
        plt.tight_layout()
        plt.show()

In [None]:
## this function is for determining only the Gabor filter hyperparameters

# def optimise_hyperparameters(n_trials,class_weights): 

#     def objective(trial):

#         angles_deg = np.arange(0, 180, 11.25)
#         theta = np.radians(angles_deg)
        
#         hps = {
#             'sigma': trial.suggest_float('sigma', 3, 6, step=0.2),
#             'gamma': trial.suggest_float('gamma', 0.2, 0.8, step=0.05),
#             'lambd': trial.suggest_float('lambd', 4, 8, step=0.2),
#             'theta': theta.tolist(),
#             # 'learning_rate': trial.suggest_loguniform('learning_rate', 1e-4, 1e-2)
#             'learning_rate': 1e-3
#         }


#         args = [
#             'python',
#             'optimise_filter.py',
#             str(hps),
#             'Training',
#             'Testing',
#             str(class_weights)
#         ]
        
#         gc.collect()
#         tf.keras.backend.clear_session()
#         time.sleep(15)

#         print("subprocess about to start")

#         subprocess.run(args,check=True,stdout=subprocess.DEVNULL)

#         with open("fold_metric_optuna.txt","r") as f:
#             for line in f:
#                 return(float(line))
    

#     study = optuna.create_study(directions=['maximize'],sampler=NSGAIISampler())
#     study.optimize(objective, n_trials=n_trials)
#     return study.best_params

In [None]:
## this function is for determining the bit-widths, or number of clusters or both for fixed Gabor filter hyperparameters

def optimise_hyperparameters(n_trials,class_weights):

    trials_data = []

    def objective(trial):

        ## angles are arranged for different configurations with equal spacing

        angles_deg = np.arange(0, 180, 180/4)
        theta = np.radians(angles_deg)

        bit_widths = []
        for i in range(12):
            bit_widths.append(trial.suggest_int(f'bit_width{i}', 2, 6))
        
        bit_widths.append(2) 
        
        for i in range(12):
            bit_widths.append(bit_widths[12-i-1])

        reshaped_bit_widths = [bit_widths[i:i+5] for i in range(0, len(bit_widths), 5)] # symmetrical filter design
        
        ## choose to include clusters or bit_widths based on your requirement
        
        # clusters = trial.suggest_int('clusters', 2, 9)

        # 4

        hps = {
            'sigma': 5,
            'gamma': 0.55,
            'lambd': 6.0,
            'theta': theta.tolist(),
            'learning_rate': 1e-3
            # ,'bit_widths': [trial.suggest_int(f'bit_width{i}', 2, 6) for i in range(clusters)]
            # ,'clusters' : clusters
             ,'bit_widths': reshaped_bit_widths
        }

        # 8

        # hps = {
        #     'sigma': 5.8,
        #     'gamma': 0.65,
        #     'lambd': 6.8,
        #     'theta': theta.tolist(),
        #     'learning_rate': 1e-3
        #     # ,'bit_widths': [trial.suggest_int(f'bit_width{i}', 2, 6) for i in range(clusters)]
        #     ,'clusters' : clusters
        # }

        # 12

        # hps = {
        #     'sigma': 5.2,
        #     'gamma': 0.45,
        #     'lambd': 6.2,
        #     'theta': theta.tolist(),
        #     'learning_rate': 1e-3
        #     ,'bit_widths': reshaped_bit_widths
        #     # ,'clusters' : clusters
        # }

        # 16

        # hps = {
        #     'sigma': 5.8,
        #     'gamma': 0.5,
        #     'lambd': 6.2,
        #     'theta': theta.tolist(),
        #     'learning_rate': 1e-3
        #     ,'bit_widths': [trial.suggest_int(f'bit_width{i}', 2, 6) for i in range(clusters)]
        #     # ,'clusters' : clusters
        # }

        args = [
            'python',
            'optimise_filter.py',
            str(hps),
            'Training',
            'Testing',
            str(class_weights)
        ]
        
        gc.collect()
        tf.keras.backend.clear_session()
        time.sleep(15)

        print("subprocess about to start")

        subprocess.run(args,check=True,stdout=subprocess.DEVNULL)

        ## choose the appropriate penalty calculation based on your requirement

        # penalty = sum(hps['bit_widths']) + hps['clusters']
        # penalty = sum(hps['bit_widths'])
        # penalty = hps['clusters']
        penalty = sum(sum(row) for row in hps['bit_widths'])

        with open("fold_metric_optuna.txt","r") as f:
            for line in f:
                accuracy = float(line)
        

        trials_data.append([accuracy, penalty, hps])

        return accuracy, penalty
    
    def plot_pareto(trials_data):
        # Convert the list of lists to a pandas DataFrame
        df = pd.DataFrame(trials_data, columns=['objective_value', 'penalty', 'hyperparameters'])

        # Sort the DataFrame based on score (maximize the score)
        df_sorted = df.sort_values(by='objective_value', ascending=False)

        # Find Pareto front (a score cannot improve without increasing penalty)
        pareto_front = []
        last_penalty = float('inf')
        
        for _, row in df_sorted.iterrows():
            if row['penalty'] < last_penalty:
                pareto_front.append(row)
                last_penalty = row['penalty']

        pareto_front = pd.DataFrame(pareto_front)

        # Plot all trials (score vs penalty)
        plt.figure(figsize=(10, 6))
        plt.scatter(df['objective_value'], df['penalty'], label='All Trials', color='blue', alpha=0.5)
        
        # Highlight the Pareto front (best trade-offs)
        plt.scatter(pareto_front['objective_value'], pareto_front['penalty'], label='Pareto Front', color='red', marker='*', s=200)

        for i, row in df.iterrows():
            plt.text(row['objective_value'], row['penalty'], str(i), fontsize=9, color='black', ha='right')
        
        # Labels and title
        plt.xlabel('Objective Value (Score)')
        plt.ylabel('Penalty')
        plt.title('Pareto Front: Trade-off Between Score and Penalty')
        plt.legend()
        plt.grid(True)
        plt.show()

    study = optuna.create_study(directions=['maximize','minimize'])
    study.optimize(objective, n_trials=n_trials)
    plot_pareto(trials_data)
    return study.best_trials, trials_data

In [None]:
def main():
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            tf.config.experimental.set_memory_growth(gpus[0], True)
            tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
            print("GPU available:", gpus[0])
        except RuntimeError as e:
            print("GPU Initialization Error:", e)
    else:
        print("No GPU detected. Using CPU.")  
    

    np.random.seed(42)
    tf.random.set_seed(42)

    config = ModelConfig(img_height=512, img_width=512, channels=1, batch_size=10)
    train_dir, validation_dir = 'Training', 'Testing'
    if not all(map(os.path.exists, [train_dir, validation_dir])):
        raise ValueError("Training or validation directory missing")

    data_processor = DataProcessor(config)

    try:

        train_ds, val_ds = data_processor.create_full_dataset(train_dir, validation_dir)
        class_weights = data_processor.compute_class_weights_from_dataset(train_ds)

        best_param, trials_data = optimise_hyperparameters(n_trials=100,class_weights=class_weights)
        # best_param = optimise_hyperparameters(n_trials=100,class_weights=class_weights)

        data_to_save = [
            {"index": idx, "objective_value": trial[0], "penalty": trial[1], "hyperparameters": trial[2]} 
            for idx, trial in enumerate(trials_data)
        ]
        
        with open('trials_data.json', 'w') as f:
            json.dump(data_to_save, f, indent=4)

        # print("best hyperparameters: ", best_param)

        # with open("best_params.txt", "w") as file:
        #     file.write("Best hyperparameters:\n")
        #     for key, value in best_param.items():
        #         file.write(f"{key}: {value}\n")

        return

    except Exception as e:
        print(f"Error: {e}")
    finally:
        tf.keras.backend.clear_session()

if __name__ == "__main__":
    main()