In [1]:
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt
import nrrd
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import nnunet

# from nnunet.dataset_conversion.utils import dataset_conversion
from ipywidgets import interact, fixed
from IPython.display import display
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from typing import Tuple, List, Union
from skimage import io
from pathlib import Path
from typing import Tuple
from batchgenerators.utilities.file_and_folder_operations import *



Please cite the following paper when using nnUNet:

Isensee, F., Jaeger, P.F., Kohl, S.A.A. et al. "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation." Nat Methods (2020). https://doi.org/10.1038/s41592-020-01008-z


If you have questions or suggestions, feel free to open an issue at https://github.com/MIC-DKFZ/nnUNet



In [2]:
cuda_dev = '0' #GPU device 0 (can be changed if multiple GPUs are available)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:" + cuda_dev if use_cuda else "cpu")

print('Device: ' + str(device))
if use_cuda:
    print('GPU: ' + str(torch.cuda.get_device_name(int(cuda_dev))))      

Device: cuda:0
GPU: GeForce GTX TITAN X


In [3]:
# directories
# change these to match where your existing data is and should be
# name of dataset directory
dataset_name = "label1and2/"
# name of task for nnU-Net to recognise
task_name = "Task104_Ovary/"
base_dir = '/vol/bitbucket/ma3720/thesis/datasets/' + dataset_name
raw_data_dir = '/vol/bitbucket/ma3720/thesis/nnUNet_raw/nnUNet_raw_data/'
tr_img_dir = raw_data_dir + task_name + 'imagesTr/'
tr_label_dir = raw_data_dir + task_name + 'labelsTr/'
ts_img_dir = raw_data_dir + task_name + 'imagesTs/'
ts_label_dir = raw_data_dir + task_name + 'labelsTs/'
json_dir = raw_data_dir + task_name + 'dataset.json'
predicted_dir = '/vol/bitbucket/ma3720/thesis/results/' + task_name
overlayed_dir = '/vol/bitbucket/ma3720/thesis/overlayed/' + task_name

In [5]:
# creates a folder at given path containing the task for nnunet 
Path(raw_data_dir + task_name).mkdir(parents=False, exist_ok=True)
Path(tr_img_dir).mkdir(parents=False, exist_ok=True)
Path(ts_img_dir).mkdir(parents=False, exist_ok=True)
Path(tr_label_dir).mkdir(parents=False, exist_ok=True)
Path(ts_label_dir).mkdir(parents=False, exist_ok=True)
Path(predicted_dir).mkdir(parents=False, exist_ok=True)
Path(overlayed_dir).mkdir(parents=False, exist_ok=True)

In [6]:
# resample image to required size by using interpolation
# code obtained from https://stackoverflow.com/questions/48065117/simpleitk-resize-images
def resample_image(img, is_label):
    dimension = img.GetDimension()
    reference_physical_size = np.zeros(dimension)
    reference_physical_size[:] = [(sz-1)*spc if sz*spc>mx  else mx for sz,spc,mx in zip(img.GetSize(), img.GetSpacing(), reference_physical_size)]

    reference_origin = np.zeros(dimension)
    reference_direction = np.identity(dimension).flatten()
    reference_size = [256]*dimension # Arbitrary sizes, smallest size that yields desired results. 
    reference_spacing = [ phys_sz/(sz-1) for sz,phys_sz in zip(reference_size, reference_physical_size) ]

    reference_image = sitk.Image(reference_size, img.GetPixelIDValue())
    reference_image.SetOrigin(reference_origin)
    reference_image.SetSpacing(reference_spacing)
    reference_image.SetDirection(reference_direction)
    reference_center = np.array(reference_image.TransformContinuousIndexToPhysicalPoint(np.array(reference_image.GetSize())/2.0))

    transform = sitk.AffineTransform(dimension)
    transform.SetMatrix(img.GetDirection())
    transform.SetTranslation(np.array(img.GetOrigin()) - reference_origin)

    centering_transform = sitk.TranslationTransform(dimension)
    img_center = np.array(img.TransformContinuousIndexToPhysicalPoint(np.array(img.GetSize())/2.0))

    centering_transform.SetOffset(np.array(transform.GetInverse().TransformPoint(img_center) - reference_center))
    centered_transform = sitk.Transform(transform)
    centered_transform.AddTransform(centering_transform)

    if is_label:
        return sitk.Resample(img, reference_image, centered_transform, sitk.sitkNearestNeighbor, 0.0)
    else:
        return sitk.Resample(img, reference_image, centered_transform, sitk.sitkLinear, 0.0)

In [7]:
def crop_image(img):
    to_crop = int(img.GetSize()[1] * 0.1) # removes this many rows from the top
    return img[:, to_crop:]

In [8]:
# used to create the nifti files for both segmentations and images to be used by nnunet
# obtained from nnunet library but some changes were made
def convert_2d_image_to_nifti_new(img, input_filename: str, output_filename_truncated: str, spacing=(999, 1, 1),
                              transform=None, is_seg: bool = False) -> None:
    """
    Reads an image (must be a format that it recognized by skimage.io.imread) and converts it into a series of niftis.
    The image can have an arbitrary number of input channels which will be exported separately (_0000.nii.gz,
    _0001.nii.gz, etc for images and only .nii.gz for seg).
    Spacing can be ignored most of the time.
    !!!2D images are often natural images which do not have a voxel spacing that could be used for resampling. These images
    must be resampled by you prior to converting them to nifti!!!
    Datasets converted with this utility can only be used with the 2d U-Net configuration of nnU-Net
    If Transform is not None it will be applied to the image after loading.
    Segmentations will be converted to np.uint32!
    :param is_seg:
    :param transform:
    :param input_filename:
    :param output_filename_truncated: do not use a file ending for this one! Example: output_name='./converted/image1'. This
    function will add the suffix (_0000) and file ending (.nii.gz) for you.
    :param spacing:
    :return:
    """
    # img = io.imread(input_filename)

    if transform is not None:
        img = transform(img)

    if input_filename.endswith(".nrrd"):
        img = np.squeeze(img)
    elif img.shape[2] == 4:
        img = img[:,:,0:3]

    if len(img.shape) == 2:  # 2d image with no color channels
        img = img[None, None]  # add dimensions
    else:
        assert len(img.shape) == 3, "image should be 3d with color channel last but has shape %s" % str(img.shape)
        # we assume that the color channel is the last dimension. Transpose it to be in first
        img = img.transpose((2, 0, 1))
        # add third dimension
        img = img[:, None]
    # image is now (c, x, x, z) where x=1 since it's 2d
    if is_seg:
        assert img.shape[0] == 1, 'segmentations can only have one color channel, not sure what happened here'
    
    for j, i in enumerate(img):

        if is_seg:
            i = i.astype(np.uint32)

        itk_img = sitk.GetImageFromArray(i)
        itk_img.SetSpacing(list(spacing)[::-1])
        if not is_seg:
            sitk.WriteImage(itk_img, output_filename_truncated + "_%04.0d.nii.gz" % j)
        else:
            sitk.WriteImage(itk_img, output_filename_truncated + ".nii.gz")

In [9]:
# obtained from nnunet library
def get_identifiers_from_splitted_files(folder: str):
    uniques = np.unique([i[:-12] for i in subfiles(folder, suffix='.nii.gz', join=False)])
    return uniques

def generate_dataset_json(output_file: str, imagesTr_dir: str, imagesTs_dir: str, modalities: Tuple,
                          labels: dict, dataset_name: str, license: str = "hands off!", dataset_description: str = "",
                          dataset_reference="", dataset_release='0.0'):
    """
    :param output_file: This needs to be the full path to the dataset.json you intend to write, so
    output_file='DATASET_PATH/dataset.json' where the folder DATASET_PATH points to is the one with the
    imagesTr and labelsTr subfolders
    :param imagesTr_dir: path to the imagesTr folder of that dataset
    :param imagesTs_dir: path to the imagesTs folder of that dataset. Can be None
    :param modalities: tuple of strings with modality names. must be in the same order as the images (first entry
    corresponds to _0000.nii.gz, etc). Example: ('T1', 'T2', 'FLAIR').
    :param labels: dict with int->str (key->value) mapping the label IDs to label names. Note that 0 is always
    supposed to be background! Example: {0: 'background', 1: 'edema', 2: 'enhancing tumor'}
    :param dataset_name: The name of the dataset. Can be anything you want
    :param license:
    :param dataset_description:
    :param dataset_reference: website of the dataset, if available
    :param dataset_release:
    :return:
    """
    train_identifiers = get_identifiers_from_splitted_files(imagesTr_dir)

    if imagesTs_dir is not None:
        test_identifiers = get_identifiers_from_splitted_files(imagesTs_dir)
    else:
        test_identifiers = []

    json_dict = {}
    json_dict['name'] = dataset_name
    json_dict['description'] = dataset_description
    json_dict['tensorImageSize'] = "4D"
    json_dict['reference'] = dataset_reference
    json_dict['licence'] = license
    json_dict['release'] = dataset_release
    json_dict['modality'] = {str(i): modalities[i] for i in range(len(modalities))}
    json_dict['labels'] = {str(i): labels[i] for i in labels.keys()}

    json_dict['numTraining'] = len(train_identifiers)
    json_dict['numTest'] = len(test_identifiers)
    json_dict['training'] = [
        {'image': "./imagesTr/%s.nii.gz" % i, "label": "./labelsTr/%s.nii.gz" % i} for i
        in
        train_identifiers]
    json_dict['test'] = ["./imagesTs/%s.nii.gz" % i for i in test_identifiers]

    if not output_file.endswith("dataset.json"):
        print("WARNING: output file name is not dataset.json! This may be intentional or not. You decide. "
              "Proceeding anyways...")
    save_json(json_dict, os.path.join(output_file))

In [10]:
files =  os.listdir(base_dir)
# renames all files to remove whitespace. Uncomment this the first time you use on dataset to remove whitespace if needed
# for file in files:
#     os.rename(base_dir + file, base_dir + file.replace(" ", ""))

sorted_files = []
# number of images
tif_count = 0
# number of segmentations
nrrd_count = 0
for file in sorted(files):
    if file.endswith(".nrrd") or file.endswith("tiff") or file.endswith("tif"):
        if file.endswith(".nrrd"):
            nrrd_count += 1
        else:
            tif_count += 1
        sorted_files.append(file)

In [11]:
# check for dataset integrity and remove files manually. Do not move to next step before all assertions pass.
valid = True
for i in range(0, int(len(sorted_files)), 2):
      # Uncomment this line if you would like to see the segmentations and images line by line
      # print(sorted_files[i], sorted_files[i+1])

      img1 = sitk.ReadImage(base_dir + sorted_files[i])
      arr1 = np.squeeze(sitk.GetArrayFromImage(img1))
      img2 = sitk.ReadImage(base_dir + sorted_files[i+1])
      arr2 = np.squeeze(sitk.GetArrayFromImage(img2))

      if sorted_files[i].endswith(".nrrd"):
            valid = sorted_files[i+1].endswith(".tif") or sorted_files[i+1].endswith(".tiff")
            assert valid, "Invalid file and segmentation. This could simply an issue with naming as the files are sorted according to names. So check consistency.\
             If names are fine then there might be an issue with the dataset such as a duplicate image file or a file without a segmentation. The two files to check around (error is in surrounding images) are "\
                  + sorted_files[i] + " and " + sorted_files[i+1] + " at image index " + str(i)  
            assert len(arr1.shape) == 3, "Segmentation does not have a shape of 3 (usually an error with the software making segmentations or there is a missing label) " + sorted_files[i] + " shape: " + " ".join(map(str,arr1.shape))
            assert arr1.shape[2] == 2, "Segmentation does not have 2 labels"
            # This is commented out because there was an error when making segmentations
            # assert not (arr1==[0,1]).all(2).any(), "Label 2 is not contained within label 1 in segmentation file: "  + sorted_files[i]
      elif sorted_files[i].endswith(".tif") or sorted_files[i].endswith(".tiff"):
            valid = sorted_files[i+1].endswith(".nrrd")
            assert valid, "Invalid file and segmentation. This could simply an issue with naming as the files are sorted according to names. So check consistency.\
             If names are fine then there might be an issue with the dataset such as a duplicate image file or a file without a segmentation. The two files to check around (error is in surrounding images) are "\
                  + sorted_files[i] + " and " + sorted_files[i+1] +  " at image index " + str(i)
            assert len(arr2.shape) == 3, "Segmentation does not have a shape of 3 (usually an error with the software making segmentations or there is a missing label) " + sorted_files[i+1] + " shape: " + " ".join(map(str,arr2.shape))
            assert arr2.shape[2] == 2, "Segmentation does not have 2 labels"
            # assert not (arr2==[0,1]).all(2).any(), "Label 2 is not contained within label 1 in segmentation file: "  + sorted_files[i+1]
      assert arr1.shape[0] == arr2.shape[0] and arr1.shape[1] == arr2.shape[1], "The image and segmentation do not have matching sizes. The files are: " + sorted_files[i] + " has shape " + " ".join(map(str,arr1.shape))\
              + " and " + sorted_files[i+1] + " has shape " + " ".join(map(str,arr2.shape))

assert nrrd_count == tif_count, "Inbalance in number of images and segmentations \n Segmentations: " + str(nrrd_count) + " Images: " + str(tif_count)


In [12]:
# chooses images to use in test set
np.random.seed(2)
test_files = np.random.choice(range(0, int(len(sorted_files)), 2), int(len(sorted_files) * 0.05), replace=False)
# These files (only one of image or segmentation appears below) will be used in the test set
print(np.array(sorted_files)[test_files.astype(int)])
print(len(test_files))


['18872_0019L12.TIF.tif' '23556_0034L12.TIF.tif' '23530tiff_9L12.tif'
 '13355_6L12.tif' '18510_0025L12.TIF.tif' '17860_0032L12.TIF.tif'
 '18853_0033L12.TIF.tif' '23866_0018L12Segmentation.seg.nrrd'
 '24306_0015L12.TIF.tif' '14654_7L12.tif' '24306_0019L12.TIF.tif'
 '23522_-3L12.tif' '15346_0001L12.TIF.tif' '19623_0011L12.TIF.tif'
 '17303_0044L12.TIF.tif' '23941_0014L12Segmentation.seg.nrrd'
 '14117_1L12.tif' '13142_0009L12.TIF.tif' '15526_0028L12.TIF.tif'
 '23677_0001L12Segmentation.seg.nrrd' '23506_0048L12.TIF.tif'
 '18441_0026L12.TIF.tif' '23529tiff_5L12.tif' '19436_0028L12.TIF.tif'
 '23676tiff_2L12.tif' '17853_0019L12.TIF.tif' '18441_0024L12.TIF.tif'
 '17584_0009L12Segmentation.seg.nrrd' '23742_0015L12Segmentation.seg.nrrd'
 '23495tiff_7L12.Segmentation.seg.nrrd' '23649_0013L12.TIF.tif'
 '23644_6L12.tif' '16683_0018L12.TIF.tif' '23536.-3L12.tif'
 '23478tiff_1L12.tif' '14984_0035L12Segmentation_1.seg.nrrd'
 '12041_3L12Segmentation.seg.nrrd' '13291_0010L12.TIF.tif'
 '18786tiff_7L12.Seg

In [13]:
# fills up the folders with the dataset
for i in range(0, int(len(sorted_files)), 2):
    # second file is the one to get the name from 
    if sorted_files[i].endswith(".nrrd"):
        img = sitk.ReadImage(base_dir + sorted_files[i+1])
        lbl = sitk.ReadImage(base_dir + sorted_files[i])
        lbl_arr = np.squeeze(sitk.GetArrayFromImage(lbl))
        # If image has pixel with label 2 and not label 1 then add label 1
        if (lbl_arr==[0,1]).all(2).any():
            lbl_arr[(lbl_arr==[0,1]).all(2)] = [1,1]
        # Sums all values in 3rd dimension so that we can know to which class a pixel belongs. If both labels exist the label to be predicted will be 2 (both cyst and locule exist at this pixel).
        lbl_arr = lbl_arr.sum(axis=-1)
        lbl = sitk.GetImageFromArray(lbl_arr)
        img_cropped = crop_image(img)
        lbl_cropped = crop_image(lbl)
        resampled_lbl = resample_image(lbl_cropped, True)
        resampled_img = resample_image(img_cropped, False)
        resampled_lbl_arr = sitk.GetArrayFromImage(resampled_lbl)
        resampled_img_arr = sitk.GetArrayFromImage(resampled_img)
        new_file_pre = sorted_files[i+1].split(".")[0]
        if i in test_files:
            convert_2d_image_to_nifti_new(resampled_img_arr, base_dir + sorted_files[i+1], ts_img_dir + new_file_pre)
            convert_2d_image_to_nifti_new(resampled_lbl_arr, base_dir + sorted_files[i], ts_label_dir + new_file_pre, is_seg=True) 
        else:
            convert_2d_image_to_nifti_new(resampled_img_arr, base_dir + sorted_files[i+1], tr_img_dir + new_file_pre)
            convert_2d_image_to_nifti_new(resampled_lbl_arr, base_dir + sorted_files[i], tr_label_dir + new_file_pre, is_seg=True) 
    else:
        img = sitk.ReadImage(base_dir + sorted_files[i])
        lbl = sitk.ReadImage(base_dir + sorted_files[i+1])
        lbl_arr = np.squeeze(sitk.GetArrayFromImage(lbl))
        # If image has pixel with label 2 and not label 1 then add label 1
        if (lbl_arr==[0,1]).all(2).any():
            lbl_arr[(lbl_arr==[0,1]).all(2)] = [1,1]
        # Sums all values in 3rd dimension so that we can know to which class a pixel belongs. If both labels exist the label to be predicted will be 2 (both cyst and locule exist at this pixel).
        lbl_arr = lbl_arr.sum(axis=-1)
        lbl = sitk.GetImageFromArray(lbl_arr)
        img_cropped = crop_image(img)
        lbl_cropped = crop_image(lbl)
        resampled_lbl = resample_image(lbl_cropped, True)
        resampled_img = resample_image(img_cropped, False)
        resampled_lbl_arr = sitk.GetArrayFromImage(resampled_lbl)
        resampled_img_arr = sitk.GetArrayFromImage(resampled_img)
        new_file_pre = sorted_files[i].split(".")[0]
        if i in test_files:
            convert_2d_image_to_nifti_new(resampled_img_arr, base_dir + sorted_files[i], ts_img_dir + new_file_pre)
            convert_2d_image_to_nifti_new(resampled_lbl_arr, base_dir + sorted_files[i+1], ts_label_dir + new_file_pre, is_seg=True) 
        else:
            convert_2d_image_to_nifti_new(resampled_img_arr, base_dir + sorted_files[i], tr_img_dir + new_file_pre)
            convert_2d_image_to_nifti_new(resampled_lbl_arr, base_dir + sorted_files[i+1], tr_label_dir + new_file_pre, is_seg=True) 


In [14]:
# create dataset.json file 
task = 'Task104_Ovary'
generate_dataset_json(json_dir, tr_img_dir, ts_img_dir, ('Red', 'Green', 'Blue'), labels={0: 'background', 1: 'cyst', 2: 'locules'}, dataset_name=task)

In [None]:
# overlays segmentations from labels and the predictions from the model. This only works if labels are overlapping and one label is contained within another
label_name = 'label_1'
label_number = 1
Path(overlayed_dir+label_name).mkdir(parents=False, exist_ok=True)

# Where to get image, gold standard and predictions from
labels_dir = ts_label_dir
images_dir = ts_img_dir
predicted_dir = predicted_dir

results_files = os.listdir(predicted_dir)
for file in sorted(results_files):
    if file.endswith('gz'):
        current_img_name = file.split('.')[0]
        extension = '.nii.gz'
        # Get RGB slices of image
        img0 = sitk.ReadImage(images_dir + current_img_name + '_0000' + extension) 
        img1 = sitk.ReadImage(images_dir + current_img_name + '_0001' + extension) 
        img2 = sitk.ReadImage(images_dir + current_img_name + '_0002' + extension)
        # Get arrays from images
        img0_arr = np.squeeze(sitk.GetArrayFromImage(img0))
        img1_arr = np.squeeze(sitk.GetArrayFromImage(img1))
        img2_arr = np.squeeze(sitk.GetArrayFromImage(img2))
        # Stack the slices to reform rgb image
        img_arr = np.stack([img0_arr, img1_arr, img2_arr], axis=2)
        img_arr = np.squeeze(img_arr)   
        # Get the gold standard and predicted segmentations
        prd_img = predicted_dir + current_img_name + extension
        lbl_img = labels_dir + current_img_name + extension
        prd_img = sitk.ReadImage(prd_img)
        prd_img = sitk.GetArrayFromImage(prd_img)
        lbl_img = sitk.ReadImage(lbl_img)
        lbl_img = sitk.GetArrayFromImage(lbl_img)
        lbl_img = np.squeeze(lbl_img)
        prd_img = np.squeeze(prd_img)

        # Image at this point still has label 1 as only what isn't label 2, since that's how nnU-Net is trained (overlapping is not allowed). To get the actual label 1, we get the union of label 1 and label 2.
        fixed_lbl = lbl_img
        fixed_lbl[fixed_lbl>label_number] = label_number
        fixed_prd = prd_img
        fixed_prd[fixed_prd>label_number] = label_number

        # gets the gold standard pixels
        mask1 = np.ma.masked_where(fixed_lbl != label_number, fixed_lbl)
        # gets the predicted pixels
        mask2 = np.ma.masked_where(fixed_prd != label_number, fixed_prd)
        # removes the overlap
        mask3 = np.ma.masked_where(fixed_prd == label_number, mask1)
        # removes the overlap
        mask4 = np.ma.masked_where(fixed_lbl == label_number, mask2)
        # gets the overlap pixels      
        mask5 = np.ma.masked_where(fixed_prd != label_number, mask1)

        plt.imshow(img_arr)
        plt.imshow(mask3, 'viridis_r', alpha=0.5)
        plt.imshow(mask4, 'autumn', alpha=0.5)
        plt.imshow(mask5, 'brg_r', alpha=0.5)
        
        plt.savefig(overlayed_dir + label_name + '/' + current_img_name + '.png', dpi=300)
        plt.show()

