In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import SimpleITK as sitk
from monai import transforms

from utils import CheckMaskVol

In [None]:
root_dir = Path("..").resolve()

In [None]:
img_relation = pd.read_csv(root_dir.joinpath("data", "filtered_midas900_t2w.csv"))

In [None]:
spacings = []
for image_path in img_relation["Image"]:
    image = sitk.ReadImage(image_path)
    spacings.append(image.GetSpacing())
spacings = np.vstack(spacings)
print(f"Median spacing: {np.median(spacings, axis=0)}")
print(f"Mean spacing: {np.mean(spacings, axis=0)}")
print(f"Std spacing: {np.std(spacings, axis=0)}")

In [None]:
class CropForegroundd(transforms.MapTransform):
    def __init__(
        self, keys=["image"], source_key="mask", margin=0, k_divisible=(64, 64, 1)
    ):
        super().__init__(keys)
        self.k_divisible = k_divisible
        self.margin = margin
        self.source_key = source_key

    def __call__(self, x):
        key = self.keys[0]
        bool_mask = np.where(
            x[self.source_key] == x["valid_labels"][0], x[self.source_key], 0
        )
        for label in x["valid_labels"][1:]:
            bool_mask += np.where(x[self.source_key] == label, x[self.source_key], 0)
        input_data = {"image": np.where(bool_mask, x[key], np.nan), "mask": x[self.source_key]}
        discs = []
        labels = []
        for label, disc in enumerate(x["valid_labels"], start=1):
            select_fn = lambda x: x == disc
            crop = transforms.CropForegroundd(
                keys=self.keys,
                source_key=self.source_key,
                select_fn=select_fn,
                margin=self.margin,
                k_divisible=self.k_divisible,
            )(input_data)
            # crop2 = transforms.CenterSpatialCropd(keys=["image"], roi_size=(-1, -1, 1))(
            #     crop
            # )
            discs.append(crop["image"])
            labels.append(x[str(label)])
        return [{"image": disc, "label": label} for disc, label in zip(discs, labels)]

transforms_ = transforms.Compose(
    [
        CheckMaskVol(
            keys=["image", "mask"], minimum_roi_dimensions=3, minimum_roi_size=1000
        ),
        transforms.LoadImaged(
            keys=["image", "mask"], image_only=True, ensure_channel_first=True
        ),
        # transforms.HistogramNormalized(keys=["image"]),
        # transforms.ScaleIntensityd(keys=["image"]),
        CropForegroundd(
            keys=["image"], source_key="mask", margin=0, k_divisible=(1, 1, 1)
        ),
        transforms.CenterSpatialCropd(keys=["image"], roi_size=(64, 64, -1)),
        transforms.Transposed(keys=["image"], indices=(0, 3, 1, 2)),
        transforms.SqueezeDimd(keys=["image"], dim=0),
        transforms.ToTensord(keys=["image"]),
    ]
)

In [None]:
index = 5
sample = transforms_(
    {
        "image": img_relation.iloc[index]["Image"],
        "mask": img_relation.iloc[index]["Mask"],
        "1": 1,
        "2": 1,
        "3": 1,
        "4": 1,
        "5": 1,
    }
)

In [None]:
for i in range(9):
    plt.imshow(np.where(sample[1]["image"] != 0, sample[1]["image"], np.nan)[i, :, :], cmap="gray")
    plt.show()

In [None]:
for i in range(5):
    plt.plot(*np.unique(sample[i]["image"], return_counts=True))
    plt.hist(sample[i]["image"][~np.isnan(sample[i]["image"])].ravel(), bins=100)
    plt.show()