In [9]:
import numpy as np
import torchio as tio
import torch
import SimpleITK as sitk

In [14]:
def _crop(image, cropping_params):

    low = cropping_params[::2]
    high = cropping_params[1::2]
    index_ini = low
    index_fin = np.array(image.shape) - high 
    i0, j0, k0 = index_ini
    i1, j1, k1 = index_fin
    image_cropped = image[i0:i1, j0:j1, k0:k1]

    return(image_cropped)

def _pad(image, padding_params):
    paddings = padding_params[:2], padding_params[2:4], padding_params[4:]
    image_padded = np.pad(image, paddings, mode = 'constant', constant_values = 0)  

    return(image_padded)

def invertCropOrPad(image, padding_params, cropping_params):
    if padding_params is not None:
        image = _crop(image, padding_params)
    if cropping_params is not None:
        image = _pad(image, cropping_params)

    return(image)

def getCroppingParams(subject, mask_name, target_shape):
    '''Function to get the cropping and padding parameters used in an apply_transform call of torchio.CropOrPad, which can then be used to invert the transformation later on'''

    mask_data = subject[mask_name].data.bool().numpy()

    subject_shape = subject.spatial_shape
    bb_min, bb_max = _bbox_mask(mask_data[0])
    center_mask = np.mean((bb_min, bb_max), axis=0)
    padding = []
    cropping = []

    for dim in range(3):
        target_dim = target_shape[dim]
        center_dim = center_mask[dim]
        subject_dim = subject_shape[dim]

        center_on_index = not (center_dim % 1)
        target_even = not (target_dim % 2)

        # Approximation when the center cannot be computed exactly
        # The output will be off by half a voxel, but this is just an
        # implementation detail
        if target_even ^ center_on_index:
            center_dim -= 0.5

        begin = center_dim - target_dim / 2
        if begin >= 0:
            crop_ini = begin
            pad_ini = 0
        else:
            crop_ini = 0
            pad_ini = -begin

        end = center_dim + target_dim / 2
        if end <= subject_dim:
            crop_fin = subject_dim - end
            pad_fin = 0
        else:
            crop_fin = 0
            pad_fin = end - subject_dim

        padding.extend([pad_ini, pad_fin])
        cropping.extend([crop_ini, crop_fin])
    
    # Conversion for SimpleITK compatibility
    padding_array = np.asarray(padding, dtype=int)
    cropping_array = np.asarray(cropping, dtype=int)
    if padding_array.any():
        padding_params = tuple(padding_array.tolist())
    else:
        padding_params = None
    if cropping_array.any():
        cropping_params = tuple(cropping_array.tolist())
    else:
        cropping_params = None
    return padding_params, cropping_params  # type: ignore[return-value]

def _bbox_mask(mask_volume: np.ndarray):
        """Return 6 coordinates of a 3D bounding box from a given mask.

        Taken from `this SO question <https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array>`_.

        Args:
            mask_volume: 3D NumPy array.
        """  # noqa: B950
        i_any = np.any(mask_volume, axis=(1, 2))
        j_any = np.any(mask_volume, axis=(0, 2))
        k_any = np.any(mask_volume, axis=(0, 1))
        i_min, i_max = np.where(i_any)[0][[0, -1]]
        j_min, j_max = np.where(j_any)[0][[0, -1]]
        k_min, k_max = np.where(k_any)[0][[0, -1]]
        bb_min = np.array([i_min, j_min, k_min])
        bb_max = np.array([i_max, j_max, k_max]) + 1
        return bb_min, bb_max

TESTING: points_list pre crop: [[70 48 59]
 [65 37 52]
 [64 63 73]
 [68 47 75]
 [67 91 62]]
TESTING: pad/crop ((1, 0, 0, 0, 2, 22), (0, 33, 1, 31, 0, 0))
TESTING: points list post crop: tensor([[76, 46, 70],
        [53, 36, 67],
        [74, 62, 66],
        [60, 47, 72],
        [63, 90, 69]])

In [91]:
pts = np.array(([[70, 48, 59],
 [65, 37, 52],
 [64, 63, 73],
 [68, 47, 75],
 [67, 91, 62]]))
points_mask = np.zeros((104,160,160))
points_mask[*pts.T] = 1

image = sitk.GetArrayFromImage(sitk.ReadImage('/home/t722s/Desktop/Datasets/BratsMini/imagesTs/BraTS2021_01646.nii.gz'))
label = sitk.GetArrayFromImage(sitk.ReadImage('/home/t722s/Desktop/Datasets/BratsMini/labelsTs/BraTS2021_01646.nii.gz'))

subject = tio.Subject(
    image = tio.ScalarImage(tensor = torch.from_numpy(image).permute(2,1,0).unsqueeze(0)), # add channel dimension to everything, and permute to x,y,z orientation
    points_mask = tio.LabelMap(tensor = torch.from_numpy(points_mask).permute(2,1,0).float().unsqueeze(0)),
    label = tio.LabelMap(tensor = torch.from_numpy(label).permute(2,1,0).unsqueeze(0))
)

padding_params, cropping_params = getCroppingParams(subject, 'points_mask', [128,128,128])

t = tio.CropOrPad((128,128,128), mask_name = 'points_mask')
subject = t(subject)
subject.image.data.shape

torch.Size([1, 128, 128, 128])

In [92]:
mask = points_mask.transpose(2,1,0)
mask_crop = invertCropOrPad(mask, cropping_params, padding_params)

mask_uncrop = invertCropOrPad(mask_crop, padding_params, cropping_params)
mask_uncrop = mask_uncrop.transpose(2,1,0)
#print(np.argwhere(mask_uncrop), pts)

# The invert function works: The pts are the same after inverting as they were before applying the transformation


In [104]:
print(padding_params, cropping_params)
pts_turned = pts[:,::-1]
pts_turned_trans = pts_turned.copy()
pts_turned_trans[:,2]-=3


(0, 0, 0, 0, 0, 27) (0, 32, 0, 32, 3, 0)


(array([[59, 48, 67],
        [52, 37, 62],
        [73, 63, 61],
        [75, 47, 65],
        [62, 91, 64]]),
 array([[52, 37, 62],
        [59, 48, 67],
        [62, 91, 64],
        [73, 63, 61],
        [75, 47, 65]]),
 tensor([[ 0, 52, 37, 62],
         [ 0, 59, 48, 67],
         [ 0, 62, 91, 64],
         [ 0, 73, 63, 61],
         [ 0, 75, 47, 65]]))

In [149]:
a = np.array([1,2,3])
a.tolist()

[1, 2, 3]

In [120]:
import numpy as np

# Assuming pts_trans, axis_add, and axis_sub are defined
# For example:
pts_trans = np.array([[1, 2], [3, 4], [5, 6]])
axis_add = np.array([1, 5])
axis_sub = np.array([0, 1])

# Perform the operation in a single line
pts_trans += axis_add - axis_sub
pts_trans


array([[ 2,  6],
       [ 4,  8],
       [ 6, 10]])

In [147]:
a = torch.tensor([5,6])
type(a) == torch.Tensor

True

In [138]:
pts_trans = torch.argwhere(subject.points_mask.data.squeeze(0)).numpy()

def transformPoints(pts, padding_params, cropping_params):
    # Handle none types. Could have been handled in getCroppingParams but it's kept there for consistency with torchio
    if padding_params is None:
        padding_params = np.zeros(6)
    if cropping_params is None:
        cropping_params = np.zeros(6)

    axis_add, axis_sub = padding_params[::2], cropping_params[::2]

    pts = pts + axis_add - axis_sub # same as pts_trans[:,i] = pts_trans[:,i] + axis_add[i] - axis_sub[i] iterating over i
    pts = pts[:,::-1]
    return(pts)

In [142]:
type(torch.tensor([1,2,3])) == torch.Tensor

True