In [14]:
from pathlib import Path
import pandas as pd
import numpy as np
import nibabel as nib
from typing import Union, Sequence
import torch
import matplotlib.pyplot as plt
from scipy.interpolate import RegularGridInterpolator
import h5py
from skimage.color import label2rgb
import json



In [15]:
# Function for plotting MRI scans


def get_scaled_image(
        x: Union[torch.Tensor, np.ndarray], percentile=0.99, clip=False
):
    """Scales image by intensity percentile (and optionally clips to [0, 1]).

    Args:
      x (torch.Tensor | np.ndarray): The image to process.
      percentile (float): The percentile of magnitude to scale by.
      clip (bool): If True, clip values between [0, 1]

    Returns:
      torch.Tensor | np.ndarray: The scaled image.
    """
    is_numpy = isinstance(x, np.ndarray)
    if is_numpy:
        x = torch.as_tensor(x)

    scale_factor = torch.quantile(x, percentile)
    x = x / scale_factor
    if clip:
        x = torch.clip(x, 0, 1)

    if is_numpy:
        x = x.numpy()

    return x


def plot_images(
        images, processor=None, disable_ticks=True, titles: Sequence[str] = None,
        ylabel: str = None, xlabels: Sequence[str] = None, cmap: str = "gray",
        show_cbar: bool = False, overlay=None, opacity: float = 0.3,
        hsize=5, wsize=5, axs=None, fontsize=20
):
    """Plot multiple images in a single row.

    Add an overlay with the `overlay=` argument.
    Add a colorbar with `show_cbar=True`.
    """

    def get_default_values(x, default=""):
        if x is None:
            return [default] * len(images)
        return x

    titles = get_default_values(titles)
    ylabels = get_default_values(images)
    xlabels = get_default_values(xlabels)

    N = len(images)
    if axs is None:
        fig, axs = plt.subplots(1, N, figsize=(wsize * N, hsize))
    else:
        assert len(axs) >= N
        fig = axs.flatten()[0].get_figure()
    k = 0
    for ax, img, title, xlabel in zip(axs, images, titles, xlabels):
        if processor is not None:
            img = processor(img)
        if type(cmap) == list:
            im = ax.imshow(img, cmap=cmap[k])
            if type(show_cbar) == list:
                if show_cbar[k]:
                    fig.subplots_adjust(right=0.8)
                    cbar_ax = fig.add_axes([0.85, 0.2, 0.01, 0.60])
                    fig.colorbar(im, cax=cbar_ax)
        else:
            im = ax.imshow(img, cmap=cmap)
        k = k + 1
        ax.set_title(title, fontsize=fontsize)
        ax.set_xlabel(xlabel)

    if type(overlay) == list:
        for i, ax in enumerate(axs.flatten()):
            if overlay[i] is not None:
                im = ax.imshow(overlay[i], alpha=opacity)

    if disable_ticks:
        for ax in axs.flatten():
            ax.get_xaxis().set_ticks([])
            ax.get_yaxis().set_ticks([])

    return axs


# Function for transforming segmentation classes into one hot
def categorical_to_one_hot(x, channel_dim: int = 1, background=0, num_categories=None, dtype=None):
    """Converts categorical predictions to one-hot encoded predictions.

    Args:
        x (torch.Tensor | np.ndarray): Categorical array or tensor.
        channel_dim (int, optional): Channel dimension for output tensor.
        background (int | NoneType, optional): The numerical label of the
            background category. If ``None``, assumes that the background is
            a class that should be one-hot encoded.
        num_categories (int, optional): Number of categories (excluding background).
            Defaults to the ``max(x) + 1``.
        dtype (type, optional): Data type of the output.
            Defaults to boolean (``torch.bool`` or ``np.bool``).

    Returns:
        torch.Tensor | np.ndarray: One-hot encoded predictions.
    """
    is_ndarray = isinstance(x, np.ndarray)
    if is_ndarray:
        x = torch.from_numpy(x)

    if num_categories is None:
        num_categories = torch.max(x).type(torch.long).cpu().item()
    num_categories += 1

    shape = x.shape
    out_shape = (num_categories,) + shape

    if dtype is None:
        dtype = torch.bool
    default_value = True if dtype == torch.bool else 1
    if x.dtype != torch.long:
        x = x.type(torch.long)

    out = torch.zeros(out_shape, dtype=dtype, device=x.device)
    out.scatter_(0, x.reshape((1,) + x.shape), default_value)
    if background is not None:
        out = torch.cat([out[0:background], out[background + 1:]], dim=0)
    if channel_dim != 0:
        if channel_dim < 0:
            channel_dim = out.ndim + channel_dim
        order = (channel_dim,) + tuple(d for d in range(out.ndim) if d != channel_dim)
        out = out.permute(tuple(np.argsort(order)))
        out = out.contiguous()

    if is_ndarray:
        out = out.numpy()
    return out


# Function for combining one hot encoded classes
def collect_mask(
        mask: np.ndarray,
        index: Sequence[Union[int, Sequence[int], int]],
        out_channel_first: bool = True,
):
    """Collect masks by index.

    Collated indices will be summed. For example, `index=(1,(3,4))` will return
    `np.stack(mask[...,1], mask[...,3]+mask[...,4])`.

    TODO: Add support for adding background.

    Args:
        mask (ndarray): A (...)xC array.
        index (Sequence[int]): The index/indices to select in mask.
            If sub-indices are collated, they will be summed.
        out_channel_first (bool, optional): Reorders dimensions of output mask to Cx(...)
    """
    if isinstance(index, int):
        index = (index,)

    if not any(isinstance(idx, Sequence) for idx in index):
        mask = mask[..., index]
    else:
        o_seg = []
        for idx in index:
            c_seg = mask[..., idx]
            if isinstance(idx, Sequence):
                c_seg = np.sum(c_seg, axis=-1)
            o_seg.append(c_seg)
        mask = np.stack(o_seg, axis=-1)

    if out_channel_first:
        last_idx = len(mask.shape) - 1
        mask = np.transpose(mask, (last_idx,) + tuple(range(0, last_idx)))

    return mask


# Function for transforming segmentation classes into categorical
def one_hot_to_categorical(x, channel_dim: int = 1, background=False):
    """Converts one-hot encoded predictions to categorical predictions.

    Args:
        x (torch.Tensor | np.ndarray): One-hot encoded predictions.
        channel_dim (int, optional): Channel dimension.
            Defaults to ``1`` (i.e. ``(B,C,...)``).
        background (bool, optional): If ``True``, assumes index 0 in the
            channel dimension is the background.

    Returns:
        torch.Tensor | np.ndarray: Categorical array or tensor. If ``background=False``,
        the output will be 1-indexed such that ``0`` corresponds to the background.
    """
    is_ndarray = isinstance(x, np.ndarray)
    if is_ndarray:
        x = torch.as_tensor(x)

    if background is not None and background is not False:
        out = torch.argmax(x, channel_dim)
    else:
        out = torch.argmax(x.type(torch.long), dim=channel_dim) + 1
        out = torch.where(x.sum(channel_dim) == 0, torch.tensor([0], device=x.device), out)

    if is_ndarray:
        out = out.numpy()
    return out

In [16]:
# Load the corresponding data
dataset_dir = '/data/projects/recon/data/public/multitask/skm-tea/v1-release/'
meta_dir = Path(dataset_dir) / 'all_metadata.csv'
meta_data = pd.read_csv(meta_dir)

recon_file = Path(dataset_dir) / "files_recon_calib-24/"

seg_file = Path(dataset_dir)  / str("segmentation_masks/raw-data-track/")

patient_id = 'MTR_013' #meta_data['MTR_ID'][60]
print(patient_id)
meta_data = meta_data[meta_data['MTR_ID']==patient_id]

with h5py.File(str(recon_file / patient_id)+'.h5', "r") as f:
    print(f.keys())
    kspace_org = f["kspace"][:, :, :, :, :]  # Shape: (x, ky, kz, #echos, #coils)
    maps_org = f["maps"][:, :, :, :, :]      # Shape: (x, ky, kz, #coils, #maps) - maps are the same for both echos
    image_org = f['target'][()]

segmentation = nib.load(str(seg_file / patient_id)+'.nii.gz').get_fdata()
segmentation_one = categorical_to_one_hot(segmentation,channel_dim=-1)
segmentation_one_org = collect_mask(segmentation_one, (0, 1, (2, 3), (4, 5)), out_channel_first=False)
crop_scale = [1,1,1] #Option to crop
print(kspace_org.shape, segmentation_one_org.shape, maps_org.shape)

MTR_013
<KeysViewHDF5 ['kspace', 'maps', 'masks', 'target']>
(512, 512, 160, 2, 8) (512, 512, 160, 4) (512, 512, 160, 8, 1)


In [36]:
with open('/data/projects/recon/data/public/multitask/skm-tea/v1-release/annotations/v1.0.0/train.json', 'r') as f:
    object = json.load(f)

print(object.keys())
print(object['categories'])
print(meta_data['PatientID'])

dict_keys(['info', 'categories', 'tissues', 'images', 'annotations'])
[{'supercategory': 'Meniscal Tear', 'supercategory_id': 1, 'id': 1, 'name': 'Meniscal Tear (Myxoid)'}, {'supercategory': 'Meniscal Tear', 'supercategory_id': 1, 'id': 2, 'name': 'Meniscal Tear (Horizontal)'}, {'supercategory': 'Meniscal Tear', 'supercategory_id': 1, 'id': 3, 'name': 'Meniscal Tear (Radial)'}, {'supercategory': 'Meniscal Tear', 'supercategory_id': 1, 'id': 4, 'name': 'Meniscal Tear (Vertical/Longitudinal)'}, {'supercategory': 'Meniscal Tear', 'supercategory_id': 1, 'id': 5, 'name': 'Meniscal Tear (Oblique)'}, {'supercategory': 'Meniscal Tear', 'supercategory_id': 1, 'id': 6, 'name': 'Meniscal Tear (Complex)'}, {'supercategory': 'Meniscal Tear', 'supercategory_id': 1, 'id': 7, 'name': 'Meniscal Tear (Flap)'}, {'supercategory': 'Meniscal Tear', 'supercategory_id': 1, 'id': 8, 'name': 'Meniscal Tear (Extrusion)'}, {'supercategory': 'Ligament Tear', 'supercategory_id': 2, 'id': 9, 'name': 'Ligament Tear (

In [13]:
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=False)

Using cache found in /home/tmpaquaij/.cache/torch/hub/ultralytics_yolov5_master


ModuleNotFoundError: No module named 'cv2'