In [5]:
from batchgenerators.transforms.local_transforms import (
    BrightnessGradientAdditiveTransform,
    LocalGammaTransform,
    LocalSmoothingTransform,
    LocalContrastTransform
)
from batchgenerators.transforms.abstract_transforms import (
    Compose,
    AbstractTransform
)

import os, sys
sys.path.append('../../')

from data_utils import *

In [3]:
nnUNet_train_augmentations = [
   SpatialTransform(
        independent_scale_for_each_axis = False, 
        p_rot_per_sample = 0.2, 
        p_scale_per_sample = 0.2, 
        p_el_per_sample = 0.2, 
        data_key = 'data', 
        label_key = 'seg', 
        patch_size = np.array([256, 256]), 
        patch_center_dist_from_border = None, 
        do_elastic_deform = False, 
        alpha = (0.0, 200.0), 
        sigma = (9.0, 13.0), 
        do_rotation = True, 
        angle_x = (-3.141592653589793, 3.141592653589793), 
        angle_y = (-0.0, 0.0), 
        angle_z = (-0.0, 0.0), 
        do_scale = True,
        scale = (0.7, 1.4), 
        border_mode_data = 'constant',
        border_cval_data = 0, 
        order_data = 3, 
        border_mode_seg = 'constant',
        border_cval_seg = -1, 
        order_seg = 1,
        random_crop = False,
        p_rot_per_axis = 1, 
        p_independent_scale_per_axis = 1
    ),
#     GaussianNoiseTransform(
#         p_per_sample = 0.1, 
#         data_key = 'data', 
#         noise_variance = (0, 0.1), 
#         p_per_channel = 1, 
#         per_channel = False
#     ),
    GaussianBlurTransform(
        p_per_sample = 0.2, 
        different_sigma_per_channel = True, 
        p_per_channel = 0.5, 
        data_key = 'data', 
        blur_sigma = (0.5, 1.0), 
        different_sigma_per_axis = False, 
        p_isotropic = 0
    ),
    BrightnessMultiplicativeTransform(
        p_per_sample = 0.15, 
        data_key = 'data', 
        multiplier_range = (0.75, 1.25), 
        per_channel = True
    ),
    ContrastAugmentationTransform(
        p_per_sample = 0.15, 
        data_key = 'data', 
        contrast_range = (0.75, 1.25), 
        preserve_range = True, 
        per_channel = True, 
        p_per_channel = 1
    ),
#     SimulateLowResolutionTransform(
#         order_upsample = 3, 
#         order_downsample = 0, 
#         channels = None, 
#         per_channel = True, 
#         p_per_channel = 0.5, 
#         p_per_sample = 0.25, 
#         data_key = 'data',
#         zoom_range = (0.5, 1), 
#         ignore_axes = None
#     ),
    GammaTransform(
        p_per_sample = 0.1,
        retain_stats = True, 
        per_channel = True, 
        data_key = 'data', 
        gamma_range = (0.7, 1.5), 
        invert_image = True
    ),
    GammaTransform(
        p_per_sample = 0.3,
        retain_stats = True, 
        per_channel = True, 
        data_key = 'data', 
        gamma_range = (0.7, 1.5), 
        invert_image = False
    ),
    MirrorTransform(
        p_per_sample = 1, 
        data_key = 'data', 
        label_key = 'seg', 
        axes = (0, 1)
    ),
#    RemoveLabelTransform(
#         output_key = 'seg', 
#         input_key = 'seg', 
#         replace_with = 0, 
#         remove_label = -1
#     ),
#     RenameTransform(
#         delete_old = True, 
#         out_key = 'target', 
#         in_key = 'seg'
#     ),
#    NumpyToTensor(
#         keys = ['data', 'target'], 
#         cast_to = 'float'
#     )
]



nnUNet_val_augmentations = [
    RemoveLabelTransform(
        output_key = 'seg', 
        input_key = 'seg',
        replace_with = 0,
        remove_label = -1
    ),
   RenameTransform(
        delete_old = True,
        out_key = 'target',
        in_key = 'seg'
    ),
   NumpyToTensor(
        keys = ['data', 'target'], 
        cast_to = 'float')    
]


original_transforms = (
    SimulateLowResolutionTransform,
    GaussianNoiseTransform,
    RemoveLabelTransform,
    RenameTransform,
    NumpyToTensor
)
    
# Local version of training transforms
scale = 200.
local_transforms = [
    BrightnessGradientAdditiveTransform(
        scale=scale, 
        max_strength=4, 
        p_per_sample=0.2, 
        p_per_channel=1
    ),
    LocalGammaTransform(
        scale=scale, 
        gamma=(2, 5), 
        p_per_sample=0.2,
        p_per_channel=1
    ),
    LocalSmoothingTransform(
        scale=scale,
        smoothing_strength=(0.5, 1),
        p_per_sample=0.2,
        p_per_channel=1
    ),
    LocalContrastTransform(
        scale=scale,
        new_contrast=(1, 3),
        p_per_sample=0.2,
        p_per_channel=1
    ),
]

In [15]:
class Transforms(object):
    """Composes image transforms for 
    
    """
    
    def __init__(
        self,
    ):
        self.transforms = {}
        
        io_transforms = [
            RemoveLabelTransform(
                output_key = 'seg', 
                input_key = 'seg',
                replace_with = 0,
                remove_label = -1
            ),
           RenameTransform(
                delete_old = True,
                out_key = 'target',
                in_key = 'seg'
            ),
           NumpyToTensor(
                keys = ['data', 'target'], 
                cast_to = 'float')    
        ]
        self.transforms[
            'io_transforms'
        ] = io_transforms
        
        global_nonspatial_transforms = [
            SimulateLowResolutionTransform(
                order_upsample = 3, 
                order_downsample = 0, 
                channels = None, 
                per_channel = True, 
                p_per_channel = 0.5, 
                p_per_sample = 0.25, 
                data_key = 'data',
                zoom_range = (0.5, 1), 
                ignore_axes = None
            ),
            GaussianNoiseTransform(
                p_per_sample = 0.1, 
                data_key = 'data', 
                noise_variance = (0, 0.1), 
                p_per_channel = 1, 
                per_channel = False
            ),
        ] 
        self.transforms[
            'global_nonspatial_transforms'
        ] = global_nonspatial_transforms + io_transforms
        
        global_transforms = [
            SpatialTransform(
                independent_scale_for_each_axis = False, 
                p_rot_per_sample = 0.2, 
                p_scale_per_sample = 0.2, 
                p_el_per_sample = 0.2, 
                data_key = 'data', 
                label_key = 'seg', 
                patch_size = np.array([256, 256]), 
                patch_center_dist_from_border = None, 
                do_elastic_deform = False, 
                alpha = (0.0, 200.0), 
                sigma = (9.0, 13.0), 
                do_rotation = True, 
                angle_x = (-3.141592653589793, 3.141592653589793), 
                angle_y = (-0.0, 0.0), 
                angle_z = (-0.0, 0.0), 
                do_scale = True,
                scale = (0.7, 1.4), 
                border_mode_data = 'constant',
                border_cval_data = 0, 
                order_data = 3, 
                border_mode_seg = 'constant',
                border_cval_seg = -1, 
                order_seg = 1,
                random_crop = False,
                p_rot_per_axis = 1, 
                p_independent_scale_per_axis = 1
            ),
            GaussianBlurTransform(
                p_per_sample = 0.2, 
                different_sigma_per_channel = True, 
                p_per_channel = 0.5, 
                data_key = 'data', 
                blur_sigma = (0.5, 1.0), 
                different_sigma_per_axis = False, 
                p_isotropic = 0
            ),
            BrightnessMultiplicativeTransform(
                p_per_sample = 0.15, 
                data_key = 'data', 
                multiplier_range = (0.75, 1.25), 
                per_channel = True
            ),
            ContrastAugmentationTransform(
                p_per_sample = 0.15, 
                data_key = 'data', 
                contrast_range = (0.75, 1.25), 
                preserve_range = True, 
                per_channel = True, 
                p_per_channel = 1
            ),
            GammaTransform(
                p_per_sample = 0.1,
                retain_stats = True, 
                per_channel = True, 
                data_key = 'data', 
                gamma_range = (0.7, 1.5), 
                invert_image = True
            ),
            GammaTransform(
                p_per_sample = 0.3,
                retain_stats = True, 
                per_channel = True, 
                data_key = 'data', 
                gamma_range = (0.7, 1.5), 
                invert_image = False
            ),
            MirrorTransform(
                p_per_sample = 1, 
                data_key = 'data', 
                label_key = 'seg', 
                axes = (0, 1)
            ),
        ] 
        self.transforms[
            'global_transforms'
        ] = global_transforms + io_transforms
        
        local_transforms = [
            BrightnessGradientAdditiveTransform(
                scale=scale, 
                max_strength=4, 
                p_per_sample=0.2, 
                p_per_channel=1
            ),
            LocalGammaTransform(
                scale=scale, 
                gamma=(2, 5), 
                p_per_sample=0.2,
                p_per_channel=1
            ),
            LocalSmoothingTransform(
                scale=scale,
                smoothing_strength=(0.5, 1),
                p_per_sample=0.2,
                p_per_channel=1
            ),
            LocalContrastTransform(
                scale=scale,
                new_contrast=(1, 3),
                p_per_sample=0.2,
                p_per_channel=1
            ),
        ]
        self.transforms[
            'local_transforms'
        ] = global_nonspatial_transforms + local_transforms + io_transforms
        
    def get_transforms(
        self, 
        arg: str
    ) -> AbstractTransform:
        return Compose(self.transforms[arg])

In [16]:
transforms = Transforms()

In [14]:
transforms.get_transforms('global_transforms')

Compose ( [SpatialTransform( independent_scale_for_each_axis = False, p_rot_per_sample = 0.2, p_scale_per_sample = 0.2, p_el_per_sample = 0.2, data_key = 'data', label_key = 'seg', patch_size = array([256, 256]), patch_center_dist_from_border = None, do_elastic_deform = False, alpha = (0.0, 200.0), sigma = (9.0, 13.0), do_rotation = True, angle_x = (-3.141592653589793, 3.141592653589793), angle_y = (-0.0, 0.0), angle_z = (-0.0, 0.0), do_scale = True, scale = (0.7, 1.4), border_mode_data = 'constant', border_cval_data = 0, order_data = 3, border_mode_seg = 'constant', border_cval_seg = -1, order_seg = 1, random_crop = False, p_rot_per_axis = 1, p_independent_scale_per_axis = 1 ), GaussianBlurTransform( p_per_sample = 0.2, different_sigma_per_channel = True, p_per_channel = 0.5, data_key = 'data', blur_sigma = (0.5, 1.0), different_sigma_per_axis = False, p_isotropic = 0 ), BrightnessMultiplicativeTransform( p_per_sample = 0.15, data_key = 'data', multiplier_range = (0.75, 1.25), per_cha