In [1]:
import SimpleITK as sitk

import numpy as np

import nibabel as nib

import math

In [2]:
mha_path = "ABO_SHx1.mha"
transform_template = "Seq_Frame{:04d}_ImageToReferenceTransform"

In [3]:
def create_linear_transform(x_translation, y_translation, z_translation, x_rotation, y_rotation, z_rotation):
    """Creates a linear transformation from the given parameters.

    Args:
        x_translation (float): The x translation.
        y_translation (float): The y translation.
        z_translation (float): The z translation.
        x_rotation (float): The x rotation, in degrees.
        y_rotation (float): The y rotation, in degrees.
        z_rotation (float): The z rotation, in degrees.

    Returns:
        np.ndarray: 4x4 standard homogenous transformation matrix.
    """

    # Convert degrees to radians.
    x_rotation = math.radians(x_rotation)
    y_rotation = math.radians(y_rotation)
    z_rotation = math.radians(z_rotation)

    # Create the rotation matrix.
    rotation_matrix_x = np.array([
        [1, 0, 0, 0],
        [0, math.cos(x_rotation), -math.sin(x_rotation), 0],
        [0, math.sin(x_rotation), math.cos(x_rotation), 0],
        [0, 0, 0, 1]
    ])
    rotation_matrix_y = np.array([
        [math.cos(y_rotation), 0, math.sin(y_rotation), 0],
        [0, 1, 0, 0],
        [-math.sin(y_rotation), 0, math.cos(y_rotation), 0],
        [0, 0, 0, 1]
    ])
    rotation_matrix_z = np.array([
        [math.cos(z_rotation), -math.sin(z_rotation), 0, 0],
        [math.sin(z_rotation), math.cos(z_rotation), 0, 0],
        [0, 0, 1, 0],
        [0, 0, 0, 1]
    ])
    rotation_matrix = rotation_matrix_x @ rotation_matrix_y @ rotation_matrix_z

    # Create the translation matrix.
    translation_matrix = np.array([
        [1, 0, 0, x_translation],
        [0, 1, 0, y_translation],
        [0, 0, 1, z_translation],
        [0, 0, 0, 1]
    ])

    # Combine the rotation and translation matrices.
    return translation_matrix @ rotation_matrix


In [4]:
# Load the MHA file
image = sitk.ReadImage(mha_path)

In [5]:
# Get image array
image_array = sitk.GetArrayFromImage(image)

print(image_array.shape)

(1636, 544, 227)


In [6]:
image_spacing = image.GetSpacing()
print(image_spacing)
print(image.GetDepth())

(0.17250000000000001, 0.17278000000000002, 1.0)
1636


In [9]:
image.GetMetaDataKeys()

('ITK_InputFilterName',
 'ITK_original_direction',
 'ITK_original_spacing',
 'Kinds',
 'Seq_Frame0000_ImageStatus',
 'Seq_Frame0000_ImageToReferenceTransform',
 'Seq_Frame0000_ImageToReferenceTransformStatus',
 'Seq_Frame0000_Timestamp',
 'Seq_Frame0001_ImageStatus',
 'Seq_Frame0001_ImageToReferenceTransform',
 'Seq_Frame0001_ImageToReferenceTransformStatus',
 'Seq_Frame0001_Timestamp',
 'Seq_Frame0002_ImageStatus',
 'Seq_Frame0002_ImageToReferenceTransform',
 'Seq_Frame0002_ImageToReferenceTransformStatus',
 'Seq_Frame0002_Timestamp',
 'Seq_Frame0003_ImageStatus',
 'Seq_Frame0003_ImageToReferenceTransform',
 'Seq_Frame0003_ImageToReferenceTransformStatus',
 'Seq_Frame0003_Timestamp',
 'Seq_Frame0004_ImageStatus',
 'Seq_Frame0004_ImageToReferenceTransform',
 'Seq_Frame0004_ImageToReferenceTransformStatus',
 'Seq_Frame0004_Timestamp',
 'Seq_Frame0005_ImageStatus',
 'Seq_Frame0005_ImageToReferenceTransform',
 'Seq_Frame0005_ImageToReferenceTransformStatus',
 'Seq_Frame0005_Timestamp',
 '

In [6]:
transform_lines = [ image.GetMetaData(transform_template.format(i)) for i in range(image_array.shape[0]) ]

In [7]:
transform_matrix = np.stack( [ np.array(list(map(float, line.split()))).reshape(4, 4) for line in transform_lines ], axis=0 )

In [8]:
transform = create_linear_transform(0, 0, 0, 0, 30, 180)

In [9]:
transform_matrix = (transform @ transform_matrix.transpose(0, 2, 1)).transpose(0, 2, 1)

In [10]:
transposed_image_array = image_array.transpose(2, 1, 0)

nii_image = nib.Nifti1Image(transposed_image_array, np.eye(4))
# Save as test.nii
nib.save(nii_image, "test.nii")

with open("test.csv", "w") as f:
    for i in range(transform_matrix.shape[0]):

        m = transform_matrix[i].T
        line = ",".join(map(str, m.flatten()))
        f.write(line + "\n")

print(transform_matrix.shape)

(1636, 4, 4)
