# Experimentation: Data Pre-Processing
---

In [None]:
# Import libraries
from monai.data import DataLoader
from monai.transforms import (EnsureChannelFirstd,
Compose, LoadImaged, ResampleToMatchd, MapTransform)

from monai.apps import TciaDataset
from monai.apps.auto3dseg import AutoRunner
from monai.bundle import ConfigParser

from monai.config import print_config
import json

print_config()


In [None]:
# Specify the collection and segmentation type
collection, seg_type = "HCC-TACE-Seg", "SEG"

# Create a dictionary to map the labels in the segmentation to the labels in the image
label_dict = {'Liver': 0,
  'Tumor': 1,
  'vessels': 2,
  'aorta': 3}

class UndoOneHotEncoding(MapTransform):
    def __init__(self, keys):
        super().__init__(keys)

    def __call__(self, data):
        for key in self.keys:
            data[key] = data[key].argmax(dim=1).unsqueeze(1)
        return data
    
# Create a composed transform that loads the image and segmentation, resamples the image to match the segmentation,
# and undoes the one-hot encoding of the segmentation
transform = Compose(
    [
        LoadImaged(reader="PydicomReader", keys=["image", "seg"], label_dict=label_dict),
        EnsureChannelFirstd(keys=["image", "seg"]),
        ResampleToMatchd(keys="image", key_dst="seg"),
        UndoOneHotEncoding(keys=["seg"]),
    ]
)

# Create a dataset for the training with a validation split
train_dataset = TciaDataset(
    root_dir="../data",
    collection=collection,
    section="training",
    transform=transform,
    download=True,
    download_len=2,
    seg_type=seg_type,
    progress=True,
    cache_rate=0.0,
    val_frac=0.0,
)

In [None]:
print(train_dataset.datalist)

In [None]:
# Create a dataloader
train_loader = DataLoader(train_dataset, batch_size=1, num_workers=0)

In [None]:
# Sample a batch of data from the dataloader
batch = next(iter(train_loader))

In [None]:
# Print the batch data keys
print(batch.keys())

# Print the batch data shapes
print(batch["image"].shape, batch["seg"].shape)

# Print the batch data types
print(batch["image"].dtype, batch["seg"].dtype)

In [None]:
# Separate the image and segmentation from the batch
image, seg = batch["image"], batch["seg"]

# Undo the one-hot encoding of the segmentation
seg = seg.argmax(dim=1)
seg = seg.unsqueeze(1)

print(image.shape, seg.shape, seg.unique())


In [None]:
import torch
import matplotlib.pyplot as plt

# Sample a slice from the image
CT_slice = image[0, 0, :, :, 45]

# Get the maximum segmentation class for each pixel in the slice
_, CT_seg_slice_max = torch.max(seg[0, :, :, :, 45], dim=0)

print(CT_slice.shape, CT_seg_slice_max.shape)

# Plot the image and segmentation slice as a subplot
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(CT_slice, cmap="gray")
axes[0].set_title("CT Image")
axes[1].imshow(CT_seg_slice_max, cmap="jet")
axes[1].set_title("CT Segmentation")
plt.colorbar(mappable=axes[1].imshow(CT_seg_slice_max, cmap='jet'), ax=axes[1])
plt.show()

## Setup AutoRunner for automatic segmentation model training and hyperparameter finetuning
---

In [None]:
# Add a fold key to all the training data
train_dataset.datalist = [{**item, 'fold': 0} for item in train_dataset.datalist]

# Change "seg" to "label" in the datalist
for item in train_dataset.datalist:
    item["label"] = item.pop("seg")

# Concatenate the training and test datalists
data_list = {"training": train_dataset.datalist}

datalist_file = "../auto3dseg_datalist.json"
with open(datalist_file, "w") as f:
    json.dump(data_list, f)

In [None]:
# Create input configuration .yaml file
input_config = {
    "name": "HCC-TACE-Seg",
    "task": "segmentation",  
    "modality": "CT", 
    "datalist": "../auto3dseg_datalist.json", 
    "dataroot": "../data", 
}

config_yaml = "./auto3dseg_config.yaml"
ConfigParser.export_config_file(input_config, config_yaml)

In [None]:
runner = AutoRunner(work_dir = "../data/auto3dseg", input=input_config)
runner.run()