In [1]:
import numpy as np
import nibabel as nib
import SimpleITK as sitk

In [4]:
def apply_transform_to_label(image_path, transform_path, output_path, label_value=2):
    """
    Apply a transform to a specific label within a NIfTI image.

    Args:
        image_path (str): Path to the input NIfTI image (.nii.gz file).
        transform_path (str): Path to the transform file (.tfm).
        output_path (str): Path to save the transformed NIfTI image.
        label_value (int) - Default = 2 (heart): The label value to which the transform will be applied.

    Returns:
        None
    """

    image = sitk.ReadImage(image_path)
    transform = sitk.ReadTransform(transform_path)

    image_array = sitk.GetArrayFromImage(image)

    label_mask = (image_array == label_value).astype(np.uint8)
    label_image = sitk.GetImageFromArray(label_mask)
    label_image.CopyInformation(image)

    resampled_label = sitk.Resample(
        label_image,
        image,  # Reference image defines the output space
        transform,
        sitk.sitkNearestNeighbor,
        0,  # Default pixel value for areas outside the original image
        label_image.GetPixelID()
    )

    resampled_label_array = sitk.GetArrayFromImage(resampled_label)

    # Remove original label from image
    image_array_without_label = np.where(image_array == label_value, 0, image_array)

    # Add the transformed label back into the image array
    final_image_array = np.where(resampled_label_array > 0, label_value, image_array_without_label)

    final_image = sitk.GetImageFromArray(final_image_array)
    final_image.CopyInformation(image)

    sitk.WriteImage(final_image, output_path)

    print(f"Transformed image saved to {output_path}")

# For LINUX:
path ='ai4mi_project/data/segthor_train/train/'

# For WSL:
# path = '/home/{user}/ai4mi_project/data/segthor_train/train/'
transform_path = 'data/segthor_train/train/Transform_fix.tfm'
    
for i in range(1,41):
    image_path  = path + 'Patient_' + str(i).zfill(2) + '/GT.nii.gz'
    output_path = path + 'Patient_' + str(i).zfill(2) + '/GT_fixed.nii.gz'    

    apply_transform_to_label(image_path, transform_path, output_path)

Transformed image saved to /home/lucasp/ai4mi/ai4mi_project/data/segthor_train/train/Patient_01/GT_fixed.nii.gz
Transformed image saved to /home/lucasp/ai4mi/ai4mi_project/data/segthor_train/train/Patient_02/GT_fixed.nii.gz
Transformed image saved to /home/lucasp/ai4mi/ai4mi_project/data/segthor_train/train/Patient_03/GT_fixed.nii.gz
Transformed image saved to /home/lucasp/ai4mi/ai4mi_project/data/segthor_train/train/Patient_04/GT_fixed.nii.gz
Transformed image saved to /home/lucasp/ai4mi/ai4mi_project/data/segthor_train/train/Patient_05/GT_fixed.nii.gz
Transformed image saved to /home/lucasp/ai4mi/ai4mi_project/data/segthor_train/train/Patient_06/GT_fixed.nii.gz
Transformed image saved to /home/lucasp/ai4mi/ai4mi_project/data/segthor_train/train/Patient_07/GT_fixed.nii.gz
Transformed image saved to /home/lucasp/ai4mi/ai4mi_project/data/segthor_train/train/Patient_08/GT_fixed.nii.gz
Transformed image saved to /home/lucasp/ai4mi/ai4mi_project/data/segthor_train/train/Patient_09/GT_fixed

In [5]:
# To check how good transformation on Patient_27 is:
def calculate_iou(file1_path, file2_path, label_value=2):
    """
    Calculate Intersection over Union (IoU) between two NIfTI files for a specific label.
    
    Args:
        file1_path (str): Path to the first NIfTI file.
        file2_path (str): Path to the second NIfTI file.
        label_value (int) - Default = 2 (heart): The label value of the region for which IoU should be calculated.
        
    Returns:
        float: The IoU score between the two volumes.
    """
    data1 = nib.load(file1_path).get_fdata()
    data2 = nib.load(file2_path).get_fdata()
    
    mask1 = (data1 == label_value).astype(np.uint8)
    mask2 = (data2 == label_value).astype(np.uint8)
    
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()
    
    if union == 0:
        return 0.0
    iou = intersection / union
    return iou

iou_27 = calculate_iou("data/segthor_train/train/Patient_27/GT_fixed.nii.gz", "data/segthor_train/train/Patient_27/GT2.nii.gz")
print("IoU of heart in Patient_27: ", iou_27)

IoU of heart in Patient_27:  0.9902309695521799
