# UniverSeg

- From the following paper: https://universeg.csail.mit.edu/
- From the git repo: https://github.com/JJGO/UniverSeg
- Form the example google colab: https://colab.research.google.com/drive/1TiNAgCehFdyHMJsS90V9ygUw0rLXdW0r?usp=sharing

Quoted from the tutorial, helpful text for setting up the context;

"""Given a new segmentation task (e.g. new biomedical domain, new image type, new region of interest, etc), most existing strategies involve training or fine-tuning a segmentation model (often a UNet-like CNN) that takes as input an image $x$ and outputs the segmentation map $y$.

This process works well in machine-learning labs, but is challenging in many applied settings, such as for scientists or clinical researchers who drive important scientific questions, but often lack the machine-learning expertiese and computational resources necessary.

UniverSeg enables users to tackle a new segmentation task without the need to train or fine-tune a model, removing the ML experience requirements and computational burden. The key idea is to have a *single* global model which adapts to a new segmentation task at inference. Given a new segmentation task, defined by a few example image-segmentation pairs $\mathcal{T} = \{x_n, y_n\}$, UniverSeg segments a new image $x$ by taking as input both $x$ and the task examples $\mathcal{T}$ and outputs the segmentation map $f(x, \mathcal{T}) = y$."""

## Setup

In [1]:
import os 
import subprocess

def setup_data_vars(mine = True, overwrite = True):
    """
    From within any directory related to radiotherapy with backtrack into the data folder
    and execute the data_vars script. The assumption is that the datavars script will
    output the list of environment variables that need to be set. This function will set
    the environment variables for the current session.

    For the mean while, my model hasn't completely finished training, therefore, to get
    this task done, I will use Ben's pretrained nnUNet and then once mine has finished
    training I will use my own. For the mean while, this means that we can choose between
    using Ben's pretrained model or my own.
    """

    # If the environment variables are not set, assume that either a custom one has been
    # provided or resetting them again is a redundant task
    if os.environ.get('nnUNet_raw') is None or overwrite is True:
        # run the script in the data folder for specifying the environment variables
        if mine:
            cwd = os.getcwd().split('/')
            data_dir = os.path.join('/'.join(cwd[:cwd.index('radiotherapy') + 1]), 'data')

            # Assuming the data_vars.sh script echoes the environment variables
            script = os.path.join(data_dir, 'data_vars.sh')
            output = subprocess.run([script], capture_output=True)
            
            assert len(output.stdout) != 0, f"Please check {script} and make sure it echoes \
    the environment variables."

            output = output.stdout.decode('utf-8')
        else:
            data_dir = '/vol/biomedic3/bglocker/nnUNet'

            # Assuming this script won't change, it contains hard coded exports
            script = os.path.join(data_dir, 'exports')

            with open(script, 'r') as file:
                output = file.read()
        
        for line in output.split('\n'):
            if line != '':
                if mine:
                    line = line.split(': ')
                    os.environ[line[0]] = line[1]
                else:
                    line = line.split('=')
                    os.environ[line[0].split(' ')[1]] = line[1]

    assert os.environ.get('nnUNet_raw') is not None, "Environemnt variables not set. \
Please run the data_vars.sh script in the data folder."

In [80]:
from tqdm import tqdm
import SimpleITK as sitk
import numpy as np
import torchio as tio

def universeg_preprocess(path_to_images: str, normalize: bool, overwrite = True):
    """Given a input path to images, preprocess the images according to the specification
    supplied by the paper. That is, we resize the input images to be 128x128 WxH
    dimensions and normalize the CT images to be within the range of 0 to 1. The images
    are then saved to a new directory with a sensible name. Files will be named after the
    input, therefore, if there are alraedy files in this directory, we will check that
    there is a match between a filename and the input image in order to skip.

    Args:
        path_to_images (str): A path to the directory containing the images that need to
        be preprocessed. The output directory will be saved in a sensible location
        derrived from `os.environ.get('nnUNet_raw)`
        
        normalize (bool): A flag to indicate whether the input pixels values should be
        normalized to the 0 to 1 range. 

        overwrite (bool): A flag to indicate whether the output directory should be
        overwritten.
    """

    # Check that the input directory exists
    assert os.path.exists(path_to_images), f"Path to images: {path_to_images} does not exist."
    assert os.environ.get('nnUNet_raw') is not None, "Environment variables not set. \
Please run the data_vars.sh script in the data folder."

    def resize_image(original_CT, width_new = 128, height_new = 128):
        # Data Augmentation with SimpleITK:
        # https://github.com/InsightSoftwareConsortium/SimpleITK-Notebooks/blob/master/Python/70_Data_Augmentation.ipynb
        # https://stackoverflow.com/questions/48065117/simpleitk-resize-images

        reference_dimension = original_CT.GetDimension()
        reference_origin = original_CT.GetOrigin()
        reference_direction = original_CT.GetDirection()
        reference_size = original_CT.GetSize()
        reference_spacing = original_CT.GetSpacing()

        # Compute the reference physical size which acts as a placeholder for the expected
        # metadata of the image after transformation

        reference_physical_size = np.zeros(reference_dimension)
        reference_physical_size[:] = [(sz-1)*spc if sz*spc>mx  else mx for sz,spc,mx in zip(reference_size, reference_spacing, reference_physical_size)]
        
        reference_size = [width_new, height_new, reference_size[2]]
        reference_spacing = [ phys_sz/(sz-1) for sz,phys_sz in zip(reference_size, reference_physical_size) ]

        # Create a blank image with the desired size after transormation

        reference_image = sitk.Image(reference_size, original_CT.GetPixelIDValue())
        reference_image.SetOrigin(reference_origin)
        reference_image.SetSpacing(reference_spacing)
        reference_image.SetDirection(reference_direction)

        # Calculate affine transform to match direction matrix of the original image moves
        # the center of the original image to the center of the new image. This transform
        # is added to the composite transform.

        reference_center = np.array(reference_image.TransformContinuousIndexToPhysicalPoint(np.array(reference_image.GetSize())/2.0))
        
        transform = sitk.AffineTransform(reference_dimension)
        transform.SetMatrix(original_CT.GetDirection())

        transform.SetTranslation(np.array(original_CT.GetOrigin()) - reference_origin)
    
        centering_transform = sitk.TranslationTransform(reference_dimension)
        img_center = np.array(original_CT.TransformContinuousIndexToPhysicalPoint(np.array(original_CT.GetSize())/2.0))
        centering_transform.SetOffset(np.array(transform.GetInverse().TransformPoint(img_center) - reference_center))
        centered_transform = sitk.CompositeTransform(transform)
        centered_transform.AddTransform(centered_transform)

        # sitk.Show(sitk.Resample(original_CT, reference_image, centered_transform, sitk.sitkLinear, 0.0))
        
        return sitk.Resample(original_CT, reference_image, centered_transform, sitk.sitkLinear, 0.0)

    def normalize_image(image):
        # https://www.imaios.com/en/resources/blog/ct-images-normalization-zero-centering-and-standardization
        # Convert SimpleITK image to TorchIO image
        x, y, z = image.GetSize()
        torchio_image = tio.ScalarImage(tensor = sitk.GetArrayFromImage(image).reshape(1, x, y , z))

        # Apply normalization transform
        normalization_transform = tio.transforms.RescaleIntensity(out_min_max = (0, 1), percentiles = (0.5, 99.5))
        normalized_image = normalization_transform(torchio_image)

        # Convert TorchIO image back to SimpleITK image
        normalized_sitk_image = normalized_image.as_sitk()

        # Copy the metadata from the original image
        normalized_sitk_image.CopyInformation(image)

        return normalized_sitk_image

    # Get the output directory
    class_name = path_to_images.split('/')[-2]
    image_type = path_to_images.split('/')[-1]
    output_dir = os.path.join(os.environ.get('nnUNet_raw')[:-len('nnUNet_raw')-1], 'UniverSegPreprocessed', class_name, image_type)
    os.makedirs(output_dir, exist_ok=True)

    # Arrange files to pre-process
    input_files = sorted([file for file in os.listdir(path_to_images) if file.endswith('.nii.gz')])
    output_files = sorted([file for file in os.listdir(output_dir) if file.endswith('.nii.gz')])

    prefix = 'universeg_'

    to_process = sorted(set(input_files)- set([file[len(prefix):] for file in output_files])) if not overwrite else input_files

    for file in tqdm(to_process):
        input_file = os.path.join(path_to_images, file)
        output_file = os.path.join(output_dir, prefix + file)

        # Load data
        data = sitk.ReadImage(input_file)
        # Resize data
        data = resize_image(data)
        # Normalize data
        data = data if not normalize else normalize_image(data)

        # Save data
        sitk.WriteImage(data, output_file)          

    return output_dir

## Running Inference

In [None]:
def universeg_run_inference(path_to_image, path_to_labels, path_to_output):
    """Performs the forward pass of the model on the input image according to the
    walkthrough at https://github.com/JJGO/UniverSeg

    Args:
        path_to_image (str): path to image

        path_to_labels (str): path to labels

        path_to_output (str): path to save predictions in
    """

    assert os.path.exists(path_to_image), f"Path to image: {path_to_image} does not exist."
    assert os.path.exists(path_to_labels), f"Path to labels: {path_to_labels} does not exist."

    os.makedirs(path_to_output, exist_ok=True)

    # Load the data in

    # Separate the different z slices of the image into different tasks for the model to
    # predict. We must do the same with the ground truth segmentations and convince
    # ourselves that the segmentations vs the iamge are aligned and have been transformed
    # correctly.

    # Run the inference

    # Save the output to the path_to_output

    

## Main

In [81]:
if __name__ == "__main__":
    setup_data_vars()

    print(f'Raw directory {os.environ.get("nnUNet_raw")}')
    print(f'Data Traning images {os.environ.get("data_trainingImages")}')
    print(f'Data Traning labels {os.environ.get("data_trainingLabels")}')

    classes = [os.environ.get('data_Anorectum'), 
        os.environ.get('data_Bladder'), 
        os.environ.get('data_CTVn'), 
        os.environ.get('data_CTVp'), 
        os.environ.get('data_Parametrium'), 
        os.environ.get('data_Uterus'), 
        os.environ.get('data_Vagina')]
    
    # Suppose we try to predict labels for the Anorectum
    class_id = 0

    image_path = '/'.join([os.environ.get("nnUNet_raw"), classes[class_id], os.environ.get("data_trainingImages")])
    label_path = '/'.join([os.environ.get("nnUNet_raw"), classes[class_id], os.environ.get("data_trainingLabels")])
    
    print(f'Image_Path: {image_path} {"(Warning: os cannot find this path)" if not os.path.isdir(image_path) else ""}')
    print(f'Label_Path: {label_path} {"(Warning: os cannot find this path)" if not os.path.isdir(label_path) else ""}')

    # For all inputs ensure that pixel values are min-max normalized to the range [0, 1]
    # and that the spatial dimensions are (H, W) = (128, 128)

    universeg_image_path = universeg_preprocess(image_path, normalize=True, overwrite=False)
    universeg_label_path = universeg_preprocess(label_path, normalize=False, overwrite=False)

    assert os.path.isdir(universeg_image_path), f'File at `{universeg_image_path}` doesn\'t exist'
    assert os.path.isdir(universeg_label_path), f'File at `{universeg_image_path}` doesn\'t exist'

    # Attempt to perform inference on 2D slices of the preprocessed image
    
    

Raw directory /vol/bitbucket/az620/radiotherapy/data/nnUNet_raw
Data Traning images imagesTr
Data Traning labels labelsTr
Image_Path: /vol/bitbucket/az620/radiotherapy/data/nnUNet_raw/Dataset001_Anorectum/imagesTr 
Label_Path: /vol/bitbucket/az620/radiotherapy/data/nnUNet_raw/Dataset001_Anorectum/labelsTr 


  0%|          | 0/83 [00:00<?, ?it/s]

100%|██████████| 83/83 [01:28<00:00,  1.06s/it]
0it [00:00, ?it/s]
