In [None]:
import torchio as tio
import monai
from monai.transforms import (
    AddChanneld,
    LoadImage,
    LoadImaged,
    Orientationd,
    Rand3DElasticd,
    RandAffined,
    Spacingd,
    ToTensord,
    RandAffine
)
import torchvision as tv
import os
import numpy as np
import custom_models
import torch as t
from copy import deepcopy
import loss_module
import time
import matplotlib.pyplot as plt
from SVR_Preprocessor import Preprocesser
import SimpleITK as sitk

In [None]:
pixdim = (2.0,2.0,2.0)

src_folder = "sample_data"
prep_folder = "cropped_images"
src_folder = "sample_data"
result_folder = os.path.join("results","two_stacks_first_outlier_removal")

stack_filenames = ["10_3T_nody_001.nii.gz",
                
                "14_3T_nody_001.nii.gz",
                
                "21_3T_nody_001.nii.gz",
                
                "23_3T_nody_001.nii.gz"]
device = "cpu"
mask_filename = "mask_10_3T_brain_smooth.nii.gz"
mode =  "bilinear"
tio_mode = "welch"

svr_preprocessor = Preprocesser(src_folder, prep_folder, result_folder, stack_filenames, mask_filename, device, mode, tio_mode=tio_mode)
        
fixed_images, stacks = svr_preprocessor.preprocess_stacks_and_common_vol(pixdim)

#tio_image = svr_preprocessor.monai_to_torchio(stacks[0])

In [None]:
stacks = svr_preprocessor.load_stacks()
tio_image = tio.Image("cropped_images/10_3T_nody_001.nii.gz")
monai_image = stacks[0]

In [None]:
model = custom_models.Volume_to_Slice(5,'cpu')

transaltions = [[0,0,0],[1,1,0],[2,1,2]]
rotations = [[t.pi/8, 0 ,0],[0,t.pi/2,0],[0,-t.pi/4,t.pi/8]]

affines = list()

for i in range(0,3):
    rotation_tensor = monai.transforms.utils.create_rotate(3, rotations[i],backend='torch')
    translation_tensor = monai.transforms.utils.create_translate(3, transaltions[i],backend='torch')
    affines.append (t.matmul(translation_tensor,rotation_tensor))

In [None]:
tuple(np.array(rotations[0])*2*np.pi)

In [None]:
rot_deg = np.array(rotations[0]) * 180 / np.pi
tio_rotations = (rot_deg[0], rot_deg[1], rot_deg[2])

tio_aff = tio.Affine(1,tio_rotations,tuple(np.array(transaltions[0])),center="image",default_pad_value=0.0)
monai_aff = monai.networks.layers.AffineTransform(mode = "bilinear",  normalized = True, align_corners= False, padding_mode = "zeros")
#monai_Rot = monai.transforms(Rotat)

In [None]:
t.tensor(affines[0]).unsqueeze(0).shape

In [None]:
tio_trans = tio_aff(tio_image)
monai_trans = monai_image
monai_trans["image"] = monai_aff(monai_image["image"].unsqueeze(0),t.tensor(affines[0]).unsqueeze(0))

monai_trans["image"] = monai_trans["image"].squeeze().unsqueeze(0)

In [None]:
tio_trans.save('tests/tio_transformed.nii.gz')

In [None]:
 nifti_saver = monai.data.NiftiSaver(output_dir="tests",
                                            resample=False, padding_mode="zeros",
                                            separate_folder=False)
nifti_saver.save(monai_trans["image"],monai_trans["image_meta_dict"])

In [None]:
monai_trans["image"].shape

In [None]:
colin = tio.datasets.Colin27().t1
colin.plot()

In [None]:
monai_image["image_meta_dict"]["affine"]

In [None]:
look_up = tio.Affine(1, (45, 0, 0), (0,0,10),image_interpolation="welch")
colin_looking_up = look_up(colin)
colin_looking_up.plot()


In [None]:
colin

In [None]:
sitk_img = tio_image.as_sitk()

In [None]:
affine_matr = t.eye(4)
sitk_transformed = sitk_affine_transform(tio_image,affine_matr)

In [None]:
rotation_center = (100, 100, 100)
axis = (0,0,1)
angle = np.pi/2.0
translation = (1,2,3)
scale_factor = 2.0
similarity = sitk.Similarity3DTransform(scale_factor, axis, angle, translation, rotation_center)

affine = sitk.AffineTransform(3)
affine.SetMatrix(similarity.GetMatrix())
affine.SetTranslation(similarity.GetTranslation())
affine.SetCenter(similarity.GetCenter())

# Apply the transformations to the same set of random points and compare the results.

In [None]:
def sitk_affine_transform(tio_image:tio.Image, affine_matr:t.tensor)->t.tensor:
    sitk_image = tio_image.as_sitk()

    rotation = affine_matr[:3,:3].ravel().tolist()
    translation = affine_matr[:3,3].tolist()
    affine = sitk.AffineTransform(rotation,translation)

    reference_image = sitk_image
    interpolator = sitk.sitkWelchWindowedSinc
    default_value = 0

    resampled =  sitk.Resample(sitk_image,reference_image,affine,interpolator,default_value)

    tensor = t.permute(t.tensor(sitk.GetArrayFromImage(resampled)),(2,1,0))

    tensor = tensor.unsqueeze(0)
    return tensor