<a href="https://colab.research.google.com/github/SabrineOuni/An-Automated-Method-for-Multiple-Sclerosis-Detection/blob/main/Untitled3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# --- 1. Download all needed datasets ---
!apt-get update
!apt-get install -y megatools


# Download MS dataset (MRI + clinical)
ms_mri_url = "https://mega.nz/file/O45BRBJD#CA8XAaACqlymX3MhcGkJzK8DNp8vZYoxxQW2pBcl4wM"
!megadl {ms_mri_url} --path ms_dataset.zip
!unzip -q ms_dataset.zip -d ms_dataset

ms_clinical_url = "https://mega.nz/file/v1RWUY7R#UIgjHQWMqC6BgtvrfeQxnr0sMR0Q79Ek37-LFrqeYIU"
!megadl {ms_clinical_url} --path ms_clinical.csv



0% [Working]            Get:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]
Get:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:4 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Get:5 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Hit:6 https://r2u.stat.illinois.edu/ubuntu jammy InRelease
Get:7 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  Packages [1,929 kB]
Get:8 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [127 kB]
Hit:9 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:10 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Get:11 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 Packages [3,520 kB]
Hit:12 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Get:13 http://archive.ubunt

In [None]:
!pip install tensorflow==2.19.0 segmentation-models==1.0.1


# ==============================================
# 2. IMPORTS
# ==============================================
import os
import numpy as np
import cv2
import nibabel as nib
import random
import albumentations as A
import tensorflow as tf
from sklearn.model_selection import train_test_split

# Patch for segmentation-models and TensorFlow 2.19 compatibility
import keras
if not hasattr(keras.utils, 'generic_utils'):
    keras.utils.generic_utils = keras.utils

import segmentation_models as sm

# ==============================================
# 3. CONFIG
# ==============================================
IMG_SIZE = 256
BATCH_SIZE = 4
SEED = 42
MRI_DIR = "ms_dataset"  # folder with patient subfolders

# ==============================================
# 4. FUNCTION TO LOAD PATIENT SLICES
# ==============================================
def load_patient_slices(mri_path, mask_path):
    """Load a patient's MRI & mask slices and resize to IMG_SIZE."""
    mri_img = nib.load(mri_path).get_fdata()
    mask_img = nib.load(mask_path).get_fdata()

    imgs, masks = [], []
    for i in range(mri_img.shape[2]):
        img_slice = mri_img[:, :, i]
        mask_slice = mask_img[:, :, i]

        # Normalize MRI slice to [0,255]
        img_slice = cv2.normalize(img_slice, None, 0, 255, cv2.NORM_MINMAX)
        img_slice = img_slice.astype(np.uint8)

        # Resize both image and mask
        img_slice = cv2.resize(img_slice, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR)
        mask_slice = cv2.resize(mask_slice, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)

        imgs.append(img_slice[..., np.newaxis])
        masks.append((mask_slice[..., np.newaxis] > 0).astype(np.float32))

    return imgs, masks

# ==============================================
# 5. LOAD DATA
# ==============================================
X_slices, Y_slices = [], []
patients = sorted(os.listdir(MRI_DIR))
for p in patients:
    patient_folder = os.path.join(MRI_DIR, p)
    if not os.path.isdir(patient_folder):
        continue
    patient_id = p.split('-')[-1]
    mri_path = os.path.join(patient_folder, f"{patient_id}-Flair.nii")
    mask_path = os.path.join(patient_folder, f"{patient_id}-LesionSeg-Flair.nii")
    if os.path.exists(mri_path) and os.path.exists(mask_path):
        imgs, masks = load_patient_slices(mri_path, mask_path)
        X_slices.extend(imgs)
        Y_slices.extend(masks)

X_slices = np.array(X_slices, np.float32)
Y_slices = np.array(Y_slices, np.uint8)

print(f"✅ Loaded: {X_slices.shape[0]} slices, shape={X_slices.shape[1:]}")

# ==============================================
# 6. BALANCE LESION / BACKGROUND SLICES
# ==============================================
lesion_mask = (Y_slices.reshape(Y_slices.shape[0], -1).sum(axis=1) > 0)
lesion_idx = np.where(lesion_mask)[0]
bg_idx = np.where(~lesion_mask)[0]

bg_sample_size = min(len(bg_idx), len(lesion_idx))
bg_idx = np.random.choice(bg_idx, size=bg_sample_size, replace=False)

final_idx = np.concatenate([lesion_idx, bg_idx])
np.random.shuffle(final_idx)

X_slices = X_slices[final_idx]
Y_slices = Y_slices[final_idx]

# ==============================================
# 7. TRAIN/VAL/TEST SPLIT
# ==============================================
X_train, X_temp, Y_train, Y_temp = train_test_split(X_slices, Y_slices, test_size=0.3, random_state=SEED)
X_val, X_test, Y_val, Y_test = train_test_split(X_temp, Y_temp, test_size=0.5, random_state=SEED)

# ==============================================
# 8. DATA AUGMENTATION
# ==============================================
augmenter = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.3),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=15, p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.GaussNoise(p=0.3),
    A.RandomGamma(p=0.3)
])

# Prepare for ResNet
sm.set_framework('tf.keras')
BACKBONE = 'resnet34'
preprocess_input = sm.get_preprocessing(BACKBONE)

def aug_fn(image, mask):
    data = {"image": image, "mask": mask}
    augmented = augmenter(**data)
    return augmented["image"], augmented["mask"]

def preprocess_fn(image, mask):
    image = preprocess_input(image)
    return image, mask

def make_dataset_aug(X, Y, batch_size, training=True):
    X = np.repeat(X, 3, axis=-1).astype(np.float32)  # repeat grayscale to RGB
    Y = Y.astype(np.float32)
    ds = tf.data.Dataset.from_tensor_slices((X, Y))
    if training:
        ds = ds.shuffle(500, seed=SEED)
        def _map_fn(img, mask):
            img, mask = tf.numpy_function(aug_fn, [img, mask], [tf.float32, tf.float32])
            img.set_shape((IMG_SIZE, IMG_SIZE, 3))
            mask.set_shape((IMG_SIZE, IMG_SIZE, 1))
            return img, mask
        ds = ds.map(_map_fn, num_parallel_calls=tf.data.AUTOTUNE)
    def _prep(img, mask):
        img, mask = tf.numpy_function(preprocess_fn, [img, mask], [tf.float32, tf.float32])
        img.set_shape((IMG_SIZE, IMG_SIZE, 3))
        mask.set_shape((IMG_SIZE, IMG_SIZE, 1))
        return img, mask
    ds = ds.map(_prep, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds

train_ds = make_dataset_aug(X_train, Y_train, BATCH_SIZE, True)
val_ds = make_dataset_aug(X_val, Y_val, BATCH_SIZE, False)

# ==============================================
# 9. MODEL
# ==============================================
model = sm.Unet(
    BACKBONE,
    encoder_weights='imagenet',
    classes=1,
    activation='sigmoid',
    decoder_use_batchnorm=True,
    decoder_block_type='transpose',
    decoder_filters=(512, 256, 128, 64, 32)
)

# Loss: Weighted BCE + Dice
bce = sm.losses.BinaryCELoss()
dice = sm.losses.DiceLoss()
total_loss = 0.5 * bce + 0.5 * dice

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss=total_loss,
    metrics=[sm.metrics.iou_score, 'accuracy']
)

callbacks = [
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3),
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
    tf.keras.callbacks.ModelCheckpoint('best_model.h5', monitor='val_iou_score', save_best_only=True, mode='max')
]

# ==============================================
# 10. TRAIN
# ==============================================
history = model.fit(train_ds, validation_data=val_ds, epochs=100, callbacks=callbacks)

# ==============================================
# 11. SAVE FINAL MODEL
# ==============================================
model.save("model_final1.h5")
print("✅ Model saved as model_final1.h5")

Segmentation Models: using `keras` framework.
✅ Loaded: 1451 slices, shape=(256, 256, 1)


  A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.3),
  original_init(self, **validated_kwargs)


Downloading data from https://github.com/qubvel/classification_models/releases/download/0.0.1/resnet34_imagenet_1000_no_top.h5
[1m85521592/85521592[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step
Epoch 1/60
[1m  6/254[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m31:33[0m 8s/step - accuracy: 0.3777 - iou_score: 0.0035 - loss: 0.9309

KeyboardInterrupt: 