# Evaluation at the end

In [None]:
import SimpleITK as sitk

In [None]:
def compute_mutual_information(fixed_image, moving_image):
    # Setup a mutual information metric
    metric = sitk.MutualInformationImageToImageMetricv4()
    
    # The metric requires a mask or sample points, for simplicity we create an arbitrary region mask
    fixed_mask = sitk.Image(fixed_image.GetSize(), sitk.sitkUInt8)
    fixed_mask.CopyInformation(fixed_image)
    fixed_mask = sitk.Cast(fixed_image > fixed_image.GetPixelIDValueMax()*0.1, sitk.sitkUInt8)

    metric.SetFixedImage(fixed_image)
    metric.SetFixedImageRegion(fixed_image.GetBufferedRegion())
    metric.SetFixedImageMask(fixed_mask)
    metric.SetMovingImage(moving_image)
    
    # Use the JointHistogram as a faster way to compute MI without fully setting up registration framework
    metric.Initialize(sitk.ImageToImageMetricv4.JointHistogram)

    # Return MI value
    return metric.GetValue()

# Load your fixed and registered images
fixed_image = sitk.ReadImage("/path_to_fixed_image.nii")
registered_moving_image = sitk.ReadImage("/path_where_you_saved_registered_image.nii")

mi_value = compute_mutual_information(fixed_image, registered_moving_image)
print(f"Mutual Information: {mi_value}")


In [None]:
import nibabel as nib
import torch
from torchir.transformers import BsplineTransformer

fixed_ct_image = nib.load("/90days/s4692034/RBWH_data/NIFTI_CT/PET/resampled/resampled_1_PET_normalized.nii").get_fdata()
moving_ct_image = nib.load("/90days/s4692034/RBWH_data/NIFTI_CT/SPECT/resampled/resampled_1_SPECT_normalized.nii").get_fdata()

fixed_ct_image = np.asarray(fixed_ct_image, dtype=np.float32)
moving_ct_image = np.asarray(moving_ct_image, dtype=np.float32)

fixed_ct_image = np.expand_dims(fixed_ct_image, axis=0)
moving_ct_image = np.expand_dims(moving_ct_image, axis=0)

model = model.cpu()
model.eval()

# Load Images using SimpleITK
fixed_ct_image = sitk.ReadImage("/90days/s4692034/RBWH_data/NIFTI_CT/PET/resampled/resampled_1_PET_normalized.nii")
moving_ct_image = sitk.ReadImage("/90days/s4692034/RBWH_data/NIFTI_CT/SPECT/resampled/resampled_1_SPECT_normalized.nii")

# Convert SimpleITK images to PyTorch tensors for processing
fixed_ct_tensor = torch.tensor(sitk.GetArrayFromImage(fixed_ct_image), dtype=torch.float32).unsqueeze(0).unsqueeze(0)
moving_ct_tensor = torch.tensor(sitk.GetArrayFromImage(moving_ct_image), dtype=torch.float32).unsqueeze(0).unsqueeze(0)

# Predict the DVF using the trained model
with torch.no_grad():
    predicted_dvf = model(fixed_ct_tensor, moving_ct_tensor)

# Assuming the BsplineTransformer class has a method to apply the transform:
bspline_transformer = BsplineTransformer(ndim=3, upsampling_factors=(8, 8, 8))  # Assuming 3D images and some upsampling factors

# Apply the DVF to the moving SPECT image
moving_spect_image = sitk.ReadImage("/90days/s4692034/RBWH_data/NIFTI_SPECT/resampled/resampled_1_PET_normalized.nii")
registered_spect_image = bspline_transformer.apply_transform(predicted_dvf, fixed_ct_image, moving_spect_image)

# Save the registered image
sitk.WriteImage(registered_spect_image, "path_where_you_want_to_save_registered_image.nii")

# Predict the DVF using the trained model
with torch.no_grad():
    bspline_coefficients = model(fixed_ct_tensor, moving_ct_tensor)

# Convert SimpleITK image to PyTorch tensor for processing
moving_spect_tensor = torch.tensor(sitk.GetArrayFromImage(moving_spect_image), dtype=torch.float32).unsqueeze(0).unsqueeze(0)

# Apply the DVF to the moving SPECT image using the apply_transform method
registered_spect_tensor = bspline_transformer.apply_transform(bspline_coefficients, fixed_ct_tensor, moving_spect_tensor)

# Convert the registered PyTorch tensor back to a SimpleITK image
registered_spect_array = registered_spect_tensor.squeeze(0).squeeze(0).numpy()
registered_spect_image = sitk.GetImageFromArray(registered_spect_array)
registered_spect_image.SetSpacing(moving_spect_image.GetSpacing())
registered_spect_image.SetOrigin(moving_spect_image.GetOrigin())
registered_spect_image.SetDirection(moving_spect_image.GetDirection())

# Save the registered image
sitk.WriteImage(registered_spect_image, "path_where_you_want_to_save_registered_image.nii")

In [None]:
def compute_mutual_information(fixed_image, moving_image):
    fixed_image = sitk.GetImageFromArray(fixed_image.cpu().numpy())
    moving_image = sitk.GetImageFromArray(moving_image.cpu().numpy())
    
    registration_method = sitk.ImageRegistrationMethod()
    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)
    
    return -registration_method.MetricEvaluate(fixed_image, moving_image)