In [1]:
import csv
from glob import glob
from natsort import natsorted
import os
from pprint import pprint
import pandas as pd
import re
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
from tqdm.notebook import tqdm
from multiprocessing import Pool
from time import sleep

In [2]:
parc_paths = natsorted(glob(os.path.join("parcellation_data", "GIF_with_tumour_all_classes", "vs_gk_*", "parcellation_with_tumour.nii.gz")))
output_path = os.path.join("parcellation_data", "GIF_with_tumour_9classes")
n_processes = 10  # how many parallel processes (should be less than number of available CPU cores)

In [3]:
parc_paths

['parcellation_data/GIF_with_tumour_all_classes/vs_gk_1/parcellation_with_tumour.nii.gz',
 'parcellation_data/GIF_with_tumour_all_classes/vs_gk_3/parcellation_with_tumour.nii.gz',
 'parcellation_data/GIF_with_tumour_all_classes/vs_gk_4/parcellation_with_tumour.nii.gz']

In [4]:
labels_to_keep = \
[
    300,# tumour
    35, # pons
    36, # brainstem
    72, # vermal lobules
    73, # vermal lobules
    74, # vermal lobules
    39, # Right Cerebellum Exterior
    40, # Left Cerebellum Exterior
    41, # Right Cerebellum White Matter - merged with 39
    42, # Left Cerebellum White Matter - merged with 40
]

labels_to_merge = [(39, 41), (40, 42)]

In [5]:
def create_multiclass_seg(parc_path):
    print(parc_path+"\n")
    parc_nii = nib.load(parc_path)
    parc_data = parc_nii.get_fdata()

    # merge labels
    for (label_kept, label_removed) in labels_to_merge:
        parc_data[parc_data==label_removed] = label_kept

    # assign consecutive labels to all labels to be kept
    seg_data = np.zeros_like(parc_data)
    for i, l in enumerate(labels_to_keep):
        seg_data[parc_data == l] = i+1

    seg_nii = nib.Nifti1Image(seg_data, parc_nii.affine)

    folder_name = parc_path.split(os.sep)[-2]
    save_path = os.path.join(output_path, folder_name, "vs_gk_seg_GIF_multiclass_refT1.nii.gz")
    os.makedirs(os.path.join(output_path, folder_name), exist_ok=True)
    nib.save(seg_nii, save_path)

In [6]:
pool = Pool(processes=n_processes)
result = pool.map(create_multiclass_seg, parc_paths)

parcellation_data/GIF_with_tumour_all_classes/vs_gk_3/parcellation_with_tumour.nii.gz
parcellation_data/GIF_with_tumour_all_classes/vs_gk_1/parcellation_with_tumour.nii.gz
parcellation_data/GIF_with_tumour_all_classes/vs_gk_4/parcellation_with_tumour.nii.gz



