In [1]:
import numpy as np
import nibabel as nib
import pathlib as plb
import cc3d
import csv
import sys
import os
import matplotlib.pyplot as plt
import torch
import SimpleITK as sitk
import json
import pandas as pd


def nii2numpy(nii_path):
    # input: path of NIfTI segmentation file, output: corresponding numpy array and voxel_vol in ml
    mask_nii = nib.load(str(nii_path))
    mask = mask_nii.get_fdata()
    pixdim = mask_nii.header['pixdim']   
    voxel_vol = pixdim[1]*pixdim[2]*pixdim[3]/1000
    return mask, voxel_vol

def con_comp(seg_array):
    # input: a binary segmentation array output: an array with seperated (indexed) connected components of the segmentation array
    connectivity = 18
    conn_comp = cc3d.connected_components(seg_array, connectivity=connectivity)
    return conn_comp

def process_array(arr,min_size):
    min_size = int(min_size)
    
    unique_elements, counts = np.unique(arr, return_counts=True)

    # Create a mapping to change values less than 125 to 0
    change_to_zero = unique_elements[counts < min_size]

    # Create a new array with the same values as the input array
    processed_arr = arr.copy()

    # Change values less than 125 to 0
    for val in change_to_zero:
        processed_arr[processed_arr == val] = 0
        
    processed_arr[processed_arr != 0] = 1

    return processed_arr

def post_process_nifti(nifti_pred_path,output_path,min_size):
    # get numpy array
    pred_array, voxel_vol = nii2numpy(nifti_pred_path)
    pred_conn_comp = con_comp(pred_array)
    
    # returns numpy array with removal of small connected components
    processed_arr = process_array(pred_conn_comp,min_size)
    
    # return and save nifti image
    nifti_img = nib.Nifti1Image(processed_arr, affine=np.eye(4)) 
    nib.save(nifti_img, output_path)

In [None]:
dict_of_paths = {'h':['h_219_inf_results.csv','nnUNetTrainer_1500epochs__nnUNetPlans__3d_fullres']}
base_path = 'nnUNet_results/Dataset219_PETCT/'

# iterate through each model
for key in dict_of_paths.keys():
    print(key)
    relative_path = dict_of_paths[key][1]

    # create folder for pp images
    processed_inference_folder = os.path.join(base_path,relative_path,'pp10_inferTs')
    if not os.path.exists(processed_inference_folder):
        os.mkdir(processed_inference_folder)

    # get non-processed inference images
    inference_folder = os.path.join(base_path,relative_path,'inferTs')
    inference_images = [x for x in os.listdir(inference_folder) if '.nii.gz' in x]

    # apply post processing and save images to folder
    for image in inference_images:
        post_process_nifti(os.path.join(inference_folder,image),os.path.join(processed_inference_folder,image),10)

    print('post-processing finished')