## Exploratory analysis of the MM-WHS dataset

In [25]:
# imports 
import numpy as np
import os
import monai
from monai.data import CacheDataset, DataLoader, GridPatchDataset
from monai.transforms import (
    Compose,
    EnsureChannelFirstd,
    EnsureTyped,
    LoadImaged,
    Resized,
    ScaleIntensityd,
    SqueezeDimd,
    SaveImaged,
    MapLabelValued,
)
import itk

import matplotlib.pyplot as plt
from tqdm import tqdm
import glob
import torch

# Example CT dataset

In [2]:
data_dir = "../MMWHS_Dataset/ct_train"
images = sorted(glob.glob(os.path.join(data_dir, "ct_train_1001_image.nii.gz")))
labels = sorted(glob.glob(os.path.join(data_dir, "ct_train_1001_label.nii.gz")))

example_dataset = [{"img": img, "seg": seg} for img, seg in zip(images, labels)]

In [3]:
volume_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        EnsureChannelFirstd(keys=["img", "seg"]),
        ScaleIntensityd(keys="img"), # normalization between 0 and 1
        EnsureTyped(keys=["img", "seg"]),
    ]
)

In [4]:
volume_ds_person1 = CacheDataset(data=example_dataset, transform=volume_transforms, cache_rate=1.0, num_workers=4) # experiment with last two parameters
check_loader = DataLoader(volume_ds_person1, batch_size=1)
check_data = monai.utils.misc.first(check_loader)

Loading dataset:   0%|          | 0/1 [00:00<?, ?it/s]

Loading dataset: 100%|██████████| 1/1 [01:03<00:00, 63.42s/it]


In [5]:
print("first volume's shape", check_data["img"].shape, check_data["seg"].shape)

middle_image = check_data["img"][0, :, :, :, 170]
middle_label = check_data["seg"][0, :, :, :, 170]
print("image shape", middle_image.shape, "label shape", middle_label.shape)

first volume's shape torch.Size([1, 1, 512, 512, 363]) torch.Size([1, 1, 512, 512, 363])
image shape torch.Size([1, 512, 512]) label shape torch.Size([1, 512, 512])


In [21]:
output_dir = "/processed/ct_train"

In [32]:
patch_func = monai.data.PatchIterd(
    keys=["img", "seg"],
    patch_size=(None, None, 1),  # dynamic first two dimensions
    start_pos=(0, 0, 0)
)
patch_transform = Compose(
    [
        SqueezeDimd(keys=["img", "seg"], dim=-1),  # squeeze the last dim
        #Resized(keys=["img", "seg"], spatial_size=[224, 224]),
        MapLabelValued(keys=["seg"], orig_labels=[205, 420, 500, 550, 600, 820, 850], target_labels=[1, 2, 3, 4, 5 , 6 , 7]),
    ]
)

example_patch_ds = GridPatchDataset(data=volume_ds_person1, patch_iter=patch_func, transform=patch_transform)
patch_data_loader = DataLoader(example_patch_ds, batch_size=1, num_workers=2, pin_memory=torch.cuda.is_available())


In [33]:
i = 0
output_dir = "/preprocessed/ct_train/"
for batch in patch_data_loader:
    file_name = f"ct_train_image_{i}.nii.gz"
    label_name = f"ct_train_label_{i}.nii.gz"
    images, labels = batch[0]["img"], batch[0]["seg"]
    itk.imwrite(images, os.path.join(output_dir, file_name))

    

AttributeError: 'MetaTensor' object has no attribute 'UpdateOutputInformation'

In [None]:
print("first image shape", check_data_patch["img"].shape, check_data_patch["seg"].shape)


plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.axis("off")
plt.imshow(middle_image.T, cmap="gray")
plt.subplot(1, 2, 2)
plt.axis("off")
plt.imshow(middle_label.T)
