In [None]:
import glob
import numpy as np
import os
import SimpleITK as sitk
import torch

from ct import Ct

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
data_path = '../data/manifest-1668678461097/NSCLC Radiogenomics/AMC-027/04-28-1994-NA-VascularGATEDCHESTCTA Adult-45663/'

cts = Ct.load_all_series(data_path)

In [None]:
fixed_image = cts[2].img.numpy()
moving_image = cts[3].img.numpy()

fixed_image.shape, moving_image.shape

In [None]:
# Require SimpleElastix that throws error compiling on Win10
# elastixImageFilter = sitk.ElastixImageFilter()
# elastixImageFilter.SetFixedImage(fixed_image)
# elastixImageFilter.SetMovingImage(moving_image)
# elastixImageFilter.SetParameterMap(sitk.GetDefaultParameterMap("affine"))
# elastixImageFilter.Execute()
# resultImage = elastixImageFilter.GetResultImage()
# sitk.WriteImage(resultImage, "result_image")

In [None]:
# download elastix binaries from https://github.com/SuperElastix/elastix/releases/tag/5.0.1
import pyelastix

os.environ['ELASTIX_PATH'] = 'C:\\Users\\rabdo\\git\\thesis\\3rd-party\\elastix-5.0.1-win64'

params = pyelastix.get_default_params(type='AFFINE')
#params = pyelastix.get_default_params()
params.Metric = 'NormalizedMutualInformation'
params.MaximumNumberOfIterations = 50
#params.Transform = 'AffineTransform'
moving_deformed, field = pyelastix.register(fixed_image, moving_image, params)

In [None]:
from matplotlib import pyplot as plt
idx = 250

fig = plt.figure(figsize=(30, 90))
subplot = fig.add_subplot(1, 3, 1)
subplot.set_title('fixed')
plt.imshow(fixed_image[idx], cmap='gray')

subplot = fig.add_subplot(1, 3, 2)
subplot.set_title('moving')
plt.imshow(moving_image[idx], cmap='gray')

subplot = fig.add_subplot(1, 3, 3)
subplot.set_title('transformed')
plt.imshow(moving_deformed[idx], cmap='gray')

In [None]:
from sklearn.metrics.cluster import normalized_mutual_info_score

normalized_mutual_info_score(fixed_image.ravel(), moving_image.ravel())

In [None]:
normalized_mutual_info_score(fixed_image.ravel(), moving_deformed.ravel())

In [None]:
normalized_mutual_info_score(fixed_image.ravel(), fixed_image.ravel())

In [None]:
# torch_similarity from https://github.com/yuta-hi/pytorch_similarity
from torch_similarity.modules import NormalizedCrossCorrelation

ncc = NormalizedCrossCorrelation()
print(ncc(cts[2].img, cts[3].img))
print(ncc(cts[2].img, torch.from_numpy(moving_deformed.copy())))