In [None]:
import torch as t
import monai
from monai.transforms import (
    AddChanneld,
    LoadImaged,
    ToTensord,
)
import torchio as tio
import torchmetrics as tm

In [None]:
tio_cyle = tio.ScalarImage("CycleGAN/2_CycleGAN_i.nii.gz")
tio_svrtk = tio.ScalarImage("SVRTK/2_label.nii.gz")

In [None]:
resampler = tio.Resample(tio_cyle,image_interpolation="welch")
tio_svrtk_resampled = resampler(tio_svrtk)

In [None]:
print(f'Cycle shape:{tio_cyle.data.shape} Cycle affine:{tio_cyle.affine}')
print(f'SVRTK shape:{tio_svrtk_resampled.data.shape} SVRTK affine: {tio_svrtk_resampled.affine}')
print(f'Affines match: {t.all(t.tensor(tio_cyle.affine == tio_svrtk_resampled.affine))}')

In [None]:
def normalize_to_unit_interval(tio_image:tio.ScalarImage):
    min_val, max_val = t.min(tio_image.data), t.max(tio_image.data)
    range_val = max_val - min_val
    norm_data = t.div((tio_image.data - min_val), range_val)
    tio_image.set_data(norm_data)
    return tio_image

In [None]:
tio_cyle, tio_svrtk_resampled = normalize_to_unit_interval(tio_cyle), normalize_to_unit_interval(tio_svrtk_resampled)

In [None]:
tio_cyle.save("CycleGAN.nii.gz")
tio_svrtk_resampled.save("SVRTK_resampled.nii.gz")

In [None]:
def load_monai(file):
    loader = LoadImaged(keys=["image"])
    stack_dict = {"image": file}
    stack_dict = loader(stack_dict)
    return stack_dict

In [None]:
monai_cycle = load_monai("CycleGAN.nii.gz")
monai_svrtk = load_monai("SVRTK_resampled.nii.gz")

In [None]:
tm_ssim = tm.StructuralSimilarityIndexMeasure(kernel_size=99, reduction='sum')
monai_psnr = monai.metrics.PSNRMetric(1.0)

In [None]:
cycle_tens,svrtk_tens = tio_cyle.data.float(),tio_svrtk_resampled.data.float()
cycle_tens.shape


In [None]:
ssim = tm_ssim(svrtk_tens, cycle_tens)

In [None]:
monai_psnr(cycle_tens,svrtk_tens)

In [None]:
ssim