In [None]:
import argparse
import glob
import json
import os
import shutil
import copy

import numpy as np
import matplotlib.pyplot as plt

from monai.config import print_config
from monai.transforms import (
    LoadNiftid,
    AsChannelFirstd,
    Spacingd,
    Orientationd,

    AddChanneld,
    Resized,
    NormalizeIntensityd,
    ToTensord
)

from byoc import (
    SpatialCropForegroundd
)

In [None]:
roi_size = [256, 256]
data = {'image': '/workspace/data/52432/Training/img/img0001.nii.gz', 'label': '/workspace/data/52432/Training/label/label0001.nii.gz'}
slice_idx = 111

transforms = [
    LoadNiftid(keys=('image', 'label')),
    AsChannelFirstd(keys=('image', 'label')),
    Spacingd(keys=('image', 'label'), pixdim=(1.0, 1.0), mode=('bilinear', 'nearest')),
    Orientationd(keys=('image', 'label'), axcodes="RAS"),
]

pre_transforms = [
    AddChanneld(keys=('image', 'label')),
    SpatialCropForegroundd(keys=('image', 'label'), source_key='label', spatial_size=roi_size),
    Resized(keys=('image', 'label'), spatial_size=roi_size, mode=('area', 'nearest')),
    NormalizeIntensityd(keys='image', subtrahend=208.0, divisor=388.0),
    #ToTensord(keys=('image', 'label'))
]

def show_image(image, label):
    plt.figure("check", (12, 6))
    plt.subplot(1, 2, 1)
    plt.title("image")
    plt.imshow(image, cmap="gray")
    plt.colorbar()

    plt.subplot(1, 2, 2)
    plt.title("label")
    plt.imshow(label)
    plt.colorbar()
    plt.show()


for t in transforms:
    tname = type(t).__name__ 

    data = t(data)
    image = data['image']
    label = data['label']

    print(f"{tname} => image shape: {image.shape}, label shape: {label.shape}")

    image = image[:, :, slice_idx] if tname in ('LoadNiftid') else image[slice_idx, :, :]
    label = label[:, :, slice_idx] if tname in ('LoadNiftid') else label[slice_idx, :, :]
    show_image(image, label)
    
for i in range(2, 3): # 6 is liver
    pdata = copy.deepcopy(data)
    image = pdata['image']
    label = pdata['label']

    # Get slice and matching label
    label = (label == i).astype(np.float32)
    image = image[slice_idx, :, :]
    label = label[slice_idx, :, :]
    
    if np.sum(label) == 0:
        continue

    pdata['image'] = image
    pdata['label'] = label

    for t in pre_transforms:
        tname = type(t).__name__ 
        pdata = t(pdata) if tname != 'CropForegroundd' else pdata
        
        if tname == 'SpatialCropForegroundd':
            print("Cropped size: {}".format(pdata['image_meta_dict']['foreground_cropped_shape']))

        image = pdata['image']
        label = pdata['label']
        print(f"region-{i}:: {tname} => image shape: {image.shape}, label shape: {label.shape};  sum: {np.sum(label)}; min: {np.min(label)}; max: {np.max(label)}")

        show_image(image[0], label[0])
