# Creating the synthetic dataset

In [None]:
EXPERIMENT_NAME = "BraTS_W_PWA100__W_PWT100__Unet_FC_new"
UNET = "Unet_FC" # Unet or Unet_FC
RESUME = 990
DEVICE = "cuda:0"
HOME_DIR = "/projects"
WORK_DIR = "/projects"

from os.path import join
# Fix variablles #
DATA_LIST_KEY = "training"
DATA_LIST_FILE_PATH = join(WORK_DIR, "aritifcial-head-and-neck-cts/GANs/data/BraTS2023_GLI_data_split.json") # Path where to save the json file 
DATA_DIR = "../../brats2023/ASNR-MICCAI-BraTS2023-GLI-Challenge-TrainingData" # WORK_DIR: Working dir
CHECKPOINT_PATH = f"../../checkpoint/style_256/{EXPERIMENT_NAME}"
IMAGE_SIZE = (256, 256, 256)
DIM = 1024
NOISE_DIM = 512
IN_CHANNEL_G = 3
OUT_CHANNEL_G = 1
SKIP_LATENT = False
TAHN_ACT = False
SAVE_DIR = "../../nnUNet/nnUNet_raw/Dataset996_BraTS_GAN"


In [None]:
from os.path import join, exists, dirname, basename
from os import listdir, makedirs, environ
def maybe_make_dir(directory):
    if not exists(directory):
        # If it doesn't exist, create the directory
        makedirs(directory)

In [None]:
import sys
import torch
import numpy as np
import nibabel as nib
from nilearn import plotting
from tqdm import tqdm
from monai.data import load_decathlon_datalist, DataLoader, CacheDataset, Dataset
from monai.transforms import (
    Compose, 
    LoadImaged,
    EnsureChannelFirstd, 
    EnsureTyped,
    Orientationd,
    ResizeWithPadOrCropd,
    ScaleIntensityd,
    RandAffined,
    RandFlipd,
    CropForegroundd,
    Invertd,
    RandZoomd,
    
)

sys.path.insert(1, join(HOME_DIR, "aritifcial-head-and-neck-cts/GANs/src"))
from utils.data_loader_utils import ConvertToMultiChannelBasedOnBratsClasses2023d
if UNET=="Unet_FC":
    print("USING THE UNET Fully Connected LIKE GENERATOR")
    from network.cWGAN_Style_Unet_256_FC import Generator, Critic
elif UNET=="Unet":
    print("USING THE UNET LIKE GENERATOR")
    from network.cWGAN_Style_Unet_256 import Generator, Critic
else:
    print("WELCOME TO THE ERROR ZONE")

In [None]:
def get_gen(checkpoint_path, RESUME):
    print(f"Loading from: {checkpoint_path}, {RESUME}")
    gen = Generator(in_channels=DIM, latent_dim=NOISE_DIM, IN_CHANNEL_G=IN_CHANNEL_G, OUT_CHANNEL_G=OUT_CHANNEL_G, z_dim=NOISE_DIM, w_dim=NOISE_DIM, skip_latent=SKIP_LATENT, tahn_act=TAHN_ACT)
    gen.to(DEVICE)
    gen_weight_path = join(checkpoint_path, "weights", f"{RESUME}_gen.pth")
    checkpoint = torch.load(gen_weight_path, map_location=torch.device(DEVICE))
    # Load the model's state dictionary
    gen.load_state_dict(checkpoint['model_state_dict'])
    gen.eval()
    return gen

def generate_detection_train_transform(
    image_key,
    label_key,
    image_size,
    ConvertToMultiChannel_BackandForeground, 
):
    """
    Generate training transform for the GAN.

    ARGS:
        image_key: the key to represent images in the input json files
        label_key: the key to represent labels in the input json files
        image_size: final image size for resizing 

    RETURN:
        training transform for the GAN
    """
    compute_dtype = torch.float32
    transforms = Compose(
        [
            LoadImaged(keys=[image_key, label_key], meta_key_postfix="meta_dict", image_only=False),
            EnsureChannelFirstd(keys=[image_key, label_key]),
            EnsureTyped(keys=[image_key, label_key], dtype=torch.float32),
            Orientationd(keys=[image_key, label_key], axcodes="RAS"),
            ResizeWithPadOrCropd(
                    keys=[image_key, label_key],
                    spatial_size=image_size,
                    mode="constant",
                    value=0
                ),
            #ScaleIntensityd(keys=[image_key], minv=-1, maxv=1),
            ConvertToMultiChannel_BackandForeground(
                keys=[label_key],
            ),
            EnsureTyped(keys=[image_key, label_key], dtype=compute_dtype)
        ]
    )

    return transforms

def get_loader(IMAGE_SIZE, DATA_LIST_KEY, DATA_DIR):
    """
    ARGS:
        image_size: final image size for resizing 
        batch_size: Batch size
        
    RETURN:
        train_loader: data loader
        train_data: dict of the data loaded 
    """

    # Get train transforms
    ConvertToMultiChannel_BackandForeground = ConvertToMultiChannelBasedOnBratsClasses2023d
    transforms = generate_detection_train_transform(
            image_key = "t1c",
            label_key = "seg",
            image_size = IMAGE_SIZE,
            ConvertToMultiChannel_BackandForeground = ConvertToMultiChannel_BackandForeground
        )

    # Get training data dict 
    data_set = load_decathlon_datalist(
            DATA_LIST_FILE_PATH,
            is_segmentation=True,
            data_list_key=DATA_LIST_KEY,
            base_dir=DATA_DIR,
        )
    print(data_set[0])
    ds = CacheDataset(
        data=data_set[:],
        transform=transforms,
        cache_rate=1,
        copy_cache=False,
        progress=True,
        num_workers=4,
    )

    loader = DataLoader(
        ds,
        batch_size=1,
        num_workers=4,
        pin_memory=torch.cuda.is_available(),
        shuffle=False,
        #collate_fn=no_collation,
    )

    return loader, ds

In [None]:
def do_gen_infer(gen, data):
    fake_image = gen(data)
    return fake_image

def get_affine(file_path):
    nii_img = nib.load(file_path)
    affine_matrix = nii_img.affine
    return affine_matrix

def save_nifti(data, reality, affine=None, save=None):
    if affine is None:
        affine = np.array([[1, 0, 0, 0],
                   [0, 1, 0, 0],
                   [0, 0, 1, 0],  # Assuming 3 for the spacing along the third axis
                   [0, 0, 0, 1]])
    try:
        np_fake = np.squeeze((data).data.cpu().numpy()).astype(np.float32)
    except:
        #print("Not torch!")
        np_fake = data
    nifti_fake = nib.Nifti1Image(np_fake, affine=affine)
    #plotting.plot_img(nifti_fake, title=reality, cut_coords=None, annotate=False, draw_cross=False, black_bg=True)
    if save!=None:
        nib.save(nifti_fake, save)



In [None]:
def normalize_intensity(image, new_min, new_max):
    """
    Normalise the intensities into a new min and a new max 
    """
    # Assuming 'image' is a NumPy array with intensities in the range [-1, 1]
    clipped_image = torch.clip(image, -1, 1)
    
    # Define the original range
    original_min = -1
    original_max = 1
    
    # Perform linear transformation to the new range
    normalized_image = (clipped_image - original_min) / (original_max - original_min) * (new_max - new_min) + new_min
    
    return normalized_image

def post_processing(fake_image, seg, ct_scan):
    """
    Performing post processing to the generated cases.
    Normalise intensity and crop.
    """


    fake_image_np = fake_image[0][0]
    ct_scan_np = ct_scan[0][0]

    # Converting segmentation back from regions 
    seg_0 = seg[0][0]
    seg_1 = seg[0][1]
    seg_2 = seg[0][2]
    final_seg = seg_0 + seg_1 + seg_2
    new_seg = torch.zeros_like(final_seg)
    new_seg[final_seg==2] = 1
    new_seg[final_seg==1] = 2
    new_seg[final_seg==3] = 3

    # Normalise intensities
    fake_image_np_norm = normalize_intensity(image=fake_image_np, new_min=ct_scan_np.min(), new_max=ct_scan_np.max())
    
    # Cropping
    min_x, max_x, min_y, max_y, min_z, max_z = 8, -8, 8, -8, 50, -51
    cropped_ct_scan = ct_scan_np[min_x:max_x, min_y:max_y, min_z:max_z]
    cropped_fake_scan = fake_image_np_norm[min_x:max_x, min_y:max_y, min_z:max_z]
    cropped_seg = new_seg[min_x:max_x, min_y:max_y, min_z:max_z]

    # Flipping to have the same orientation as the original cases
    cropped_ct_scan = np.flip(cropped_ct_scan, axis=1)
    cropped_ct_scan = np.flip(cropped_ct_scan, axis=0)
    cropped_fake_scan = np.flip(cropped_fake_scan, axis=1)
    cropped_fake_scan = np.flip(cropped_fake_scan, axis=0)
    cropped_seg = np.flip(cropped_seg, axis=1)
    cropped_seg = np.flip(cropped_seg, axis=0)
    
    return cropped_ct_scan, cropped_fake_scan, cropped_seg

In [None]:
loader, ds = get_loader(IMAGE_SIZE=IMAGE_SIZE, DATA_LIST_KEY=DATA_LIST_KEY, DATA_DIR=DATA_DIR)

In [None]:
gen = get_gen(checkpoint_path=CHECKPOINT_PATH, RESUME=RESUME)

In [None]:
maybe_make_dir(directory=SAVE_DIR)
maybe_make_dir(directory=join(SAVE_DIR, "imagesTr"))
maybe_make_dir(directory=join(SAVE_DIR, "labelsTr"))
maybe_make_dir(directory=join(SAVE_DIR, "labelsTr_origin"))

In [None]:
def get_seg_transforms(prob, label_key):
    affine_transforms = Compose(
            [
                # The label will:
                    # Rotate between -180 and +180 degrees
                    # Shear between -10% and +10%
                    # Flip in all 3 axis 
                RandAffined(
                    keys = [label_key],
                    prob = prob, 
                    rotate_range = np.pi, 
                    shear_range = 0.1, 
                    #translate_range = 50, 
                    mode = "nearest",
                    padding_mode = "zeros"
                    ),
                RandFlipd(
                    keys = [label_key],
                    prob = prob/2, 
                    spatial_axis=0
                ),
                RandFlipd(
                    keys = [label_key],
                    prob = prob/2, 
                    spatial_axis=1
                ),
                RandFlipd(
                    keys = [label_key],
                    prob = prob/2, 
                    spatial_axis=2
                ),
                RandZoomd(
                    keys = [label_key],
                    prob = prob, 
                    min_zoom = 0.9,
                    max_zoom = 1.1,
                    mode = "nearest",
                ),
            ]
    )

    translate_transforms = Compose(
            [
                    # Shift between -50 and +50 voxels in all 3 axis
                RandAffined(
                    keys = [label_key],
                    prob = prob, 
                    translate_range = 10, 
                    mode = "nearest",
                    padding_mode = "zeros"
                    ),
                EnsureTyped(keys=[label_key], dtype=torch.float32)
            ]
        )
    return affine_transforms, translate_transforms

In [None]:
def coordenates_cropping_label(label):
    # Find non-background coordinates
    non_background_coords = np.where(label != 0)

    # Determine the cropping bounds
    min_x, max_x = np.min(non_background_coords[0]), np.max(non_background_coords[0])
    min_y, max_y = np.min(non_background_coords[1]), np.max(non_background_coords[1])
    min_z, max_z = np.min(non_background_coords[2]), np.max(non_background_coords[2])

    return min_x, max_x, min_y, max_y, min_z, max_z
    
def get_origin_center_coord(min_x, max_x, min_y, max_y, min_z, max_z):
    # Returns the center
    center_x = (min_x + max_x) // 2
    center_y = (min_y + max_y) // 2
    center_z = (min_z + max_z) // 2

    return center_x, center_y, center_z

# This portion of code rotates the tumour and places it again the in same place
def paste_rotated_seg(trans_seg, center_x, center_y, center_z):
    # Cropping the rotated and sheared tumour
    min_x, max_x, min_y, max_y, min_z, max_z = coordenates_cropping_label(torch.sum(trans_seg[0], dim=0).numpy())
    cropped_seg = trans_seg[:, :, min_x:max_x+1, min_y:max_y+1, min_z:max_z+1] # shape [1,3,x,y,z]

    x, y, z = cropped_seg.shape[2], cropped_seg.shape[3], cropped_seg.shape[4]

    new_seg = torch.zeros_like(seg)
    # Calculate the bounds for placing the new tumour
    start_x = max(0, center_x - x // 2)
    end_x = min(seg.shape[2], center_x + (x + 1) // 2)

    start_y = max(0, center_y - y // 2)
    end_y = min(seg.shape[3], center_y + (y + 1) // 2)

    start_z = max(0, center_z - z // 2)
    end_z = min(seg.shape[4], center_z + (z + 1) // 2)

    # Placing the new tumour with the same center 
    new_seg[:, :, start_x:end_x, start_y:end_y, start_z:end_z] = cropped_seg
    return new_seg

def apply_rotate_and_shear(seg, affine_transforms):
    # Geting center of tumour 
    min_x, max_x, min_y, max_y, min_z, max_z = coordenates_cropping_label(torch.sum(seg[0], dim=0).numpy())
    center_x, center_y, center_z = get_origin_center_coord(min_x, max_x, min_y, max_y, min_z, max_z)

    # Applying transforms
    in_trans = {"seg": seg[0]}
    trans_batch = affine_transforms(in_trans)
    trans_seg = trans_batch["seg"].unsqueeze(0)

    new_seg = paste_rotated_seg(trans_seg, center_x, center_y, center_z)
    return new_seg


## Generating the same number of cases as the original dataset (1000)
* No transformation is used in the labels

In [None]:
###
loop_train = tqdm(loader, leave=True)
for batch_idx, batch in enumerate(loop_train):
    with torch.no_grad():
        ct_scan, seg = batch["t1c"].to(DEVICE), batch["seg"].to(DEVICE)

        # Generating synthetic scan
        fake_image = do_gen_infer(gen=gen, data=seg)

        # Normalising synthetic scan intensity to the same values as the original case, 
        # and cropping to the same shape
        cropped_ct_scan, cropped_fake_scan, cropped_seg = post_processing(fake_image, seg, ct_scan)

        # Get affine 
        ct_path = batch["t1c_meta_dict"]["filename_or_obj"][0]
        seg_path = batch["seg_meta_dict"]["filename_or_obj"][0]
        ct_name = f"{ct_path.split('/')[-1].split('-t1c')[0]}"
        affine_matrix = get_affine(ct_path)

        # Saving synthetic scan
        save_path = join(SAVE_DIR, f"imagesTr/{ct_name}_0000.nii.gz")
        save_nifti(data=cropped_fake_scan, reality="Fake", affine=affine_matrix, save=save_path)
        # Saving segmentation
        save_path = join(SAVE_DIR, f"labelsTr_origin/{ct_name}.nii.gz")
        save_nifti(data=cropped_seg, reality="Fake", affine=affine_matrix, save=save_path)

    loop_train.set_postfix(
        Case = ct_name,
    )

    