# Training a segmentation model

In [None]:
#Preparing data
import os
import nibabel as nib
import pandas as pd
from os.path import join
import numpy as np
from skimage.io import imread
import SimpleITK as sitk
import matplotlib.pyplot as plt
from shutil import copyfile
from nnunet.utils import generate_dataset_json

## preprocessing

In [None]:
def load_tiff_convert_to_nifti(img_file, lab_file, img_out_base, anno_out, spacing):
    img = imread(img_file)
    img_itk = sitk.GetImageFromArray(img.astype(np.float32))
    img_itk.SetSpacing(np.array(spacing)[::-1])
    sitk.WriteImage(img_itk, join(img_out_base + "_0000.nii.gz"))

    if lab_file is not None:
        l = imread(lab_file)
        l = (l / 255).astype(int) # set label to 0 or 1
        l[l > 0] = 1
        l_itk = sitk.GetImageFromArray(l.astype(np.uint8))
        l_itk.SetSpacing(np.array(spacing)[::-1])
        sitk.WriteImage(l_itk, join(anno_out+'.nii.gz'))

In [None]:
idx = 'M1_2'
postfix = '_C0.tif'
base_name = 'KAKU4-wt--CRWN1-wt--CRWN4-wt_Cot_J13_STD_FIXE_H258_{}'.format(idx)

img_folder = 'images_sophie'
img_file = img_folder+base_name+postfix

msk_folder = 'masks_sophie'
msk_file = msk_folder+base_name+postfix

img_out_base_folder = 'nnUNet_raw_data_base/nnUNet_raw_data/Task500_Nucleus/imagesTr/'
img_test_out_base_folder = 'nnUNet_raw_data_base/nnUNet_raw_data/Task500_Nucleus/imagesTs/'
img_out_base = img_out_base_folder+base_name

anno_out_folder = 'nnUNet_raw_data_base/nnUNet_raw_data/Task500_Nucleus/labelsTr/'
anno_test_out_folder = 'nnUNet_raw_data_base/nnUNet_raw_data/Task500_Nucleus/labelsTs/'
anno_out = anno_out_folder+base_name

case = 'Nucleus'

target_base = 'nnUNet_raw_data_base/nnUNet_raw_data/Task500_Nucleus'

spacing=(0.2e-3, 0.1032e-3,0.1032e-3)

In [None]:
def add_zeros_before(i):
    # converts i to '00i'
    if i < 10:
        return '00{}'.format(i)
    elif i < 100:
        return '0{}'.format(i)
    else:
        return '{}'.format(i)

In [None]:
# converts all images to nifti format
import tqdm
list_imgs = os.listdir(img_folder)
for i in range(len(list_imgs)):
#     img = imread(img_folder + list_imgs[i])
#     msk = imread(msk_folder + list_imgs[i])
    
    # input name
    base_name = list_imgs[i][:str.rfind(list_imgs[i], '.')]
    
    img_out_base = img_out_base_folder+case+'_'+add_zeros_before(i)
    anno_out = anno_out_folder+case+'_'+add_zeros_before(i)
    
    
    load_tiff_convert_to_nifti(
        img_folder + '/' + list_imgs[i], 
        msk_folder + '/' + list_imgs[i], 
        img_out_base, 
        anno_out, 
        spacing)

# plt.subplot(1,2,1)
# plt.imshow(img[len(img)//2])

# plt.subplot(1,2,2)
# plt.imshow(msk[len(msk)//2])

In [None]:
# move test images to the right folder

#'fold_x' Ã  la place de 'hold_out'
def get_train_test_df(df):
    """
    Return the train set and the test set
    """
    train_set = np.array(df4[df4['fold_x']==0].iloc[:,0])
    test_set = np.array(df4[df4['fold_x']==1].iloc[:,0])
    return train_set, test_set

df_path = '/home/mougeotg/all/data/nuclei/gred_val_all/'
df4 = pd.read_csv('folds_x_sophie.csv')

In [None]:
train_set, test_set = get_train_test_df(df4)
print("Size of train set {}".format(len(train_set)))
print("Size of test set {}".format(len(test_set)))

In [None]:
# copy test files to test folder
# removes them from the train folder
list_imgs = os.listdir(img_folder)
for i in range(len(list_imgs)):
    base_name = list_imgs[i]
    if base_name in test_set:
        img_out_base = case+'_'+add_zeros_before(i)+'_0000.nii.gz'
        if os.path.exists(img_out_base_folder+img_out_base):
            copyfile(img_out_base_folder+img_out_base, img_test_out_base_folder+img_out_base)
            os.remove(img_out_base_folder+img_out_base)
            
        anno_out_base = case+'_'+add_zeros_before(i)+'.nii.gz'
        if os.path.exists(anno_out_folder+anno_out_base):
            copyfile(anno_out_folder+anno_out_base, anno_test_out_folder+anno_out_base)
            os.remove(anno_out_folder+anno_out_base)
            
        

# for i in range(len(test_set)):
#     base_name = list_imgs[i][:str.rfind(test_set[i], '.')]
#     img_name = base_name + 
#     copyfile(img_out_base_folder+test_set[i], img_test_out_base_folder+test_set[i])
#     os.remove(img_out_base_folder+test_set[i])

In [None]:
# dataset.json generator
generate_dataset_json(
    join(target_base,'dataset.json'),
    img_out_base_folder,
    img_test_out_base_folder,
    modalities=('D'),
    labels = {0: 'background', 1: 'nucleus'},
    dataset_name=case,
    license='MIT'
)

In [None]:
#pre processing
os.system('nnUNet_plan_and_preprocess -t 500 --verify_dataset_integrity')

## Training

In [None]:
os.system('nnUNet_train 3d_fullres nnUNetTrainer_Experimental Task500_Nucleus 0 --npz')

In [None]:
os.system('nnUNet_train 3d_fullres nnUNetTrainer_Experimental Task500_Nucleus 1 --npz')

In [None]:
os.system('nnUNet_train 3d_fullres nnUNetTrainer_Experimental Task500_Nucleus 2 --npz')

In [None]:
os.system('nnUNet_train 3d_fullres nnUNetTrainer_Experimental Task500_Nucleus 3 --npz')

In [None]:
os.system('nnUNet_train 3d_fullres nnUNetTrainer_Experimental Task500_Nucleus 4 --npz')

In [None]:
#os.system('nnUNet_find_best_configuration -m 3d_fullres -t 500')

## Prediction

In [None]:
os.system('nnUNet_predict -i $nnUNet_raw_data_base/nnUNet_raw_data/Task500_Nucleus/imagesTs/ -o output_for_fold_0 -t 500 -tr nnUNetTrainer_Experimental -m 3d_fullres -f 0')

In [None]:
os.system('nnUNet_predict -i $nnUNet_raw_data_base/nnUNet_raw_data/Task500_Nucleus/imagesTs/ -o output_for_fold_1 -t 500 -tr nnUNetTrainer_Experimental -m 3d_fullres -f 1')

In [None]:
os.system('nnUNet_predict -i $nnUNet_raw_data_base/nnUNet_raw_data/Task500_Nucleus/imagesTs/ -o output_for_fold_2 -t 500 -tr nnUNetTrainer_Experimental -m 3d_fullres -f 2')

In [None]:
os.system('nnUNet_predict -i $nnUNet_raw_data_base/nnUNet_raw_data/Task500_Nucleus/imagesTs/ -o output_for_fold_3 -t 500 -tr nnUNetTrainer_Experimental -m 3d_fullres -f 3')

In [None]:
os.system('nnUNet_predict -i $nnUNet_raw_data_base/nnUNet_raw_data/Task500_Nucleus/imagesTs/ -o output_for_fold_4 -t 500 -tr nnUNetTrainer_Experimental -m 3d_fullres -f 4')

In [None]:
os.system('nnUNet_predict -i $nnUNet_raw_data_base/nnUNet_raw_data/Task500_Nucleus/imagesTs/ -o output_directory -t 500 -tr nnUNetTrainer_Experimental -m 3d_fullres')

## postprocessing

In [None]:
def nii2np_single(img_path):
    """
    convert nifti format (.nii.gz) to numpy array
    """
    img = sitk.ReadImage(img_path)
    img_np = sitk.GetArrayFromImage(img)
    return img_np

In [None]:
def nii2tif_single(nii_path, tif_path, resample):
    """
    load a nifti file and save it into a tif
    """
    img = nii2np_single(nii_path)
    img = resample(img)
    io.imsave(tif_path, img)

In [None]:
def abs_path(root, listdir_):
    """
    absolute path
    add root to the beginning of each path in listdir
    """
    listdir = listdir_.copy()
    for i in range(len(listdir)):
        listdir[i] = root + '/' + listdir[i]
    return listdir

def abs_listdir(path):
    """
    absolute path
    read all the path of files stored in 'path' 
    and add root to the beginning of each path in listdir
    """
    return abs_path(path, os.listdir(path))

In [None]:
def nii2tif_folder(nii_folder, tif_folder, resample):
    """
    load a folder of nifti file and save it into a folder of tif
    """
    list_rel = os.listdir(nii_folder)
    list_abs = abs_listdir(nii_folder)
    for i, nii_path in enumerate(list_abs):
        print('Loading index: {:d}/{}'.format(i+1, len(list_abs)), end='')
        print('{:s}\r'.format(''), end='', flush=True)
        
        end = list_rel[i][list_rel[i].rfind('.'):]
        if end=='.gz': # assert it is a nifti file
            tif_path = list_rel[i][:list_rel[i].rfind('.')]
            tif_path = tif_path[:tif_path.rfind('.')]
            tif_path = os.path.join(tif_folder, tif_path+'.tif')
            nii2tif_single(nii_path, tif_path, resample)

In [None]:
#os.system('pip install torchio')

In [None]:
import torchio as tio
def resample(img, size=(128,128,128), rerange_image=False, rerange_label=False):
    transform = tio.transforms.Resize(target_shape=size)
    img_tmp = np.expand_dims(img,0)
    # for label: rerange 
#     print(np.max(img_tmp))
    img_tmp = transform(img_tmp)
    if rerange_label:
#         print(type(img_tmp[0][0][0][0]))
        img_tmp = (img_tmp > 0).astype(np.uint8)*255
        img_tmp = img_tmp.astype(np.uint8)
        if len(np.unique(img_tmp[0]))!=2:
            print('error')
    elif rerange_image:
        img_tmp = (img_tmp - img_tmp.min()) / (img_tmp.max() - img_tmp.min())
    return img_tmp[0]

In [None]:
from skimage import io
nii2tif_folder('output_directory', 'tiff_files', resample=lambda x: resample(x))

In [None]:
#Je vais convertir les images de test
nii2tif_folder('nnUNet_raw_data_base/nnUNet_raw_data/Task500_Nucleus/imagesTs', 'ImagesTS', resample=lambda x: resample(x))

In [None]:
images = os.listdir('tiff_files')
for image in images :
    label   = imread(f"tiff_files/{image}")
    image   = image[:-4] + '_0000.tif'
    img     = imread(f"ImagesTS/{image}")
    z = img.shape[0] // 2
    y = label.shape[0] // 2
    plt.subplot(1, 2, 1)    
    plt.imshow(img[z])
    plt.title('Original image')
    plt.subplot(1, 2, 2)
    plt.imshow(label[z])
    plt.title('Labeled image')