In [None]:
import logging
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import pandas as pd
import SimpleITK as sitk

from pathlib import Path
from radiomics import imageoperations

%matplotlib inline

In [None]:
root_dir = Path("..").resolve()

# Prepare dataset

In [None]:
midas_img_relation = pd.read_csv(
    root_dir.joinpath("data", "filtered_midas900_t2w.csv"), sep=","
)
midas_img_relation["Subject_MIDS"] = midas_img_relation["Image"].map(
    lambda x: x.split("/")[8]
)
midas_img_relation["Session_MIDS"] = midas_img_relation["Image"].map(
    lambda x: x.split("/")[9]
)
midas_img_relation["Subject_XNAT"] = midas_img_relation["Subject_MIDS"].map(
    lambda x: f"ceibcs_S{int(x.split('sub-S')[1])}"
)
midas_img_relation["Session_XNAT"] = midas_img_relation["Session_MIDS"].map(
    lambda x: f"ceibcs_E{int(x.split('ses-E')[1])}"
)
midas_img_relation["Image"] = midas_img_relation["Image"].map(
    lambda x: x.replace("mnt", "mnt/ceib")
)
midas_img_relation["Mask"] = midas_img_relation["Mask"].map(
    lambda x: x.replace("mnt", "mnt/ceib")
)

In [None]:
labels_df = pd.read_csv(
    root_dir.joinpath("data", "midasdisclabelsJDCarlos.csv"), sep=","
)
labels_df.dropna(inplace=True)
labels_df.rename(
    columns={"subject_ID": "Subject_XNAT", "ID": "Session_XNAT"}, inplace=True
)

id_labels = labels_df.merge(midas_img_relation, on=["Subject_XNAT", "Session_XNAT"])
id_labels.rename(
    columns={
        "L5-S": "1",
        "L4-L5": "2",
        "L3-L4": "3",
        "L2-L3": "4",
        "L1-L2": "5",
    },
    inplace=True,
)

In [None]:
from bimcv_aikit.data.genetic_train_test_split import separate_dataset

In [None]:
partition_csv = separate_dataset(
    id_labels,
    column_subjects="Subject_XNAT",
    column_classes="2",
    new_column_name="Partition",
    label_partitions=["train", "val", "test"],
    label_percentages=[0.7, 0.1, 0.2],
    verbose=True,
)

In [None]:
partition_csv.to_csv(
    root_dir.joinpath("data", "filtered_midas900_t2w_partition.csv"), index=False
)

# Define transforms

In [None]:
def show_image_and_masks(img_path, mask_path):
    image = nib.load(img_path)
    mask = nib.load(mask_path)
    image_data = image.get_fdata()
    mask_data = mask.get_fdata()

    center_slice = image_data.shape[2] // 2
    unique_labels = np.unique(mask_data.ravel())

    _, axs = plt.subplots(1, 3, figsize=(12, 12))
    for i, slice in enumerate(range(center_slice - 1, center_slice + 2)):
        axs[i].imshow(image_data[:, :, slice], cmap="gray")
        im = axs[i].imshow(
            mask_data[:, :, slice],
            cmap="jet",
            alpha=np.where(mask_data[:, :, slice] == 0, 0, 0.3),
        )
        axs[i].grid(False)
        axs[i].axis("off")
    colors = [im.cmap(im.norm(value)) for value in unique_labels]
    patches = [
        mpatches.Patch(color=colors[i], label=f"{unique_labels[i]}")
        for i in range(len(unique_labels))
    ]
    plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)
    plt.tight_layout()
    plt.show()

In [None]:
index = 1
show_image_and_masks(
    midas_img_relation.iloc[index]["Image"], midas_img_relation.iloc[index]["Mask"]
)

In [None]:
import monai.transforms as transforms

In [None]:
class CheckMaskVol(transforms.MapTransform):
    def __init__(
        self,
        keys=["image", "mask"],
        minimum_roi_dimensions: int = 3,
        minimum_roi_size: int = 1000,
    ):
        super().__init__(keys)
        self.minimum_roi_dimensions = minimum_roi_dimensions
        self.minimum_roi_size = minimum_roi_size

    def __call__(self, x):
        image = sitk.ReadImage(x[self.keys[0]])
        mask = sitk.ReadImage(x[self.keys[1]])
        labels = np.unique(sitk.GetArrayFromImage(mask).ravel())
        valid_labels = []
        for label in labels:
            if label != 0:
                try:
                    imageoperations.checkMask(
                        image,
                        mask,
                        minimumROIDimensions=self.minimum_roi_dimensions,
                        minimumROISize=self.minimum_roi_size,
                        label=label,
                    )
                    result = label
                except Exception as e:
                    result = None
                if result:
                    valid_labels.append(result)
        x["valid_labels"] = valid_labels[:5]
        return x

In [None]:
class CropForegroundd(transforms.MapTransform):
    def __init__(
        self, keys=["image"], source_key="mask", margin=0, k_divisible=(64, 64, 1)
    ):
        super().__init__(keys)
        self.k_divisible = k_divisible
        self.margin = margin
        self.source_key = source_key

    def __call__(self, x):
        key = self.keys[0]
        input_data = {"image": x[key], "mask": x[self.source_key]}
        discs = []
        labels = []
        for label, disc in enumerate(x["valid_labels"], start=1):
            select_fn = lambda x: x == disc
            crop = transforms.CropForegroundd(
                keys=self.keys,
                source_key=self.source_key,
                select_fn=select_fn,
                margin=self.margin,
                k_divisible=self.k_divisible,
            )(input_data)
            discs.append(crop["image"])
            labels.append(x[str(label)])

        return [{"image": disc, "label": label} for disc, label in zip(discs, labels)]

In [None]:
transforms_ = transforms.Compose(
    [
        CheckMaskVol(
            keys=["image", "mask"], minimum_roi_dimensions=3, minimum_roi_size=1000
        ),
        transforms.LoadImaged(
            keys=["image", "mask"], image_only=True, ensure_channel_first=True
        ),
        transforms.HistogramNormalized(keys=["image"]),
        transforms.ScaleIntensityd(keys=["image"]),
        CropForegroundd(
            keys=["image"], source_key="mask", margin=0, k_divisible=(64, 64, 1)
        ),
        transforms.CenterSpatialCropd(keys=["image"], roi_size=(64, 64, 3)),
        # transforms.HistogramNormalized(keys=["image"], num_bins=100),
        transforms.Transposed(keys=["image"], indices=(0, 3, 1, 2)),
        transforms.SqueezeDimd(keys=["image"], dim=0),
        transforms.ToTensord(keys=["image"]),
    ]
)

In [None]:
index = 5
sample = transforms_(
    {
        "image": id_labels.iloc[index]["Image"],
        "mask": id_labels.iloc[index]["Mask"],
        "1": id_labels.iloc[index]["1"],
        "2": id_labels.iloc[index]["2"],
        "3": id_labels.iloc[index]["3"],
        "4": id_labels.iloc[index]["4"],
        "5": id_labels.iloc[index]["5"],
    }
)

In [None]:
for i in range(5):
    print(sample[i]["label"])
    plt.imshow(sample[i]["image"][0, :, :], cmap="gray")
    plt.show()

# Create dataloader

In [None]:
%%writefile bimcv_aikit/dataloaders/projects/MIDASDataLoader.py
from monai import transforms
from monai.data import CacheDataset, DataLoader
from numpy import array, float32, unique
from pandas import read_csv
from pathlib import Path
from radiomics import imageoperations
from SimpleITK import GetArrayFromImage, ReadImage
from sklearn.utils.class_weight import compute_class_weight
from torch import as_tensor, uint8
from torch.nn.functional import one_hot


class MIDASDataLoader:
    """
    A data loader for the MIDAS dataset and IVD degeneration.
    """

    def __init__(
        self,
        path: str,
        sep: str = ",",
        test_run: bool = False,
        partition_column: str = "Partition",
        config: dict = {},
    ):
        """
        Initializes an instance of the MIDASDataLoader class.

        Args:
            path (str): Path to the CSV file containing the data.
            sep (str, optional): Separator used in the CSV file. Defaults to ",".
            classes (list, optional): List of classes to include in the data. Defaults to ["CN", "AD"].
            map_labels_dict (dict, optional): Dictionary mapping class names to integer labels. If provided, only the classes in the dictionary will be included in the data. Defaults to None.
            test_run (bool, optional): If True, only a small subset of the data will be loaded for testing purposes. Defaults to False.
            input_shape (str, optional): Spatial size of the input images. Defaults to "(96,96,96)".
            partition_column (str, optional): The name of the column in the file that contains the partition. Defalts to Partition".
            config (dict, optional): Additional configuration options. Defaults to {}.
        """
        df = read_csv(path, sep=sep)
        n_classes = len(unique(df["2"].values))
        onehot = lambda x: one_hot(as_tensor(int(x)-1), num_classes=5).float()
        for i in range(1, 6):
            df[f"onehot_{i}"] = df[str(i)].apply(lambda x: onehot(x))
        self.groupby = df.groupby(partition_column)
        self._class_weights = compute_class_weight(
            class_weight="balanced",
            classes=unique(self.groupby.get_group("train")["2"].values),
            y=self.groupby.get_group("train")["2"].values,
        )
        self.transforms = transforms.Compose([
            CheckMaskVol(keys=["image", "mask"], minimum_roi_dimensions=3, minimum_roi_size=1000),
            transforms.LoadImaged(keys=["image", "mask"], image_only=True, ensure_channel_first=True),
            transforms.HistogramNormalized(keys=["image"]),
            transforms.ScaleIntensityd(keys=["image"]),
            CropForegroundd(keys=["image"], source_key="mask", margin=0, k_divisible=(64, 64, 1)),
            transforms.CenterSpatialCropd(keys=["image"], roi_size=(64, 64, 3)),
            # transforms.HistogramNormalized(keys=["image"], num_bins=100),
            transforms.Transposed(keys=["image"], indices=(0, 3, 1, 2)),
            transforms.SqueezeDimd(keys=["image"], dim=0),
            transforms.ToTensord(keys=["image"]),
        ])
        self.test_run = test_run
        self.config_args = config

    def __call__(self, partition: str):
        """
        Returns a DataLoader object for the specified partition.

        Args:
            partition (str): The partition to load data for (e.g. "train", "val", or "test").

        Returns:
            DataLoader: A PyTorch DataLoader object containing the specified partition's data.
        """
        image_paths = self.groupby.get_group(partition)["Image"].values
        mask_paths = self.groupby.get_group(partition)["Mask"].values
        labels_disc_1 = self.groupby.get_group(partition)["onehot_1"].values
        labels_disc_2 = self.groupby.get_group(partition)["onehot_2"].values
        labels_disc_3 = self.groupby.get_group(partition)["onehot_3"].values
        labels_disc_4 = self.groupby.get_group(partition)["onehot_4"].values
        labels_disc_5 = self.groupby.get_group(partition)["onehot_5"].values
        data = [
            {"image": img_path, "mask": mask_path, "1": label_disc_1, "2": label_disc_2, "3": label_disc_3, "4": label_disc_4, "5": label_disc_5, "label": label_disc_2,}
            for img_path, mask_path, label_disc_1, label_disc_2, label_disc_3, label_disc_4, label_disc_5 in zip(image_paths, mask_paths, labels_disc_1, labels_disc_2, labels_disc_3, labels_disc_4, labels_disc_5)
        ]
        if self.test_run:
            data = data[:16]
        dataset = CacheDataset(data=data, transform=self.transforms, num_workers=7)
        return DataLoader(dataset, **self.config_args)

    @property
    def class_weights(self):
        """
        Returns the class weights for the dataset.
        """
        return self._class_weights

class CheckMaskVol(transforms.MapTransform):
    def __init__(self, keys = ["image", "mask"], minimum_roi_dimensions: int = 3, minimum_roi_size: int = 1000):
        super().__init__(keys)
        self.minimum_roi_dimensions = minimum_roi_dimensions
        self.minimum_roi_size = minimum_roi_size

    def __call__(self, x):
        image = ReadImage(x[self.keys[0]])
        mask = ReadImage(x[self.keys[1]])
        labels = unique(GetArrayFromImage(mask).ravel())
        valid_labels = []
        for label in labels:
            if label != 0:
                try:
                    imageoperations.checkMask(image, mask, minimumROIDimensions=self.minimum_roi_dimensions, minimumROISize=self.minimum_roi_size, label=label)
                    result = label
                except Exception as e:
                    result = None
                if result:
                    valid_labels.append(result)
        x["valid_labels"] = valid_labels[:5]
        return x

class CropForegroundd(transforms.MapTransform):
    def __init__(self, keys = ["image"], source_key = "mask", margin=0, k_divisible=(64, 64, 1)):
        super().__init__(keys)
        self.k_divisible = k_divisible
        self.margin = margin
        self.source_key = source_key

    def __call__(self, x):
        key = self.keys[0]
        input_data = {"image": x[key], "mask": x[self.source_key]}
        discs = []
        labels = []
        for label, disc in enumerate(x["valid_labels"], start=1):
            select_fn = lambda x: x == disc
            crop = transforms.CropForegroundd(keys=self.keys, 
                                              source_key=self.source_key, 
                                              select_fn=select_fn, 
                                              margin=self.margin, 
                                              k_divisible=self.k_divisible)(input_data)
            discs.append(crop["image"])
            labels.append(x[str(label)])
        
        return [{"image": disc, "label": label} for disc, label in zip(discs, labels)]

In [None]:
dl = MIDASDataLoader(
    root_dir.joinpath("data", "filtered_midas900_t2w_partition.csv"),
    test_run=True,
    config={"batch_size": 32, "shuffle": True, "num_workers": 7},
)

# Create training config

In [None]:
%%writefile ../src/config.json
{
    "name": "all_discs/EfficientNetBN",
    "description": "",
    "task": "classification",
    "n_gpu": 1,
    "seed": 42,
    "arch": {
        "module": "monai.networks.nets",
        "type": "EfficientNetBN",
        "args": {
            "model_name": "efficientnet-b7",
            "spatial_dims": 2,
            "in_channels": 3,
            "num_classes": 5,
            "pretrained": false,
            "progress": false
        }
    },
    "data_loader": {
        "module": "bimcv_aikit.dataloaders",
        "type": "MIDASDataLoader",
        "partitions": {
            "train": "train",
            "val": "val",
            "test": "test"
        },
        "args": {
            "path": "../data/filtered_midas900_t2w_partition.csv",
            "test_run": false,
            "config": {
                "batch_size": 64,
                "drop_last": true,
                "shuffle": true
            }
        }
    },
    "optimizer": {
        "type": "Adadelta",
        "args": {
            "lr": 1.0,
            "rho": 0.95,
            "eps": 1e-07
        }
    },
    "loss": {
        "module": "torch.nn",
        "type": "CrossEntropyLoss",
        "args": {}
    },
    "metrics": {
        "accuracy": {
            "module": "torchmetrics.functional.classification",
            "type": "accuracy",
            "args": {
                "task": "multiclass",
                "average": "weighted",
                "num_classes": 5
            }
        },
        "f1": {
            "module": "torchmetrics.functional",
            "type": "f1_score",
            "args": {
                "task": "multiclass",
                "average": "weighted",
                "num_classes": 5
            }
        }
    },
    "trainer": {
        "type": "ClassificationTrainer",
        "epochs": 100,
        "save_dir": "../logs/",
        "save_period": 33,
        "verbosity": 2,
        "monitor": "min val_loss",
        "early_stop": 100,
        "tensorboard": true
    }
}

# Test run

In [None]:
!bimcv_train -c ../src/config.json