In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
from ipywidgets import interact, IntSlider, ToggleButtons, fixed
import nibabel as nib
from other import pad_slice_to_target_shape
from tqdm.auto import tqdm
import h5py
from MRI_dataset import MRIdataset_all_dims_hdf5
from other import *
from torch.utils.data import DataLoader
from Unet import Unet
import torch.nn as nn
from torchvision import transforms
from torch.amp import autocast, GradScaler
import time
from ConbinedLoss import CombinedLoss
from metrics import *

### Train dataset visualizations

In [2]:
folder_paths = "MSLesSeg-Dataset/train/P{}"

def get_selected_folder_files(patient_num, folder_paths="MSLesSeg-Dataset/train/P{}"):
    
    MRI_folders =[folder_paths.format(patient_num) + "/" + folder for folder in os.listdir(folder_paths.format(patient_num))] 

    def get_selected_files(MRI_folder):

        MRI_files = [MRI_folder + "/" + file for file in os.listdir(MRI_folder) if file.endswith(".nii.gz")]

        def plot_files():
        
            img_tensor = [torch.tensor(load_nii(file), dtype=torch.float32) for file in MRI_files]
            len_files = len(img_tensor)

            def explore_slices(layer_sagittal, layer_coronal, layer_axial, len_files):

                fig, ax = plt.subplots(3, len_files, figsize=(10, 5)) 
                for i in range(len_files):
                    ax[0][i].grid(False)
                    ax[0][i].imshow(img_tensor[i][layer_sagittal, :, :].T, cmap="gray", origin="lower")
                    ax[0][i].set_title(MRI_files[i].split("/")[-1])

                    ax[1][i].grid(False)
                    ax[1][i].imshow(img_tensor[i][:, layer_coronal, :].T, cmap="gray", origin="lower")
                    ax[0][i].set_title(MRI_files[i].split("/")[-1])

                    ax[2][i].grid(False)
                    ax[2][i].imshow(img_tensor[i][:, :, layer_axial].T, cmap="gray", origin="lower")
                    ax[0][i].set_title(MRI_files[i].split("/")[-1])

            interact(explore_slices, layer_axial=(0, img_tensor[0].shape[0] - 1), layer_sagittal=(0, img_tensor[0].shape[1] - 1), layer_coronal=(0, img_tensor[0].shape[2] - 1), len_files=fixed(len_files))

        interact(plot_files)
    
    interact(get_selected_files, MRI_folder=MRI_folders)

interact(get_selected_folder_files, patient_num=[i for i in range(1, os.listdir("MSLesSeg-Dataset/train").__len__() + 1)], folder_paths=fixed(folder_paths))


interactive(children=(Dropdown(description='patient_num', options=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, …

<function __main__.get_selected_folder_files(patient_num, folder_paths='MSLesSeg-Dataset/train/P{}')>

### Test dataset visualizations

In [10]:
folder_paths = "MSLesSeg-Dataset/test/P{}"

def get_selected_folder_files(patient_num, folder_paths="MSLesSeg-Dataset/test/P{}"):
    
    MRI_files = os.listdir(folder_paths.format(patient_num))
    MRI_files = [folder_paths.format(patient_num) + "/" + file for file in MRI_files if file.endswith(".nii.gz")]


    def plot_files():
    
        img_tensor = [torch.tensor(load_nii(file), dtype=torch.float32) for file in MRI_files]
        len_files = len(img_tensor)

        def explore_slices(layer_sagittal, layer_coronal, layer_axial, len_files):

            fig, ax = plt.subplots(3, len_files, figsize=(13, 8)) 
            for i in range(len_files):
                ax[0][i].grid(False)
                ax[0][i].imshow(img_tensor[i][layer_sagittal, :, :].T, cmap="gray", origin="lower")
                ax[0][i].set_title(MRI_files[i].split("/")[-1])

                ax[1][i].grid(False)
                ax[1][i].imshow(img_tensor[i][:, layer_coronal, :].T, cmap="gray", origin="lower")
                ax[0][i].set_title(MRI_files[i].split("/")[-1])

                ax[2][i].grid(False)
                ax[2][i].imshow(img_tensor[i][:, :, layer_axial].T, cmap="gray", origin="lower")
                ax[0][i].set_title(MRI_files[i].split("/")[-1])
        
        interact(explore_slices, layer_sagittal=(0, img_tensor[0].shape[0] - 1), layer_coronal=(0, img_tensor[0].shape[1] - 1), layer_axial=(0, img_tensor[0].shape[2] - 1), len_files=fixed(len_files))
    
    interact(plot_files)

interact(get_selected_folder_files, patient_num=[i for i in range(54, 76)], folder_paths=fixed(folder_paths))


interactive(children=(Dropdown(description='patient_num', options=(54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,…

<function __main__.get_selected_folder_files(patient_num, folder_paths='MSLesSeg-Dataset/test/P{}')>

### HDF5 dataset

In [None]:
'MSLesSeg-Dataset/test/P{}\P{}_FLAIR.nii.gz'

In [None]:
def create_h5(output_file,start_num_patient, end_num_patient=53, data_path="MSLesSeg-Dataset/test/P{}/P{}_FLAIR.nii.gz", roi_path="MSLesSeg-Dataset/test/P{}/P{}_MASK.nii.gz"):
    data_slices = []
    roi_slices = []

    for patient_id in tqdm(range(start_num_patient, end_num_patient + 1)):
        data_path, roi_path = data_path.format(patient_id, patient_id), roi_path.format(patient_id, patient_id)
        
        if data_path:
            data_volume = load_nii(data_path).astype(np.float32)
            if os.path.exists(roi_path):
                roi_volume = load_nii(roi_path).astype(np.float32)
            else:
                roi_volume = np.zeros_like(data_volume)

            # Axial slices (z-axis)
            for slice_idx in range(20, data_volume.shape[2]-20):
                data_slices.append(pad_slice_to_target_shape(data_volume[:, :, slice_idx], target_shape=(224, 224)))
                roi_slices.append(pad_slice_to_target_shape(roi_volume[:, :, slice_idx], target_shape=(224, 224)))

            # Sagittal slices (x-axis)
            for slice_idx in range(20, data_volume.shape[0]-20):
                data_slices.append(pad_slice_to_target_shape(data_volume[slice_idx, :, :], target_shape=(224, 224)))
                roi_slices.append(pad_slice_to_target_shape(roi_volume[slice_idx, :, :], target_shape=(224, 224)))
        
            # Coronal slices (y-axis)
            for slice_idx in range(20, data_volume.shape[1]-20):
                data_slices.append(pad_slice_to_target_shape(data_volume[:, slice_idx, :], target_shape=(224, 224)))
                roi_slices.append(pad_slice_to_target_shape(roi_volume[:, slice_idx, :], target_shape=(224, 224)))

    # Convert to numpy arrays for efficient shuffling
    data_slices = np.array(data_slices)
    roi_slices = np.array(roi_slices)

    # Shuffle slices while keeping data and roi aligned
    shuffle_indices = np.random.permutation(data_slices.shape[0])
    data_slices = data_slices[shuffle_indices]
    roi_slices = roi_slices[shuffle_indices]

    # Save the shuffled slices to an HDF5 file
    with h5py.File(output_file, 'w') as hf:
        hf.create_dataset("data", data=data_slices, compression="gzip", compression_opts=9)
        hf.create_dataset("roi", data=roi_slices, compression="gzip", compression_opts=9)

    print(f"Flattened, shuffled slices from all dimensions saved to {output_file}")

In [14]:
data_path = "MSLesSeg-Dataset/test/P{}\P{}_FLAIR.nii.gz" # "MSLesSeg-Dataset/train/P{}/P{}_FLAIR.nii.gz"
roi_path = "MSLesSeg-Dataset/test/P{}\P{}_MASK.nii.gz" # "MSLesSeg-Dataset/train/P{}/P{}_MASK.nii.gz"

create_h5("MSLesSeg-Dataset/test_data.h5", 54, 75)

  0%|          | 0/22 [00:00<?, ?it/s]

Flattened, shuffled slices from all dimensions saved to MSLesSeg-Dataset/test_data.h5


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose([transforms.ToTensor()])

train_dataset = MRIdataset_all_dims_hdf5(hdf5_path='MSLesSeg-Dataset/train_data.h5', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)


Dataset loaded. Total slices: 24486


In [6]:
print(train_dataset[0][0].shape)
print(train_dataset[0][1].shape)

torch.Size([1, 224, 224])
torch.Size([1, 224, 224])


### Training loop

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

EPOCHS = 20
LRate = 1e-4

depth = 3
k_size = 3
base_channels = 64
inception = True

model = Unet(in_channels=1, num_features=1, depth=depth, k_size=k_size, base_channels=base_channels, inception=inception).to(device)

criterion = CombinedLoss() #  nn.BCEWithLogitsLoss() # DiceLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LRate, weight_decay=1e-5)# torch.optim.SGD(model.parameters(), lr=LRate, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer,mode='min', factor=0.5, patience = 3)

In [None]:
scaler = GradScaler("cuda")

train_losses = []
epoch_train_iou = []
epoch_train_recall = []
epoch_train_precision = []
epoch_train_dice = []

epoch_train_losses = []
train_times = []

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    avg_loss = 0.0
    running_train_iou = 0.0
    running_train_recall = 0.0
    running_train_precision = 0.0
    running_train_dice = 0.0
    start = time.time()

    # Training
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{EPOCHS}", leave=False)
    for i, (data_slice, roi_slice) in enumerate(progress_bar):
        data_slice = data_slice.to(device)
        roi_slice = roi_slice.to(device)

        optimizer.zero_grad()

        with autocast("cuda"):
          outputs = model(data_slice)
          loss = criterion(outputs, roi_slice)

        train_losses.append(loss.item())

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Update progress
        running_loss += loss.item()

        # Compute metrics for this batch:
        # Threshold predictions at 0.5 to obtain binary mask
        preds = (outputs > 0.5).float()
        # Squeeze channel dimension so shape becomes [batch, H, W]
        preds_np = preds.squeeze(1).cpu().numpy()
        truth_np = roi_slice.squeeze(1).cpu().numpy()

        batch_iou = []
        batch_recall = []
        batch_precision = []
        batch_dice = []
        for pred_mask, truth_mask in zip(preds_np, truth_np):
            # Use the iou_score function from metrics.py
            batch_iou.append(iou_score(truth_mask, pred_mask))
            batch_recall.append(recall_score_(truth_mask, pred_mask))
            batch_precision.append(precision_score_(truth_mask, pred_mask))
            batch_dice.append(dice_coef(truth_mask, pred_mask))
        running_train_iou += np.mean(batch_iou)
        running_train_recall += np.mean(batch_recall)
        running_train_precision += np.mean(batch_precision)
        running_train_dice += np.mean(batch_dice)

        progress_bar.set_postfix({'Batch Loss': loss.item(), 'Batch IoU': np.mean(batch_iou), 'Batch Recall': np.mean(batch_recall), 'Batch Precision': np.mean(batch_precision), 'Batch Dice': np.mean(batch_dice)})

    progress_bar.close()

    # Average loss for the epoch
    avg_loss = running_loss / len(train_loader)
    avg_train_iou = running_train_iou / len(train_loader)
    avg_train_recall = running_train_recall / len(train_loader)
    avg_train_precision = running_train_precision / len(train_loader)
    avg_train_dice = running_train_dice / len(train_loader)
    epoch_train_losses.append(avg_loss)
    epoch_train_iou.append(avg_train_iou)
    epoch_train_recall.append(avg_train_recall)
    epoch_train_precision.append(avg_train_precision)
    epoch_train_dice.append(avg_train_dice)
    train_times.append(time.time() - start)

    scheduler.step(avg_loss)

    print(f"Epoch [{epoch+1}/{EPOCHS}] Training Loss: {avg_loss:.4f}, Train IoU: {avg_train_iou:.4f}, Time: {train_times[-1]:.2f}s")

print("Training finished")


total = sum(train_times)
hours, rem = divmod(total, 3600)
minutes, seconds = divmod(rem, 60)
print(f"Total computation time: {int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}")

print(f"Training time: {sum(train_times):.2f}s")

In [None]:
torch.save(model, "model_MSLesSeg_d{depth}_k{k_size}_b{base_channels}_incept{inception}.pth")

### Visualization

In [None]:
epochs = range(1, EPOCHS + 1)
plt.figure(figsize=(10, 5))
plt.plot(epoch, epoch_train_losses, label='Training Loss')
plt.plot(epoch, epoch_train_iou, label='Training IoU')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.show()

In [None]:
epochs = range(1, EPOCHS + 1)
plt.figure(figsize=(10, 5))
plt.plot(epochs, epoch_train_iou, label='Train IoU')
plt.plot(epochs, epoch_train_recall, label='Train Recall')
plt.plot(epochs, epoch_train_precision, label='Train Precision')
plt.plot(epochs, epoch_train_dice, label='Train Dice')
plt.title('Training Metrics')
plt.xlabel('Epoch')
plt.ylabel('Metrics')
plt.legend()
plt.show()

In [11]:
folder_paths = "MSLesSeg-Dataset/train/P{}"

def get_selected_folder_files(patient_num, folder_paths="MSLesSeg-Dataset/train/P{}", model_path="MSLesSeg-Dataset/Unet_MSL.pth"):
    
    MRI_folders =[folder_paths.format(patient_num) + "/" + folder for folder in os.listdir(folder_paths.format(patient_num))] 

    def get_selected_files(MRI_folder):

        MRI_files = [f"P{patient_num}_T1_FLAIR.nii.gz", f"P{patient_num}_T1_MASK.nii.gz"]
        MRI_files = [MRI_folder + "/" + file for file in MRI_files]

        def plot_files():
        
            img_tensor = [torch.tensor(load_nii(file), dtype=torch.float32) for file in MRI_files]
            len_files = len(img_tensor)

            model = torch.load(model_path)
            model.eval()

            def explore_slices(layer_sagittal, layer_coronal, layer_axial, len_files):

                fig, ax = plt.subplots(3, len_files+1, figsize=(10, 5)) 
                for i in range(len_files):
                    ax[0][i].grid(False)
                    ax[0][i].imshow(img_tensor[i][layer_sagittal, :, :].T, cmap="gray", origin="lower")
                    ax[0][i].set_title(MRI_files[i].split("/")[-1])

                    ax[1][i].grid(False)
                    ax[1][i].imshow(img_tensor[i][:, layer_coronal, :].T, cmap="gray", origin="lower")
                    ax[0][i].set_title(MRI_files[i].split("/")[-1])

                    ax[2][i].grid(False)
                    ax[2][i].imshow(img_tensor[i][:, :, layer_axial].T, cmap="gray", origin="lower")
                    ax[0][i].set_title(MRI_files[i].split("/")[-1])

                # Predict slice
                with torch.no_grad():
                    data_slice_sagittal = img_tensor[0][layer_sagittal, :, :].unsqueeze(0).unsqueeze(0).to(device)
                    pred_slice_sagittal = model(data_slice_sagittal)
                    pred_slice_sagittal = (pred_slice_sagittal > 0.5).float().squeeze(0).squeeze(0).cpu().numpy()

                    data_slice_coronal = img_tensor[0][:, layer_coronal, :].unsqueeze(0).unsqueeze(0).to(device)
                    pred_slice_coronal = model(data_slice_coronal)
                    pred_slice_coronal = (pred_slice_coronal > 0.5).float().squeeze(0).squeeze(0).cpu().numpy()

                    data_slice_axial = img_tensor[0][:, :, layer_axial].unsqueeze(0).unsqueeze(0).to(device)
                    pred_slice_axial = model(data_slice_axial)
                    pred_slice_axial = (pred_slice_axial > 0.5).float().squeeze(0).squeeze(0).cpu().numpy()

                # Plot predicted slices

                ax[0][len_files].grid(False)    
                ax[0][len_files].imshow(pred_slice_sagittal.T, cmap="gray", origin="lower")
                ax[0][len_files].set_title("Predictions")

                ax[1][len_files].grid(False)
                ax[1][len_files].imshow(pred_slice_coronal.T, cmap="gray", origin="lower")

                ax[2][len_files].grid(False)
                ax[2][len_files].imshow(pred_slice_axial.T, cmap="gray", origin="lower")

            interact(explore_slices, layer_axial=(0, img_tensor[0].shape[0] - 1), layer_sagittal=(0, img_tensor[0].shape[1] - 1), layer_coronal=(0, img_tensor[0].shape[2] - 1), len_files=fixed(len_files))

        interact(plot_files)
    
    interact(get_selected_files, MRI_folder=MRI_folders)

interact(get_selected_folder_files, patient_num=[i for i in range(1, os.listdir("MSLesSeg-Dataset/train").__len__() + 1)], folder_paths=fixed(folder_paths))

interactive(children=(Dropdown(description='patient_num', options=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, …

<function __main__.get_selected_folder_files(patient_num, folder_paths='MSLesSeg-Dataset/train/P{}', model_path='MSLesSeg-Dataset/Unet_MSL.pth')>