In [1]:
!pip install tensorflow==2.17.1 nibabel matplotlib gradio scikit-learn

Collecting tensorflow==2.17.1
  Downloading tensorflow-2.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Collecting ml-dtypes<0.5.0,>=0.3.1 (from tensorflow==2.17.1)
  Downloading ml_dtypes-0.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 (from tensorflow==2.17.1)
  Downloading protobuf-4.25.8-cp37-abi3-manylinux2014_x86_64.whl.metadata (541 bytes)
Collecting tensorboard<2.18,>=2.17 (from tensorflow==2.17.1)
  Downloading tensorboard-2.17.1-py3-none-any.whl.metadata (1.6 kB)
Collecting numpy<2.0.0,>=1.26.0 (from tensorflow==2.17.1)
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Downloading tensorflow-2.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x8

ValueError: numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject

In [2]:

import os, glob, random, tempfile, requests
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import tensorflow as tf
import gradio as gr

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K

from sklearn.model_selection import train_test_split

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
base = ("https://raw.githubusercontent.com/muschellij2/open_ms_data/"
        "master/cross_sectional/coregistered_resampled")

patients = [f"patient{i:02d}" for i in range(1, 31)]

modalities = {
    "FLAIR": "FLAIR.nii.gz",
    "Brain Mask": "brainmask.nii.gz",
    "Lesion Mask": "consensus_gt.nii.gz"
}

out_dir = "ms_data_resampled_labeled"
os.makedirs(out_dir, exist_ok=True)

def download_and_log(url, out):
    if os.path.exists(out):
        print("↪ Already exists:", out)
        return True
    r = requests.get(url, stream=True)
    if r.ok:
        with open(out, "wb") as f:
            for chunk in r.iter_content(1 << 20):  # 1 MB chunks
                f.write(chunk)
        print("✅", out)
        return True
    else:
        print("❌ Failed:", url, "status:", r.status_code)
        return False

downloaded = []
for pid in patients:
    success = True
    for label, fname in modalities.items():
        url = f"{base}/{pid}/{fname}"
        out = os.path.join(out_dir, f"{pid}_{label.replace(' ', '_')}.nii.gz")
        if not download_and_log(url, out):
            success = False
    if success:
        downloaded.append(pid)

print(f"\nDownload complete. Successful patients: {len(downloaded)} / {len(patients)}")

↪ Already exists: ms_data_resampled_labeled/patient01_FLAIR.nii.gz
↪ Already exists: ms_data_resampled_labeled/patient01_Brain_Mask.nii.gz
↪ Already exists: ms_data_resampled_labeled/patient01_Lesion_Mask.nii.gz
↪ Already exists: ms_data_resampled_labeled/patient02_FLAIR.nii.gz
↪ Already exists: ms_data_resampled_labeled/patient02_Brain_Mask.nii.gz
↪ Already exists: ms_data_resampled_labeled/patient02_Lesion_Mask.nii.gz
↪ Already exists: ms_data_resampled_labeled/patient03_FLAIR.nii.gz
↪ Already exists: ms_data_resampled_labeled/patient03_Brain_Mask.nii.gz
↪ Already exists: ms_data_resampled_labeled/patient03_Lesion_Mask.nii.gz
↪ Already exists: ms_data_resampled_labeled/patient04_FLAIR.nii.gz
↪ Already exists: ms_data_resampled_labeled/patient04_Brain_Mask.nii.gz
↪ Already exists: ms_data_resampled_labeled/patient04_Lesion_Mask.nii.gz
↪ Already exists: ms_data_resampled_labeled/patient05_FLAIR.nii.gz
↪ Already exists: ms_data_resampled_labeled/patient05_Brain_Mask.nii.gz
↪ Already exi

In [4]:
DATA_DIR = "ms_data_resampled_labeled"
patients = sorted(list({os.path.basename(p).split("_")[0] for p in glob.glob(os.path.join(DATA_DIR, "*_FLAIR.nii.gz"))}))
print("Patients found:", len(patients))

IMG_SIZE = (256, 256) # U-Nets work well with powers of 2
SLICES_PER_PATIENT = 12

def slice_indices_centered(n_slices, k=SLICES_PER_PATIENT):
    mid = n_slices // 2
    start = max(0, mid - k//2)
    end = min(n_slices, start + k)
    start = max(0, end - k)
    return list(range(start, end))

def preprocess_slice(slice_img):
    """ Prepares a 2D slice for model input. """
    s = slice_img.astype(np.float32)
    mn, mx = np.percentile(s, 0.5), np.percentile(s, 99.5)
    s = np.clip((s - mn) / max(mx - mn, 1e-6), 0, 1)
    s = tf.image.resize(s[..., None], IMG_SIZE)
    s = tf.image.grayscale_to_rgb(s)   # U-Net input expects 3 channels
    return s.numpy()

Patients found: 30


In [5]:

random.seed(42)
train_p, test_p = train_test_split(patients, test_size=0.2, random_state=42)
train_p, val_p = train_test_split(train_p, test_size=0.2, random_state=42)
print("Patients split -> train:", len(train_p), "val:", len(val_p), "test:", len(test_p))


def dataset_for_segmentation(pids, batch=16, augment=False, shuffle=True):
    def gen():
        for pid in pids:
            flair_path = os.path.join(DATA_DIR, f"{pid}_FLAIR.nii.gz")
            mask_path  = os.path.join(DATA_DIR, f"{pid}_Lesion_Mask.nii.gz")
            flair = nib.load(flair_path).get_fdata()
            mask  = nib.load(mask_path).get_fdata()
            nz = flair.shape[2]
            idxs = slice_indices_centered(nz, SLICES_PER_PATIENT)
            for z in idxs:
                # Preprocess the FLAIR image slice (output shape: 256, 256, 3)
                s = preprocess_slice(flair[:, :, z])

                # Preprocess the corresponding mask slice (output shape: 256, 256, 1)
                m = mask[:, :, z].astype(np.float32)
                m = tf.image.resize(m[..., None], IMG_SIZE, method='nearest') # Use 'nearest' for masks

                yield s, m

    out_sig = (
        tf.TensorSpec(shape=(IMG_SIZE[0], IMG_SIZE[1], 3), dtype=tf.float32),
        tf.TensorSpec(shape=(IMG_SIZE[0], IMG_SIZE[1], 1), dtype=tf.float32) # Mask output
    )
    ds = tf.data.Dataset.from_generator(gen, output_signature=out_sig)

    if shuffle:
        ds = ds.shuffle(1024, reshuffle_each_iteration=True)

    if augment:
        def aug(x, y):
            seed = tf.random.uniform((2,), 0, tf.int32.max, dtype=tf.int32)
            x = tf.image.stateless_random_flip_left_right(x, seed)
            y = tf.image.stateless_random_flip_left_right(y, seed)
            x = tf.image.stateless_random_brightness(x, 0.05, seed)
            return x, y
        ds = ds.map(aug, num_parallel_calls=tf.data.AUTOTUNE)

    ds = ds.batch(batch).prefetch(tf.data.AUTOTUNE)
    return ds

# Create the new datasets
train_ds_seg = dataset_for_segmentation(train_p, batch=8, augment=True, shuffle=True)
val_ds_seg = dataset_for_segmentation(val_p, batch=8, augment=False, shuffle=False)
test_ds_seg = dataset_for_segmentation(test_p, batch=8, augment=False, shuffle=False)

Patients split -> train: 19 val: 5 test: 6


In [6]:
# @title Step 5: Define and Build the U-Net Model
def unet_model(input_size=(256, 256, 3), num_classes=1):
    inputs = Input(input_size)

    # Encoder Path
    c1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(inputs)
    c1 = Dropout(0.1)(c1)
    c1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
    p1 = MaxPooling2D((2, 2))(c1)

    c2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
    c2 = Dropout(0.1)(c2)
    c2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
    p2 = MaxPooling2D((2, 2))(c2)

    c3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
    c3 = Dropout(0.2)(c3)
    c3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
    p3 = MaxPooling2D((2, 2))(c3)

    # Bottleneck
    c5 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
    c5 = Dropout(0.3)(c5)
    c5 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)

    # Decoder Path
    u6 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = concatenate([u6, c3])
    c6 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
    c6 = Dropout(0.2)(c6)
    c6 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)

    u7 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = concatenate([u7, c2])
    c7 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
    c7 = Dropout(0.1)(c7)
    c7 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)

    u8 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = concatenate([u8, c1])
    c8 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
    c8 = Dropout(0.1)(c8)
    c8 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)

    outputs = Conv2D(num_classes, (1, 1), activation='sigmoid')(c8)

    model = Model(inputs=[inputs], outputs=[outputs])
    return model

# Instantiate the model
seg_model = unet_model(input_size=(IMG_SIZE[0], IMG_SIZE[1], 3))
seg_model.summary()

In [8]:
# @title Step 6: Define Loss Function and Train the Model
# Dice Coefficient Metric and Loss
def dice_coef(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

# Compile the model
seg_model.compile(optimizer=Adam(1e-4), loss=dice_loss, metrics=[dice_coef])

# Define callbacks
callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor="val_dice_coef", patience=8, mode="max", restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=3)
]

# Train the model
history_seg = seg_model.fit(
    train_ds_seg,
    validation_data=val_ds_seg,
    epochs=12, # Segmentation requires more training
    callbacks=callbacks,
    verbose=1 # Explicitly set verbose to 1 for progress bar
)

# Define the save path in your Google Drive
save_path = "/content/drive/MyDrive/unet_segmentation_model.h5"

# Save the model
seg_model.save(save_path)
print(f"Segmentation model saved to: {save_path}")

[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m161s[0m 5s/step - dice_coef: 0.0095 - loss: 0.9905 - val_dice_coef: 0.0187 - val_loss: 0.9777 - learning_rate: 1.0000e-04




Segmentation model saved to: /content/drive/MyDrive/unet_segmentation_model.h5


In [11]:

!pip install -q scikit-image pandas

from skimage import measure
import pandas as pd
from scipy.ndimage import label, center_of_mass

# Define the path to your saved segmentation model
MODEL_PATH = "/content/drive/MyDrive/unet_segmentation_model.h5"

# Load your trained segmentation model
try:
    loaded_seg_model = tf.keras.models.load_model(
        MODEL_PATH,
        custom_objects={'dice_loss': dice_loss, 'dice_coef': dice_coef}
    )
    print("✅ Segmentation model loaded successfully.")
except Exception as e:
    print(f"❌ Error loading model: {e}")
    loaded_seg_model = None


def predict_and_segment_mri(nii_file):
    if not loaded_seg_model or nii_file is None:
        return None, "Model not loaded or file not uploaded."

    # 1. Load and process the MRI data
    try:
        nii = nib.load(nii_file.name)
        flair = nii.get_fdata()
    except Exception as e:
        return None, f"Error reading NIfTI file: {e}"

    z_slice_index = flair.shape[2] // 2
    slice_img = flair[:, :, z_slice_index]
    preprocessed_slice = preprocess_slice(slice_img)
    model_input = np.expand_dims(preprocessed_slice, axis=0)

    # 2. Predict the segmentation mask
    pred_mask_raw = loaded_seg_model.predict(model_input)[0, :, :, 0]
    pred_mask_binary = (pred_mask_raw > 0.5).astype(np.uint8)

    # 3. Generate a focused, two-panel visualization
    fig, axes = plt.subplots(1, 2, figsize=(14, 7))
    fig.patch.set_facecolor('#F0F0F0')
    original_slice_display = preprocess_slice(slice_img)[:,:,0]

    # Panel 1: Original MRI Slice
    axes[0].imshow(original_slice_display.T, cmap='gray', origin='lower')
    axes[0].set_title(f'Original MRI Slice (z={z_slice_index})', fontsize=16)
    axes[0].axis('off')

    # Panel 2: Lesion Localization
    axes[1].imshow(original_slice_display.T, cmap='gray', origin='lower')
    axes[1].set_title('Lesion Localization', fontsize=16)
    axes[1].axis('off')

    # 4. Analyze and Overlay Lesions
    summary = ""
    lesion_details = []
    labeled_mask, num_features = label(pred_mask_binary)

    if num_features > 0:
        # Create a semi-transparent overlay for the lesion area
        mask_overlay = np.ma.masked_where(pred_mask_binary.T == 0, pred_mask_binary.T)
        axes[1].imshow(mask_overlay, cmap='Reds', alpha=0.5, origin='lower')

        # Find contours for a sharp outline
        contours = measure.find_contours(pred_mask_binary.T, 0.5)
        for contour in contours:
            axes[1].plot(contour[:, 1], contour[:, 0], linewidth=2.5, color='cyan')

        # Find centroids to mark the center
        centroids = center_of_mass(pred_mask_binary, labeled_mask, range(1, num_features + 1))
        for i, c in enumerate(centroids):
            # Centroid coordinates are (row, col), which is (y, x) for plotting
            axes[1].plot(c[1], c[0], 'w+', markersize=15, markeredgewidth=2) # White '+' marker
            lesion_details.append({
                "Lesion ID": i + 1,
                "Area (pixels)": int(np.sum(labeled_mask == i + 1)),
                "Avg. Confidence": f"{np.mean(pred_mask_raw[labeled_mask == i + 1]):.2%}",
                "Location (Y, X)": f"({int(c[0])}, {int(c[1])})"
            })

        df = pd.DataFrame(lesion_details)
        status = f"🧠 Lesion{'s' if num_features > 1 else ''} Detected"
        summary = (
            f"### **Analysis Complete**\n---\n"
            f"**Status:** {status} (**{num_features}** distinct object{'s' if num_features > 1 else ''} found).\n\n"
            f"{df.to_markdown(index=False)}"
        )
    else:
        status = "✅ Normal"
        summary = (
            f"### **Analysis Complete**\n---\n"
            f"**Status:** {status}\n\n"
            "No significant lesions were detected on this slice."
        )

    plt.tight_layout()
    return fig, summary

# Create the Gradio Interface
demo = gr.Interface(
    fn=predict_and_segment_mri,
    inputs=gr.File(label="Upload MRI FLAIR scan (.nii.gz)"),
    outputs=[
        gr.Plot(label="Comparison View"),
        gr.Markdown(label="Quantitative Analysis")
    ],
    title="Focused MRI Brain Lesion Localization",
    description="Upload a NIfTI file. The model shows the original scan and an enhanced view pinpointing the exact size, shape, and location of any detected lesions.",
    allow_flagging="never",
    examples=[os.path.join(DATA_DIR, p + "_FLAIR.nii.gz") for p in test_p[:2]]
)

demo.launch(debug=True)



✅ Segmentation model loaded successfully.
It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://1ef1ed09fa21019364.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 485ms/step
Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://1ef1ed09fa21019364.gradio.live


