In [33]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [36]:
import torch
import torch.nn as nn
from ml.models.building_blocks import VG_discriminator
from ml.extra_libraries.pytorch_ssim import SSIM3D
from ml.extra_libraries.cldice import soft_cldice

In [35]:
class CycleLoss:
    def __init__(self, type='mse'):
        self.loss_fn = nn.MSELoss()

    def __call__(self, real, pred):
        return self.loss_fn(real, pred) 
    
class ReconstructionLoss:
    def __init__(self):
        self.loss_fn = SSIM3D()

    def __call__(self, real, pred):
        real = (real - real.min())/(real.max() - real.min())
        pred = (pred - pred.min())/(pred.max() - pred.min())
        return self.loss_fn(real, pred)
    
class SegmentationLoss:
    def __init__(self):
        self.loss_fn = SSIM3D()

    def __call__(self, real, pred):
        #real = (real - real.min())/(real.max() - real.min())
        #pred = (pred - pred.min())/(pred.max() - pred.min())
        return self.loss_fn(real, pred)
    
class DiscriminatorLoss:
    def __init__(self, type=None):
        self.loss_fn = SSIM3D()

    def __call__(self, real, pred):
        real = (real - real.min())/(real.max() - real.min())
        pred = (pred - pred.min())/(pred.max() - pred.min())
        return self.loss_fn(real, pred)

In [28]:
img1 = 2*torch.randn(1, 1, 64, 64, 64)
img2 = 3*torch.randn(1, 1, 64, 64, 64)

tensor(0.2324)
tensor(0.9960)


In [None]:
def discriminator_loss_fn(self, real_image, fake_image, typ=None, from_logits=True):
    if typ == None:
        return 0.5 * (MSE(self, tf.ones_like(real_image), real_image) + MSE(self, tf.zeros_like(fake_image), fake_image))
    else :
        if typ == "bce":
            loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=from_logits, reduction=tf.keras.losses.Reduction.NONE)
        elif typ == "bfce":
            loss_obj = tf.keras.losses.BinaryFocalCrossentropy(from_logits=from_logits, reduction=tf.keras.losses.Reduction.NONE)
        real = real_image
        fake = fake_image
        if from_logits == False:
            real = min_max_norm_tf(real)
            fake = min_max_norm_tf(fake)
        loss = (loss_obj(tf.ones_like(real), real) + loss_obj(tf.zeros_like(fake), fake)) * 0.5
        return reduce_mean(self, loss)

In [None]:

@tf.function
def cycle_seg_loss(self, real_image, cycled_image):
    """
    Compute the segmentation loss between the real image and the cycled image
    
    Args:
    - real_image: a tensor of shape (batch_size, image_size, image_size, channels) representing the real image
    - cycled_image: a tensor of shape (batch_size, image_size, image_size, channels) representing the cycled image
    
    Returns:
    - a scalar tensor representing the segmentation loss
    """
    real = min_max_norm_tf(real_image)
    cycled = min_max_norm_tf(cycled_image)
    cl_loss_obj = soft_dice_cldice_loss()
    return reduce_mean(self, cl_loss_obj(real, cycled)) * self.lambda_identity

@tf.function
def identity_loss(self, real_image, same_image, typ=None):
    """
    Compute the identity loss between the real image and the same image.

    Args:
        real_image: the real image
        same_image: the generated same image
        typ: the type of loss to use. Currently only supports "cldice", other MAE used.

    Returns:
        The identity loss between the real image and the same image.
    """
    if typ is None:
        return self.lambda_identity * MAE(self, real_image, same_image)
    else:
        if typ == "cldice":
            real = min_max_norm_tf(real_image)
            same = min_max_norm_tf(same_image)
            loss_obj = soft_dice_cldice_loss()
            # bf_loss_obj = tf.keras.losses.BinaryFocalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
            # id_loss = reduce_mean(self, bf_loss_obj(real, same_image)) * self.lambda_identity
            spat_loss = reduce_mean(self, loss_obj(real, same)) * self.lambda_identity
            return spat_loss

@tf.function
def generator_loss_fn(self, fake_image, typ=None, from_logits=True):
    """
    Calculates the loss for the generator.

    Args:
        self (object): Instance of the VANGAN class.
        fake_image (tf.Tensor): Generated fake image tensor.
        typ (str): Type of loss. If None, default MSE is used.
                   Else, the valid types are: "bce" - Binary cross-entropy,
                   "bfce" - Binary focal cross-entropy.
                   Default: None.
        from_logits (bool): Whether to use logits or probabilities.
                            Default: True.

    Returns:
        tf.Tensor: The generator loss.
    """
    if typ == None:
        return MSE(self, tf.ones_like(fake_image), fake_image)
    else :
        if typ == "bce":
            loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=from_logits, reduction=tf.keras.losses.Reduction.NONE)
        elif typ == "bfce":
            loss_obj = tf.keras.losses.BinaryFocalCrossentropy(from_logits=from_logits, reduction=tf.keras.losses.Reduction.NONE)
        fake = fake_image
        if from_logits == False:
            fake = min_max_norm_tf(fake)
        loss = loss_obj(tf.ones_like(fake_image), fake)
        return reduce_mean(self, loss)

@tf.function
