In [None]:
import numpy as np
import random
from scipy.ndimage import gaussian_filter, map_coordinates

class Advanced3DAugment:
    def __init__(self, 
                 max_rotate=15,  # degrees
                 max_shift=5,    # voxels
                 max_scale=0.1,  # ±10%
                 elastic_alpha=500, elastic_sigma=20,
                 p_intensity=0.5):
        self.max_rotate = max_rotate
        self.max_shift = max_shift
        self.max_scale = max_scale
        self.elastic_alpha = elastic_alpha
        self.elastic_sigma = elastic_sigma
        self.p_intensity = p_intensity

    def __call__(self, patch):
        # Remove channel dim
        vol = patch[0].astype(np.float32)
        D, H, W = vol.shape

        # 1. Random rotation around each axis
        for axis in [(1,2), (0,2), (0,1)]:
            angle = random.uniform(-self.max_rotate, self.max_rotate)
            vol = self.rotate_3d(vol, angle, axis)

        # 2. Random shift
        shifts = [random.uniform(-self.max_shift, self.max_shift) for _ in range(3)]
        vol = np.roll(vol, shifts, axis=(0,1,2))

        # 3. Random scale (resample & crop/pad)
        scale = 1.0 + random.uniform(-self.max_scale, self.max_scale)
        vol = self.rescale(vol, scale)

        # 4. Elastic deformation
        vol = self.elastic_transform(vol, self.elastic_alpha, self.elastic_sigma)

        # 5. Photometric variation
        if random.random() < self.p_intensity:
            vol = self.intensity_transform(vol)

        # re-add channel
        return np.expand_dims(vol, axis=0)

    def rotate_3d(self, vol, angle, axes):
        from scipy.ndimage import rotate
        return rotate(vol, angle, axes=axes, reshape=False, order=1, mode='constant', cval=0)

    def rescale(self, vol, scale):
        from scipy.ndimage import zoom
        zoomed = zoom(vol, scale, order=1)
        # center crop or pad
        result = np.zeros_like(vol)
        cd = (np.array(zoomed.shape) - np.array(vol.shape)) // 2
        if scale >= 1:
            result = zoomed[
                cd[0]:cd[0]+vol.shape[0],
                cd[1]:cd[1]+vol.shape[1],
                cd[2]:cd[2]+vol.shape[2]
            ]
        else:
            result[
                -cd[0]:-cd[0]+zoomed.shape[0],
                -cd[1]:-cd[1]+zoomed.shape[1],
                -cd[2]:-cd[2]+zoomed.shape[2]
            ] = zoomed
        return result

    def elastic_transform(self, vol, alpha, sigma):
        shape = vol.shape
        dx = gaussian_filter((np.random.rand(*shape)*2 -1), sigma)*alpha
        dy = gaussian_filter((np.random.rand(*shape)*2 -1), sigma)*alpha
        dz = gaussian_filter((np.random.rand(*shape)*2 -1), sigma)*alpha
        x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2]), indexing='xy')
        coords = np.array([y+dy, x+dx, z+dz])
        return map_coordinates(vol, coords, order=1, mode='reflect')

    def intensity_transform(self, vol):
        # Random contrast and brightness
        vol = vol * random.uniform(0.9, 1.1) + random.uniform(-0.1, 0.1)
        return np.clip(vol, 0, 1)
