# Fixing the ground truth heart data

In [197]:
from pathlib import Path
import nibabel as nib
import numpy as np
from numpy import pi
from scipy import ndimage
import matplotlib.pyplot as plt
import torch
from torchvision import transforms

In [198]:
gt_path1 = Path("../data/segthor_train/train/Patient_27/GT.nii.gz")
gt_path2 = Path("../data/segthor_train/train/Patient_27/GT2.nii.gz")
img_path = Path("../data/segthor_train/train/Patient_27/Patient_27.nii.gz")

In [199]:
gt1 = nib.load(gt_path1).get_fdata()
gt2 = nib.load(gt_path2).get_fdata()
img = nib.load(img_path).get_fdata()

In [None]:
gt1.shape, gt2.shape, img.shape

## Identify issue

In [204]:
def show_slices(indices, img, gt1, gt2, z_translate=0):
    fig, axes = plt.subplots(len(indices), 2, figsize=(4, 2 * len(indices)))
    plt.subplots_adjust(wspace=0.1, hspace=0.1)

    # Define colors for each class (excluding background)
    colors = ['red', 'green', 'blue', 'yellow']

    for i, idx in enumerate(indices):
        axes[i, 0].imshow(img[..., idx], cmap='gray')
        axes[i, 1].imshow(img[..., idx], cmap='gray')

        # Create RGBA arrays for gt1 and gt2
        gt1_colored = np.zeros((*gt1[..., idx].shape, 4))
        gt2_colored = np.zeros((*gt2[..., idx].shape, 4))

        for class_idx, color in enumerate(colors, start=1):
            mask1 = gt1[..., idx+z_translate] == class_idx
            mask2 = gt2[..., idx] == class_idx
            
            rgba_color = plt.cm.colors.to_rgba(color, alpha=0.5)
            gt1_colored[mask1] = rgba_color
            gt2_colored[mask2] = rgba_color

        axes[i, 0].imshow(gt1_colored)
        axes[i, 1].imshow(gt2_colored)

        for ax in axes[i, :]:
            ax.axis('off')

    # Create legend
    legend_elements = [
        plt.Rectangle((0, 0), 1, 1, facecolor='white', edgecolor='black'),
        plt.Rectangle((0, 0), 1, 1, facecolor='red', alpha=0.5),
        plt.Rectangle((0, 0), 1, 1, facecolor='green', alpha=0.5),
        plt.Rectangle((0, 0), 1, 1, facecolor='blue', alpha=0.5),
        plt.Rectangle((0, 0), 1, 1, facecolor='yellow', alpha=0.5)
    ]
    fig.legend(
        legend_elements, 
        ["Background", "Esophagus", "Heart", "Trachea", "Aorta"], 
        loc="lower center", 
        ncol=5
    )
    plt.tight_layout()
    plt.show()

In [None]:
show_slices([*range(50, 80)], img, gt1, gt2)

## Method 1: Without using the provided matrix but a clever trick

We use patient 27 to align the images extracted from the nifti files.

In a 3D visualization tool we saw that the rotation only happened in x and y direction and the only transformation on the z axis is translation. After listing the images from both nifti files side by side, we could handpick the z-translation, which could easily be seen to be 15. After this the problem was as simple as finding the 2D affine matrix.

Not shown here but we used a mse loss to determine the exact parameter values for transforms.functional.affine (we eyeballed approximate paramters first). When comparing to the actual affine matrix we see that our rotation is off by 2 degrees (we have 25 degrees instead of the actual 27), which can be attributed to the various interpolations that were caused by applying and reversing the affine transformation.

In [None]:
show_slices([*range(50, 80)], img, gt1, gt2, z_translate=-15)

In [None]:
# This is found with counting the unmatched pixels + grid search
# Note that we apply the transformation after resizing so it is a bit more comlicated
from skimage.transform import resize
from functools import partial
from typing import Callable

def show_slices(indices, img, gt1, gt2):
    fig, axes = plt.subplots(len(indices), 2, figsize=(10, 5 * len(indices)))
    plt.subplots_adjust(wspace=0.05, hspace=0.1)

    # Define colors for each class (excluding background)
    colors = ['red', 'green', 'blue', 'yellow']

    for i, idx in enumerate(indices):
        # Preprocess the slices
        img_slice = resize_(img[..., idx], (256, 256))
        gt1_slice = resize_(gt1[..., idx], (256, 256), order=0).astype(np.uint8)
        gt2_slice = resize_(gt2[..., idx], (256, 256), order=0).astype(np.uint8)

        gt_heart_slice = resize_(gt1[:, :, idx-15], (256, 256), order=0).astype(np.uint8)
        rotate = np.zeros_like(gt1_slice)
        rotate[gt_heart_slice == 2] = 2
        gt1_slice[gt1_slice == 2] = 0
        rotated = transforms.functional.affine(torch.from_numpy(rotate[None, ...]), angle=25, translate=(7, 45), scale=1.0, shear=0).squeeze()
        mask = rotated > 0
        gt1_slice[mask] = 0
        gt1_slice = gt1_slice + rotated.numpy()

        axes[i, 0].imshow(img_slice.astype(np.float32), cmap='gray')
        axes[i, 1].imshow(img_slice.astype(np.float32), cmap='gray')

        # Create RGBA arrays for gt1 and gt2
        gt1_colored = np.zeros((*gt1_slice.shape, 4))
        gt2_colored = np.zeros((*gt2_slice.shape, 4))

        for class_idx, color in enumerate(colors, start=1):
            mask1 = gt1_slice == class_idx
            mask2 = gt2_slice == class_idx
            
            rgba_color = plt.cm.colors.to_rgba(color, alpha=0.5)
            gt1_colored[mask1] = rgba_color
            gt2_colored[mask2] = rgba_color

        axes[i, 0].imshow(gt1_colored)
        axes[i, 1].imshow(gt2_colored)

        for ax in axes[i, :]:
            ax.axis('off')
        
        # Add slice index to each subplot in the row
        axes[i, 0].set_title(f'Modified GT - Slice {idx}', fontsize=10)
        axes[i, 1].set_title(f'Original GT - Slice {idx}', fontsize=10)

    # Create legend
    legend_elements = [
        plt.Rectangle((0, 0), 1, 1, facecolor='white', edgecolor='black'),
        plt.Rectangle((0, 0), 1, 1, facecolor='red', alpha=0.5),
        plt.Rectangle((0, 0), 1, 1, facecolor='green', alpha=0.5),
        plt.Rectangle((0, 0), 1, 1, facecolor='blue', alpha=0.5),
        plt.Rectangle((0, 0), 1, 1, facecolor='yellow', alpha=0.5)
    ]
    fig.legend(
        legend_elements, 
        ["Background", "Esophagus", "Heart", "Trachea", "Aorta"], 
        loc="lower center", 
        ncol=5
    )
    
    # Add main title for each column
    fig.text(0.25, 0.98, 'Modified Ground Truth', ha='center', va='center', fontsize=12, fontweight='bold')
    fig.text(0.75, 0.98, 'Original Ground Truth', ha='center', va='center', fontsize=12, fontweight='bold')

    plt.tight_layout()
    plt.subplots_adjust(top=0.95, bottom=0.05)  # Adjust top and bottom margins
    plt.show()

# Define the resize function
resize_: Callable = partial(resize, mode="constant", preserve_range=True, anti_aliasing=False)

# Define the indices
idxs = list(range(50, 80))

# Call the function
show_slices(idxs, img, gt1, gt2)

## Method 2: with the affine matrix

The announcement said that we still need to show that we are able to work with given affine matrix and nifti files, so here it is:

In [179]:
TR = np.asarray([[1, 0, 0, 50],
                 [0,  1, 0, 40],  # noqa: E241
                 [0,             0,      1, 15],  # noqa: E241
                 [0,             0,      0, 1]])  # noqa: E241

DEG: int = 27
phi: float = - DEG / 180 * pi
RO = np.asarray([[np.cos(phi), -np.sin(phi), 0, 0],  # noqa: E241, E201
                 [np.sin(phi),  np.cos(phi), 0, 0],  # noqa: E241
                 [     0,         0,     1, 0],  # noqa: E241, E201
                 [     0,         0,     0, 1]])  # noqa: E241, E201

X_bar: float = 275
Y_bar: float = 200
Z_bar: float = 0
C1 = np.asarray([[1, 0, 0, X_bar],
                 [0, 1, 0, Y_bar],
                 [0, 0, 1, Z_bar],
                 [0, 0, 0,    1]])  # noqa: E241
C2 = np.linalg.inv(C1)

AFF = C1 @ RO @ C2 @ TR
INV = np.linalg.inv(AFF)

In [180]:
def apply_ground_truth_transform(INV, gt1):
    # Create a copy of gt1 to avoid modifying the original
    result = gt1.copy()
    
    # Extract the 3x3 rotation matrix and the translation vector
    rotation_matrix = INV[:3, :3]
    translation = INV[:3, 3]

    # Create a mask for pixels with value 2
    mask = (gt1 == 2).astype(float)

    # Apply the affine transformation to the mask
    transformed_mask = ndimage.affine_transform(
        mask, 
        rotation_matrix, 
        offset=translation, 
        order=1
    )

    # Set pixels to 0 where gt1 was 2
    result[gt1 == 2] = 0

    # Set pixels to 2 where the transformed mask is greater than a threshold
    result[transformed_mask > 0.2] = 2

    return result

## Visualize

In [181]:
def show_slices(indices, img, gt1, gt2):
    fig, axes = plt.subplots(len(indices), 4, figsize=(20, 5 * len(indices)))
    plt.subplots_adjust(wspace=0.1, hspace=0.1)

    # Define colors for each class (excluding background)
    colors = ['red', 'green', 'blue', 'yellow']

    tgt1 = apply_ground_truth_transform(INV, gt1)

    for i, idx in enumerate(indices):
        axes[i, 0].imshow(img[..., idx], cmap='gray')
        axes[i, 1].imshow(img[..., idx], cmap='gray')
        axes[i, 2].imshow(img[..., idx], cmap='gray')
        axes[i, 3].imshow(img[..., idx], cmap='gray')


        # Create RGBA arrays for gt1 and gt2
        gt1_colored = np.zeros((*gt1[..., idx].shape, 4))
        gt2_colored = np.zeros((*gt2[..., idx].shape, 4))
        tgt1_colored = np.zeros((*tgt1[..., idx].shape, 4))

        for class_idx, color in enumerate(colors, start=1):
            mask1 = gt1[..., idx] == class_idx
            mask2 = gt2[..., idx] == class_idx
            mask3 = tgt1[..., idx] == class_idx
            
            rgba_color = plt.cm.colors.to_rgba(color, alpha=0.5)
            gt1_colored[mask1] = rgba_color
            gt2_colored[mask2] = rgba_color
            tgt1_colored[mask3] = rgba_color

        axes[i, 1].imshow(gt1_colored)
        axes[i, 2].imshow(gt2_colored)
        axes[i, 3].imshow(tgt1_colored)

        for ax in axes[i, :]:
            ax.axis('off')

    # Create legend
    legend_elements = [
        plt.Rectangle((0, 0), 1, 1, facecolor='white', edgecolor='black'),
        plt.Rectangle((0, 0), 1, 1, facecolor='red', alpha=0.5),
        plt.Rectangle((0, 0), 1, 1, facecolor='green', alpha=0.5),
        plt.Rectangle((0, 0), 1, 1, facecolor='blue', alpha=0.5),
        plt.Rectangle((0, 0), 1, 1, facecolor='yellow', alpha=0.5)
    ]
    fig.legend(
        legend_elements, 
        ["Background", "Esophagus", "Heart", "Trachea", "Aorta"], 
        loc="lower center", 
        ncol=5
    )
    plt.tight_layout()
    plt.show()

In [None]:
indices = [77,78,79,80,81,82,83,84,85,86]
show_slices(indices, img, gt1, gt2)