In [47]:
import h5py
import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
import torch
from typing import  Union

In [48]:
# coding=utf-8
__author__ = "Dimitris Karkalousos"

import torch
from torch import nn
from typing import Optional
from atommic.core.classes.loss import Loss

class FocalLoss(Loss):
    """Wrapper around PyTorch's CrossEntropyLoss to support 2D and 3D inputs."""

    def __init__(
        self,
        num_samples: int = 50,
        ignore_index: int = -100,
        reduction: str = "none",
        label_smoothing: float = 0.0,
        weight: torch.Tensor = None,
        gamma: float = 2,
        alpha: float = 0.25,
    ):
        """Inits :class:`CrossEntropyLoss`.

        Parameters
        ----------
        num_samples : int, optional
            Number of Monte Carlo samples, by default 50
        ignore_index : int, optional
            Index to ignore, by default -100
        reduction : str, optional
            Reduction method, by default "none"
        label_smoothing : float, optional
            Label smoothing, by default 0.0
        weight : torch.Tensor, optional
            Weight for each class, by default None
        """
        super().__init__()
        self.mc_samples = num_samples
        self.ignore_index =ignore_index
        self.weight = weight
        self.label_smoothing = label_smoothing
        self.reduction = reduction
        self.gamma =gamma
        self.alpha =alpha



    def forward(self, target: torch.Tensor, _input: torch.Tensor, pred_log_var: torch.Tensor = None) -> torch.Tensor:
        """Forward pass of :class:`CrossEntropyLoss`.

        Parameters
        ----------
        target : torch.Tensor
            Target tensor. Shape: (batch_size, num_classes, *spatial_dims)
        _input : torch.Tensor
            Prediction tensor. Shape: (batch_size, num_classes, *spatial_dims)
        pred_log_var : torch.Tensor, optional
            Prediction log variance tensor. Shape: (batch_size, num_classes, *spatial_dims). Default is ``None``.

        Returns
        -------
        torch.Tensor
            Loss tensor. Shape: (batch_size, *spatial_dims)
        """
        # In case we do not have a batch dimension, add it
        if _input.dim() == 3:
            _input = _input.unsqueeze(0)
        if target.dim() == 3:
            target = target.unsqueeze(0)


        cross_entropy = torch.nn.CrossEntropyLoss(
            #weight=self.weight.to(_input),
            ignore_index=self.ignore_index,
            reduction=self.reduction,
            label_smoothing=self.label_smoothing,
        )

        if self.mc_samples == 1 or pred_log_var is None:
            ce_loss = cross_entropy(_input.float(), target.float())
            pt = torch.exp(-ce_loss)
            focal_loss = (self.alpha * (1 - pt) ** self.gamma * ce_loss).mean()


            return focal_loss

        pred_shape = [self.mc_samples, *_input.shape]
        noise = torch.randn(pred_shape, device=_input.device)
        noisy_pred = _input.unsqueeze(0) + torch.sqrt(torch.exp(pred_log_var)).unsqueeze(0) * noise
        noisy_pred = noisy_pred.view(-1, *_input.shape[1:])
        tiled_target = target.unsqueeze(0).tile((self.mc_samples,)).view(-1, *target.shape[1:])
        loss = cross_entropy(noisy_pred, tiled_target).to(target).view(self.mc_samples, -1, *_input.shape[-2:]).mean(0)
        return loss.mean()



In [110]:
def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor:
    """Convert labels to one-hot representation.

    Parameters
    ----------
    labels: torch.Tensor
        the labels of shape [BNHW[D]].
    num_classes: int
        number of classes.
    dtype: torch.dtype
        the data type of the returned tensor.
    dim: int
        the dimension to expand the one-hot tensor.

    Returns
    -------
    torch.Tensor
        The one-hot representation of the labels.

    Examples
    --------
    >>> labels = torch.tensor([[[[0, 1, 2]]]])
    >>> one_hot(labels, num_classes=3)
    tensor([[[[1., 0., 0.],
                [0., 1., 0.],
                [0., 0., 1.]]]])
    """
    # if `dim` is bigger, add singleton dim at the end
    if labels.ndim < dim + 1:
        shape = list(labels.shape) + [1] * (dim + 1 - len(labels.shape))
        labels = torch.reshape(labels, shape)
    sh = list(labels.shape)
    sh[dim] = num_classes
    o = torch.zeros(size=sh, dtype=dtype, device=labels.device)
    labels = o.scatter_(dim=dim, index=labels.long(), value=1)
    return labels

In [49]:
def get_scaled_image(
    x: Union[torch.Tensor, np.ndarray], percentile=0.99, clip=False
):
  """Scales image by intensity percentile (and optionally clips to [0, 1]).

  Args:
    x (torch.Tensor | np.ndarray): The image to process.
    percentile (float): The percentile of magnitude to scale by.
    clip (bool): If True, clip values between [0, 1]

  Returns:
    torch.Tensor | np.ndarray: The scaled image.
  """
  is_numpy = isinstance(x, np.ndarray)
  if is_numpy:
    x = torch.as_tensor(x)

  scale_factor = torch.quantile(x, percentile)
  x = x / scale_factor
  if clip:
    x = torch.clip(x, 0, 1)

  if is_numpy:
    x = x.numpy()

  return x

In [50]:
with h5py.File('/data/projects/utwente/recon/SKM-TEA_small/v1-release/files_recon_calib-24/MTR_135.h5', "r") as f:
    kspace_small= f["kspace"][:, :, :, :, :]  # Shape: (x, ky, kz, #echos, #coils)
    maps_small = f["maps"][:, :, :, :, :]      # Shape: (x, ky, kz, #coils, #maps) - maps are the same for both echos

with h5py.File('/data/projects/recon/data/public/multitask/skm-tea/v1-release/files_recon_calib-24/MTR_135.h5', "r") as f:
    kspace_org= f["kspace"][:, :, :, :, :]  # Shape: (x, ky, kz, #echos, #coils)
    maps_org = f["maps"][:, :, :, :, :]      # Shape: (x, ky, kz, #coils, #maps) - maps are the same for both echos

segmentation_one_small = nib.load('/data/projects/utwente/recon/SKM-TEA_small/v1-release/segmentation_masks/raw-data-track/MTR_135.nii.gz').get_fdata()

segmentation_one_org = nib.load('/data/projects/recon/data/public/multitask/skm-tea/v1-release/segmentation_masks/raw-data-track/MTR_135.nii.gz').get_fdata()

In [51]:
print(kspace_org.shape)
print(maps_org.shape)
print(kspace_small.shape)
print(maps_small.shape)
print(segmentation_one_org.shape)
print(segmentation_one_small.shape)

(512, 512, 168, 2, 8)
(512, 512, 168, 8, 1)
(88, 256, 208, 2, 8)
(88, 256, 208, 1, 8)
(512, 512, 168)
(88, 256, 208, 4)


In [52]:
plt.subplot(3,1,1)
plt.imshow(get_scaled_image(np.abs(kspace_org[0,48:-48,40:-40,0,0]),0.99, clip=True),cmap='gray')
plt.subplot(3,1,2)
plt.imshow(get_scaled_image(np.abs(kspace_small[87,:,:,0,0]),0.99, clip=True),cmap='gray')
plt.subplot(3,1,3)
plt.imshow(get_scaled_image(np.abs(maps_small[84,:,:,0,0]),0.99, clip=True),cmap='gray')

<matplotlib.image.AxesImage at 0x7ff1889d1ae0>

In [120]:
seed =1

input = torch.randn((5,5,256,216))
target = torch.randn((5,5,256,216))
target_one_hot = one_hot(torch.argmax(torch.abs(target),dim=1).unsqueeze(1),num_classes=target.shape[1]).float()
print(target_one_hot.shape)
target_indices =torch.argmax(torch.abs(target),dim=1)
loss = torch.nn.functional.cross_entropy(input,target_one_hot).mean()
loss_indices = torch.nn.functional.cross_entropy(input,target_indices).mean()
print(loss,loss_indices)

torch.Size([5, 5, 256, 216])
tensor(3.8275) tensor(3.8275)
