# GBA pipeline

### By Guillaume Sallé

Here is the list of folders used and created by this notebook :

    to_augment3D and to_augmentSEG3D : images (and labels) to augment
    
    np3D and npSEG3D : naively-augmented images (and labels)
    
    sinGAN_input and slices2D : 2D slices, respectively where tumor appears (before being fed to SinGAN) and where tumor does not appear. In sinGAN_input, 2D masks are also stored as required by SinGAN.
    
    sinGAN_output_scaleXXker5 : 2D slices generated by SinGAN, where XX is the chosen scale for SinGAN harmonization (K-k* in our paper)
    
    augmented3D_scaleXXker : 3D volumes corresponding to 2D slices generated by SinGAN (GBA images)

## I. Import

In [None]:
import os
import re
import shutil
import nibabel as nib
import nibabel.processing as nibp
import numpy as np
import glob
import shutil
from scipy import ndimage
import matplotlib.pyplot as plt
import json
from skimage import measure

def connected_components(mask):
    new_mask = measure.label(mask)
    shapes = new_mask.shape
    #print(np.bincount(new_mask.flatten())[1:])
    try:
        label_value = np.argmax(np.bincount(new_mask.flatten())[1:])+1
    except ValueError:
        mask[0,200,0]=1
        new_mask = measure.label(mask)
        shapes = new_mask.shape
        label_value = np.argmax(np.bincount(new_mask.flatten())[1:])+1
    #print(label_value)
    new_mask[new_mask<label_value-0.1]=0
    new_mask[new_mask>label_value+0.1]=0
    new_mask[new_mask!=0]=1
    return new_mask.reshape(shapes)

def dice(pred, gt):
    intersection = np.sum(pred[gt==1.0]) * 2.0
    return intersection / (np.sum(pred) + np.sum(gt))

def nb_connected_components(mask):
    return np.max(measure.label(mask))

def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

## II. Naive augmentation and SinGAN preprocessing

The cell below generates naively-augmented images (without SinGAN) and preprocess them for SinGAN model by converting them from 3D to 2D, and splitting 2D slices into 2 directories, regarding if the tumor appears on the slice or not.

You can choose different lambda values and affect them a probability.

In [None]:
# SinGAN preprocessing : 3D -> 2D with SinGAN format
aug_folder = "augmentationGBA_test01"
lambda_values = [0.5, 0.8, 1.5]
probabilies_at_each_aug = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
nb_augmentations = 3 # number of variations you want to create for each image to augment
folder_img3D = os.path.join(".","..",aug_folder,"to_augment3D") # specify the directory
folder_seg3D = os.path.join(".","..",aug_folder,"to_augmentSEG3D") # specify the directory

folder_np3D = os.path.join(".","..",aug_folder,"np3D") # will be created
folder_npseg3D = os.path.join(".","..",aug_folder,"npSEG3D") # will be created
folder_img2DNOTUMOR = os.path.join(".","..",aug_folder,"slices2D") # will be created
folder_img2DinputSG = os.path.join(".","..",aug_folder,"sinGAN_input") # will be created
os.makedirs(folder_np3D, exist_ok=True)
os.makedirs(folder_img2DNOTUMOR, exist_ok=True)
os.makedirs(folder_img2DinputSG, exist_ok=True)
os.makedirs(folder_npseg3D, exist_ok=True)

assert(nb_augmentations==len(probabilies_at_each_aug))
for i in range(len(probabilies_at_each_aug)):
    assert(len(lambda_values)==len(probabilies_at_each_aug[i]))
img_files = os.listdir(folder_img3D)
seg_files = os.listdir(folder_seg3D)
img_files.sort(key = natural_keys)
seg_files.sort(key = natural_keys)

for k in range(nb_augmentations):
    # Select lambda values according to probabilities
    lambda_val = np.random.choice(lambda_values, size=len(img_files), p=probabilies_at_each_aug[k])
    for i in range(len(img_files)):
        # Load image and mask. Scale input image from range [-1,1] to [0,1]
        img = nib.load(os.path.join(folder_img3D, img_files[i]))
        img_arr = (img.get_fdata()+1)/2
        seg = nib.load(os.path.join(folder_seg3D, seg_files[i]))
        seg_arr = seg.get_fdata()
        # Apply naive contrast alteration on the tumor. The tumo
        img_arr[seg_arr==1]*=lambda_val[i]
        # Put back the image in the range [-1,1]
        img_arr = img_arr*2-1.0
        # Save the naively-augmented 3D image 
        # but also 2D axial slices, in 2 separated folders, if the tumor appears or not
        # If it appears, it also saves the mask, as SinGAN requires the mask.        
        m = re.search('_([0-9]+)_', img_files[i]) # if the patient numbering is written between underscore characters
        img_name = img_files[i][:18]+m.group(1)+"_aug"+str(k+1)+"_fakeT2.nii.gz"
        nib.save(nib.Nifti1Image(img_arr, None, img.header), os.path.join(folder_np3D, img_name))
        for j in range(img_arr.shape[2]):
            img_name = img_files[i][:18]+m.group(1)+"_aug"+str(k+1)+"_fakeT2_s"+format(j,"03d")+".nii.gz"
            seg_name = img_name[:-7]+"_mask.nii.gz"
            if (seg_arr[:,:,j]).max()>0.9:
                nib.save(nib.Nifti1Image(img_arr[:,:,j], None, img.header), os.path.join(folder_img2DinputSG, img_name))
                nib.save(nib.Nifti1Image(seg_arr[:,:,j], None, seg.header), os.path.join(folder_img2DinputSG, seg_name))
            else:
                nib.save(nib.Nifti1Image(img_arr[:,:,j], None, img.header), os.path.join(folder_img2DNOTUMOR, img_name))
        #break

# Fill npSEG3D folder with copies of the real masks
seg_files = os.listdir(folder_seg3D)
seg_files.sort(key=natural_keys)
for k in range(nb_augmentations):
    for i in range(len(seg_files)):
        m = re.search('_([0-9]+)_', seg_files[i])
        img_name = seg_files[i][:18]+m.group(1)+"_aug"+str(k+1)+"_SEG.nii.gz"
        shutil.copy(os.path.join(folder_seg3D, seg_files[i]), os.path.join(folder_npseg3D,img_name))

## III. Appliquer SinGAN via le dossier de passation "SinGAN_pour_harmonisation"

Use official SinGAN implementation on folder "sinGAN_input" to apply SinGAN harmonization.

## IV. SinGAN post-processing

The cell below postprocesses SinGAN outputs by stacking generated slices and slices without tumor from the corresponding patient in the original space. We used regular expression to make sure we don't mix patients together.

    m = re.search('202([0-9]+)_.*_([0-9]+)_aug([0-9]+)', aug2D[i])
    
In CrossMoDA 2022 dataset, images from center A start with "2021_..." while images from center B start with "2022_...". We separated each of them according to this trick and the previous numbering. As we sort slices according to file names, we read each of file names and compare it with the previous one to decide to which patient belongs a given slice. You should adapt the regular expressions according to your file names.

We defined a list of list to separate each patient before gathering the corresponding slices.

In [None]:
# SinGAN postprocessing : from 2D slices to 3D volumes
aug_folder = "augmentationGBA_test01"
SGoutput = "sinGAN_output_scale13ker5"
augmented3D = "augmented3D_scale13ker5"

folder_ref = os.path.join(".","..",aug_folder,"np3D")
folder_img2DNOTUMOR = os.path.join(".","..",aug_folder,"slices2D")
folder_img2DoutputSG = os.path.join(".","..",aug_folder,SGoutput)
folder_augmented3D = os.path.join(".","..",aug_folder,augmented3D) # will be created

ref3D = os.listdir(folder_ref)
ref3D.sort(key=natural_keys)
aug2D = os.listdir(folder_img2DoutputSG) + os.listdir(folder_img2DNOTUMOR) # merging all slice files
aug2D.sort(key=natural_keys)
if os.path.exists(folder_augmented3D):
    shutil.rmtree(folder_augmented3D)
os.makedirs(folder_augmented3D)

# Create list of list with all slices per patients
patients = [[] for h in range(800)]
curr_index = -1
prev = ""
for i in range(len(aug2D)):
    m = re.search('202([0-9]+)_.*_([0-9]+)_aug([0-9]+)', aug2D[i])
    if m.group(1)+"_"+m.group(2)+"_"+m.group(3) != prev:
        curr_index+=1
        prev = m.group(1)+"_"+m.group(2)+"_"+m.group(3)
    patients[curr_index].append(aug2D[i])
# Stacking slices to build 3D volumes
k=0 #in order to take input headers in the right order
for i in range(800): #while there are still some patients
    if patients[i] != []:
        ind_patient = i+1
        header = nib.load(os.path.join(folder_ref, ref3D[i])).header
        # initialize empty volume
        stacked = np.zeros((256, 256, len(patients[i]))) # 256x256xZ crops only because of SinGAN memory
        # fill it with corresponding slices
        for j in range(len(patients[i])):
            if os.path.exists(os.path.join(folder_img2DoutputSG, patients[i][j])):
                stacked[:,:,j] = (nib.load(os.path.join(folder_img2DoutputSG, patients[i][j]))).get_fdata()
            else:
                stacked[:,:,j] = (nib.load(os.path.join(folder_img2DNOTUMOR, patients[i][j]))).get_fdata()
        nib.save(nib.Nifti1Image(stacked, None, header), \
                 os.path.join(folder_augmented3D, ref3D[i]))
        k+=1
#"""