In [None]:
import json
import math
import os
from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union

import itk
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import nibabel as nib
import numpy as np
import pandas as pd
from scipy.ndimage import distance_transform_edt
import wandb

%matplotlib inline

api = wandb.Api()

In [None]:
def load_table(run_id: str, entity_name: str, project_name: str, table_name: str, iterations: Union[int, List[int]]) -> pd.DataFrame:
    """
    Loads specified table artifacts of a given Weights and Biases run.

    Args:
        run_id (string): ID of the Weights and Biases run.
        entity_name (string): Name of the Weights and Biases entity to which the run belongs.
        project_name (string): Name of the Weights and Biases project to which the run belongs.
        table_name (string): Name of the tabel artifact to be loaded.
        iterations (Union[int, List[int]]): Active learning iterations for which the tables are to be loaded.

    Returns:
        pandas.DataFrame: Concatenated tables for the specified active learning iterations.
    """

    if isinstance(iterations, int):
        iterations = [iterations]
    tables = []
    with wandb.init(entity=entity_name, project=project_name) as run:
        for iteration in iterations:
            table = run.use_artifact(f"run-{run_id}-{table_name}:v{iteration-1}").get(table_name)
            table = pd.DataFrame(table.data, columns=table.columns)
            tables.append(table)

    return pd.concat(tables)

def download_all_tables(run_id: str, entity_name: str, project_name: str, table_name: str):
    """
    Loads all table artifacts of a given Weights and Biases run.

    Args:
        run_id (string): ID of the Weights and Biases run.
        entity_name (string): Name of the Weights and Biases entity to which the run belongs.
        project_name (string): Name of the Weights and Biases project to which the run belongs.
        table_name (string): Name of the tabel artifact to be loaded.

    Returns:
        pandas.DataFrame: Concatenated tables for all active learning iterations.
    """

    run = api.run(f"/active-segmentation/{project_name}/runs/{run_id}")
    run_metrics = run.history()
    max_iteration = int(run_metrics["trainer/iteration"].max())

    selected_items = load_table(run.id, entity_name, project_name, table_name, list(range(1, max_iteration+1)))

    selected_items.to_csv(f"{run.name}_{run.id}.csv", index=False)

In [None]:
def plot_image(image: np.ndarray, mask: Optional[np.ndarray] = None) -> None:
    """
    Plots a single image.

    Args:
        image (numpy.ndarray): Image to be plotted. Must have shape `(h, w)` where `image height`, and `w = image width`.
        mask (numpy.ndarray, optional): Mask to be plotted onto the image. Must have same shape as `image`. Defaults to `None`.
    """

    plt.axis('off')
    plt.imshow(image, cmap="gray")

    if mask is not None:
        cmap = cm.Accent.copy()
        cmap.set_under('k', alpha=0)

        plt.imshow(mask, cmap=cmap, alpha=0.7, clim=[0.9, 10])


def plot_images(images: List[np.array], masks: Optional[List[np.array]] = None, max_images: int = 10) -> None:
    """
    Plots several images.

    Args:
        image (numpy.ndarray): Images to be plotted. Must have shape `(N, h, w)` where `N = number of images`,
            ` h = image height`, and `w = image width`.
        mask (numpy.ndarray, optional): Masks to be plotted onto the images. Must have same shape as `image`. Defaults to `None`.
        max_images (int): Maximum number of images to be plotted.
    """
    
    images = images[:max_images]

    plt.figure(figsize=(32 * len(images), 32))
    for idx, image in enumerate(images):
        plt.subplot(1, len(images), idx+1)
        plt.imshow(image, cmap='gray')

        if masks is not None:
            cmap = cm.Accent.copy()
            cmap.set_under('k', alpha=0)

            plt.imshow(masks[idx], cmap=cmap, alpha=0.7, clim=[0.9, 10])

In [None]:
def signed_distance_interpolation(
    top: np.array,
    bottom: np.array,
    block_thickness: int,
) -> np.array:
    """
    Interpolates between top and bottom slices if possible. Uses a signed distance function to interpolate.

    Args:
        top (np.array): The top slice of the block.
        bottom (np.array): The bottom slice of the block.
        block_thickness (int): The thickness of the block.
    Returns:
        np.array: The interpolated slices between top and bottom.
    """

    def signed_dist(mask):
        inverse_mask = np.ones(mask.shape) - mask
        return distance_transform_edt(mask) - distance_transform_edt(inverse_mask) + 0.5

    def interpolation(start, end, dist):
        dist_start = signed_dist(start)
        dist_end = signed_dist(end)
        interp = (dist_start * (1 - dist)) + (dist_end * dist)
        interp = interp >= 0
        return interp

    step = 1 / (block_thickness - 1)
    interpolation_steps = [i * step for i in range(1, block_thickness - 1)]

    return np.array([interpolation(top, bottom, step) for step in interpolation_steps])


def morphological_contour_interpolation(
    top: np.array,
    bottom: np.array,
    block_thickness: int,
) -> np.array:
    """
    Interpolates between top and bottom slices using the `morphological_contour_interpolator
    <https://www.researchgate.net/publication/307942551_ND_morphological_contour_interpolation>`_ from ITK.

    Args:
        top (np.array): The top slice of the block.
        bottom (np.array): The bottom slice of the block.
        block_thickness (int): The thickness of the block.
    Returns:
        np.array: The interpolated slices between top and bottom.
    """

    block = np.zeros((block_thickness, *top.shape))
    block[0, :, :] = bottom
    block[-1, :, :] = top
    image_type = itk.Image[itk.UC, 3]
    itk_img = itk.image_from_array(block.astype(np.uint8), ttype=(image_type,))
    image = itk.morphological_contour_interpolator(itk_img)
    interpolated_block = itk.GetArrayFromImage(image)
    interpolated_slices = interpolated_block[1:-1, :, :]

    return interpolated_slices.astype(bool)

    
def interpolate_slices(
    top: np.array,
    bottom: np.array,
    class_ids: Iterable[int],
    block_thickness: int,
    interpolation_type: Literal["signed-distance", "morph-contour"],
) -> Optional[np.array]:
    """
    Interpolates between top and bottom slices if possible. Uses a signed distance function to interpolate.

    Args:
        top (np.array): The top slice of the block.
        bottom (np.array): The bottom slice of the block.
        class_ids (Iterable[int]): The class ids.
        block_thickness (int): The thickness of the block.
        interpolation_type (Literal): The type of interpolation to use. One of ["signed-distance", "morph-contour"]
    Returns:
        Optional[np.array]: The interpolated slices between top and bottom.
    """

    interpolation_thickness = block_thickness - 2

    single_class_interpolations = {}
    for class_id in class_ids:
        class_top = top == class_id
        class_bottom = bottom == class_id

        if not np.any(class_top) and not np.any(class_bottom):
            single_class_interpolations[class_id] = np.zeros(
                (interpolation_thickness, *top.shape), dtype=bool
            )
        elif not np.any(np.logical_and(class_top, class_bottom)):
            return None
        else:
            if interpolation_type == "signed-distance":
                interpolation_fn = signed_distance_interpolation

            elif interpolation_type == "morph-contour":
                interpolation_fn = morphological_contour_interpolation

            else:
                raise ValueError(
                    f"Invalid interpolation type {interpolation_type}."
                )

            slices = interpolation_fn(class_top, class_bottom, block_thickness)
            single_class_interpolations[class_id] = slices

    result = np.zeros((interpolation_thickness, *top.shape))

    for class_id, interpolation in single_class_interpolations.items():
        result[interpolation] = class_id

    return result

In [None]:
def get_selected_blocks(run_id: str,
                        entity_name: str,
                        project_name: str,
                        table_name: str,
                        iteration: int,
                        class_ids: Dict[int, str], 
                        interpolation_type: Literal["signed-distance", "morph-contour"] = "signed-distance") -> List[Tuple[List[str], List[bool], np.ndarray, np.ndarray, np.ndarray]]:
    """
    Assembles the blocks selected in the given active learning run and iteration.

    Args:
        run_id (string): ID of the Weights and Biases run.
        entity_name (string): Name of the Weights and Biases entity to whicht the run belongs.
        project_name (string): Name of the Weights and Biases project to whicht the run belongs.
        table_name (string): Name of the table artifact to which the selected slices are logged.
        iteration (int): Active learning iteration for which the selected blocks are to be assembled.
        class_ids (Dict[int, str]): Mapping of class IDs to class names.
        interpolation_type (Literal): The type of interpolation to use. One of ["signed-distance", "morph-contour"]

    Returns:
        List[Tuple[List[str], List[bool], np.ndarray, np.ndarray, np.ndarray]]: For each selected block a tuple with the following elements is returend:
            - List[str]: Case IDs of all slices belonging to the block.
            - List[bool]: Indicates whether a slice is a pseudo label or not.
            - np.ndarray: Image slices of the block
            - np.ndarray: Labels of the slices (true labels for the top and bottom slice, pseudo-labels for itermediate slices)
            - np.ndarray: True labels of all slices.
    """
    
    run = api.run(f"/active-segmentation/{project_name}/runs/{run_id}")
    table_file = f"{run.name}_{run.id}.csv"

    if not os.path.exists(table_file):
        download_all_tables(run_id, entity_name, project_name, table_name)

    selected_slices = pd.read_csv(table_file)

    selected_slices = selected_slices[selected_slices["iteration"] == iteration]

    true_labels = selected_slices[selected_slices["pseudo_label"] == False]
    pseudo_labels = selected_slices[selected_slices["pseudo_label"] == True]

    selected_blocks = []

    for idx, bottom_slice in true_labels.iterrows():
        if idx % 2 == 1:
            image_path = bottom_slice["image_path"]
            label_path = image_path.replace("imagesTr", "labelsTr")

            image =  nib.load(image_path).get_fdata()
            image = np.moveaxis(image, 2, 0)

            label = nib.load(label_path).get_fdata()
            label = np.moveaxis(label, 2, 0)

            slices = []
            labels = []
            original_labels = []
            case_ids = []
            is_pseudo_label = []

            case_ids.append(bottom_slice["case_id"])
            slices.append(image[bottom_slice["slice_index"]])
            labels.append(label[bottom_slice["slice_index"]])
            original_labels.append(label[bottom_slice["slice_index"]])
            is_pseudo_label.append(False)

            top_slice = true_labels.iloc[idx-1]

            interpolated_slices = pseudo_labels[
                (pseudo_labels["image_id"] == bottom_slice["image_id"]) &
                (pseudo_labels["slice_index"] > bottom_slice["slice_index"]) &
                (pseudo_labels["slice_index"] < top_slice["slice_index"])
            ]

            interpolated_slices = interpolated_slices.sort_values(by=['slice_index'])

            bottom_slice_idx = bottom_slice["slice_index"]
            top_slice_idx = top_slice["slice_index"]

            if len(interpolated_slices) > 0:
                interpolated_labels = interpolate_slices(label[bottom_slice_idx], label[top_slice_idx], class_ids, len(interpolated_slices) + 2, interpolation_type)

                if interpolated_slices is not None:
                    for inner_idx, (_, interpolated_slice) in enumerate(interpolated_slices.iterrows()):
                        case_ids.append(interpolated_slice["case_id"])
                        slices.append(image[interpolated_slice["slice_index"]])
                        labels.append(interpolated_labels[inner_idx])
                        original_labels.append(label[interpolated_slice["slice_index"]])
                        is_pseudo_label.append(True)

            case_ids.append(top_slice["case_id"])
            slices.append(image[top_slice["slice_index"]])
            labels.append(label[top_slice["slice_index"]])
            original_labels.append(label[top_slice["slice_index"]])
            is_pseudo_label.append(False)

            selected_blocks.append((case_ids, is_pseudo_label, slices, labels, original_labels))

    return selected_blocks

def visualize_selected_blocks(selected_blocks: List[Tuple[List[str], List[bool], np.ndarray, np.ndarray, np.ndarray]],
                                image_size: int = 5,
                                max_images_per_row: int = 3) -> None:
    """
    Plots blocks selected in an active learning iteration.

    Args:
        selected_blocks (List[Tuple[List[str], List[bool], np.ndarray, np.ndarray, np.ndarray]]): Selected blocks as returned by `get_selected_blocks`.
        image_size (int, optional): Size in which the images should be displayed in inches. Defaults to 5.
        max_images_per_row (int, optional): Maximum number of images to be displayed in one row. Defaults to 3.
    """
    for block in selected_blocks:
        case_ids, is_pseudo_label, slices, labels, true_labels = block

        n_rows = math.ceil(len(slices) / max_images_per_row)

        plt.figure(figsize=(image_size * max_images_per_row, image_size * n_rows))

        for idx, image in enumerate(slices):
            plt.subplot(n_rows, max_images_per_row, idx+1)
            plt.axis('off')
            plt.imshow(image, cmap='gray')

            if labels is not None:
                cmap = cm.Accent.copy()
                cmap.set_under('k', alpha=0)

                plt.axis('off')
                plt.imshow(labels[idx], cmap=cmap, alpha=0.7, clim=[0.9, 10])
                plt.title(f"{case_ids[idx]} - {'interpolated label' if is_pseudo_label[idx] else 'true label'}")

                if is_pseudo_label[idx]:
                    cmap = cm.plasma.copy()
                    cmap.set_under('k', alpha=0)
                    true_label_diff = true_labels[idx].copy()
                    true_label_diff[true_labels[idx] == labels[idx]] = 0
                    plt.axis('off')
                    plt.imshow(true_label_diff, cmap=cmap, alpha=1, clim=[0.9, 10])

In [None]:
dataset_file_name = "/dhc/groups/mpws2021cl1/Data/Decathlon/Task02_Heart/dataset.json"
entity_name = "active-segmentation"
project_name = "task-02-heart-paper"
table_name = "selected_items"

with open(dataset_file_name, encoding="utf-8") as dataset_file:
    dataset_info = json.load(dataset_file)
    labels = dataset_info["labels"]
class_ids = {int(key): labels[key] for key in labels}

In [None]:
run_id = "2fns4wu0"

selected_blocks = get_selected_blocks(run_id, entity_name, project_name, table_name, 1, class_ids, "signed-distance")

visualize_selected_blocks(selected_blocks)