# UNet++ Transformations

This notebook implements some basic transformations required for the dataset to be passed to UNet successfully.

In [None]:
from typing import List, Callable, Tuple

import numpy as np
import albumentations as A
from sklearn.externals._pilutil import bytescale
from skimage.util import crop

def normalize_01(inp: np.ndarray):
    # squash image input to the value range [0, 1] (no clipping)
    inp_out = (inp - np.min(inp)) / np.ptp(inp)
    return inp_out

def normalize(inp: np.ndarray, mean: float, std: float):
    # normalize based on mean and standard deviation
    inp_out = (inp - mean) / std
    return inp_out

def create_dense_target(tar: np.ndarray):
    classes = np.unique(tar)
    dummy = np.zeros_like(tar)
    for idx, value in enumerate(classes):
        mask = np.where(tar == value)
        dummy[mask] = idx
        
    return dummy

def center_crop_to_size(x: np.ndarray, size: Tuple, copy: bool = False) -> np.ndarray:
    # center crop a given array x to the size passed in the function
    # expects even spatial dimensions!
    x_shape = np.array(x.shape)
    size = np.array(size)
    params_list = ((x_shape - size) / 2).astype(np.int).tolist()
    params_tuple = tuple([(i, i) for i in params_list])
    cropped_image = crop(x, crop_width=params_tuple, copy=copy)
    return cropped_image

def re_normalize(inp: np.ndarray, low: int = 0, high: int = 255):
    # normalize the data to a certain range
    # default: [0-255]
    inp_out = bytescale(inp, low=low, high=high)
    return inp_out

def random_flip(inp: np.ndarray, tar: np.ndarray, ndim_spatial: int):
    flip_dims = [np.random.randint(low=0, high=2) for dim in range(ndim_spatial)]
    
    flip_dims_inp = tuple([i + 1 for i, element in enumerate(flip_dims) if element == 1])
    flip_dims_tar = tuple([i for i, element in enumerate(flip_dims) if element == 1])
    
    inp_flipped = np.flip(inp, axis=flip_dims_inp)
    tar_flipped = np.flip(tar, axis=flip_dims_tar)
    
    return inp_flipped, tar_flipped


class Repr:
    # evaluable string representation of an object
    
    def __repr__(self):
        return f'{self.__class__.__name__}: {self.__dict__}'
    
    
class FunctionWrapperSingle(Repr):
    # a function wrapper that returns a partial for input only
    
    def __init__(self, function: Callable, *args, **kwargs):
        from functools import partial
        self.function = partial(function, *args, **kwargs)
        
    def __call__(self, inp: np.ndarray):
        return self.function(inp)

    
class FunctionWrapperDouble(Repr):
    # a function wrapper that returns a partial for an input-target pair
    
    def __init__(self, function: Callable, input: bool = True, target: bool = False, *args, **kwargs):
        from functools import partial
        self.function = partial(function, *args, **kwargs)
        self.input = input
        self.target = target
        
    def __call__(self, inp: np.ndarray, tar: dict):
        if self.input: 
            inp = self.function(inp)
        if self.target:
            tar = self.function(tar)
        return inp, tar
    

class Compose:
    # baseclass - composes several transforms together
    
    def __init__(self, transforms: List[Callable]):
        self.transforms = transforms
        
    def __repr__(self): 
        return str([transform for transform in transforms])
    

class ComposeDouble(Compose):
    # composes transforms for input-target pairs
    
    def __call__(self, inp: np.ndarray, target: dict):
        for t in self.transforms:
            inp, target = t(inp, target)
        return inp, target
    

class ComposeSingle(Compose):
    # composes transforms for input only
    
    def __call__(self, inp: np.ndarray):
        for t in self.transforms:
            inp = t(inp)
        return inp
    
class AlbuSeg2d(Repr):
    # wrapper for albumentations' segmentation-compatible 2D augmentations
    # wraps an augmentation so it can be used within the provided transform pipeline
    # see https://github.com/albu/albumentations for more information
    # expected input: (C, spatial_dims)
    # expected target: (spatial_dims) -> No (C)hannel dimension
    
    def __init__(self, albumentation: Callable):
        self.albumentation = albumentation
        
    def __call__(self, inp: np.ndarray, tar: np.ndarray):
        # input, target
        out_dict = self.albumentation(image=inp, mask=tar)
        input_out = out_dict['image']
        target_out = out_dict['mask']
        
        return input_out, target_out
    
class AlbuSeg3d(Repr):
    # wrapper for albumentations' segmentation-compatible 2D augmentations.
    # wraps an augmentation so it can be used within the provided transform pipeline.
    # see https://github.com/albu/albumentations for more information.
    # expected input: (spatial_dims)  -> No (C)hannel dimension
    # expected target: (spatial_dims) -> No (C)hannel dimension
    # iterates over the slices of a input-target pair stack and performs the same albumentation function.
    
    def __init__(self, albumentation: Callable):
        self.albumentation = A.ReplayCompose([albumentation])
        
    def __call__(self, inp: np.ndarray, tar: np.ndarray):
        # input, target
        # target has to be in uint8
        tar = tar.astype(np.uint8)
        
        input_copy = np.copy(inp)
        target_copy = np.copy(tar)
        
        # perform an albu on one slice and access the replay dict
        replay_dict = self.albumentation(image=inp[0])['replay']
        
        # todo: consider cases with rgb 3d or multimodal 3d input
        # only if input_shape == target_shape
        for index, (input_slice, target_slice) in enumerate(zip(inp, tar)):
            result = A.ReplayCompose.replay(replay_dict, image=input_slice, mask=target_slice)
            input_copy[index] = result['image']
            target_copy[index] = result['mask']
            
        return input_copy, target_copy

The above code segment defines classes and functions that enable composition of transforms. These are needed to perform preprocessing on the input images before passing to the model. The following code segment tests out the functionality of the defined classes+functions.

In [None]:
import numpy as np
from skimage.transform import resize

x = np.random.randint(0, 256, size=(128, 128, 3), dtype=np.uint8)
y = np.random.randint(10, 15, size=(128, 128), dtype=np.uint8)

transforms = ComposeDouble([
    FunctionWrapperDouble(resize, input=True, target=False, output_shape=(64, 64, 3)),
    FunctionWrapperDouble(resize, input=False, target=True, output_shape=(64, 64), order=0, anti_aliasing=False, preserve_range=True),
    FunctionWrapperDouble(create_dense_target, input=False, target=True),
    FunctionWrapperDouble(np.moveaxis, input=True, target=False, source=-1, destination=0),
    FunctionWrapperDouble(normalize_01)
])

x_t, y_t = transforms(x, y)

print(f'x = shape: {x.shape}; type: {x.dtype}')
print(f'x = min: {x.min()}; max: {x.max()}')
print(f'x_t = shape: {x_t.shape}; type: {x_t.dtype}')
print(f'x_t = min: {x_t.min()}; max: {x_t.max()}')

print(f'y = shape: {y.shape}; class: {np.unique(y)}')
print(f'y_t = shape: {y_t.shape}; class: {np.unique(y_t)}')