In [None]:
!pip install tensorflow==2.17.1 nibabel matplotlib scikit-learn faiss-cpu




In [None]:
import os, glob, random, tempfile, requests
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import tensorflow as tf

from tensorflow.keras.applications import InceptionV3
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, classification_report, confusion_matrix
from collections import Counter

import faiss


In [None]:
import os, requests

base = ("https://raw.githubusercontent.com/muschellij2/open_ms_data/"
        "master/cross_sectional/coregistered_resampled")

patients = [f"patient{i:02d}" for i in range(1, 31)]

modalities = {
    "FLAIR": "FLAIR.nii.gz",
    "Brain Mask": "brainmask.nii.gz",
    "Lesion Mask": "consensus_gt.nii.gz"
}

out_dir = "ms_data_resampled_labeled"
os.makedirs(out_dir, exist_ok=True)

def download_and_log(url, out):
    if os.path.exists(out):
        print("↪ Already exists:", out)
        return True
    r = requests.get(url, stream=True)
    if r.ok:
        with open(out, "wb") as f:
            for chunk in r.iter_content(1 << 20):  # 1 MB chunks
                f.write(chunk)
        print("✅", out)
        return True
    else:
        print("❌ Failed:", url, "status:", r.status_code)
        return False

downloaded = []
for pid in patients:
    success = True
    for label, fname in modalities.items():
        url = f"{base}/{pid}/{fname}"
        out = os.path.join(out_dir, f"{pid}_{label.replace(' ', '_')}.nii.gz")
        if not download_and_log(url, out):
            success = False
    if success:
        downloaded.append(pid)

print(f"\nDownload complete. Successful patients: {len(downloaded)} / {len(patients)}")


✅ ms_data_resampled_labeled/patient01_FLAIR.nii.gz
✅ ms_data_resampled_labeled/patient01_Brain_Mask.nii.gz
✅ ms_data_resampled_labeled/patient01_Lesion_Mask.nii.gz
✅ ms_data_resampled_labeled/patient02_FLAIR.nii.gz
✅ ms_data_resampled_labeled/patient02_Brain_Mask.nii.gz
✅ ms_data_resampled_labeled/patient02_Lesion_Mask.nii.gz
✅ ms_data_resampled_labeled/patient03_FLAIR.nii.gz
✅ ms_data_resampled_labeled/patient03_Brain_Mask.nii.gz
✅ ms_data_resampled_labeled/patient03_Lesion_Mask.nii.gz
✅ ms_data_resampled_labeled/patient04_FLAIR.nii.gz
✅ ms_data_resampled_labeled/patient04_Brain_Mask.nii.gz
✅ ms_data_resampled_labeled/patient04_Lesion_Mask.nii.gz
✅ ms_data_resampled_labeled/patient05_FLAIR.nii.gz
✅ ms_data_resampled_labeled/patient05_Brain_Mask.nii.gz
✅ ms_data_resampled_labeled/patient05_Lesion_Mask.nii.gz
✅ ms_data_resampled_labeled/patient06_FLAIR.nii.gz
✅ ms_data_resampled_labeled/patient06_Brain_Mask.nii.gz
✅ ms_data_resampled_labeled/patient06_Lesion_Mask.nii.gz
✅ ms_data_resamp

In [None]:
DATA_DIR = "ms_data_resampled_labeled"
patients = sorted(list({os.path.basename(p).split("_")[0] for p in glob.glob(os.path.join(DATA_DIR, "*_FLAIR.nii.gz"))}))
print("Patients found:", len(patients))

IMG_SIZE = (299, 299)
SLICES_PER_PATIENT = 12

def slice_indices_centered(n_slices, k=SLICES_PER_PATIENT):
    mid = n_slices // 2
    start = max(0, mid - k//2)
    end = min(n_slices, start + k)
    start = max(0, end - k)
    return list(range(start, end))

def preprocess_slice(slice_img):
    s = slice_img.astype(np.float32)
    mn, mx = np.percentile(s, 0.5), np.percentile(s, 99.5)
    s = np.clip((s - mn) / max(mx - mn, 1e-6), 0, 1)
    s = tf.image.resize(s[..., None], IMG_SIZE)
    s = tf.image.grayscale_to_rgb(s)   # shape (299,299,3)
    return s.numpy()

slice_records = []
for pid in patients:
    flair_path = os.path.join(DATA_DIR, f"{pid}_FLAIR.nii.gz")
    mask_path  = os.path.join(DATA_DIR, f"{pid}_Lesion_Mask.nii.gz")
    if not (os.path.exists(flair_path) and os.path.exists(mask_path)):
        continue
    flair = nib.load(flair_path)
    mask  = nib.load(mask_path)
    nz = flair.shape[2]
    idxs = slice_indices_centered(nz, SLICES_PER_PATIENT)
    for z in idxs:
        mask_slice = mask.get_fdata()[:, :, z]
        label = 1 if np.count_nonzero(mask_slice) > 0 else 0
        slice_records.append((pid, z, label))
print("Total slices prepared (approx):", len(slice_records))


Patients found: 30
Total slices prepared (approx): 360


In [None]:
random.seed(42)
train_p, test_p = train_test_split(patients, test_size=0.2, random_state=42)
train_p, val_p = train_test_split(train_p, test_size=0.2, random_state=42)
print("Patients split -> train:", len(train_p), "val:", len(val_p), "test:", len(test_p))

def dataset_for_patients(pids, batch=16, augment=False, shuffle=True):
    def gen():
        for pid in pids:
            flair_path = os.path.join(DATA_DIR, f"{pid}_FLAIR.nii.gz")
            mask_path  = os.path.join(DATA_DIR, f"{pid}_Lesion_Mask.nii.gz")
            flair = nib.load(flair_path).get_fdata()
            mask  = nib.load(mask_path).get_fdata()
            nz = flair.shape[2]
            idxs = slice_indices_centered(nz, SLICES_PER_PATIENT)
            for z in idxs:
                s = preprocess_slice(flair[:, :, z])
                lbl = 1 if np.count_nonzero(mask[:, :, z]) > 0 else 0
                yield s.astype(np.float32), np.float32(lbl)
    out_sig = (
        tf.TensorSpec(shape=(IMG_SIZE[0], IMG_SIZE[1], 3), dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.float32)
    )
    ds = tf.data.Dataset.from_generator(gen, output_signature=out_sig)
    if shuffle:
        ds = ds.shuffle(1024, reshuffle_each_iteration=True)
    if augment:
        def aug(x, y):
            x = tf.image.random_flip_left_right(x)
            x = tf.image.random_flip_up_down(x)
            x = tf.image.random_brightness(x, 0.05)
            return x, y
        ds = ds.map(aug, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(batch).prefetch(tf.data.AUTOTUNE)
    return ds

train_ds = dataset_for_patients(train_p, batch=16, augment=True, shuffle=True)
val_ds   = dataset_for_patients(val_p,   batch=16, augment=False, shuffle=False)
test_ds  = dataset_for_patients(test_p,  batch=16, augment=False, shuffle=False)


Patients split -> train: 19 val: 5 test: 6


In [None]:
from tensorflow.keras.applications import InceptionV3
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.optimizers import Adam

base = InceptionV3(weights="imagenet", include_top=False, input_shape=(299,299,3))
for l in base.layers:
    l.trainable = False

x = GlobalAveragePooling2D()(base.output)
x = Dense(512, activation="relu")(x)
x = Dropout(0.4)(x)
x = Dense(128, activation="relu")(x)
x = Dropout(0.3)(x)
out = Dense(1, activation="sigmoid")(x)

model = Model(inputs=base.input, outputs=out)
model.compile(
    optimizer=Adam(1e-4),
    loss="binary_crossentropy",
    metrics=[tf.keras.metrics.BinaryAccuracy(name="acc"), tf.keras.metrics.AUC(name="auc")]
)

model.summary()


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m87910968/87910968[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 0us/step


In [None]:
callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor="val_auc", patience=5, mode="max", restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2)
]

history = model.fit(train_ds, validation_data=val_ds, epochs=12, callbacks=callbacks)


Epoch 1/12
     15/Unknown [1m49s[0m 2s/step - acc: 0.7346 - auc: 0.4794 - loss: 0.5519



[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m62s[0m 3s/step - acc: 0.7383 - auc: 0.4805 - loss: 0.5499 - val_acc: 0.8000 - val_auc: 0.7561 - val_loss: 0.5372 - learning_rate: 1.0000e-04
Epoch 2/12
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 3s/step - acc: 0.8275 - auc: 0.6005 - loss: 0.4617 - val_acc: 0.8000 - val_auc: 0.8073 - val_loss: 0.4430 - learning_rate: 1.0000e-04
Epoch 3/12
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m82s[0m 3s/step - acc: 0.8650 - auc: 0.6461 - loss: 0.3829 - val_acc: 0.8000 - val_auc: 0.7812 - val_loss: 0.4710 - learning_rate: 1.0000e-04
Epoch 4/12
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m81s[0m 5s/step - acc: 0.8517 - auc: 0.8281 - loss: 0.3350 - val_acc: 0.8000 - val_auc: 0.8090 - val_loss: 0.4329 - learning_rate: 1.0000e-04
Epoch 5/12
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m48s[0m 3s/step - acc: 0.8401 - auc: 0.8138 - loss: 0.3654 - val_acc: 0.8000 - val_auc: 0.8307 - val_lo

In [None]:
# collect predictions and true labels on test set
y_true, y_prob = [], []
for x_batch, y_batch in test_ds:
    probs = model.predict(x_batch, verbose=0).ravel()
    y_prob.append(probs)
    y_true.append(y_batch.numpy())
y_true = np.concatenate(y_true)
y_prob = np.concatenate(y_prob)

y_pred = (y_prob >= 0.5).astype(int)
print("ROC AUC:", roc_auc_score(y_true, y_prob))
print(classification_report(y_true, y_pred, digits=4))
print("Confusion matrix:\n", confusion_matrix(y_true, y_pred))


In [None]:

def patient_probs(pid):
    flair_path = os.path.join(DATA_DIR, f"{pid}_FLAIR.nii.gz")
    mask_path  = os.path.join(DATA_DIR, f"{pid}_Lesion_Mask.nii.gz")
    flair = nib.load(flair_path).get_fdata()
    mask  = nib.load(mask_path).get_fdata()
    idxs = slice_indices_centered(flair.shape[2], SLICES_PER_PATIENT)
    probs = []
    for z in idxs:
        s = preprocess_slice(flair[:, :, z])
        p = float(model.predict(s[None,...], verbose=0)[0,0])
        probs.append((z, p, int(np.count_nonzero(mask[:,:,z])>0)))
    return probs

# Evaluate on test patients
scan_y_true, scan_y_prob = [], []
for pid in test_p:
    probs = patient_probs(pid)
    per_slice_probs = [p for (_,p,_) in probs]
    scan_prob_max = max(per_slice_probs)
    scan_prob_mean = float(np.mean(per_slice_probs))
    gt = int(any([lab for (_,_,lab) in probs]))
    scan_y_true.append(gt)
    scan_y_prob.append(scan_prob_max)   # use max for demo

print("Scan ROC AUC:", roc_auc_score(scan_y_true, scan_y_prob))
scan_pred = [1 if p>=0.5 else 0 for p in scan_y_prob]
print(classification_report(scan_y_true, scan_pred, digits=4))
print("Confusion matrix (scan-level):\n", confusion_matrix(scan_y_true, scan_pred))


In [None]:
def show_patient(pid):
    probs = patient_probs(pid)
    flair_path = os.path.join(DATA_DIR, f"{pid}_FLAIR.nii.gz")
    mask_path  = os.path.join(DATA_DIR, f"{pid}_Lesion_Mask.nii.gz")
    flair = nib.load(flair_path).get_fdata()
    mask  = nib.load(mask_path).get_fdata()>0
    plt.figure(figsize=(15,3))
    for i,(z,p,gt) in enumerate(probs):
        sl = flair[:,:,z]
        plt.subplot(1,len(probs), i+1)
        plt.imshow(sl.T, cmap="gray", origin="lower")
        if gt:
            plt.contour(mask[:,:,z].T, levels=[0.5], colors='r')
        plt.title(f"z={z}\np={p:.2f}")
        plt.axis("off")
    plt.suptitle(f"Patient {pid} - slice probs")
    plt.show()

for pid in test_p[:3]:
    show_patient(pid)
