In [1]:
"""
This implementation of the TLoss class is based on the original work provided by:
University of Basel and Lucerne University of Applied Sciences and Arts Authors
Link: https://github.com/Digital-Dermatology/t-loss
Original Author: Alvaro Gonzalez-Jimenez

Modifications:
- Adjusted to handle 3D volumetric data (B x C x D x H x W).
- Added dynamic initialization of the lambdas tensor based on the input tensor dimensions.
- Ensured compatibility with nnUNet framework by adding configuration dictionary handling.

Original TLoss Description:
This is an implementation of the T-Loss function, which is formulated to handle arbitrary diagonal matrices.
For detailed explanation, please refer to the original implementation.

Args:
    config: Configuration dictionary for the loss.
    nu (float): Value of nu.
    epsilon (float): Value of epsilon.
    reduction (str): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
                     'none': no reduction will be applied,
                     'mean': the sum of the output will be divided by the number of elements in the output,
                     'sum': the output will be summed.
"""
import numpy as np
import torch
import torch.nn as nn

class TLoss(nn.Module):
    def __init__(
        self,
        config: dict,
        nu: float = 1.0,
        epsilon: float = 1e-8,
        reduction: str = "mean",
    ):
        super().__init__()
        self.config = config
        image_size = self.config['data']['image_size']
        self.D = torch.tensor(
            np.prod(image_size),
            dtype=torch.float,
            device=config['device'],
        )

        self.lambdas = nn.Parameter(
            torch.ones(
                image_size,
                dtype=torch.float,
                device=config['device'],
            ) * 0.01  # Initialize to a small value to prevent large initial loss
        )
        self.nu = nn.Parameter(
            torch.tensor(nu, dtype=torch.float, device=config['device'])
        )
        self.epsilon = torch.tensor(epsilon, dtype=torch.float, device=config['device'])
        self.reduction = reduction

    def forward(
        self, input_tensor: torch.Tensor, target_tensor: torch.Tensor
    ) -> torch.Tensor:
        #print(f"Input tensor shape: {input_tensor.shape}")
        #print(f"Target tensor shape: {target_tensor.shape}")
        #print(f"Lambdas shape: {self.lambdas.shape}")

        # Resize lambdas if necessary
        if self.lambdas.shape != input_tensor.shape[2:]:
            lambdas_resized = torch.nn.functional.interpolate(
                self.lambdas.unsqueeze(0).unsqueeze(0),
                size=input_tensor.shape[2:],
                mode='trilinear',
                align_corners=False
            ).squeeze(0).squeeze(0)
        else:
            lambdas_resized = self.lambdas

        delta_i = input_tensor - target_tensor
        sum_nu_epsilon = torch.exp(self.nu) + self.epsilon

        # First term: -log G((? + D) / 2)
        first_term = -torch.lgamma((sum_nu_epsilon + self.D) / 2)
        # Second term: log G(? / 2)
        second_term = torch.lgamma(sum_nu_epsilon / 2)
        # Third term: 1/2 S log(?_d + e)
        third_term = -0.5 * torch.sum(torch.log(lambdas_resized + self.epsilon))
        # Fourth term: (D / 2) log(p)
        fourth_term = (self.D / 2) * torch.log(torch.tensor(np.pi))
        # Fifth term: (D / 2) (? + e)
        fifth_term = (self.D / 2) * (self.nu + self.epsilon)

        delta_squared = torch.pow(delta_i, 2)
        lambdas_exp = torch.exp(lambdas_resized + self.epsilon)

        numerator = delta_squared * lambdas_exp
        numerator = torch.sum(numerator, dim=(1, 2, 3, 4))

        fraction = numerator / sum_nu_epsilon
        sixth_term = ((sum_nu_epsilon + self.D) / 2) * torch.log(1 + fraction)

        total_losses = (
            first_term
            + second_term
            + third_term
            + fourth_term
            + fifth_term
            + sixth_term
        )

        if self.reduction == "mean":
            return total_losses.mean()
        elif self.reduction == "sum":
            return total_losses.sum()
        elif self.reduction == "none":
            return total_losses
        else:
            raise ValueError(
                f"The reduction method '{self.reduction}' is not implemented."
            )



In [2]:
import monai
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from torchinfo import summary

# Define the 3D U-Net model
model = UNet(
    spatial_dims=3,
    in_channels=1,     # Number of input channels (e.g., 1 for grayscale images, 3 for RGB images)
    out_channels=2,    # Number of output channels (e.g., number of classes for segmentation)
    channels=(16, 32, 64, 128, 256),  # Number of filters in each layer
    strides=(2, 2, 2, 2),  # Strides for down-sampling layers
    num_res_units=2,  # Number of residual units
    norm=Norm.BATCH,  # Normalization type
)

# Print the model architecture
print(model)


UNet(
  (model): Sequential(
    (0): ResidualUnit(
      (conv): Sequential(
        (unit0): Convolution(
          (conv): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
          (adn): ADN(
            (N): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
        (unit1): Convolution(
          (conv): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (adn): ADN(
            (N): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
      )
      (residual): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    )
    (1): SkipConnection(
      (submodule): Sequential(
        (0): ResidualUnit(
          (conv): Sequential(


In [3]:
summary(model)

Layer (type:depth-idx)                                                                Param #
UNet                                                                                  --
├─Sequential: 1-1                                                                     --
│    └─ResidualUnit: 2-1                                                              --
│    │    └─Sequential: 3-1                                                           7,442
│    │    └─Conv3d: 3-2                                                               448
│    └─SkipConnection: 2-2                                                            --
│    │    └─Sequential: 3-3                                                           4,799,182
│    └─Sequential: 2-3                                                                --
│    │    └─Convolution: 3-4                                                          1,735
│    │    └─ResidualUnit: 3-5                                                         110
T

In [9]:
torch.exp(torch.tensor(-10))

tensor(4.5400e-05)