# Inspect data

date: 08 Sept, 2021 <br>

content: <br>
* implementation of visualization methods
* inspect dataset folder

In [None]:
import nibabel as nib
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import copy
import cv2
import sys
import os

In [None]:
%config Completer.use_jedi = False
%matplotlib inline

## Utils methods

In [None]:
def load_data_from_folder(data_folder, subfiles=None):
    """
    Load data from data folder
    :param data_folder: path to data folder
    :param subfiles: subfiles in data folder
    :return: list with images as np.arrays
    """

    if subfiles is None:
        subfiles = ["vs_gk_t1_refT1.nii", "vs_gk_t2_refT1.nii", "vs_gk_struc1_TV_refT1.nii"]

    path = [os.path.join(data_folder, f) for f in subfiles]
    images = [nib.load(p) for p in path]
    return [img.get_fdata() for img in images]

def get_non_zero_slices_segmentation(segmentation):
    """
    Extract all slices of segmentation that are non-zero
    :param segmentation: np.array segmentation mask
    :return: list with non-zero slice indices
    """
    non_zero_index = []
    for idx in range(0, segmentation.shape[2]):
        if np.sum(segmentation[:, :, idx]) != 0:
            non_zero_index.append(idx)
    return non_zero_index

In [None]:
def tile_plot_timepoint(data_folder, slice_to_vis=None, figsize=(15, 15), cmap=None):
    """
    Tile plot of one timepoint with T1, T2 and segmentation mask.
    :param data_folder: path to data folder
    :param slice_to_vis: slice to be visualized; if None  - the median slice of non-zero slices will be taken
    :param figsize: figure size
    :param cmap: color map
    """
    # load NIFTI images
    data = load_data_from_folder(data_folder)

    # determine median slice that has segmentation mask
    if slice_to_vis is None:
        slice_to_vis = int(np.round(np.mean(get_non_zero_slices_segmentation(data[2]))))

    # plot
    fig, ax = plt.subplots(1, 3, figsize=figsize)
    fig.tight_layout()
    ax[0].imshow(data[0][:, :, slice_to_vis], cmap=cmap)
    ax[0].set_title("T1")
    ax[1].imshow(data[1][:, :, slice_to_vis], cmap=cmap)
    ax[1].set_title("T2")
    ax[2].imshow(data[2][:, :, slice_to_vis], cmap=cmap)
    ax[2].set_title("SEG")
    fig.suptitle("Data slice {}".format(slice_to_vis), fontsize=16)
    plt.show()


def tile_plot(data_folder, nrows=5, ncols=6, skip=2, cmap=None):
    """
    Tile plot with multiple slides in a grid order.
    :param data_folder: path to data folder
    :param nrows: number of rows in the grid
    :param ncols: number of columns in the grid
    :param skip: number of slices to be skipped (default: every second slice is shown)
    :param cmap: color map
    :return:
    """
    # load NIFTI images
    data = load_data_from_folder(data_folder)

    slices = data[0].shape[2]
    if nrows * ncols > slices:
        raise ValueError("nrows*ncols larger than {}".format(slices))

    # plot
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 15))
    counter = 0
    for row in range(nrows):
        for col in range(ncols):
            axes[row, col].imshow(data[0][:, :, counter], cmap=cmap)
            axes[row, col].set_title(counter)
            axes[row, col].axis("off")
            counter += skip
    plt.show()


def overlap_plot(data_folder, slice_to_vis=None, figsize=(15, 15), modality="T1", cmap="gray"):
    """
    Plot segmentation mask on top of either T1 or T2 image for one time point.
    :param data_folder: path to data folder
    :param slice_to_vis: slice to be visualized; if None  - the median slice of non-zero slices will be taken
    :param figsize: figure size
    :param cmap: color map
    :param modality: modality to be shown; either T1 or T2
    """
    # load NIFTI images
    if "t1" in modality.lower():
        data = load_data_from_folder(data_folder, subfiles=["vs_gk_t1_refT1.nii", "vs_gk_struc1_TV_refT1.nii"])
    elif "t2" in modality.lower():
        data = load_data_from_folder(data_folder, subfiles=["vs_gk_t2_refT1.nii", "vs_gk_struc1_TV_refT1.nii"])
    else:
        raise ValueError("Modality needs to be either T1 or T2.")

    # determine median slice that has segmentation mask
    if slice_to_vis is None:
        slice_to_vis = int(np.round(np.mean(get_non_zero_slices_segmentation(data[1]))))

    # process data
    tmp = copy.deepcopy(data[0][:, :, slice_to_vis])
    toshow = cv2.normalize(tmp, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX)
    toshow3 = data[1][:, :, slice_to_vis]

    # plot
    if cmap == "gray":
        fig, ax = plt.subplots()
        ax.imshow(toshow, cmap="gray")
        ax.imshow(toshow3 * 200, alpha=0.3)
    else:
        dst = cv2.addWeighted(toshow / 255., 1.0, toshow3, 1.0, 0.0)
        fig = plt.figure(figsize=figsize)
        plt.imshow(dst)
    fig.suptitle("{}, data slice {}".format(modality, slice_to_vis), fontsize=16)
    plt.show()


def overlap_plot_comparison(data_folder, slice_to_vis=None, figsize=(15, 15), cmap="gray"):
    """
    Plot segmentation mask on top of either T1 and T2 image for one time point.
    :param data_folder: path to data folder
    :param slice_to_vis: slice to be visualized; if None  - the median slice of non-zero slices will be taken
    :param figsize: figure size
    :param cmap: color map
    """
    # load NIFTI images
    data = load_data_from_folder(data_folder)

    # determine median slice that has segmentation mask
    if slice_to_vis is None:
        slice_to_vis = int(np.round(np.mean(get_non_zero_slices_segmentation(data[2]))))

    # process data
    tmp = copy.deepcopy(data[0][:, :, slice_to_vis])
    tmp2 = copy.deepcopy(data[1][:, :, slice_to_vis])
    toshow = np.zeros_like(tmp)
    toshow2 = np.zeros_like(tmp2)
    toshow = cv2.normalize(tmp, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX)
    toshow2 = cv2.normalize(tmp2, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX)
    toshow3 = data[2][:, :, slice_to_vis]

    # plot
    if cmap == "gray":
        fig, ax = plt.subplots(1, 2, figsize=figsize)
        ax[0].imshow(toshow, cmap="gray")
        ax[0].imshow(toshow3 * 200, alpha=0.3)
        ax[0].set_title("T1")
        ax[0].axis("off")
        ax[1].imshow(toshow2, cmap="gray")
        ax[1].imshow(toshow3 * 200, alpha=0.3)
        ax[1].set_title("T2")
        ax[1].axis("off")
    else:
        img = cv2.addWeighted(toshow / 255., 1.0, toshow3, 0.5, 0.0)
        img2 = cv2.addWeighted(toshow2 / 255., 1.0, toshow3, 0.5, 0.0)
        fig, ax = plt.subplots(1, 2, figsize=figsize)
        ax[0].imshow(img)
        ax[0].set_title("T1")
        ax[0].axis("off")
        ax[1].imshow(img2)
        ax[1].set_title("T2")
        ax[1].axis("off")
    fig.suptitle("Data slice {}".format(slice_to_vis), fontsize=16)
    plt.show()

## Inspect data folder

In [None]:
folder_path = "/tf/workdir/data/VS_segm/VS_registered/training/vs_gk_1"

In [None]:
tile_plot(folder_path)

In [None]:
tile_plot_timepoint(folder_path)

In [None]:
overlap_plot_comparison(folder_path)

In [None]:
overlap_plot(folder_path, figsize=(5,5))

In [None]:
overlap_plot(folder_path, figsize=(5,5), modality="T2")