In [None]:
import os
import copy
import math
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from scipy import ndimage
import nibabel as nib
import torch


import monai
print(monai.__version__, monai.__file__)

import monai.transforms.spatial.old_array as old
import monai.transforms.spatial.old_dictionary as oldd
from monai.transforms.lazy.functional import apply_pending
from monai.transforms.spatial.functional import spacing
# from monai.utils.mapping_stack import MetaMatrix
from monai.transforms import Invert, AddChannel, Compose, Crop, LoadImage, LoadImaged, EnsureChannelFirst, EnsureChannelFirstd
from monai.transforms.spatial.array import Flip, RandFlip, Resize, Rotate, RandRotate, Rotate90, RandRotate90, Spacing, Zoom, RandZoom
from monai.transforms.spatial.array import RandGridDistortion, Rand2DElastic
from monai.transforms.spatial.dictionary import Spacingd, Resized, RandFlipd, RandRotated, RandRotate90d, RandZoomd
from monai.transforms.croppad.functional import croppad
from monai.transforms.croppad.old_array import RandSpatialCrop
from monai.transforms.croppad.array import CropPad
from monai.transforms.croppad.old_dictionary import RandSpatialCropd
from monai.transforms.croppad.dictionary import RandCropPadd
from monai.data.meta_tensor import MetaTensor
from monai.losses.dice import DiceLoss
# !pip list | grep monai

In [None]:
display_images = dict()
base_dir_ = '/home/ben/data/preprocessed/Task01_BrainTumour/orig'
sample_str_ = 'BRATS_{}_{}.nii.gz'
brats_match_ = lambda x: x[:5] == 'BRATS' and x[-7:] == '.nii.gz'

def get_source_id(name):
    id_ = name.split('.')[0].split('_')[1]
    return id_

def entry_type(entry):
    if 'image' in entry.split('/')[-1]:
        return 'image'
    return 'label'

In [None]:
def generate_entry_list(datapath, pattern_match, get_source_id, get_entry_type):
    samples = dict()

    for r, d, f in os.walk(datapath, pattern_match):
        for ff in f:
            if pattern_match(os.path.basename(ff)):
                value = get_source_id(ff)
    #             print(r, ff)
                entries = samples.get(value, [])
                entries.append(os.path.join(r, ff))
                samples[value] = sorted(entries)

    entry_list = sorted(list(samples.items()), key=lambda x: x[0])
    # print(len(entry_list))
    entry_list = [e for e in entry_list if len(e[1]) == 2]
    # print(len(entry_list))
    return entry_list

entries_ = generate_entry_list(base_dir_, brats_match_, get_source_id, entry_type)
# print(entries_)

In [None]:
def get_img(size, dtype=torch.float32, offset=0):
    img = torch.zeros(size, dtype=dtype)
    if len(size) == 2:
        for j in range(size[0]):
            for i in range(size[1]):
                img[j, i] = i + j * size[0] + offset
    else:
        for k in range(size[0]):
            for j in range(size[1]):
                for i in range(size[2]):
                    img[k, j, i] = i + j * size[0] + k * size[0] * size[1]
    return np.expand_dims(img, 0)


def load_sample(sample, get_image=True, get_label=True):

    img = nib.load(sample[1][0]).get_fdata()
    lbl = nib.load(sample[1][1]).get_fdata()

    img = np.transpose(img, axes=(3, 0, 1, 2))
    lbl = np.expand_dims(lbl, axis=0)

    return None if get_image is False else img, None if get_label is False else lbl


def plot_datas(datas, cols=4, tight=False, size=20, axis=False, titles=None, font='arial'):
    # print(len(datas))
    fonts = ('arial', 'timesnewroman')
    if font not in fonts:
        print(f"unrecognised font {font}. Must be one of {fonts}")
    if font == 'arial':
        fontspec = {'fontname': 'Arial'}
    else:
        fontspec = {'fontname': 'Times New Roman'}
    minv = min([d.min() for d in datas])
    maxv = max([d.max() for d in datas])
    rows = len(datas) // cols if len(datas) % cols == 0 else len(datas) // cols + 1
    fig, ax = plt.subplots(rows, cols, figsize=(size, size * rows / cols))
    print("plot data shape:", ax.shape)
    if rows == 1:
        ax = np.expand_dims(ax, axis=0)
    if tight == True:
        plt.tight_layout()

    if titles is not None:
        if len(titles) != len(datas):
            raise ValueError("titles must be the same length as data if set")

    for i_d, d in enumerate(datas):
        if axis == False:
            ax[i_d // cols, i_d % cols].axis('off')
        if titles is not None:
            ax[i_d // cols, i_d % cols].set_title(titles[i_d], **fontspec)
            ax[i_d // cols, i_d % cols].title.set_fontsize(28)
        if len(datas) <= cols:
            ax[i_d // cols, i_d % cols].imshow(d[0,...] if len(d.shape) > 2 else d, vmin=minv, vmax=maxv)
        else:
            ax[i_d // cols, i_d % cols].imshow(d[0,...] if len(d.shape) > 2 else d)

def rand_seed(rng):
    value = rng.randint(np.int32((1<<31) - 1), dtype=np.int32)
#     print(value, type(value))
    return value


class RNGWrapper(np.random.RandomState):

    def __init__(self, tag, rng):
        self.tag = tag
        self.rng = rng
        self.calls = 0

    def rand(self, *args, **kwargs):
        self.calls += 1
        value = self.rng.rand(*args, **kwargs)
        print(self.tag, self.calls, value)
        return value

    def randint(self, *args, **kwargs):
        self.calls += 1
        value = self.rng.randint(*args, **kwargs)
        print(self.tag, self.calls, value)
        return value

    
def find_mid_label_z(label):
    
    first_z = None
    last_z = None
    for z in range(label.shape[-1]):
        count = np.count_nonzero(label[..., z])
        if count > 0:
            if first_z is None:
                first_z = z
            last_z = z
    
    if first_z is None:
        return 0, label.shape[-1], label.shape[-1] // 2

    return first_z, last_z, int((first_z + last_z) / 2)


def find_mid_label(label):
    first_v = [None, None, None]
    last_v = [None, None, None]
    slice_dv = [lambda im, v: im[:, v, ...],
                lambda im, v: im[..., v, :],
                lambda im, v: im[..., v]]

    for d in range(3):
        for v in range(label.shape[d+1]):
            count = np.count_nonzero(slice_dv[d](label, v))
            if count > 0:
                first_v[d] = v if first_v[d] is None else first_v[d]
                last_v[d] = v
#     print("first_v:", first_v)
#     print("last_v:", last_v)
    if first_v[0] == None:
        return tuple((0, label.shape[d+1], label.shape[d+1] // 2) for d in range(3))
    
    return tuple((first_v[d], last_v[d], (last_v[d] + first_v[d]) // 2) for d in range(3))


def sanitized_range_from_extents(mid_v, max_v, range_v):
    half_range = range_v // 2
    if mid_v - half_range < 0:
        return 0, range_v
    if mid_v + half_range >= max_v:
        return max_v - range_v, max_v
    return mid_v - half_range, mid_v + half_range


# def entropy(vol):
#     jh, _ = np.histogram(vol.ravel(), bins=256, density=True)
#     # Add epsillon values to compensate for 0 bins
#     jh = jh + np.finfo(np.float32).eps
#     # Compute and return the joint entropy
#     return np.sum(jh.ravel()*np.log(jh.ravel()))


def entropy(img_data):
    hist = np.histogram(img_data, bins=256)
    p = hist[0]
    p = p / np.sum(p)
    e = -np.sum(np.where(p != 0, p * np.log2(p), 0))
    return e

class Dots:
    def __init__(self, length=10):
        self._cur = 0
        self._length = length

    def dot(self):
        if self._cur > 0 and self._cur % self._length == 0:
            print()
        print(".", end="")
        self._cur += 1

    def done(self):
        if self._cur > 0:
            print()

In [None]:
def trad_pipeline():

    keys = ('image', 'label')
    masterrng = np.random.RandomState(12345678)

    resized = oldd.Resized(keys=keys, spatial_size=(192, 192, 72), mode=("area", "nearest"))
    randflipd = oldd.RandFlipd(keys=keys, prob=0.5, spatial_axis=[1, 2])
    randflipd.set_random_state(state=np.random.RandomState(rand_seed(masterrng)))
    rotate90d = oldd.RandRotate90d(keys=keys, prob=0.5, spatial_axes=(0, 1))
    rotate90d.set_random_state(state=np.random.RandomState(rand_seed(masterrng)))
    zoomd = oldd.RandZoomd(keys=keys, prob=1.0, min_zoom=0.75, max_zoom=1.25, mode=("area", "nearest"), keep_size=True)
    zoomd.set_random_state(state=np.random.RandomState(rand_seed(masterrng)))
    rotated = oldd.RandRotated(keys=keys, prob=1.0, range_z=(-torch.pi/4, torch.pi/4), mode=("bilinear", "nearest"), align_corners=True)
    rotated.set_random_state(state=np.random.RandomState(rand_seed(masterrng)))
    pipeline = Compose([resized, randflipd, rotate90d, zoomd, rotated])

    return pipeline


def trad_pipeline_patch_first():

    keys = ('image', 'label')
    masterrng = np.random.RandomState(12345678)
    randfliprng = np.random.RandomState(rand_seed(masterrng))
    rotate90rng = np.random.RandomState(rand_seed(masterrng))
    zoomrng = np.random.RandomState(rand_seed(masterrng))
    rotaterng = np.random.RandomState(rand_seed(masterrng))
    patch_seed = rand_seed(masterrng)
    # print("lazy patch seed:", patch_seed)
    patchrng = np.random.RandomState(patch_seed)

    patchd = RandSpatialCropd(keys=keys, roi_size=(160, 160, 155), random_size=False)
    patchd.set_random_state(state=patchrng)
    resized = oldd.Spacingd(keys=keys, pixdim=(1.0, 1.0, 155/72), mode=("bilinear", "nearest"))
    randflipd = oldd.RandFlipd(keys=keys, prob=0.5, spatial_axis=[1, 2])
    randflipd.set_random_state(state=randfliprng)
    rotate90d = oldd.RandRotate90d(keys=keys, prob=0.5, spatial_axes=(0, 1))
    rotate90d.set_random_state(state=rotate90rng)
    zoomd = oldd.RandZoomd(keys=keys, prob=1.0, min_zoom=0.75, max_zoom=1.25, mode=("area", "nearest"), keep_size=True)
    zoomd.set_random_state(state=zoomrng)
    rotated = oldd.RandRotated(keys=keys, prob=1.0, range_z=(-torch.pi/4, torch.pi/4), mode=("bilinear", "nearest"), align_corners=True)
    rotated.set_random_state(state=rotaterng)
    pipeline = Compose([patchd, resized, randflipd, rotate90d, zoomd, rotated])

    return pipeline


def trad_pipeline_patch_last():

    keys = ('image', 'label')
    masterrng = np.random.RandomState(12345678)
    randfliprng = np.random.RandomState(rand_seed(masterrng))
    rotate90rng = np.random.RandomState(rand_seed(masterrng))
    zoomrng = np.random.RandomState(rand_seed(masterrng))
    rotaterng = np.random.RandomState(rand_seed(masterrng))
    patch_seed = rand_seed(masterrng)
    print("lazy patch seed:", patch_seed)
    patchrng = RNGWrapper("trad", np.random.RandomState(patch_seed))

    resized = oldd.Spacingd(keys=keys, pixdim=(1.0, 1.0, 155/72), mode=("bilinear", "nearest"))
    randflipd = oldd.RandFlipd(keys=keys, prob=0.5, spatial_axis=[1, 2])
    randflipd.set_random_state(state=randfliprng)
    rotate90d = oldd.RandRotate90d(keys=keys, prob=0.5, spatial_axes=(0, 1))
    rotate90d.set_random_state(state=rotate90rng)
    zoomd = oldd.RandZoomd(keys=keys, prob=1.0, min_zoom=0.75, max_zoom=1.25, mode=("area", "nearest"), keep_size=True)
    zoomd.set_random_state(state=zoomrng)
    rotated = oldd.RandRotated(keys=keys, prob=1.0, range_z=(-torch.pi/4, torch.pi/4), mode=("bilinear", "nearest"), align_corners=True)
    rotated.set_random_state(state=rotaterng)
    patchd = RandSpatialCropd(keys=keys, roi_size=(160, 160, 72), random_size=False)
    patchd.set_random_state(state=patchrng)
    pipeline = Compose([resized, randflipd, rotate90d, zoomd, rotated, patchd])

    return pipeline


def trad_pipeline_label_only():

    print("trad_pipeline_label_only")
    masterrng = np.random.RandomState(12345678)
    loadimage = LoadImage(image_only=True)
    ensurech = EnsureChannelFirst()
    resize = old.Resize(spatial_size=(192, 192, 72), mode="nearest")
    randflip = old.RandFlip(prob=0.5, spatial_axis=[1, 2])
    # randflip.set_random_state(state=RNGWrapper("t", np.random.RandomState(rand_seed(masterrng))))
    randflip.set_random_state(state=np.random.RandomState(rand_seed(masterrng)))
    rotate90 = old.RandRotate90(prob=0.5, spatial_axes=(0, 1))
    rotate90.set_random_state(state=np.random.RandomState(rand_seed(masterrng)))
    zoom = old.RandZoom(prob=1.0, min_zoom=0.75, max_zoom=1.25, mode="nearest", keep_size=True)
    zoom.set_random_state(state=np.random.RandomState(rand_seed(masterrng)))
    rotate = old.RandRotate(prob=1.0, range_z=(-torch.pi/4, torch.pi/4), mode="nearest", align_corners=True)
    rotate.set_random_state(state=np.random.RandomState(rand_seed(masterrng)))
    pipeline = Compose([loadimage, ensurech, resize, randflip, rotate90, zoom, rotate])

    return pipeline


def lazy_pipeline(lazy=True):
    keys = ('image', 'label')
    masterrng = np.random.RandomState(12345678)

    pipeline = Compose([
            Resized(keys=keys, spatial_size=(192, 192, 72), mode=("bilinear", "nearest"), lazy=lazy),
            RandFlipd(keys=keys, prob=0.5, spatial_axis=[1, 2], lazy=lazy,
                      state=np.random.RandomState(rand_seed(masterrng))),
            RandRotate90d(keys=keys, prob=0.5, spatial_axes=(0, 1), lazy=lazy,
                          state=np.random.RandomState(rand_seed(masterrng))),
            RandZoomd(keys=keys, prob=1.0, min_zoom=0.75, max_zoom=1.25, mode=("bilinear", "nearest"), keep_size=True, lazy=lazy,
                      state=np.random.RandomState(rand_seed(masterrng))),
            RandRotated(keys=keys, prob=1.0, range_z=(-torch.pi/4, torch.pi/4), mode=("bilinear", "nearest"), align_corners=True, lazy=lazy,
                        state=np.random.RandomState(rand_seed(masterrng))),        
    ])
    
    return pipeline


# def lazy_pipeline_patch_first(lazy=True):
#     keys = ('image', 'label')
#     masterrng = np.random.RandomState(12345678)
#     randfliprng = np.random.RandomState(rand_seed(masterrng))
#     rotate90rng = np.random.RandomState(rand_seed(masterrng))
#     zoomrng = np.random.RandomState(rand_seed(masterrng))
#     rotaterng = np.random.RandomState(rand_seed(masterrng))
#     patch_seed = rand_seed(masterrng)
#     print("lazy patch seed:", patch_seed)
#     patchrng = np.random.RandomState(patch_seed)

#     pipeline = Compose([
#             RandCropPadd(keys=keys, sizes=(160, 160, 155), lazy_evaluation=lazy, state=patchrng),
#             Spacingd(keys=keys, pixdim=(1.0, 1.0, 155/72), mode=("bilinear", "nearest"), lazy_evaluation=lazy),
#             RandFlipd(keys=keys, prob=0.5, spatial_axis=[1, 2], lazy_evaluation=lazy, state=randfliprng),
#             RandRotate90d(keys=keys, prob=0.5, spatial_axes=(0, 1), lazy_evaluation=lazy, state=rotate90rng),
#             RandZoomd(keys=keys, prob=1.0, min_zoom=0.75, max_zoom=1.25, mode=("bilinear", "nearest"), keep_size=True, lazy_evaluation=lazy, state=zoomrng),
#             RandRotated(keys=keys, prob=1.0, range_z=(-torch.pi/4, torch.pi/4), mode=("bilinear", "nearest"), align_corners=True, lazy_evaluation=lazy, state=rotaterng),        
#     ])
    
#     return pipeline


def lazy_pipeline_patch_first(lazy=True, **kwargs):
    keys = ('image', 'label')
    masterrng = np.random.RandomState(12345678)
    randfliprng = np.random.RandomState(rand_seed(masterrng))
    rotate90rng = np.random.RandomState(rand_seed(masterrng))
    zoomrng = np.random.RandomState(rand_seed(masterrng))
    rotaterng = np.random.RandomState(rand_seed(masterrng))
    patch_seed = rand_seed(masterrng)
    # print("lazy patch seed:", patch_seed)
    patchrng = np.random.RandomState(patch_seed)

    pipeline = Compose([
        RandCropPadd(keys=keys, sizes=(160, 160, 155), lazy=lazy, state=patchrng),
        Resized(keys=keys, spatial_size=(160, 160, 72), mode=("bilinear", "nearest"), lazy=lazy),
        RandFlipd(keys=keys, prob=0.5, spatial_axis=[1, 2], lazy=lazy, state=randfliprng),
        RandRotate90d(keys=keys, prob=0.5, spatial_axes=(0, 1), lazy=lazy, state=rotate90rng),
        RandZoomd(keys=keys, prob=1.0, min_zoom=0.75, max_zoom=1.25, mode=("bilinear", "nearest"), keep_size=True,
                  lazy=lazy, state=zoomrng),
        RandRotated(keys=keys, prob=1.0, range_z=(-torch.pi / 4, torch.pi / 4), mode=("bilinear", "nearest"),
                    align_corners=True, lazy=lazy, state=rotaterng),
    ])

    return pipeline


# def lazy_pipeline_patch_last(lazy=True):
#     keys = ('image', 'label')
#     masterrng = np.random.RandomState(12345678)
#     randfliprng = np.random.RandomState(rand_seed(masterrng))
#     rotate90rng = np.random.RandomState(rand_seed(masterrng))
#     zoomrng = np.random.RandomState(rand_seed(masterrng))
#     rotaterng = np.random.RandomState(rand_seed(masterrng))
#     patch_seed = rand_seed(masterrng)
#     print("lazy patch seed:", patch_seed)
#     patchrng = np.random.RandomState(patch_seed)

#     pipeline = Compose([
#             Spacingd(keys=keys, pixdim=(1.0, 1.0, 155/72), mode=("bilinear", "nearest"), lazy_evaluation=lazy),
#             RandFlipd(keys=keys, prob=0.5, spatial_axis=[1, 2], lazy_evaluation=lazy, state=randfliprng),
#             RandRotate90d(keys=keys, prob=0.5, spatial_axes=(0, 1), lazy_evaluation=lazy, state=rotate90rng),
#             RandZoomd(keys=keys, prob=1.0, min_zoom=0.75, max_zoom=1.25, mode=("bilinear", "nearest"), keep_size=True, lazy_evaluation=lazy, state=zoomrng),
#             RandRotated(keys=keys, prob=1.0, range_z=(-torch.pi/4, torch.pi/4), mode=("bilinear", "nearest"), align_corners=True, lazy_evaluation=lazy, state=rotaterng),        
#             RandCropPadd(keys=keys, sizes=(160, 160, 72), lazy_evaluation=lazy, state=patchrng),
#     ])
    
#     return pipeline


def lazy_pipeline_patch_last(lazy=True, **kwargs):

    keys = ('image', 'label')
    masterrng = np.random.RandomState(12345678)
    randfliprng = np.random.RandomState(rand_seed(masterrng))
    rotate90rng = np.random.RandomState(rand_seed(masterrng))
    zoomrng = np.random.RandomState(rand_seed(masterrng))
    rotaterng = np.random.RandomState(rand_seed(masterrng))
    patch_seed = rand_seed(masterrng)
    print("lazy patch seed:", patch_seed)
    patchrng = RNGWrapper("lazy", np.random.RandomState(patch_seed))

    pipeline = Compose([
        Resized(keys=keys, spatial_size=(240, 240, 72), mode=("bilinear", "nearest"), lazy=lazy),
        RandFlipd(keys=keys, prob=0.5, spatial_axis=[1, 2], lazy=lazy, state=randfliprng),
        RandRotate90d(keys=keys, prob=0.5, spatial_axes=(0, 1), lazy=lazy, state=rotate90rng),
        RandZoomd(keys=keys, prob=1.0, min_zoom=0.75, max_zoom=1.25, mode=("bilinear", "nearest"), keep_size=True,
                  lazy=lazy, state=zoomrng),
        RandRotated(keys=keys, prob=1.0, range_z=(-torch.pi / 4, torch.pi / 4), mode=("bilinear", "nearest"),
                    align_corners=True, lazy=lazy, state=rotaterng),
        RandCropPadd(keys=keys, sizes=(160, 160, 72), lazy=lazy, state=patchrng),
    ])

    return pipeline


def lazy_pipeline_label_only(lazy=True):
    masterrng = np.random.RandomState(12345678)

    print("lazy_pipeline_label_only")
    pipeline = Compose([
            LoadImage(image_only=True),
            EnsureChannelFirst(),
            Resize(spatial_size=(192, 192, 72), mode="nearest", lazy=lazy),
            RandFlip(prob=0.5, spatial_axis=[1, 2], lazy=lazy,
                     state=np.random.RandomState(rand_seed(masterrng))),
            # RandFlip(prob=0.5, spatial_axis=[1, 2], lazy=lazy,
            #          state=RNGWrapper("l", np.random.RandomState(rand_seed(masterrng)))),
            RandRotate90(prob=0.5, spatial_axes=(0, 1), lazy=lazy,
                         state=np.random.RandomState(rand_seed(masterrng))),
            RandZoom(prob=1.0, min_zoom=0.75, max_zoom=1.25, mode="nearest", keep_size=True, lazy=lazy,
                     state=np.random.RandomState(rand_seed(masterrng))),
            RandRotate(prob=1.0, range_z=(-torch.pi/4, torch.pi/4), mode="nearest", align_corners=True, lazy=lazy,
                       state=np.random.RandomState(rand_seed(masterrng))),        
    ])
    
    return pipeline

# Modes

In [None]:
data_1d = np.expand_dims(np.linspace(0, 31, 32), 0)
data_2d = get_img((32, 24))
data_3d = get_img((32, 24, 16))

datas = []

result = torch.nn.functional.interpolate(datas, )
datas.append()


# Pad Modes

In [None]:
data_2d = get_img((32, 24))
data_2d[0,7:9,7:9] = 1096
data_2d[0,15:17,:] = 1160
data_2d[0,0,:] = 1224
data_2d[0,:,0] = 1286
data_2d = np.squeeze(data_2d, axis=0)

datas = []
border = (16, 12)
datas.append(data_2d)
np_padding_modes = ('constant', 'edge', 'linear_ramp', 'maximum', 'mean', 'median', 'minimum', 'reflect', 'symmetric', 'wrap', 'empty')
for p in np_padding_modes:
    result = np.pad(data_2d, border, mode=p)
    datas.append(result)

ndi_padding_modes = ('constant', 'grid-constant', 'nearest', 'reflect', 'mirror', 'grid-mirror', 'wrap', 'grid-wrap')
for p in ndi_padding_modes:
    affine = np.eye(3)
    affine[0, 2] = -12
    affine[1, 2] = -9
#     affine[0, 0] = 0.5
#     affine[1, 1] = 0.5
    result = ndimage.affine_transform(data_2d, affine, output_shape=(64, 48), mode=p, order=0)
# for p in ndi_padding_modes:
#     xs = np.arange(data_2d.shape[0] * data_2d.shape[1]) // data_2d.shape[0]
#     ys = np.arange(data_2d.shape[0] * data_2d.shape[1]) % data_2d.shape[1]
#     result = ndimage.map_coordinates(data_2d, coordinates=[xs, ys], order=0, mode=p).reshape(32, 24)
#     # print(result.shape, result)
    datas.append(result)

plot_datas(datas, cols=5, axis=True, titles=('base',) + np_padding_modes + ndi_padding_modes)

# Rotate

In [None]:
data = get_img((32, 32))
print(data.shape)

datas = []
print("rotate 0")
angle = torch.pi / 8
print(old.__file__)
r1 = old.Rotate(angle,
            keep_size=True,
            padding_mode="zeros")

datas.append(data)
data1 = r1(data)
datas.append(data1)
print("data1:", data1.shape, data1)

print("rotate 1")
r2 = Rotate(angle,
            padding_mode="zeros",
            keep_size=True,
            lazy=False)
data2 = r2(data)
datas.append(data2)
print("data2:", data2.shape)

print("rotate 2")
r3 = Rotate(angle,
            padding_mode="zeros",
            keep_size=True)
r3.lazy_evalution = True
data3a = r3(data)
data3 = apply_pending(data3a)
datas.append(data3)
print("data3:", data3.shape)

diff2_1 = data2 - data1
datas.append(diff2_1)
diff3_1 = data3 - data1
datas.append(diff3_1)
diff3_2 = data3 - data2
print(np.unique(diff3_1, return_counts=True))
datas.append(diff3_2)
plot_datas(datas)

# Rotate

In [None]:
data = get_img((32, 32))
data[0,7:9,7:9] = 1096
data[0,15:17,:] = 1160
data[0,0,:] = 1224
data[0,:,0] = 1286
print(data.shape)

angle = torch.pi / 8
def old_rotate(keep_size=True):
    z1 = old.Rotate(angle, mode="nearest", padding_mode="zeros", keep_size=True)
    z2 = old.Rotate(angle, mode="bilinear", padding_mode="zeros", keep_size=True)
    z3 = old.Rotate(angle, mode="nearest", padding_mode="zeros", keep_size=False)
    z4 = old.Rotate(angle, mode="bilinear", padding_mode="zeros", keep_size=False)
    
    def _inner(imgs):
        datas = []
        datas.append(imgs)
        data1 = z1(imgs)
        datas.append(data1)
        data2 = z2(imgs)
        datas.append(data2)
        data3 = z3(imgs)
        datas.append(data3)
        data4 = z4(imgs)
        datas.append(data4)
        return datas

    return _inner

def new_rotate(keep_size=True):
    z1 = Rotate(angle, mode="nearest", padding_mode="zeros", keep_size=True, lazy=False)
    z2 = Rotate(angle, mode="bilinear", padding_mode="zeros", keep_size=True, lazy=False)
    z3 = Rotate(angle, mode="nearest", padding_mode="zeros", keep_size=False, lazy=False)
    z4 = Rotate(angle, mode="bilinear", padding_mode="zeros", keep_size=False, lazy=False)

    def _inner(imgs):
        datas = []
        datas.append(imgs)
        data1 = z1(imgs)
        print(data1.affine)
        print(data1.affine)
        datas.append(data1)
        data2 = z2(imgs)
        datas.append(data2)
        data3 = z3(imgs)
        datas.append(data3)
        data4 = z4(imgs)
        datas.append(data4)
        return datas
    
    return _inner

old_results = old_rotate(False)(data)
new_results = new_rotate(False)(data)

diffs = []
for o, n in zip(old_results, new_results):
    diffs.append(n - o)

plot_datas(old_results + new_results + diffs, cols=5)

# Rotate by 90 degrees

In [None]:
data = get_img((32, 32))
data[0,7:9,7:9] = 1096
data[0,15:17,:] = 1160
data[0,0,:] = 1224
data[0,:,0] = 1286
print(data.shape)
data = torch.tensor(data)

datas= []
datas.append(data)
data = data.flip([1])
datas.append(data)
data = data.permute((0, 2, 1))
datas.append(data)

plot_datas(datas, cols=3)

In [None]:
data = get_img((32, 32))
data[0,7:9,7:9] = 1096
data[0,15:17,:] = 1160
data[0,0,:] = 1224
data[0,:,0] = 1286
print(data.shape)

def old_rotate(keep_size=True):
    z1 = old.Rotate(torch.pi / 2, mode="nearest", padding_mode="zeros", keep_size=True)
    z2 = old.Rotate(torch.pi, mode="bilinear", padding_mode="zeros", keep_size=True)
    z3 = old.Rotate(3 * torch.pi / 2, mode="nearest", padding_mode="zeros", keep_size=False)
    z4 = old.Rotate(2 * torch.pi, mode="bilinear", padding_mode="zeros", keep_size=False)
    
    def _inner(imgs):
        datas = []
        datas.append(imgs)
        data1 = z1(imgs)
        datas.append(data1)
        data2 = z2(imgs)
        datas.append(data2)
        data3 = z3(imgs)
        datas.append(data3)
        data4 = z4(imgs)
        datas.append(data4)
        return datas

    return _inner

def new_rotate(keep_size=True):
    z1 = Rotate(torch.pi / 2, mode="nearest", padding_mode="zeros", keep_size=True, lazy=False)
    z2 = Rotate(torch.pi, mode="bilinear", padding_mode="zeros", keep_size=True, lazy=False)
    z3 = Rotate(3 * torch.pi / 2, mode="nearest", padding_mode="zeros", keep_size=False, lazy=False)
    z4 = Rotate(2 * torch.pi, mode="bilinear", padding_mode="zeros", keep_size=False, lazy=False)

    def _inner(imgs):
        datas = []
        datas.append(imgs)
        data1 = z1(imgs)
        print(data1.affine)
        print(data1.affine)
        datas.append(data1)
        data2 = z2(imgs)
        datas.append(data2)
        data3 = z3(imgs)
        datas.append(data3)
        data4 = z4(imgs)
        datas.append(data4)
        return datas
    
    return _inner

old_results = old_rotate(False)(data)
new_results = new_rotate(False)(data)

diffs = []
for o, n in zip(old_results, new_results):
    diffs.append(n - o)

plot_datas(old_results + new_results + diffs, cols=5)

# Zoom

In [None]:
data = get_img((32, 32))
data[0,7:9,7:9] = 1096
data[0,15:17,:] = 1160
data[0,0,:] = 1224
data[0,:,0] = 1286
print(data.shape)

def old_zoom(keep_size=True):
    z1 = old.Zoom(2, mode="nearest", padding_mode="zeros", keep_size=keep_size)
    z2 = old.Zoom(2, mode="bilinear", padding_mode="zeros", keep_size=keep_size)
    z3 = old.Zoom(0.5, mode="nearest", padding_mode="constant", keep_size=keep_size)
    z4 = old.Zoom(0.5, mode="bilinear", padding_mode="constant", keep_size=keep_size)
    
    def _inner(imgs):
        datas = []
        datas.append(imgs)
        data1 = z1(imgs)
        datas.append(data1)
        data2 = z2(imgs)
        datas.append(data2)
        data3 = z3(imgs)
        datas.append(data3)
        data4 = z4(imgs)
        datas.append(data4)
        return datas

    return _inner

def new_zoom(keep_size=True):
    z1 = Zoom(2, mode="nearest", padding_mode="zeros", keep_size=keep_size, lazy=False)
    z2 = Zoom(2, mode="bilinear", padding_mode="zeros", keep_size=keep_size, lazy=False)
    z3 = Zoom(0.5, mode="nearest", padding_mode="zeros", keep_size=keep_size, lazy=False)
    z4 = Zoom(0.5, mode="bilinear", padding_mode="zeros", keep_size=keep_size, lazy=False)

    def _inner(imgs):
        datas = []
        datas.append(imgs)
        data1 = z1(imgs)
        print(data1.affine)
        print(data1.affine)
        datas.append(data1)
        data2 = z2(imgs)
        datas.append(data2)
        data3 = z3(imgs)
        datas.append(data3)
        data4 = z4(imgs)
        datas.append(data4)
        return datas
    
    return _inner

old_results = old_zoom(True)(data)
new_results = new_zoom(True)(data)

diffs = []
# for o, n in zip(old_results, new_results):
#     diffs.append(n - o)

plot_datas(old_results + new_results + diffs, axis=True, cols=5)

# Resize

In [None]:
data = get_img((32, 32))
data[0,7:9,7:9] = 1096
data[0,15:17,:] = 1160
data[0,0,:] = 1224
data[0,:,0] = 1286
print(data.shape)

zoom_in = (48, 48)
zoom_out = (24, 24)
def old_resize():
    t1 = old.Resize(zoom_in, mode="nearest")
    t2 = old.Resize(zoom_in, mode="bilinear")
    t3 = old.Resize(zoom_out, mode="nearest")
    t4 = old.Resize(zoom_out, mode="bilinear")
    
    def _inner(imgs):
        datas = []
        datas.append(imgs)
        data1 = t1(imgs)
        datas.append(data1)
        data2 = t2(imgs)
        datas.append(data2)
        data3 = t3(imgs)
        datas.append(data3)
        data4 = t4(imgs)
        datas.append(data4)
        return datas

    return _inner

def new_resize():
    t1 = Resize(zoom_in, mode="nearest", lazy=False)
    t2 = Resize(zoom_in, mode="bilinear", lazy=False)
    t3 = Resize(zoom_out, mode="nearest", lazy=False)
    t4 = Resize(zoom_out, mode="bilinear", lazy=False)

    def _inner(imgs):
        datas = []
        datas.append(imgs)
        data1 = t1(imgs)
#         print(data1.affine)
#         data1.pending_operations[0].matrix.data[0,0] = 0.5
#         data1.pending_operations[0].matrix.data[1,1] = 0.5
#         print(data1.pending_operations[0].matrix.data)
#         data1 = apply_pending(data1)[0]
        datas.append(data1)
        data2 = t2(imgs)
        datas.append(data2)
        data3 = t3(imgs)
        datas.append(data3)
        data4 = t4(imgs)
        datas.append(data4)
        return datas
    
    return _inner

old_results = old_resize()(data)
new_results = new_resize()(data)

diffs = []
# for o, n in zip(old_results, new_results):
#     diffs.append(n - o)

plot_datas(old_results + new_results + diffs, cols=5)

# Spacing

In [None]:
data = get_img((32, 32))
data[0,7:9,7:9] = 1096
data[0,15:17,:] = 1160
data[0,0,:] = 1224
data[0,:,0] = 1286
print(data.shape)

dest_pixdim = (1.0, 1.0)
big_src_pixdim = (2.0, 2.0)
sml_src_pixdim = (0.5, 0.5)

def old_spacing():
    t1 = old.Spacing(dest_pixdim, mode="nearest")
    t2 = old.Spacing(dest_pixdim, mode="bilinear")
    t3 = old.Spacing(dest_pixdim, mode="nearest")
    t4 = old.Spacing(dest_pixdim, mode="bilinear")
    
    def _inner(imgs):
        datas = []
        datas.append(imgs)
        mt = MetaTensor(imgs)
        mt.affine = torch.tensor([[big_src_pixdim[0], 0.0, 0.0, 0.0],
                                  [0.0, big_src_pixdim[1], 0.0, 0.0],
                                  [0.0, 0.0, 1.0, 0.0],
                                  [0.0, 0.0, 0.0, 1.0]])
        data1 = t1(mt)
        datas.append(data1)
        mt = MetaTensor(imgs)
        mt.affine = torch.tensor([[big_src_pixdim[0], 0.0, 0.0, 0.0],
                                  [0.0, big_src_pixdim[1], 0.0, 0.0],
                                  [0.0, 0.0, 1.0, 0.0],
                                  [0.0, 0.0, 0.0, 1.0]])
        data2 = t2(mt)
        datas.append(data2)
        mt = MetaTensor(imgs)
        mt.affine = torch.tensor([[sml_src_pixdim[0], 0.0, 0.0, 0.0],
                                  [0.0, sml_src_pixdim[1], 0.0, 0.0],
                                  [0.0, 0.0, 1.0, 0.0],
                                  [0.0, 0.0, 0.0, 1.0]])
        data3 = t3(mt)
        datas.append(data3)
        mt = MetaTensor(imgs)
        mt.affine = torch.tensor([[sml_src_pixdim[0], 0.0, 0.0, 0.0],
                                  [0.0, sml_src_pixdim[1], 0.0, 0.0],
                                  [0.0, 0.0, 1.0, 0.0],
                                  [0.0, 0.0, 0.0, 1.0]])
        data4 = t4(mt)
        datas.append(data4)
        return datas

    return _inner

def new_spacing():
    t1 = Spacing(dest_pixdim, mode="nearest", lazy=False)
#     t1 = lambda x: spacing(x, dest_pixdim, big_src_pixdim, mode="nearest", padding_mode="zeros", lazy_evaluation=False)
    t2 = Spacing(dest_pixdim, mode="bilinear", lazy=False)
#     t2 = lambda x: spacing(x, dest_pixdim, big_src_pixdim, mode="bilinear", padding_mode="zeros", lazy_evaluation=False)
    t3 = Spacing(dest_pixdim, mode="nearest", lazy=False)
#     t3 = lambda x: spacing(x, dest_pixdim, sml_src_pixdim, mode="nearest", padding_mode="zeros", lazy_evaluation=False)
    t4 = Spacing(dest_pixdim, mode="bilinear", lazy=False)
#     t4 = lambda x: spacing(x, dest_pixdim, sml_src_pixdim, mode="bilinear", padding_mode="zeros", lazy_evaluation=False)

    def _inner(imgs):
        datas = []
        datas.append(imgs)
        mt = MetaTensor(imgs)
        mt.affine[0, 0] = big_src_pixdim[0]
        mt.affine[1, 1] = big_src_pixdim[1]
        data1 = t1(mt)
        datas.append(data1)
        mt = MetaTensor(imgs)
        mt.affine[0, 0] = big_src_pixdim[0]
        mt.affine[1, 1] = big_src_pixdim[1]
        data2 = t2(mt)
        datas.append(data2)
        mt = MetaTensor(imgs)
        mt.affine[0, 0] = sml_src_pixdim[0]
        mt.affine[1, 1] = sml_src_pixdim[1]
        data3 = t3(mt)
        datas.append(data3)
        mt = MetaTensor(imgs)
        mt.affine[0, 0] = sml_src_pixdim[0]
        mt.affine[1, 1] = sml_src_pixdim[1]
        data4 = t4(mt)
        datas.append(data4)
        return datas
    
    return _inner

old_results = old_spacing()(data)
new_results = new_spacing()(data)

for d in old_results:
    print(d.shape)

for d in new_results:
    print(d.shape)

diffs = []
# for o, n in zip(old_results, new_results):
#     diffs.append(n - o)

plot_datas(old_results + new_results + diffs, cols=5)

# Flip

In [None]:
data = get_img((32, 32))
data[0,7:9,7:9] = 1096
data[0,15:17,:] = 1160
data[0,0,:] = 1224
data[0,:,0] = 1286
print(data.shape)

def old_resize():
    t1 = old.Flip(0)
    t2 = old.Flip(1)
    t3 = old.Flip((0, 1))
    
    def _inner(imgs):
        datas = []
        datas.append(imgs)
        data1 = t1(imgs)
        datas.append(data1)
        data2 = t2(imgs)
        datas.append(data2)
        data3 = t3(imgs)
        datas.append(data3)
        return datas

    return _inner

def new_resize():
    t1 = Flip(0, lazy=False)
    t2 = Flip(1, lazy=False)
    t3 = Flip((0, 1), lazy=False)

    def _inner(imgs):
        datas = []
        datas.append(imgs)
        data1 = t1(imgs)
        datas.append(data1)
        data2 = t2(imgs)
        datas.append(data2)
        data3 = t3(imgs)
        datas.append(data3)
        return datas
    
    return _inner

old_results = old_resize()(data)
new_results = new_resize()(data)

diffs = []
# for o, n in zip(old_results, new_results):
#     diffs.append(n - o)

plot_datas(old_results + new_results + diffs, cols=4, axis=True)

# Rotate90

In [None]:
data = get_img((32, 24))
data[0,7:9,7:9] = 1096
data[0,15:17,:] = 1160
data[0,0,:] = 1224
data[0,:,0] = 1286
print(data.shape)

def old_resize():
    t1 = old.Rotate90(0)
    t2 = old.Rotate90(1)
    t3 = old.Rotate90(2)
    t4 = old.Rotate90(3)
    
    def _inner(imgs):
        datas = []
        datas.append(imgs)
        data1 = t1(imgs)
        datas.append(data1)
        data2 = t2(imgs)
        datas.append(data2)
        data3 = t3(imgs)
        datas.append(data3)
        data4 = t4(imgs)
        datas.append(data4)
        return datas

    return _inner

def new_resize():
    t1 = Rotate90(0, lazy=False)
    t2 = Rotate90(1, lazy=False)
    t3 = Rotate90(2, lazy=False)
    t4 = Rotate90(3, lazy=False)

    def _inner(imgs):
        datas = []
        datas.append(imgs)
        data1 = t1(imgs)
#         print(data1.affine)
#         data1.pending_operations[0].matrix.data[0,0] = 0.5
#         data1.pending_operations[0].matrix.data[1,1] = 0.5
#         print(data1.pending_operations[0].matrix.data)
#         data1 = apply_pending(data1)[0]
        datas.append(data1)
        data2 = t2(imgs)
        datas.append(data2)
        data3 = t3(imgs)
        datas.append(data3)
        data4 = t4(imgs)
        datas.append(data4)
        return datas
    
    return _inner

old_results = old_resize()(data)
new_results = new_resize()(data)

diffs = []
# for o, n in zip(old_results, new_results):
#     diffs.append(n - o)

plot_datas(old_results + new_results + diffs, cols=5)

# Resize / Crop

In [None]:
data = get_img((32, 32))
data[0,7:9,7:9] = 1096
data[0,15:17,:] = 1160
data[0,0,:] = 1224
data[0,:,0] = 1286
print(data.shape)

def old_rotate_then_crop():
    c1 = Crop()
    r1 = old.Rotate(torch.pi / 4, keep_size=False, padding_mode="zeros")
    
    def _inner(imgs):
        datas = []
        datas.append(imgs)
        data1 = r1(imgs)
        datas.append(data1)
        data2 = c1(data1, slices=(slice(0,16), slice(0,16)))
        datas.append(data2)
        return datas
    return _inner

def old_crop_then_rotate():
    c1 = Crop()
    r1 = old.Rotate(torch.pi / 4, keep_size=False, padding_mode="zeros")
    
    def _inner(imgs):
        datas = []
        datas.append(imgs)
        data1 = c1(imgs, slices=(slice(0,16), slice(0,16)))
        datas.append(data1)
        data2 = r1(data1)
        datas.append(data2)
        return datas
    return _inner

def new_rotate_then_crop(lazy=False):
    r1 = Rotate(torch.pi / 4, keep_size=False, padding_mode="zeros", lazy=lazy)
    c1 = CropPad(padding_mode="zeros", lazy=False)
    
    def _inner(imgs):
        datas = []
        datas.append(imgs)
        data1 = r1(imgs)
        datas.append(data1)
        data2 = c1(data1, slices=(slice(0,16), slice(0,16)))
        datas.append(data2)
        return datas
    return _inner

def new_crop_then_rotate(lazy=False):
    c1 = CropPad(padding_mode="zeros", lazy=lazy)
    r1 = Rotate(torch.pi / 4, keep_size=False, padding_mode="zeros", lazy=False)
    
    def _inner(imgs):
        datas = []
        datas.append(imgs)
        data1 = c1(imgs, slices=(slice(0,16), slice(0,16)))
        datas.append(data1)
        data2 = r1(data1)
        datas.append(data2)
        return datas
    return _inner

# crops = (old_rotate_then_crop(), new_rotate_then_crop(False), new_rotate_then_crop(True),
#          old_crop_then_rotate(), new_crop_then_rotate(False), new_crop_then_rotate(True))
crops = (old_rotate_then_crop(), old_crop_then_rotate(), new_crop_then_rotate(True))
for t in crops:
    datas = t(data)
    print(datas[0].shape)
    plot_datas(datas, 3)


# Functional croppad

In [None]:
def do_functional_croppad():
    data = get_img((16, 16))
    data[0,7:9,7:9] = 1096
    data[0,15:17,:] = 1160
    data[0,0,:] = 1224
    data[0,:,0] = 1286
    print(data.shape)

    img00 = croppad(data, slices=(slice(0, 8), slice(0, 8)), padding_mode="zeros")
#     img00.push_pending_transform(MetaMatrix(tx, md))
    actual00 = apply_pending(img00)

    img10 = croppad(data, slices=(slice(1, 9), slice(0, 8)), padding_mode="zeros")
#     img10.push_pending_transform(MetaMatrix(tx, md))
    actual10 = apply_pending(img10)

    c = Crop()
    actual10_2 = c(data, slices=(slice(1, 9), slice(0, 8)))

    plot_datas([data, actual00, actual10, actual10_2])

do_functional_croppad()

# Trad vs lazy - forward pass results

## Entropy - whole volume

In [None]:
# entropy test
import numpy as np

entropy_vals = list()

def entropy_test(samples):
    img, lbl = get_image_and_lab
    img = nib.load(sample[1][0])
    img_data = img.get_fdata()
    hist = np.histogram(img_data, bins=256)
    p = hist[0]
    p = p / np.sum(p)
    print(np.sum(p))
    e = -np.sum(np.where(p != 0, p * np.log2(p), 0))
    print(e)
    fig, ax = plt.subplots(1, 1, figsize=(24, 8))
    ax.set_yscale('log')
    ax.plot(hist[1][:-1], p)


entropy_test(entries_[0])

In [None]:
import time

import numpy as np
import nibabel as nib

def show_images(sample):
    img, lbl = load_sample(sample)
    ddict = {'image': img, 'label': lbl}

    # print(ddict['image'].shape, ddict['label'].shape)

    tp = trad_pipeline()
    lp = lazy_pipeline(True)

    pre_first_z, pre_last_z, pre_mid_slice = find_mid_label_z(ddict['image'])
    p_label_exts = find_mid_label(ddict['image'])
    # print("mid-label:", p_label_exts, tuple(p[1] - p[0] for p in p_label_exts))
    p_slice_x = slice(*sanitized_range_from_extents(p_label_exts[0][2], ddict['image'].shape[1], int(100 * 240 / 192)))
    p_slice_y = slice(*sanitized_range_from_extents(p_label_exts[1][2], ddict['image'].shape[2], int(100 * 240 / 192)))
    # print("extents:", p_slice_x, p_slice_y)

    num_samples = 8
    print(pre_first_z, pre_last_z, pre_mid_slice)
    # vols = []
    results = []
    t_time = 0
    l_time = 0
    for i in range(num_samples):
        t_start = time.time()
        t_out = tp(ddict)
        t_time += time.time() - t_start
        # t_first_z, t_last_z, t_mid_slice = find_mid_label_z(t_out['label'])
        t_label_exts = find_mid_label(t_out['image'])
        # print("mid-label:", t_label_exts, tuple(t[1] - t[0] for t in t_label_exts))
        t_slice_x = slice(*sanitized_range_from_extents(t_label_exts[0][2], t_out['image'].shape[1], 100))
        t_slice_y = slice(*sanitized_range_from_extents(t_label_exts[1][2], t_out['image'].shape[2], 100))
        # print("extents:", t_slice_x, t_slice_y)
        # print(t_first_z, t_last_z, t_mid_slice)

        l_start = time.time()
        l_out = lp(ddict)
        l_time += time.time() - l_start
        # l_first_z, l_last_z, l_mid_slice = find_mid_label_z(l_out['label'])
        l_label_exts = find_mid_label(l_out['image'])
        # print("mid-label:", l_label_exts, tuple(l[1] - l[0] for l in l_label_exts))
        l_slice_x = slice(*sanitized_range_from_extents(l_label_exts[0][2], l_out['image'].shape[1], 100))
        l_slice_y = slice(*sanitized_range_from_extents(l_label_exts[1][2], l_out['image'].shape[2], 100))
        # print("extents:", l_slice_x, l_slice_y)
        # print(l_first_z, l_last_z, l_mid_slice)
    
        # vols.extend([ddict['image'][0, ...],
        #              t_out['image'][1, ...],
        #              l_out['image'][2, ...]])

        results.extend([ddict['image'][0, ..., pre_mid_slice],
                        t_out['image'][0, ..., t_label_exts[2][2]],
                        l_out['image'][0, ..., l_label_exts[2][2]]])

    print(f"trad time: {t_time}, lazy time: {l_time}")
    plot_datas(results, 3, tight=True, size=12)
    display_images['forward_segs_whole'] = copy.deepcopy(results[0:3])

show_images(entries_[0])

In [None]:
def check_entropy_whole_pipeline(sample): # base_dir, sample_str):
    img, lbl = load_sample(sample)
    img = img[0:1, ...]
    ddict = {'image': img, 'label': lbl}

    tp = trad_pipeline()
    lp = lazy_pipeline(True)

    num_samples = 8
    vols = []
    t_times = []
    l_times = []
    for i in range(num_samples):
        t_start = time.time()
        t_out = tp(ddict)
        t_times.append(time.time() - t_start)

        l_start = time.time()
        l_out = lp(ddict)
        l_times.append(time.time() - l_start)

        vols.extend([ddict['image'][0, ...],
                     t_out['image'][0, ...],
                     l_out['image'][0, ...]])

    # print("len(vols) =", len(vols))
    e_origs = list()
    e_trads = list()
    e_lazys = list()
    for i_r in range(8):
        ovol = vols[i_r * 3]
        tvol = vols[i_r * 3 + 1]
        lvol = vols[i_r * 3 + 2]
        e_orig = entropy(ovol)
        e_trad = entropy(tvol)
        e_lazy = entropy(lvol)
        # print("ovol:", e_orig,
        #       "tvol:", e_trad,
        #       "lvol:", e_lazy)
        e_origs.append(e_orig)
        e_trads.append(e_trad)
        e_lazys.append(e_lazy)

    # print(f"trad time: {sum(t_times) / len(t_times)}, lazy time: {l_time}")
    return e_origs, e_trads, e_lazys, t_times, l_times

print(entropy(np.arange(256)))
print(entropy(np.concatenate([np.arange(128), np.arange(128)])))

all_e_origs = list()
all_e_trads = list()
all_e_lazys = list()
all_t_times = list()
all_l_times = list()
d = Dots(50)
for i in range(len(entries_)):
    d.dot()
    e_origs, e_trads, e_lazys, t_times, l_times = check_entropy_whole_pipeline(entries_[i])

    all_e_origs.append(e_origs)
    all_e_trads.append(e_trads)
    all_e_lazys.append(e_lazys)
    all_t_times.append(t_times)
    all_l_times.append(l_times)
d.done()
final_e_origs = np.asarray(all_e_origs)
final_e_trads = np.asarray(all_e_trads)
final_e_lazys = np.asarray(all_e_lazys)
final_t_times = np.asarray(all_t_times)
final_l_times = np.asarray(all_l_times)

print(final_e_origs.shape, final_e_trads.shape, final_e_lazys.shape, final_t_times.shape, final_l_times.shape)


## Result: Trad vs lazy entropy for the same 8 random patch operations across all samples

In [None]:
trad_means = np.mean(final_e_trads, axis=1)
trad_stds = np.std(final_e_trads, axis=1)
lazy_means = np.mean(final_e_lazys, axis=1)
lazy_stds = np.std(final_e_lazys, axis=1)

print(f"trad: {trad_means.mean()} +/- {trad_stds.mean()}")
print(f"lazy: {lazy_means.mean()} +/- {lazy_stds.mean()}")

trad_time = np.mean(final_t_times)
lazy_time = np.mean(final_l_times)
print(trad_time, lazy_time)

In [None]:
# entropy_by_vol =np.asarray([
#     [-0.5254445716988609, -0.4482721994159677],
#     [-0.5053248873702056, -0.4289296268729898],
#     [-0.5257666333637427, -0.44651717009131064],
#     [-0.4185541374082884, -0.36717499347126215],
#     [-0.6223610710233693, -0.5163025613180311],
#     [-0.5706122085389952, -0.48348110324083005],
#     [-0.5901356124892108, -0.4983918712111845],
#     [-0.45334079076036055, -0.38482593263728376],
# ])

# print(entropy_by_vol[:, 0])
# print(entropy_by_vol[:, 1])
# print("mean:", entropy_by_vol[:, 0].mean(), entropy_by_vol[:, 1].mean())
# print("std:", entropy_by_vol[:, 0].std(), entropy_by_vol[:, 1].std())

## Entropy - patch first

In [None]:
import time

import numpy as np
import nibabel as nib

def check_entropy_patch_first_pipeline(sample):
    # base_dir = '/home/ben/data/preprocessed/Task01_BrainTumour/orig'
    # sample_str = 'BRATS_{}_{}.nii.gz'
    # sample = '001'

    # img = nib.load(os.path.join(base_dir, sample_str.format(sample, 'image')))
    # lbl = nib.load(os.path.join(base_dir, sample_str.format(sample, 'label')))

    img = nib.load(sample[1][0])
    lbl = nib.load(sample[1][1])
    
    ddict = {'image': img.get_fdata(), 'label': lbl.get_fdata()}

    ddict['image'] = np.transpose(ddict['image'], axes=(3, 0, 1, 2))
    ddict['label'] = np.expand_dims(ddict['label'], axis=0)

    print(ddict['image'].shape, ddict['label'].shape)

    tp = trad_pipeline_patch_first()
    lp = lazy_pipeline_patch_first(True)

    pre_first_z, pre_last_z, pre_mid_slice = find_mid_label_z(ddict['image'])
    p_label_exts = find_mid_label(ddict['image'])
    # print("mid-label:", p_label_exts, tuple(p[1] - p[0] for p in p_label_exts))
    p_slice_x = slice(*sanitized_range_from_extents(p_label_exts[0][2], ddict['image'].shape[1], int(100 * 240 / 192)))
    p_slice_y = slice(*sanitized_range_from_extents(p_label_exts[1][2], ddict['image'].shape[2], int(100 * 240 / 192)))
    # print("extents:", p_slice_x, p_slice_y)

    num_samples = 8
    print(pre_first_z, pre_last_z, pre_mid_slice)
    vols = []
    results = []
    t_time = 0
    l_time = 0
    for i in range(num_samples):
        t_start = time.time()
        t_out = tp(ddict)
        t_time += time.time() - t_start
        # t_first_z, t_last_z, t_mid_slice = find_mid_label_z(t_out['label'])
        t_label_exts = find_mid_label(t_out['image'])
        # print("mid-label:", t_label_exts, tuple(t[1] - t[0] for t in t_label_exts))
        t_slice_x = slice(*sanitized_range_from_extents(t_label_exts[0][2], t_out['image'].shape[1], 100))
        t_slice_y = slice(*sanitized_range_from_extents(t_label_exts[1][2], t_out['image'].shape[2], 100))
        # print("extents:", t_slice_x, t_slice_y)
        # print(t_first_z, t_last_z, t_mid_slice)

        l_start = time.time()
        l_out = lp(ddict)
        l_time += time.time() - l_start
        # l_first_z, l_last_z, l_mid_slice = find_mid_label_z(l_out['label'])
        l_label_exts = find_mid_label(l_out['image'])
        # print("mid-label:", l_label_exts, tuple(l[1] - l[0] for l in l_label_exts))
        l_slice_x = slice(*sanitized_range_from_extents(l_label_exts[0][2], l_out['image'].shape[1], 100))
        l_slice_y = slice(*sanitized_range_from_extents(l_label_exts[1][2], l_out['image'].shape[2], 100))
        # print("extents:", l_slice_x, l_slice_y)
        # print(l_first_z, l_last_z, l_mid_slice)
    
        vols.extend([ddict['image'][0, ...],
                     t_out['image'][1, ...],
                     l_out['image'][2, ...]])

        results.extend([ddict['image'][0, ..., pre_mid_slice],
                        t_out['image'][0, ..., t_label_exts[2][2]],
                        l_out['image'][0, ..., l_label_exts[2][2]]])
        
    for i_r in range(8):
        ovol = vols[i_r * 3]
        tvol = vols[i_r * 3 + 1]
        lvol = vols[i_r * 3 + 2]
        print("ovol:", entropy(ovol),
              "tvol:", entropy(tvol),
              "lvol:", entropy(lvol))
            

    print(f"trad time: {t_time}, lazy time: {l_time}")
    plot_datas(results, 3, tight=True, size=10)
    display_images['forward_images_patch'] = copy.deepcopy(results[6:9])
    
check_entropy_patch_first_pipeline(entries_[0])

In [None]:
import time

import numpy as np
import nibabel as nib

def check_entropy_patch_first_pipeline(sample):

    img, lbl = load_sample(sample)
    img = img[0:1, ...]
    ddict = {'image': img, 'label': lbl}

    tp = trad_pipeline_patch_first()
    lp = lazy_pipeline_patch_first(True)

    num_samples = 8
    vols = []
    t_times = []
    l_times = []
    for i in range(num_samples):
        t_start = time.time()
        t_out = tp(ddict)
        t_times.append(time.time() - t_start)

        l_start = time.time()
        l_out = lp(ddict)
        l_times.append(time.time() - l_start)
    
        vols.extend([ddict['image'][0, ...],
                     t_out['image'][0, ...],
                     l_out['image'][0, ...]])

    e_origs = list()
    e_trads = list()
    e_lazys = list()
    for i_r in range(8):
        ovol = vols[i_r * 3]
        tvol = vols[i_r * 3 + 1]
        lvol = vols[i_r * 3 + 2]
        e_orig = entropy(ovol)
        e_trad = entropy(tvol)
        e_lazy = entropy(lvol)
        e_origs.append(e_orig)
        e_trads.append(e_trad)
        e_lazys.append(e_lazy)
            

    # print(f"trad time: {sum(t_times) / len(t_times)}, "
    #       f"lazy time: {sum(l_times) / len(l_times)}")
    return e_origs, e_trads, e_lazys, t_times, l_times

all_e_origs = list()
all_e_trads = list()
all_e_lazys = list()
all_t_times = list()
all_l_times = list()
d = Dots(50)
for i in range(len(entries_)):
    d.dot()
    e_origs, e_trads, e_lazys, t_times, l_times = check_entropy_patch_first_pipeline(entries_[i])

    all_e_origs.append(e_origs)
    all_e_trads.append(e_trads)
    all_e_lazys.append(e_lazys)
    all_t_times.append(t_times)
    all_l_times.append(l_times)
d.done()
final_e_origs = np.asarray(all_e_origs)
final_e_trads = np.asarray(all_e_trads)
final_e_lazys = np.asarray(all_e_lazys)
final_t_times = np.asarray(all_t_times)
final_l_times = np.asarray(all_l_times)

print(final_e_origs.shape, final_e_trads.shape, final_e_lazys.shape, final_t_times.shape, final_l_times.shape)


In [None]:
trad_means = np.mean(final_e_trads, axis=1)
trad_stds = np.std(final_e_trads, axis=1)
lazy_means = np.mean(final_e_lazys, axis=1)
lazy_stds = np.std(final_e_lazys, axis=1)

print(f"trad: {trad_means.mean()} +/- {trad_stds.mean()}")
print(f"lazy: {lazy_means.mean()} +/- {lazy_stds.mean()}")

trad_time = np.mean(final_t_times)
lazy_time = np.mean(final_l_times)
print(trad_time, lazy_time)

In [None]:
entropy_by_vol =np.asarray([
    [-0.6474312370041032, -0.5746644873304773],
    [-0.5820207582029828, -0.5154538747569736],
    [-0.6303136122149402, -0.5550618036447265],
    [-0.5225025523734912, -0.45886897465822146],
    [-0.7313230393674943, -0.6375291263882021],
    [-0.6009636467276522, -0.5386724596584971],
    [-0.658219893839086, -0.5855626119150623],
    [-0.5656265423842755, -0.481065145373282],
])

print(entropy_by_vol[:, 0])
print(entropy_by_vol[:, 1])
print("mean:", entropy_by_vol[:, 0].mean(), entropy_by_vol[:, 1].mean())
print("std:", entropy_by_vol[:, 0].std(), entropy_by_vol[:, 1].std())

## Entropy - patch last

In [None]:
import time

import numpy as np
import nibabel as nib

def check_entropy_patch_last_pipeline():
    base_dir = '/home/ben/data/preprocessed/Task01_BrainTumour/orig'
    sample_str = 'BRATS_{}_{}.nii.gz'
    sample = '001'

    img = nib.load(os.path.join(base_dir, sample_str.format(sample, 'image')))
    lbl = nib.load(os.path.join(base_dir, sample_str.format(sample, 'label')))

    ddict = {'image': img.get_fdata(), 'label': lbl.get_fdata()}

    ddict['image'] = np.transpose(ddict['image'], axes=(3, 0, 1, 2))
    ddict['label'] = np.expand_dims(ddict['label'], axis=0)

    # print(ddict['image'].shape, ddict['label'].shape)

    tp = trad_pipeline_patch_last()
    lp = lazy_pipeline_patch_last(True)

    pre_first_z, pre_last_z, pre_mid_slice = find_mid_label_z(ddict['image'])
    p_label_exts = find_mid_label(ddict['image'])
    # print("mid-label:", p_label_exts, tuple(p[1] - p[0] for p in p_label_exts))
    p_slice_x = slice(*sanitized_range_from_extents(p_label_exts[0][2], ddict['image'].shape[1], int(100 * 240 / 192)))
    p_slice_y = slice(*sanitized_range_from_extents(p_label_exts[1][2], ddict['image'].shape[2], int(100 * 240 / 192)))
    # print("extents:", p_slice_x, p_slice_y)

    num_samples = 8
    print(pre_first_z, pre_last_z, pre_mid_slice)
    vols = []
    results = []
    t_time = 0
    l_time = 0
    for i in range(num_samples):
        t_start = time.time()
        t_out = tp(ddict)
        t_time += time.time() - t_start
        # t_first_z, t_last_z, t_mid_slice = find_mid_label_z(t_out['label'])
        t_label_exts = find_mid_label(t_out['image'])
        print("t_label_exts:", t_label_exts)
        # print("mid-label:", t_label_exts, tuple(t[1] - t[0] for t in t_label_exts))
        t_slice_x = slice(*sanitized_range_from_extents(t_label_exts[0][2], t_out['image'].shape[1], 100))
        t_slice_y = slice(*sanitized_range_from_extents(t_label_exts[1][2], t_out['image'].shape[2], 100))
        print("extents:", t_slice_x, t_slice_y)
        # print(t_first_z, t_last_z, t_mid_slice)

        l_start = time.time()
        l_out = lp(ddict)
        l_time += time.time() - l_start
        # l_first_z, l_last_z, l_mid_slice = find_mid_label_z(l_out['label'])
        print("l_out:", l_out['image'].shape)
        l_label_exts = find_mid_label(l_out['image'])
        print("l_label_exts:", l_label_exts)
        # print("mid-label:", l_label_exts, tuple(l[1] - l[0] for l in l_label_exts))
        l_slice_x = slice(*sanitized_range_from_extents(l_label_exts[0][2], l_out['image'].shape[1], 100))
        l_slice_y = slice(*sanitized_range_from_extents(l_label_exts[1][2], l_out['image'].shape[2], 100))
        print("extents:", l_slice_x, l_slice_y)
        # print(l_first_z, l_last_z, l_mid_slice)
    
        vols.extend([ddict['image'][0, ...],
                     t_out['image'][1, ...],
                     l_out['image'][2, ...]])

        results.extend([ddict['image'][0, ..., p_label_exts[2][2]],
                        t_out['image'][0, ..., t_label_exts[2][2]],
                        l_out['image'][0, ..., l_label_exts[2][2]]])
        
    for i_r in range(8):
        ovol = vols[i_r * 3]
        tvol = vols[i_r * 3 + 1]
        lvol = vols[i_r * 3 + 2]
        print("ovol:", entropy(ovol),
              "tvol:", entropy(tvol),
              "lvol:", entropy(lvol))
            

    print(f"trad time: {t_time}, lazy time: {l_time}")
    plot_datas(results, 3, tight=True, size=10)
    
check_entropy_patch_last_pipeline()

In [None]:
entropy_by_vol =np.asarray([
    [-0.6474312370041032, -0.5746644873304773],
    [-0.5820207582029828, -0.5154538747569736],
    [-0.6303136122149402, -0.5550618036447265],
    [-0.5225025523734912, -0.45886897465822146],
    [-0.7313230393674943, -0.6375291263882021],
    [-0.6009636467276522, -0.5386724596584971],
    [-0.658219893839086, -0.5855626119150623],
    [-0.5656265423842755, -0.481065145373282],
])

print(entropy_by_vol[:, 0])
print(entropy_by_vol[:, 1])
print("mean:", entropy_by_vol[:, 0].mean(), entropy_by_vol[:, 1].mean())
print("std:", entropy_by_vol[:, 0].std(), entropy_by_vol[:, 1].std())

# Entropy - all numbers

In [None]:
# import time

# import numpy as np
# import nibabel as nib

# def check_entropy_whole_pipeline():
#     base_dir = '/home/ben/data/preprocessed/Task01_BrainTumour/orig'
#     sample_str = 'BRATS_{}_{}.nii.gz'
#     sample = '001'

#     img = nib.load(os.path.join(base_dir, sample_str.format(sample, 'image')))
#     lbl = nib.load(os.path.join(base_dir, sample_str.format(sample, 'label')))

#     ddict = {'image': img.get_fdata(), 'label': lbl.get_fdata()}

#     ddict['image'] = np.transpose(ddict['image'], axes=(3, 0, 1, 2))
#     ddict['label'] = np.expand_dims(ddict['label'], axis=0)

#     # print(ddict['image'].shape, ddict['label'].shape)

#     tp = trad_pipeline()
#     lp = lazy_pipeline(True)

#     num_samples = 8
#     results = []
#     for i in range(num_samples):
#         t_out = tp(ddict)
#         l_out = lp(ddict)
#         results.append([
#             entropy(ddict['image']),
#             entropy(t_out['image']),
#             entropy(l_out['image'])
#         ])
        
#     results = np.asarray(results)
#     print(results.shape)
        
# check_entropy_whole_pipeline()

## Whole pipeline

In [None]:
# import time

# import numpy as np
# import nibabel as nib

# def check_whole_pipeline_forward():
#     base_dir = '/home/ben/data/preprocessed/Task01_BrainTumour/orig'
#     sample_str = 'BRATS_{}_{}.nii.gz'
#     sample = '001'

#     img = nib.load(os.path.join(base_dir, sample_str.format(sample, 'image')))
#     lbl = nib.load(os.path.join(base_dir, sample_str.format(sample, 'label')))

#     ddict = {'image': img.get_fdata(), 'label': lbl.get_fdata()}

#     ddict['image'] = np.transpose(ddict['image'], axes=(3, 0, 1, 2))
#     ddict['label'] = np.expand_dims(ddict['label'], axis=0)

#     # print(ddict['image'].shape, ddict['label'].shape)

#     tp = trad_pipeline()
#     lp = lazy_pipeline(True)

#     pre_first_z, pre_last_z, pre_mid_slice = find_mid_label_z(ddict['label'])
#     p_label_exts = find_mid_label(ddict['label'])
#     # print("mid-label:", p_label_exts, tuple(p[1] - p[0] for p in p_label_exts))
#     p_slice_x = slice(*sanitized_range_from_extents(p_label_exts[0][2], ddict['label'].shape[1], int(100 * 240 / 192)))
#     p_slice_y = slice(*sanitized_range_from_extents(p_label_exts[1][2], ddict['label'].shape[2], int(100 * 240 / 192)))
#     # print("extents:", p_slice_x, p_slice_y)

#     print(pre_first_z, pre_last_z, pre_mid_slice)
#     results = []
#     t_time = 0
#     l_time = 0
#     for i in range(8):
#         t_start = time.time()
#         t_out = tp(ddict)
#         t_time += time.time() - t_start
#         # t_first_z, t_last_z, t_mid_slice = find_mid_label_z(t_out['label'])
#         t_label_exts = find_mid_label(t_out['label'])
#         # print("mid-label:", t_label_exts, tuple(t[1] - t[0] for t in t_label_exts))
#         t_slice_x = slice(*sanitized_range_from_extents(t_label_exts[0][2], t_out['label'].shape[1], 100))
#         t_slice_y = slice(*sanitized_range_from_extents(t_label_exts[1][2], t_out['label'].shape[2], 100))
#         # print("extents:", t_slice_x, t_slice_y)
#         # print(t_first_z, t_last_z, t_mid_slice)

#         l_start = time.time()
#         l_out = lp(ddict)
#         l_time += time.time() - l_start
#         # l_first_z, l_last_z, l_mid_slice = find_mid_label_z(l_out['label'])
#         l_label_exts = find_mid_label(l_out['label'])
#         # print("mid-label:", l_label_exts, tuple(l[1] - l[0] for l in l_label_exts))
#         l_slice_x = slice(*sanitized_range_from_extents(l_label_exts[0][2], l_out['label'].shape[1], 100))
#         l_slice_y = slice(*sanitized_range_from_extents(l_label_exts[1][2], l_out['label'].shape[2], 100))
#         # print("extents:", l_slice_x, l_slice_y)
#         # print(l_first_z, l_last_z, l_mid_slice)
    
#         results.extend([ddict['label'][0, p_slice_x, p_slice_y, pre_mid_slice],
#                         t_out['label'][0, t_slice_x, t_slice_y, t_label_exts[2][2]],
#                         l_out['label'][0, l_slice_x, l_slice_y, l_label_exts[2][2]]])

#     print(f"trad time: {t_time}, lazy time: {l_time}")
    
#     display_images['forward_pass_whole'] = copy.deepcopy(results[18:21])
#     # plot_datas(results, 6)
#     plot_datas(results, 3, tight=True)
    
# check_whole_pipeline_forward()

## Patch first

In [None]:
import time

import numpy as np
import nibabel as nib

def check_patch_first_pipeline_forward():
    base_dir = '/home/ben/data/preprocessed/Task01_BrainTumour/orig'
    sample_str = 'BRATS_{}_{}.nii.gz'
    sample = '001'
    
    iterations = 4

    img = nib.load(os.path.join(base_dir, sample_str.format(sample, 'image')))
    lbl = nib.load(os.path.join(base_dir, sample_str.format(sample, 'label')))

    ddict = {'image': img.get_fdata(), 'label': lbl.get_fdata()}

    ddict['image'] = np.transpose(ddict['image'], axes=(3, 0, 1, 2))
    ddict['label'] = np.expand_dims(ddict['label'], axis=0)

    # print(ddict['image'].shape, ddict['label'].shape)

    tp = trad_pipeline_patch_first()
    lp = lazy_pipeline_patch_first(True)

    pre_first_z, pre_last_z, pre_mid_slice = find_mid_label_z(ddict['label'])
    p_label_exts = find_mid_label(ddict['label'])
    # print("mid-label:", p_label_exts, tuple(p[1] - p[0] for p in p_label_exts))
    p_slice_x = slice(*sanitized_range_from_extents(p_label_exts[0][2], ddict['label'].shape[1], int(100 * 240 / 192)))
    p_slice_y = slice(*sanitized_range_from_extents(p_label_exts[1][2], ddict['label'].shape[2], int(100 * 240 / 192)))
    # print("extents:", p_slice_x, p_slice_y)

#     pre_first_z, pre_last_z, pre_mid_slice = find_mid_label_z(ddict['label'])
#     print(pre_first_z, pre_last_z, pre_mid_slice)
    results = []
    t_time = 0
    l_time = 0
    for i in range(iterations):
        t_start = time.time()
        t_out = tp(ddict)
        t_time += time.time() - t_start
        # t_first_z, t_last_z, t_mid_slice = find_mid_label_z(t_out['label'])
        t_label_exts = find_mid_label(t_out['label'])
        # print("mid-label:", t_label_exts, tuple(t[1] - t[0] for t in t_label_exts))
        t_slice_x = slice(*sanitized_range_from_extents(t_label_exts[0][2], t_out['label'].shape[1], 100))
        t_slice_y = slice(*sanitized_range_from_extents(t_label_exts[1][2], t_out['label'].shape[2], 100))
        # print("extents:", t_slice_x, t_slice_y)
        # print(t_first_z, t_last_z, t_mid_slice)
#         t_first_z, t_last_z, t_mid_slice = find_mid_label_z(t_out['label'])
#         print(t_first_z, t_last_z, t_mid_slice)
        l_start = time.time()
        l_out = lp(ddict)
        l_time += time.time() - l_start
        # l_first_z, l_last_z, l_mid_slice = find_mid_label_z(l_out['label'])
        l_label_exts = find_mid_label(l_out['label'])
        # print("mid-label:", l_label_exts, tuple(l[1] - l[0] for l in l_label_exts))
        l_slice_x = slice(*sanitized_range_from_extents(l_label_exts[0][2], l_out['label'].shape[1], 100))
        l_slice_y = slice(*sanitized_range_from_extents(l_label_exts[1][2], l_out['label'].shape[2], 100))
        # print("extents:", l_slice_x, l_slice_y)
        # print(l_first_z, l_last_z, l_mid_slice)
#         l_first_z, l_last_z, l_mid_slice = find_mid_label_z(l_out['label'])
#         print(l_first_z, l_last_z, l_mid_slice)

        results.extend([ddict['label'][0, p_slice_x, p_slice_y, pre_mid_slice],
                        t_out['label'][0, t_slice_x, t_slice_y, t_label_exts[2][2]],
                        l_out['label'][0, l_slice_x, l_slice_y, l_label_exts[2][2]]])
    print(f"trad time: {t_time}, lazy time: {l_time}")
    # plot_datas(results, 6)
    
    display_images['forward_pass_patch'] = copy.deepcopy(results[6:9])

    plot_datas(results, 3, tight=True)
    
check_patch_first_pipeline_forward()

## Patch last

In [None]:
import time

import numpy as np
import nibabel as nib

def check_patch_first_pipeline_forward():
    base_dir = '/home/ben/data/preprocessed/Task01_BrainTumour/orig'
    sample_str = 'BRATS_{}_{}.nii.gz'
    sample = '001'
    
    iterations = 4

    img = nib.load(os.path.join(base_dir, sample_str.format(sample, 'image')))
    lbl = nib.load(os.path.join(base_dir, sample_str.format(sample, 'label')))

    ddict = {'image': img.get_fdata(), 'label': lbl.get_fdata()}

    ddict['image'] = np.transpose(ddict['image'], axes=(3, 0, 1, 2))
    ddict['label'] = np.expand_dims(ddict['label'], axis=0)

    # print(ddict['image'].shape, ddict['label'].shape)

    tp = trad_pipeline_patch_last()
    lp = lazy_pipeline_patch_last(True)

#     pre_first_z, pre_last_z, pre_mid_slice = find_mid_label_z(ddict['label'])
#     p_label_exts = find_mid_label(ddict['label'])
#     # print("mid-label:", p_label_exts, tuple(p[1] - p[0] for p in p_label_exts))
#     p_slice_x = slice(*sanitized_range_from_extents(p_label_exts[0][2], ddict['label'].shape[1], int(100 * 240 / 192)))
#     p_slice_y = slice(*sanitized_range_from_extents(p_label_exts[1][2], ddict['label'].shape[2], int(100 * 240 / 192)))
#     # print("extents:", p_slice_x, p_slice_y)

# #     pre_first_z, pre_last_z, pre_mid_slice = find_mid_label_z(ddict['label'])
# #     print(pre_first_z, pre_last_z, pre_mid_slice)
#     results = []
#     t_time = 0
#     l_time = 0
#     for i in range(iterations):
#         t_start = time.time()
#         t_out = tp(ddict)
#         t_time += time.time() - t_start
#         # t_first_z, t_last_z, t_mid_slice = find_mid_label_z(t_out['label'])
#         t_label_exts = find_mid_label(t_out['label'])
#         print("mid-label:", t_label_exts, tuple(t[1] - t[0] for t in t_label_exts))
#         t_slice_x = slice(*sanitized_range_from_extents(t_label_exts[0][2], t_out['label'].shape[1], 100))
#         t_slice_y = slice(*sanitized_range_from_extents(t_label_exts[1][2], t_out['label'].shape[2], 100))
#         # print("extents:", t_slice_x, t_slice_y)
#         # print(t_first_z, t_last_z, t_mid_slice)
# #         t_first_z, t_last_z, t_mid_slice = find_mid_label_z(t_out['label'])
# #         print(t_first_z, t_last_z, t_mid_slice)
#         l_start = time.time()
#         l_out = lp(ddict)
#         l_time += time.time() - l_start
#         # l_first_z, l_last_z, l_mid_slice = find_mid_label_z(l_out['label'])
#         l_label_exts = find_mid_label(l_out['label'])
#         print("mid-label:", l_label_exts, tuple(l[1] - l[0] for l in l_label_exts))
#         l_slice_x = slice(*sanitized_range_from_extents(l_label_exts[0][2], l_out['label'].shape[1], 100))
#         l_slice_y = slice(*sanitized_range_from_extents(l_label_exts[1][2], l_out['label'].shape[2], 100))
#         # print("extents:", l_slice_x, l_slice_y)
#         # print(l_first_z, l_last_z, l_mid_slice)
# #         l_first_z, l_last_z, l_mid_slice = find_mid_label_z(l_out['label'])
# #         print(l_first_z, l_last_z, l_mid_slice)

#         results.extend([ddict['label'][0, p_slice_x, p_slice_y, pre_mid_slice],
#                         t_out['label'][0, t_slice_x, t_slice_y, t_label_exts[2][2]],
#                         l_out['label'][0, l_slice_x, l_slice_y, l_label_exts[2][2]]])

    pre_first_z, pre_last_z, pre_mid_slice = find_mid_label_z(ddict['label'])
    print(pre_first_z, pre_last_z, pre_mid_slice)
    results = []
    t_time = 0
    l_time = 0
    for i in range(iterations):
        t_start = time.time()
        t_out = tp(ddict)
        t_time += time.time() - t_start
        print(t_out['label'].shape)
        t_first_z, t_last_z, t_mid_slice = find_mid_label_z(t_out['label'])
        print("t:", t_first_z, t_last_z, t_mid_slice)
        l_start = time.time()
        l_out = lp(ddict)
        l_time += time.time() - l_start
        l_first_z, l_last_z, l_mid_slice = find_mid_label_z(l_out['label'])
        print("l:", l_first_z, l_last_z, l_mid_slice)

    #     results.extend([ddict['image'][0, ..., pre_mid_slice], ddict['label'][0, ..., pre_mid_slice],
    #                     t_out['image'][0, ..., post_mid_slice], t_out['label'][0, ..., post_mid_slice],
    #                     l_out['image'][0, ..., post_mid_slice], l_out['label'][0, ..., post_mid_slice]])
        results.extend([ddict['label'][0, ..., pre_mid_slice],
                        t_out['label'][0, ..., t_mid_slice],
                        l_out['label'][0, ..., l_mid_slice]])
    print(f"trad time: {t_time}, lazy time: {l_time}")
    # plot_datas(results, 6)
    plot_datas(results, 3, tight=True)
    
check_patch_first_pipeline_forward()

# Roundtrip experiment with issues - probably caused by not deep copying the tensor, although that really ought not be an issue (I think it is a MONAI issue)

In [None]:
def roundtrip_experiment():
    base_dir = '/home/ben/data/preprocessed/Task01_BrainTumour/orig'
    sample_str = 'BRATS_{}_{}.nii.gz'
    sample = '001'

    img = nib.load(os.path.join(base_dir, sample_str.format(sample, 'image')))
    lbl = nib.load(os.path.join(base_dir, sample_str.format(sample, 'label')))

    ddict = {'image': img.get_fdata(), 'label': lbl.get_fdata()}

    ddict['image'] = np.transpose(ddict['image'], axes=(3, 0, 1, 2))
    ddict['label'] = np.expand_dims(ddict['label'], axis=0)

    # print(ddict['image'].shape, ddict['label'].shape)

    tp = trad_pipeline()
    lp = lazy_pipeline(True)
    # before = []
    # after = []
    print("tp:", tp)
    print("lp:", lp)
    
    tinverter = Invert(transform=tp, nearest_interp=True, device="cpu", post_func=torch.as_tensor)
    linverter = Invert(transform=lp, nearest_interp=True, device="cpu", post_func=torch.as_tensor)
    pre_first_z, pre_last_z, pre_mid_slice = find_mid_label_z(ddict['label'])
    print(pre_first_z, pre_last_z, pre_mid_slice)
    # pre_mid_slice=77
    # post_mid_slice=36
    results = []
    for i in range(8):
        t_out = tp(ddict)
        t_inv = tinverter(t_out)
#         t_first_z, t_last_z, t_mid_slice = find_mid_label_z(t_out['label'])
#         print(t_first_z, t_last_z, t_mid_slice)
        l_out = lp(ddict)
        l_inv = linverter(l_out)
#         l_first_z, l_last_z, l_mid_slice = find_mid_label_z(l_out['label'])
#         print(l_first_z, l_last_z, l_mid_slice)

    #     results.extend([ddict['image'][0, ..., pre_mid_slice], ddict['label'][0, ..., pre_mid_slice],
    #                     t_out['image'][0, ..., post_mid_slice], t_out['label'][0, ..., post_mid_slice],
    #                     l_out['image'][0, ..., post_mid_slice], l_out['label'][0, ..., post_mid_slice]])
        results.extend([ddict['image'][0, ..., pre_mid_slice],
                        t_inv['image'][0, ..., pre_mid_slice],
                        l_inv['image'][0, ..., pre_mid_slice]])

    # plot_datas(results, 6)
    plot_datas(results, 3, tight=True)
    
roundtrip_experiment()



# RandGridDistortion

In [None]:
def do_rand_grid_distortion():
    data = get_img((64, 64))
    hival = data.max()
    data[0,7:9,7:9] = hival + 64
    data[0,15:17,:] = hival + 128
    data[0,0,:] = hival + 192
    data[0,:,0] = hival + 256
    print(data.shape)
    
    r1 = RandGridDistortion(9, 1.0, (-0.1, 0.1), padding_mode="zeros")
    result1 = r1(data)
    
    r2 = Rand2DElastic((4, 4), (-0.1, 0.1), 1.0, padding_mode="zeros")
    result2 = r2(data)
    
    plot_datas([data, result1, result2])

do_rand_grid_distortion()

# Inversion test - whole images

In [None]:
def test_invert(labels, verbose=False):
    import sys
    from copy import deepcopy
    from monai.utils import set_determinism
    # from tests.utils import assert_allclose, make_nifti_image    
    from monai.data import DataLoader, Dataset, MetaTensor, create_test_image_3d, decollate_batch
    from monai.transforms import (
        Compose,
        EnsureChannelFirst,
        Invert,
        LoadImage,
    )

    mode = 'nearest'
    set_determinism(seed=0)

    num_rows = len(labels)
    if verbose:
        print(labels)
    rows_to_display = 8
    lazy = True
    base_images = Compose(
        [
            LoadImage(image_only=True),
            EnsureChannelFirst(),
        ]
    )
    
    transform_old = trad_pipeline_label_only()
    transform_new = lazy_pipeline_label_only()
    
    # results = [None for _ in range(num_rows * 3)]
    results = [None for _ in range(rows_to_display * 3)]

    dl = DiceLoss(reduction="none")
    trad_losses = list()
    lazy_losses = list()

    for i_tx, tx in enumerate([transform_old, transform_new]):
        if i_tx == 0:
            print("trad pass")
        else:
            print("lazy pass")

        losses = trad_losses if i_tx == 0 else lazy_losses
        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2
        # base_dataset = Dataset(data, transform=base_images)
        base_dataset = Dataset(labels, transform=base_images)
        base_loader = DataLoader(base_dataset, suppress_rng=True, num_workers=num_workers, batch_size=1)

        # dataset = Dataset(data, transform=tx)
        dataset = Dataset(labels, transform=tx)
        loader = DataLoader(dataset, suppress_rng=True, num_workers=num_workers, batch_size=1)
        inverter = Invert(transform=tx, nearest_interp=True, device="cpu", post_func=torch.as_tensor)

        dots = Dots(50)
        for i_d, (d_orig, d_full) in enumerate(zip(base_loader, loader)):
            if not verbose:
                dots.dot()
            assert(len(d_orig) == 1 and len(d_full) == 1)
            # only generating one sample so just dereference it
            d_orig = decollate_batch(d_orig)[0]
            d_full = decollate_batch(d_full)[0]
            # print("shapes:", d_orig.shape, d_full.shape)
            if verbose:
                print(d_orig.shape)
            if i_tx == 0:
                if i_d < rows_to_display:
                    results[i_d * 3] = d_orig

            if verbose:
                print(f"sample {i_d}:", np.unique(d_full, return_counts=True))
                print(np.unique(d_full, return_counts=True))
            d_full_copy = deepcopy(d_full)
            d_roundtrip = inverter(d_full_copy)
            if verbose:
                print(d_orig.shape, d_roundtrip.shape)
            if i_d < rows_to_display:
                results[i_d * 3 + i_tx + 1] = d_roundtrip

            if verbose:
                print("d_orig hist:", d_orig.unique(return_counts=True))
                print("d_roundtrip hist:", d_roundtrip.unique(return_counts=True))
            d_orig_1h = torch.nn.functional.one_hot(d_orig.long(), num_classes=4).permute(0, 4, 1, 2, 3)
            d_roundtrip_1h = torch.nn.functional.one_hot(d_roundtrip.long(), num_classes=4).permute(0, 4, 1, 2, 3)
            if verbose:
                print("orig hist:", d_orig_1h.unique(return_counts=True))
                print("roundtrip_hist:", d_roundtrip_1h.unique(return_counts=True))
                print("loss arg shapes:", d_orig_1h.shape, d_roundtrip_1h.shape)
            loss = dl(d_orig_1h, d_roundtrip_1h)
            if verbose:
                print("loss shape:", loss.shape)
            loss = loss.mean(dim=(2, 3, 4)).detach().cpu().numpy()
            if verbose:
                print("loss:", loss)
            losses.append(1 - loss)

                # # old buggy score calc - loss shape is [4, 240, 1, 1]
                # d_orig_1h = torch.squeeze(torch.nn.functional.one_hot(d_orig.long()), 0).permute(3, 0, 1, 2)
                # d_roundtrip_1h = torch.squeeze(torch.nn.functional.one_hot(d_roundtrip.long()), 0).permute(3, 0, 1, 2)
                # print("loss arg shapes:", d_orig_1h.shape, d_roundtrip_1h.shape)
                # loss = dl(d_orig_1h, d_roundtrip_1h)
                # print("loss shape:", loss.shape)
                # loss = loss.mean(dim=(1, 2, 3)).detach().cpu().numpy()
                # print("loss:", loss)
                # losses.append(1 - loss)
        if not verbose:
            dots.done()

        reverted = d_roundtrip.detach().cpu().numpy().astype(np.int32)
        original = LoadImage(image_only=True)(labels[-1])
        n_good = np.sum(np.isclose(reverted, original.numpy(), atol=1e-3))
        reverted_name = d_roundtrip.meta["filename_or_obj"]
        original_name = original.meta["filename_or_obj"]
        print("invert diff", reverted.size - n_good)
        set_determinism(seed=None)

    # for tl, ll in zip(trad_losses, lazy_losses):
    #     print(tl, tl.mean(), ll, ll.mean())

    clipped_results = list()
    for r in results:
        r_label_exts = find_mid_label(r)
        # print("mid-label:", t_label_exts, tuple(t[1] - t[0] for t in t_label_exts))
        r_slice_x = slice(*sanitized_range_from_extents(r_label_exts[0][2], r.shape[1], 120))
        r_slice_y = slice(*sanitized_range_from_extents(r_label_exts[1][2], r.shape[2], 120))

        clipped_results.append(r[0, r_slice_x, r_slice_y, r_label_exts[2][2]])

    display_images['round_trip_whole'] = copy.deepcopy(clipped_results[21:24])
    plot_datas(clipped_results, 3, tight=True)

    return trad_losses, lazy_losses

print(entries_[0])
trad_rt_losses_, lazy_rt_losses_ = test_invert([e[1][1] for e in entries_])

In [None]:
# trad_rt_means = np.concatenate(trad_rt_losses_).mean(axis=1)
# lazy_rt_means = np.concatenate(lazy_rt_losses_).mean(axis=1)
trad_rt_losses_np = np.concatenate(trad_rt_losses_)
lazy_rt_losses_np = np.concatenate(lazy_rt_losses_)
trad_rt_means = trad_rt_losses_np.mean(axis=1)
lazy_rt_means = lazy_rt_losses_np.mean(axis=1)
# print(trad_rt_means)
# print(lazy_rt_means)
print(f"trad - min: {trad_rt_means.min()}, max: {trad_rt_means.max()}, mean: {trad_rt_means.mean()}, std: {trad_rt_means.std()}")
print(f"lazy - min: {lazy_rt_means.min()}, max: {lazy_rt_means.max()}, mean: {lazy_rt_means.mean()}, std: {lazy_rt_means.std()}")

for i in range(4):
    trad_rt_class = trad_rt_losses_np[..., i]
    lazy_rt_class = lazy_rt_losses_np[..., i]
    print(trad_rt_class.shape)
    print(f"trad {i} - count near zero: {np.count_nonzero(trad_rt_class < 0.01)}")
    print(f"lazy {i} - count near zero: {np.count_nonzero(lazy_rt_class < 0.01)}")
    print(f"trad {i} - min: {trad_rt_class.min()}, max: {trad_rt_class.max()}, mean: {trad_rt_class.mean()}, std: {trad_rt_class.std()}")
    print(f"lazy {i} - min: {lazy_rt_class.min()}, max: {lazy_rt_class.max()}, mean: {lazy_rt_class.mean()}, std: {lazy_rt_class.std()}")
    

In [None]:
# these are wrong, see
trad_class_losses = np.asarray([
    [0.9976, 0.8967, 0.9199, 0.9024],
    [0.9985, 0.9153, 0.8667, 0.9218],
    [0.9975, 0.8710, 0.8469, 0.9291],
    [0.9988, 0.8790, 0.9178, 0.9098],
    [0.9994, 0.9362, 0.9530, 0.9339],
    [0.9968, 0.8567, 0.8951, 0.8981],
    [0.9981, 0.8157, 0.8726, 0.8884],
    [0.9977, 0.8998, 0.8801, 0.9043]
])

lazy_class_losses = np.asarray([
    [0.9996, 0.9714, 0.9692, 0.9769],
    [0.9997, 0.9795, 0.9274, 0.9832],
    [0.9992, 0.9609, 0.9249, 0.9806],
    [0.9996, 0.9470, 0.9736, 0.9507],
    [0.9997, 0.9692, 0.9736, 0.9690],
    [0.9995, 0.9695, 0.9692, 0.9774],
    [0.9997, 0.9547, 0.9657, 0.9825],
    [0.9992, 0.9564, 0.9441, 0.9612]
])

print(trad_class_losses.shape)
print(lazy_class_losses.shape)
print(trad_class_losses.mean(axis=0), trad_class_losses.std(axis=0))
print(lazy_class_losses.mean(axis=0), trad_class_losses.std(axis=0))


## Earlier inversion test - what is it for?

In [None]:
def test_invert():
    import sys
    from copy import deepcopy
    from monai.utils import set_determinism
    # from tests.utils import assert_allclose, make_nifti_image    
    from monai.data import DataLoader, Dataset, MetaTensor, create_test_image_3d, decollate_batch
    from monai.transforms import (
        CastToType,
        Compose,
        EnsureChannelFirst,
        Invert,
        LoadImage,
        Orientation,
        RandAffine,
        RandAxisFlip,
        RandFlip,
        RandRotate,
        RandRotate90,
        RandZoom,
        ResizeWithPadOrCrop,
        Spacing,
    )

    num_rows = 4

    mode = 'nearest'
    set_determinism(seed=0)
    # im_fname = make_nifti_image(create_test_image_3d(101, 100, 107, noise_max=100)[1])  # label image, discrete
#     data = [im_fname for _ in range(12)]

    data = ['/home/ben/data/preprocessed/Task01_BrainTumour/orig/BRATS_001_label.nii.gz',
            '/home/ben/data/preprocessed/Task01_BrainTumour/orig/BRATS_002_label.nii.gz',
            '/home/ben/data/preprocessed/Task01_BrainTumour/orig/BRATS_003_label.nii.gz',
            '/home/ben/data/preprocessed/Task01_BrainTumour/orig/BRATS_004_label.nii.gz']

    print(data)
    lazy = True
    base_images = Compose(
        [
            LoadImage(image_only=True),
            EnsureChannelFirst(),
        ]
    )
    transform_old = Compose(
        [
            LoadImage(image_only=True),
            EnsureChannelFirst(),
            # Orientation("RPS"),
            old.Spacing(pixdim=(1.2, 1.01, 0.9), mode=mode, dtype=np.float32),
            old.Flip(spatial_axis=[1, 2]),
            old.Rotate90(spatial_axes=(1, 2)),
            old.Zoom(zoom=0.75, keep_size=True),
            old.Rotate(angle=(np.pi, 0, 0), mode=mode, align_corners=True, dtype=np.float64),
            # RandAffine(prob=0.5, rotate_range=np.pi, mode="nearest"),
            # ResizeWithPadOrCrop(100),
            CastToType(dtype=torch.uint8),
        ]
    )
    transform_new = Compose(
        [
            LoadImage(image_only=True),
            EnsureChannelFirst(),
            # Orientation("RPS"),
            Spacing(pixdim=(1.2, 1.01, 0.9), mode=mode, dtype=np.float32, lazy_evaluation=lazy),
            Flip(spatial_axis=[1, 2], lazy_evaluation=lazy),
            Rotate90(spatial_axes=(1, 2), lazy_evaluation=lazy),
            Zoom(zoom=0.75, keep_size=True, lazy_evaluation=lazy),
            Rotate(angle=(0, 0, np.pi), mode=mode, align_corners=True, dtype=np.float64,
                   lazy_evaluation=lazy),
            # RandAffine(prob=0.5, rotate_range=np.pi, mode="nearest"),
            # ResizeWithPadOrCrop(100),
            CastToType(dtype=torch.uint8),
        ]
    )
    # print(transform._forward_transforms)
    
    # print("loader length =", len(loader))
    # fig, ax = plt.subplots(12, 3, figsize=(12, 48))
    
    results = [None for _ in range(num_rows * 3)]

    for i_tx, tx in enumerate([transform_old, transform_new]):
        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2
        base_dataset = Dataset(data, transform=base_images)
        base_loader = DataLoader(base_dataset, num_workers=num_workers, batch_size=1)

        dataset = Dataset(data, transform=tx)
        # self.assertIsInstance(transform.inverse(dataset[0]), MetaTensor)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=1)
        inverter = Invert(transform=tx, nearest_interp=True, device="cpu", post_func=torch.as_tensor)

        for i_d, d in enumerate(base_loader):
            d = decollate_batch(d)
            for item in d:
                print(item.shape)
                if i_tx == 0:
                    # results[i_d * 3] = item[0, ..., item.shape[-1] // 2]
                    results[i_d * 3] = item

        for i_d, d in enumerate(loader):
            d = decollate_batch(d)
            for item in d:
                print(np.unique(item, return_counts=True))
                orig = deepcopy(item)
                i = inverter(item)
                print(item.shape, i.shape)
                # results[i_d * 3 + i_tx + 1] = i[0, ..., i.shape[-1] // 2]
                results[i_d * 3 + i_tx + 1] = i
        # check labels match
        reverted = i.detach().cpu().numpy().astype(np.int32)
        original = LoadImage(image_only=True)(data[-1])
        n_good = np.sum(np.isclose(reverted, original.numpy(), atol=1e-3))
        reverted_name = i.meta["filename_or_obj"]
        original_name = original.meta["filename_or_obj"]
        # self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        # self.assertTrue((reverted.size - n_good) < 300000, f"diff. {reverted.size - n_good}")
        set_determinism(seed=None)

    # print(['None' if r is None else r.shape for r in results])
    dl = DiceLoss(reduction="none")
    for r in range(num_rows):
        r0 = results[r*3]
        r0h = torch.nn.functional.one_hot(r0.long())
        r0h = torch.squeeze(r0h, 0).permute(3, 0, 1, 2)
        r1 = results[r*3+1]
        r1h = torch.nn.functional.one_hot(r1.long())
        r1h = torch.squeeze(r1h, 0).permute(3, 0, 1, 2)
        print(r0h.shape, r1h.shape)
        r2 = results[r*3+2]
        r2h = torch.nn.functional.one_hot(r2.long())
        r2h = torch.squeeze(r2h, 0).permute(3, 0, 1, 2)

        print(r0h.shape)
        dl1 = dl(r0h, r1h).mean(dim=(1,2,3))
        dl2 = dl(r0h, r2h).mean(dim=(1,2,3))
        
        print(1 - dl1, 1 - dl2)

    plot_datas([r[0, ..., find_mid_label_z(r)[2]] for r in results], 3, tight=True)
    
test_invert()

## Invert version copied here after being run in a test file for debugging - can go

In [None]:
def test_invert():
    import sys
    from copy import deepcopy
    from monai.utils import set_determinism
    from tests.utils import assert_allclose, make_nifti_image    
    from monai.data import DataLoader, Dataset, MetaTensor, create_test_image_3d, decollate_batch
    from monai.transforms import (
        CastToType,
        Compose,
        EnsureChannelFirst,
        Invert,
        LoadImage,
        Orientation,
        RandAffine,
        RandAxisFlip,
        RandFlip,
        RandRotate,
        RandRotate90,
        RandZoom,
        ResizeWithPadOrCrop,
        Spacing,
    )

    mode = 'nearest'
    set_determinism(seed=0)
    im_fname = make_nifti_image(create_test_image_3d(101, 100, 107, noise_max=100)[1])  # label image, discrete
    data = [im_fname for _ in range(12)]
    lazy = True
    base_images = Compose(
        [
            LoadImage(image_only=True),
            EnsureChannelFirst(),
        ]
    )
    transform_old = Compose(
        [
            LoadImage(image_only=True),
            EnsureChannelFirst(),
            # Orientation("RPS"),
            old.Spacing(pixdim=(1.2, 1.01, 0.9), mode=mode, dtype=np.float32),
            old.Flip(spatial_axis=[1, 2]),
            old.Rotate90(spatial_axes=(1, 2)),
            old.Zoom(zoom=0.75, keep_size=True),
            old.Rotate(angle=(np.pi, 0, 0), mode=mode, align_corners=True, dtype=np.float64),
            # RandAffine(prob=0.5, rotate_range=np.pi, mode="nearest"),
            # ResizeWithPadOrCrop(100),
            CastToType(dtype=torch.uint8),
        ]
    )
    transform_new = Compose(
        [
            LoadImage(image_only=True),
            EnsureChannelFirst(),
            # Orientation("RPS"),
            Spacing(pixdim=(1.2, 1.01, 0.9), mode=mode, dtype=np.float32, lazy_evaluation=lazy),
            Flip(spatial_axis=[1, 2], lazy_evaluation=lazy),
            Rotate90(spatial_axes=(1, 2), lazy_evaluation=lazy),
            Zoom(zoom=0.75, keep_size=True, lazy_evaluation=lazy),
            Rotate(angle=(0, 0, np.pi), mode=mode, align_corners=True, dtype=np.float64,
                   lazy_evaluation=lazy),
            # RandAffine(prob=0.5, rotate_range=np.pi, mode="nearest"),
            # ResizeWithPadOrCrop(100),
            CastToType(dtype=torch.uint8),
        ]
    )
#     print(transform._forward_transforms)
    
    # print("loader length =", len(loader))
    fig, ax = plt.subplots(12, 3, figsize=(12, 48))

    for i_tx, tx in enumerate([transform_old, transform_new]):
        # num workers = 0 for mac or gpu transforms
        num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2
        base_dataset = Dataset(data, transform=base_images)
        base_loader = DataLoader(base_dataset, num_workers=num_workers, batch_size=1)

        dataset = Dataset(data, transform=tx)
        # self.assertIsInstance(transform.inverse(dataset[0]), MetaTensor)
        loader = DataLoader(dataset, num_workers=num_workers, batch_size=1)
        inverter = Invert(transform=tx, nearest_interp=True, device="cpu", post_func=torch.as_tensor)

        for i_d, d in enumerate(base_loader):
            d = decollate_batch(d)
            for item in d:
                print(item.shape)
                if i_tx == 0:
                    ax[i_d, 0].imshow(item[0, ..., item.shape[-1] // 2])

        for i_d, d in enumerate(loader):
            d = decollate_batch(d)
            for item in d:
                orig = deepcopy(item)
                i = inverter(item)
                print(item.shape, i.shape)
                # self.assertTupleEqual(orig.shape[1:], (100, 100, 100))
                # check the nearest interpolation mode
                # assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float))
                # self.assertTupleEqual(i.shape[1:], (101, 100, 107))
                # print(i.shape)
                ax[i_d, i_tx + 1].imshow(i[0, ..., i.shape[-1] // 2])
        # check labels match
        reverted = i.detach().cpu().numpy().astype(np.int32)
        original = LoadImage(image_only=True)(data[-1])
        n_good = np.sum(np.isclose(reverted, original.numpy(), atol=1e-3))
        reverted_name = i.meta["filename_or_obj"]
        original_name = original.meta["filename_or_obj"]
        # self.assertEqual(reverted_name, original_name)
        print("invert diff", reverted.size - n_good)
        # self.assertTrue((reverted.size - n_good) < 300000, f"diff. {reverted.size - n_good}")
        set_determinism(seed=None)
    
test_invert()


# Interpolation / resampling artifacts

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import monai

def generate_interpolation_artifacts():

    t1 = torch.tensor([[[1.0, 0.0],[0.0, 1.0]]], dtype=torch.double)
    t1 = t1.tile(1, 64, 64)
    t2 = np.zeros((1, 4, 4))
    t2[0, 0:2, 0:2] = 1.0
    t2[0, 2:4, 2:4] = 1.0
    t2 = torch.tensor(t2, dtype=torch.double)
    t2 = t2.tile((1, 32, 32))
    t4 = np.zeros((1, 8, 8))
    t4[0, 0:4, 0:4] = 1.0
    t4[0, 4:8, 4:8] = 1.0
    t4 = torch.tensor(t4, dtype=torch.double)
    t4 = t4.tile(1, 16, 16)

    # starting_images = (t1,)
    starting_images = (t1, t2, t4)

    fig, ax = plt.subplots(len(starting_images), 4, figsize=(32, 4 * len(starting_images)))
    if len(starting_images) == 1:
        fig = [fig]
        ax = [ax]
    base_bin_vals = np.linspace(0.0, 1.0, 11)
    bin_vals = torch.tensor(base_bin_vals, dtype=torch.double)
    label_enums = [f"{base_bin_vals[i]:1.1f}-{base_bin_vals[i+1]:1.1f}" for i in range(0, 10)]
    print(label_enums)
    for j_t, t in enumerate(starting_images):
    
        scale = monai.transforms.Zoom(1.2, lazy=False)
        inv_scale = monai.transforms.Zoom(1/1.2, lazy=False)

        ts = list()
        ts.append(t)
        t = scale(t)
        ts.append(t)
        t = inv_scale(t)
        ts.append(t)
    
        for i_t, t in enumerate(ts):
            # tc = t[0:1, 48:80, 32:96]
            tc = t[0:1, 56:72, 48:80]
            # tc = t[0:1]
            ax[j_t][i_t].imshow(tc[0], vmin=0.0, vmax=1.0)
            ax[j_t][i_t].get_xaxis().set_visible(False)
            ax[j_t][i_t].get_yaxis().set_visible(False)
            # print(tc.dtype)
            # print(i_t, tc.min(), tc.max(), torch.histogram(tc.to(dtype=torch.double), bin_vals))

        hbefore = torch.histogram(ts[0][0, 56:72, 48:80].to(dtype=torch.double), bin_vals)
        hafter = torch.histogram(ts[-1][0, 56:72, 48:80].to(dtype=torch.double), bin_vals)
        ax[j_t][3].bar(bin_vals[:-1]-0.015, hbefore[0], width=0.06)
        ax[j_t][3].bar(bin_vals[:-1]+0.015, hafter[0], width=0.06)
        ax[j_t][3].set_xticks(bin_vals[:-1], label_enums, minor=False)
        # ax[j_t][3].plot(bin_vals[:-1], hbefore[0])
        # ax[j_t][3].plot(bin_vals[:-1], hafter[0])
    plt.tight_layout()
    plt.plot()
    plt.savefig("moire_artifacts.svg", bbox_inches="tight")

generate_interpolation_artifacts()

In [None]:
def generate_interpolation_artifacts_minimal():
    t1 = torch.tensor([[[1.0, 0.0],[0.0, 1.0]]], dtype=torch.double)
    t1 = t1.tile(1, 64, 64)
    t2 = np.zeros((1, 4, 4))
    t2[0, 0:2, 0:2] = 1.0
    t2[0, 2:4, 2:4] = 1.0
    t2 = torch.tensor(t2, dtype=torch.double)
    t2 = t2.tile((1, 32, 32))
    t4 = np.zeros((1, 8, 8))
    t4[0, 0:4, 0:4] = 1.0
    t4[0, 4:8, 4:8] = 1.0
    t4 = torch.tensor(t4, dtype=torch.double)
    t4 = t4.tile(1, 16, 16)

    # starting_images = (t1,)
    starting_images = (t4,)

    fig, ax = plt.subplots(1, 3, figsize=(32, 8))
    if len(starting_images) == 1:
        fig = [fig]
        ax = [ax]
    base_bin_vals = np.linspace(0.0, 1.0, 11)
    bin_vals = torch.tensor(base_bin_vals, dtype=torch.double)
    label_enums = [f"{base_bin_vals[i]:1.1f}-{base_bin_vals[i+1]:1.1f}" for i in range(0, 10)]
    print(label_enums)
    for j_t, t in enumerate(starting_images):
    
        scale = monai.transforms.Zoom(1.2, lazy=False)
        inv_scale = monai.transforms.Zoom(1/1.2, lazy=False)

        ts = list()
        ts.append(t)
        t = scale(t)
        # ts.append(t)
        t = inv_scale(t)
        ts.append(t)
    
        for i_t, t in enumerate(ts):
            # tc = t[0:1, 48:80, 32:96]
            tc = t[0:1, 48:80, 48:80]
            # tc = t[0:1]
            ax[j_t][i_t].imshow(tc[0], vmin=0.0, vmax=1.0)
            ax[j_t][i_t].get_xaxis().set_visible(False)
            ax[j_t][i_t].get_yaxis().set_visible(False)
            # print(tc.dtype)
            # print(i_t, tc.min(), tc.max(), torch.histogram(tc.to(dtype=torch.double), bin_vals))

        hbefore = torch.histogram(ts[0][0, 48:80, 48:80].to(dtype=torch.double), bin_vals)
        hafter = torch.histogram(ts[-1][0, 48:80, 48:80].to(dtype=torch.double), bin_vals)
        ax[j_t][2].bar(bin_vals[:-1]-0.015, hbefore[0], width=0.06)
        ax[j_t][2].bar(bin_vals[:-1]+0.015, hafter[0], width=0.06)
        ax[j_t][2].set_ylabel("pixel count", fontsize=20)
        ax[j_t][2].set_yticks((0, 64, 128, 192, 256, 320, 384, 448, 512), (0, 64, 128, 192, 256, 320, 384, 448, 512), fontsize=20)
        ax[j_t][2].set_xlabel("pixel values", fontsize=20)
        ax[j_t][2].set_xticks(bin_vals[:-1], label_enums, minor=False, fontsize=20, rotation=45)
        # ax[j_t][3].plot(bin_vals[:-1], hbefore[0])
        # ax[j_t][3].plot(bin_vals[:-1], hafter[0])
    plt.tight_layout()
    plt.plot()
    plt.savefig("moire_artifacts_minimal.svg", bbox_inches="tight")

generate_interpolation_artifacts_minimal()

In [None]:
def generate_interpolation_artifacts_image():

    t4 = np.zeros((1, 16, 16))
    t4[0, 0:8, 0:8] = 1.0
    t4[0, 8:16, 8:16] = 1.0
    t4 = torch.tensor(t4, dtype=torch.double)
    t4 = t4.tile(1, 8, 8)

    scale = monai.transforms.Zoom(1.2, lazy=False)
    inv_scale = monai.transforms.Zoom(1/1.2, lazy=False)

    base_bin_vals = np.linspace(0.0, 1.0, 11)
    bin_vals = torch.tensor(base_bin_vals, dtype=torch.double)
    label_enums = [f"{base_bin_vals[i]:1.1f}-{base_bin_vals[i+1]:1.1f}" for i in range(0, 10)]

    hgrams = list()
    hgrams.append(torch.histogram(t4[0, 48:80, 48:80].to(dtype=torch.double), bin_vals))
    iterations = 2
    to_show = (0, iterations // 2, iterations)
    for _ in range(iterations):
        t4 = scale(t4)
        t4 = inv_scale(t4)
        hgrams.append(torch.histogram(t4[0, 48:80, 48:80].to(dtype=torch.double), bin_vals))


    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    # ax.imshow(t4[0])
    ax.bar(bin_vals[:-1]-0.01, hgrams[to_show[0]][0], width=0.04, label="original")
    ax.bar(bin_vals[:-1], hgrams[to_show[1]][0], width=0.04, label=f"{to_show[1]} resamples" if to_show[1] != 1 else f"{to_show[1]} resample")
    ax.bar(bin_vals[:-1]+0.01, hgrams[to_show[2]][0], width=0.04, label=f"{to_show[2]} resamples")
    ax.set_ylabel("pixel count", fontsize=20)
    ax.set_yticks((0, 64, 128, 192, 256, 320, 384, 448, 512), (0, 64, 128, 192, 256, 320, 384, 448, 512), fontsize=20)
    ax.set_xlabel("pixel values", fontsize=20)
    ax.set_xticks(bin_vals[:-1], label_enums, minor=False, fontsize=20, rotation=45)
    ax.legend(loc=9, fontsize=20)
    plt.savefig("hgram_for_grid.svg", bbox_inches='tight', pad_inches=0)

generate_interpolation_artifacts_image()

In [None]:
def generate_interpolation_artifacts_brain_rotation_hgram():

    with torch.no_grad():
        # t4 = np.zeros((1, 16, 16))
        # t4[0, 0:8, 0:8] = 1.0
        # t4[0, 8:16, 8:16] = 1.0
        # t4 = torch.tensor(t4, dtype=torch.double)
        # t4 = t4.tile(1, 8, 8)
        t4 = torch.tensor(display_images["forward_images_patch"][0])
        t4 = t4[None, ...]
        print(t4.shape)
        
        scale = monai.transforms.Zoom(1.2, lazy=False)
        inv_scale = monai.transforms.Zoom(1/1.2, lazy=False)
        
        base_bin_vals = np.linspace(0.0, 1024.0, 65)
        bin_vals = torch.tensor(base_bin_vals, dtype=torch.double)
        print(bin_vals)
        # label_enums = [f"{base_bin_vals[i]:1.1f}-{base_bin_vals[i+1]:1.1f}" for i in range(0, len(base_bin_vals)-1)]
        label_enums = [f"{int(base_bin_vals[i])}-{int(base_bin_vals[i+1])}" for i in range(0, len(base_bin_vals)-1)]
        
        hgrams = list()
        # hgrams.append(torch.histogram(t4[0, 48:80, 48:80].to(dtype=torch.double), bin_vals))
        hgrams.append(torch.histogram(t4[0].to(dtype=torch.double), bin_vals))
        print(hgrams)
        iterations = 2
        # legends = 
        for _ in range(iterations):
            t4 = scale(t4)
            t4 = inv_scale(t4)
            # hgrams.append(torch.histogram(t4[0, 48:80, 48:80].to(dtype=torch.double), bin_vals))
            hgrams.append(torch.histogram(t4[0].to(dtype=torch.double), bin_vals))
        
        
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))
        # ax.imshow(t4[0])
        x_vals = bin_vals / 1024
        # ax.bar(x_vals[:-1]-0.01, hgrams[0][0], width=0.03)
        # ax.bar(x_vals[:-1], hgrams[2][0], width=0.03)
        # ax.bar(x_vals[:-1]+0.01, hgrams[4][0], width=0.03)
        to_show = (0, iterations // 2, iterations)
        ax.plot(bin_vals[1:-1], hgrams[to_show[0]][0][1:], label="original")
        ax.plot(bin_vals[1:-1], hgrams[to_show[1]][0][1:], label=f"{to_show[1]} resamples")
        ax.plot(bin_vals[1:-1], hgrams[to_show[2]][0][1:], label=f"{to_show[2]} resamples")
        ax.set_ylabel("pixel count", fontsize=20)
        pixel_count_values = [i * 256 for i in range(9)]
        ax.set_yticks(pixel_count_values, pixel_count_values, fontsize=20)
        # ax.set_yticks((0, 64, 128), (0, 64, 128), fontsize=20)
        ax.set_xlabel("pixel values", fontsize=20)
        # ax.set_xticks(x_vals[:-1], label_enums, minor=False, fontsize=20, rotation=45)
        tick_vals = np.linspace(0, 1024, 9)
        tick_labels = tick_vals
        ax.set_xticks(tick_vals, tick_labels, minor=False, fontsize=20, rotation=45)
        ax.legend()

generate_interpolation_artifacts_brain_rotation_hgram()

In [None]:
def generate_interpolation_artifacts_brain_rotation_images():

    with torch.no_grad():
        t4 = torch.tensor(display_images["forward_images_patch"][0])
        t4 = t4[None, ...]
        print(t4.shape)
        
        scale = monai.transforms.Zoom(1.2, lazy=False)
        inv_scale = monai.transforms.Zoom(1/1.2, lazy=False)

        iterations = 8
        to_show = (0, 2, iterations)
        results = [t4]
        # fig, ax = plt.subplots(1, 3, figsize=(8, 6))
        for _ in range(iterations):
            t4 = scale(t4)
            t4 = inv_scale(t4)
            results.append(t4)
        xslice = slice(20, 200)
        yslice = slice(30, 210)
        # ax[0].imshow(results[0][0, yslice, xslice], label="original")
        
        # ax[1].imshow(results[to_show[1]][0, yslice, xslice], label=f"{to_show[1]} resamples")
        # ax[2].imshow(results[to_show[2]][0, yslice, xslice], label=f"{to_show[2]} resamples")

        _index = 2
        fig, ax = plt.subplots(1, 1, figsize=(8,8))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.imshow(results[to_show[_index]][0, yslice, xslice])
        plt.savefig(f"image_deg_{_index}.svg", bbox_inches='tight', pad_inches=0)

generate_interpolation_artifacts_brain_rotation_images()

# Figure generation for paper

In [None]:
# keys = display_images.keys()
keys = ("forward_images_patch", "round_trip_whole")
titles = {
    "forward_images_patch": ("Original image", "Patch-first: traditional", "Patch-first: lazy"),
    "round_trip_whole": ("Original labels", "Round trip labels: traditional", "Round trip labels: lazy")
}
# for k in keys:
#     print(k)
#     v = display_images[k]
#     print(v[0].shape)
#     plot_datas(v, cols=3, tight=True, axis=False, titles=titles[k])
values = display_images["forward_images_patch"] + display_images["round_trip_whole"]
titles = titles["forward_images_patch"] + titles["round_trip_whole"]
plot_datas(values, cols=3, tight=True, axis=False, titles=titles, font='timesnewroman')
plt.savefig("images_trad_vs_lazy.svg", bbox_inches='tight')

In [None]:
# _name = "forward_images_patch"
_name = "round_trip_whole"
_index = 2
fig, ax = plt.subplots(1, 1, figsize=(8,8))
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.imshow(display_images[_name][_index])
plt.savefig(f"img_{_name}_{_index}.svg", bbox_inches='tight', pad_inches=0)

 # Old

# 3D rotate

In [None]:
data = get_img((32, 32, 8))
dmax = data.max()
data[0,7:9,7:9, :] = dmax + 32
data[0,15:17,:, :] = dmax + 64
data[0,0,:, :] = dmax + 96
data[0,:,0, :] = dmax + 128
print(data.shape)
r = Rotate((0, 0, torch.pi / 4), keep_size=False, padding_mode="zeros")

data1 = r(data)

plt.imshow(data1[0,...,4], vmin=data.min(), vmax=data.max())

In [None]:
data_path = '/home/ben/data/Task07_Pancreas'

In [None]:
entries = []
for r, d, f in os.walk(os.path.join(data_path, 'imagesTr')):
    for fn in f:
        entries.append(os.path.join(r, fn))
entries = sorted(entries)
for e in entries:
    print(e)

In [None]:
i = LoadImage('nibabelreader')
c = AddChannel()
r = Rotate90(k=1, spatial_axes=(0, 2))

data = os.path.join(data_path, 'imagesTr', 'pancreas_001.nii.gz')
data = i(data)[0]
data = c(data)
md1 = dict(data.meta)
data = r(data)
md2 = dict(data.meta)

keys = sorted(set(md1.keys()).union(md2.keys()))
print(keys)
for k in keys:
    v1 = md1.get(k, None)
    v2 = md2.get(k, None)
    if isinstance(v1, np.ndarray):
        if np.isnan(v1).all() and np.isnan(v2).all():
            equiv = True
        else:
            equiv = np.allclose(v1, v2)
    elif isinstance(v1, torch.Tensor):
        equiv = torch.allclose(v1, v2)
    else:
        equiv = v1 == v2

    print(k, equiv)
    if not equiv:
        print("v1:", v1)
        print("v2:", v2)

In [None]:
a = 5
b = None
print(b or a)

In [None]:
rtn = 0
rtn = torch.pi / 4
rsin = math.sin(rtn)
rcos = math.cos(rtn)

mt = np.asarray([[1.0, 0.0, -8.0],
                 [0.0, 1.0, -8.0],
                 [0.0, 0.0, 1.0]])

mid = np.asarray([[1.0, 0.0, 0.0],
                 [0.0, 1.0, 0.0],
                 [0.0, 0.0, 1.0]])

mrpre = np.asarray([[1.0, 0.0, -16.0],
                    [0.0, 1.0, -16.0],
                    [0.0, 0.0, 1.0]])

mr = np.asarray([[rcos, -rsin, 0.0],
                [rsin, rcos, 0.0],
                [0.0, 0.0, 1.0]])

mrpost = np.asarray([[1.0, 0.0, 16.0],
                    [0.0, 1.0, 16.0],
                    [0.0, 0.0, 1.0]])

m1 = mt
m2 = mr @ m1

print("m1\n", m1)
print("m2\n", m2)

vs = [np.asarray([8.0, 8.0, 1.0]), np.asarray([0.0, 0.0, 1.0]), np.asarray([0.0, 32.0, 1.0]),
      np.asarray([32.0, 0.0, 1.0]), np.asarray([32.0, 32.0, 1.0])]

cs0 = ['#007f00', '#7f0000', '#7f0000', '#7f0000', '#7f0000']
cs1 = ['#3faf3f', '#af3f3f', '#af3f3f', '#af3f3f', '#af3f3f']
cs2 = ['#7fdf7f', '#df7f7f', '#df7f7f', '#df7f7f', '#df7f7f']

fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.set_xlim(-50, 50)
ax.set_ylim(-50, 50)
for c, v in zip(cs0, vs):
    mv = mid @ v
    ax.scatter(mv[0], mv[1], color=c)
    
for c, v in zip(cs1, vs):
    mv = m1 @ v
    print(v, mv)
    plt.scatter(mv[0], mv[1], color=c)

for c, v in zip(cs2, vs):
    mv = m2 @ v
    print(v, mv)
    plt.scatter(mv[0], mv[1], color=c)


In [None]:
import torch

tc = torch.nn.ConvTranspose3d(32,
                              32,
                              kernel_size=3,
                              stride=(2, 2, 1),
                              padding=(1, 1, 1),
                              output_padding=(1, 1, 0)
                             )

t = torch.Tensor(np.zeros((1, 32, 64, 64, 16)))
print(t.shape)

t_ = tc.forward(t)
print(t_.shape)

In [None]:
# Checking entropy

import scipy
#i = np.random.uniform(size=(64,64))
i = np.zeros((64, 64))
print(entries_[0])
i = nib.load(entries_[0][1][0]).get_fdata()[:, :, 80, 0]
# i[24:32, 40:56] = 1
j = scipy.ndimage.gaussian_filter(i, sigma=1)

fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].imshow(i)
ax[1].imshow(j)
print(entropy(i))
print(entropy(j))

In [None]:
import matplotlib.font_manager as fm
fns = fm.findSystemFonts()
for fn in fns:
    if 'Times' in fn:
        print(fn)