# Experimentation: Data Pre-Processing
---

In [None]:
# Import libraries
from monai.data import DataLoader
from monai.transforms import (EnsureChannelFirstd,
Compose, LoadImaged, ResampleToMatchd, MapTransform, SaveImaged, LoadImage)

from monai.apps import TciaDataset
from monai.apps.auto3dseg import AutoRunner
from monai.bundle import ConfigParser

from monai.config import print_config
import json

print_config()


In [None]:
# Specify the collection and segmentation type
collection, seg_type = "HCC-TACE-Seg", "SEG"

# Create a dictionary to map the labels in the segmentation to the labels in the image
label_dict = {'Liver': 0,
  'Mass': 1,
  'Necrosis': 2,
  'Portal vein': 3,
  'Abdominal aorta': 4}

class UndoOneHotEncoding(MapTransform):
    def __init__(self, keys):
        super().__init__(keys)

    def __call__(self, data):
        for key in self.keys:
            data[key] = data[key].argmax(dim=0).unsqueeze(0)
        return data
    
# Create a composed transform that loads the image and segmentation, resamples the image to match the segmentation,
# and undoes the one-hot encoding of the segmentation
transform = Compose(
    [
        LoadImaged(reader="PydicomReader", keys=["image", "seg"], label_dict=label_dict),
        EnsureChannelFirstd(keys=["image", "seg"]),
        #ResampleToMatchd(keys="image", key_dst="seg"),
        #UndoOneHotEncoding(keys="seg"),
        #SaveImaged(keys="seg", output_dir="/segmentations", output_postfix="seg", output_ext=".dcm", output_dtype="torch.float32", data_root_dir="../data/HCC-TACE-Seg"),
    ]
)

# Create a dataset for the training with a validation split
train_dataset = TciaDataset(
    root_dir="../data",
    collection=collection,
    section="training",
    transform=transform,
    download=True,
    download_len=2,
    seg_type=seg_type,
    progress=True,
    cache_rate=0.0,
    val_frac=0.0,
    seed=0,
)

# Create a dataset for the training with a validation split
test_dataset = TciaDataset(
    root_dir="../data",
    collection=collection,
    section="test",
    transform=transform,
    download=True,
    download_len=2,
    seg_type=seg_type,
    progress=True,
    cache_rate=0.0,
    val_frac=0.0,
    seed=100,
)

In [None]:
from monai.transforms import LoadImage, Transform, SaveImage
import pydicom
import os
import re

class UndoOneHotEncoding(Transform):
    def __call__(self, img):
        return img.argmax(dim=0).unsqueeze(0)

# Define the transform
loader = LoadImage(reader= "PydicomReader", image_only=True, ensure_channel_first=True)
undo_one_hot = UndoOneHotEncoding()



# Regular expression to match the patient directories
patient_dir_pattern = re.compile(r"HCC_\d{3}")
data_root = "../data/HCC-TACE-Seg"

# Iterate over all directories in the dataset
for directory in os.listdir(data_root):
    # If the directory is a patient directory
    if patient_dir_pattern.match(directory):
        
        patient_seg_dir = os.path.join(data_root, directory, "300", "seg")
        
        # Convert the One-Hot-Encoded DICOM segmentations to a single DICOM segmentation
        for seg_file in os.listdir(patient_seg_dir):
            if seg_file.endswith(".dcm"):
                seg_path = os.path.join(patient_seg_dir, seg_file)
                
                # Load the DICOM file
                dicom = loader(seg_path)

                # Apply the transform to the segmentation
                seg = undo_one_hot(dicom)
                
                save_image = SaveImage(output_dir=patient_seg_dir, output_postfix="", output_ext=".dcm", output_dtype="torch.float32", separate_folder=False)
                save_image(seg, seg_path)
                
                # Remove the One-Hot-Encoded segmentation
                # if not seg_file.endswith("_seg.dcm"):
                #     os.remove(seg_path)
                

In [None]:
# Create a dataloader
train_loader = DataLoader(train_dataset, batch_size=1, num_workers=0)

In [None]:
# Sample a batch of data from the dataloader
batch = next(iter(train_loader))


In [None]:
batch2 = next(iter(train_loader))

In [None]:
# Print the batch data keys
print(batch.keys())

# Print the batch data shapes
print(batch["image"].shape, batch["seg"].shape)

# Print the batch data types
print(batch["image"].dtype, batch["seg"].dtype)

In [None]:
# Separate the image and segmentation from the batch
image, seg = batch["image"], batch["seg"]

# Undo the one-hot encoding of the segmentation
# seg = seg.argmax(dim=1)
# seg = seg.unsqueeze(1)


print(image.shape, seg.shape, seg.unique())


In [None]:
import torch
import matplotlib.pyplot as plt

slice_idx = 60

# Sample a slice from the image
CT_slice = image[0, 0, :, :, slice_idx]

# Get the maximum segmentation class for each pixel in the slice
CT_seg_slice = seg[0, 0, :, :, slice_idx]

print(CT_slice.shape, CT_seg_slice.shape)

# Plot the image and segmentation slice as a subplot
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(CT_slice, cmap="gray")
axes[0].set_title("CT Image")
axes[1].imshow(CT_seg_slice, cmap="jet")
axes[1].set_title("CT Segmentation")
plt.colorbar(mappable=axes[1].imshow(CT_seg_slice, cmap='jet'), ax=axes[1])
plt.show()