In [1]:
import os
from glob import glob

import torch
from natsort import natsorted
import nibabel as nib
import pandas as pd
import numpy as np
from monai.transforms import AffineGrid
from math import pi
import matplotlib.pyplot as plt

import sys
sys.path.append('/home/aaron/Dropbox/KCL/Tools/CustomScripts')
from calculate_dice import get_dice

from labelmergeandsplit.labelreg_utils import get_coms, get_covs, affine_transform_image, optimize_affine_labelreg, \
    affine_transform_influence_regions
from labelmergeandsplit.merging_utils import get_merged_label_dataframe, merge_label_volumes, map_labels_in_volume
from labelmergeandsplit.splitting_utils import get_fuzzy_prior_fudged, split_merged_labels_paths, split_merged_label, \
    get_influence_regions
from utils.plot_matrix_slices import plot_matrix_slices

from config.config import PROJ_ROOT
from skimage.measure import label

In [2]:
# set the random seed
np.random.seed(0)


def get_random_affine(img_shape):
    shape = img_shape

    # example affine matrix
    # shift to image center
    R_origin_to_center = np.array([
        [1, 0, 0, -shape[0] / 2],
        [0, 1, 0, -shape[1] / 2],
        [0, 0, 1, -shape[2] / 2],
        [0, 0, 0, 1],
    ])

    affineGrid = AffineGrid(
        rotate_params=[pi / 3, pi / 5, pi / 2],  # Sequence[float] | float | None = None,
        shear_params=[0.1] * 6,  # Sequence[float] | float | None = None,
        translate_params=[-0, -0, 0],  # Sequence[float] | float | None = None,
        scale_params=None,  # Sequence[float] | float | None = None,
        device='cpu',  # np.device | None = None,
        dtype=np.float32,  # DtypeLike = np.float32,
        align_corners=False,  # bool = False,
        affine=None,  # NdarrayOrTensor | None = None,
        lazy=False,  # bool = False,
    )
    _, R_rot = affineGrid(spatial_size=(64, 64, 64))
    R_rot = R_rot.cpu().numpy()

    # shift back to original position
    R_origin_to_corner = np.array([
        [1, 0, 0, shape[0] / 2],
        [0, 1, 0, shape[1] / 2],
        [0, 0, 1, shape[2] / 2],
        [0, 0, 0, 1],
    ])

    R = R_origin_to_corner @ R_rot @ R_origin_to_center

    return R


def get_network_output_from_registered_image(label_path_unmerged, pad):
    # load an original label image
    label_nii_unmerged = nib.load(label_path_unmerged)
    label_data_unmerged = label_nii_unmerged.get_fdata()

    # pad the label image to avoid cropping of foreground after transformation
    label_data_unmerged = np.pad(label_data_unmerged, pad, mode='constant', )

    # merge the labels
    label_data_merged = map_labels_in_volume(label_data_unmerged, label_to_merged_label_mapping)

    # get an example affine that represents the ground truth transformation (to be estimated)
    R_grtr = get_random_affine(label_data_unmerged.shape)
    # R_grtr = np.diag([1, 1, 1, 1])

    # transform the label image
    transformed_label_unmerged = torch.tensor(
        affine_transform_image(torch.tensor(label_data_unmerged), torch.tensor(np.linalg.inv(R_grtr))))

    # put on GPU
    transformed_label_unmerged = transformed_label_unmerged.to("cuda")

    # merge the labels
    transformed_label_merged = map_labels_in_volume(transformed_label_unmerged, label_to_merged_label_mapping)

    return transformed_label_merged, transformed_label_unmerged, label_data_merged, R_grtr


def plot_results(image_title_pairs):
    # plot the images
    fig, axes = plt.subplots(1, len(image_title_pairs), figsize=(3 * len(image_title_pairs), 6))

    for i, (image, title) in enumerate(image_title_pairs):
        image = image.cpu().numpy() if isinstance(image, torch.Tensor) else image
        axes[i].imshow(image[:, :, image.shape[2] // 2])
        axes[i].set_title(title)
    plt.show()

In [3]:
# label support is needed to get the influence regions
label_support_path = '/mnt/dgx-server/projects2023/dynunet_pipeline_label_merging_from_label_support/data/tasks/task3061_mindaomic/models/merged_model/fold0/label_merging/label_support.pt.npz'
# reference label image
label_path_ref = '/mnt/dgx-server/projects2023/dynunet_pipeline_label_merging_from_label_support/data/tasks/task3061_mindaomic/input/dataset/labelsTr/mindaomic_0002.nii.gz'
#label_path_ref = '/mnt/dgx-server/projects2023/dynunet_pipeline_label_merging_from_label_support/data/tasks/task3061_mindaomic/results/inference/merged_model/fold0/mindaomic_0005.nii.gz'
# ground truth label image
#label_path_unmerged = '/mnt/dgx-server/projects2023/dynunet_pipeline_label_merging_from_label_support/data/tasks/task3061_mindaomic/input/dataset/labelsTr/mindaomic_0002.nii.gz'
label_path_unmerged = '/mnt/dgx-server/projects2023/dynunet_pipeline_label_merging_from_label_support/data/tasks/task3061_mindaomic/results/inference/merged_model/fold0/mindaomic_0005.nii.gz'
# read label to merged label mapping
merged_labels_csv_path = '/mnt/dgx-server/projects2023/dynunet_pipeline_label_merging_from_label_support/data/tasks/task3061_mindaomic/models/merged_model/fold0/label_merging/merged_labels.csv'


# # label support is needed to get the influence regions
# label_support_path = os.path.join(PROJ_ROOT, 'data', 'task2153_mind', 'output', 'label_support.pt.npz')
# # reference label image
# label_path_ref = os.path.join(PROJ_ROOT, 'data', 'task2153_mind', 'input', 'dataset', 'labelsTr', 'mind_000.nii.gz')
# # ground truth label image
# label_path_unmerged = os.path.join(PROJ_ROOT, 'data', 'task2153_mind', 'input', 'dataset', 'labelsTs', 'mind_038.nii.gz')
# label_path_unmerged = os.path.join(PROJ_ROOT, 'data', 'task2153_mind', 'input', 'dataset', 'labelsTr', 'mind_002.nii.gz')
# # read label to merged label mapping
# merged_labels_csv_path = os.path.join(PROJ_ROOT, 'data', 'task2153_mind', 'output', 'merged_labels.csv')


channel_to_label_mapping = pd.read_csv(merged_labels_csv_path, index_col='channel').to_dict()['merged_label']
merged_labels_df = pd.read_csv(merged_labels_csv_path, index_col='label')
label_to_merged_label_mapping = merged_labels_df['merged_label'].to_dict()


pad = 10

# perturb the ground truth label image with an affine transformation
transformed_label_merged, split_label_grtr, label_data_merged_mni, R = (
    get_network_output_from_registered_image(label_path_unmerged, pad))

  label_data = torch.tensor(label_data, device="cuda")


In [5]:
transformed_label_merged.shape

torch.Size([213, 249, 213])

In [14]:
label(transformed_label_merged.cpu().numpy(), return_num=True, connectivity=3, background=0)

(array([[[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        ...,
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],