In [8]:



# Custom weighted signal/noise MSE loss function
def ada_weighted_mse_loss(reconstructed_image, target_image, zero_weighting=1, nonzero_weighting=1):
    """
    Calculates the weighted mean squared error (MSE) loss between target_image and reconstructed_image.
    The loss for zero pixels in the target_image is weighted by zero_weighting, and the loss for non-zero
    pixels is weighted by nonzero_weighting.

    Args:
    - target_image: a tensor of shape (B, C, H, W) containing the target image
    - reconstructed_image: a tensor of shape (B, C, H, W) containing the reconstructed image
    - zero_weighting: a scalar weighting coefficient for the MSE loss of zero pixels
    - nonzero_weighting: a scalar weighting coefficient for the MSE loss of non-zero pixels

    Returns:
    - weighted_mse_loss: a scalar tensor containing the weighted MSE loss
    """
    
    # Get the indices of 0 and non 0 values in target_image as a mask for speed
    zero_mask = (target_image == 0)
    nonzero_mask = ~zero_mask         # Invert mask
    
    # Get the values in target_image
    values_zero = target_image[zero_mask]
    values_nonzero = target_image[nonzero_mask]
    
    # Get the corresponding values in reconstructed_image
    corresponding_values_zero = reconstructed_image[zero_mask]
    corresponding_values_nonzero = reconstructed_image[nonzero_mask]
    
    # Create an instance of MSELoss class
    mse_loss = torch.nn.MSELoss(reduction='mean')
    
    # Compute the MSE losses
    zero_loss = mse_loss(corresponding_values_zero, values_zero)
    nonzero_loss = mse_loss(corresponding_values_nonzero, values_nonzero)

    # Protection from there being no 0 vals or no non zero vals, which then retunrs nan for MSE and creates a nan overall MSE return (which is error)
    if torch.isnan(zero_loss):
        zero_loss = 0
    if torch.isnan(nonzero_loss):
        nonzero_loss = 0
    
    # Sum losses with weighting coefficiants 
    weighted_mse_loss = (zero_weighting * zero_loss) + (nonzero_weighting * nonzero_loss) 
    
    return weighted_mse_loss



def weighted_perfect_recovery_loss(reconstructed_image, target_image, zero_weighting=1, nonzero_weighting=1):

    # Get the indices of 0 and non 0 values in target_image as a mask for speed
    zero_mask = (target_image == 0)
    nonzero_mask = ~zero_mask         # Invert mask
    
    # Get the values in target_image
    values_zero = target_image[zero_mask]
    values_nonzero = target_image[nonzero_mask]

    #Calualte the number of value sin each of values_zero and values_nonzero for use in the class balancing
    zero_n = len(values_zero)
    nonzero_n = len(values_nonzero)
    
    # Get the corresponding values in reconstructed_image
    corresponding_values_zero = reconstructed_image[zero_mask]
    corresponding_values_nonzero = reconstructed_image[nonzero_mask]

    # Calculate the loss for zero values
    loss_value_zero = (values_zero != corresponding_values_zero).float().sum() 
    print(loss_value_zero)
    # Calculate the loss for non-zero values
    loss_value_nonzero = (values_nonzero != corresponding_values_nonzero).float().sum() 
    print(loss_value_nonzero)


    if zero_n == 0:
        zero_loss = 0
    else:
        zero_loss = zero_weighting*( (1/zero_n) * loss_value_zero)

    if nonzero_n == 0:
        nonzero_loss = 0
    else:
        nonzero_loss = nonzero_weighting*( (1/nonzero_n) * loss_value_nonzero) 

    # Calculate the total loss with automatic class balancing and user class weighting
    loss_value = zero_loss + nonzero_loss

    return loss_value



# Mean Squared Error (MSE):
def MSE(clean_input, noised_target):
    """
    Mean Squared Error (MSE)

    Args:
    clean_input (torch.Tensor): The original image.
    noised_target (torch.Tensor): The recovered image.

    Returns:
    The calculated Mean Squared Error value.
    """

    mse = torch.mean((torch.pow(clean_input - noised_target, 2)))
    return (float(mse.numpy()))




# create a torch tensor of shape 5 x 5 and fill it with 0's

import torch
import numpy as np

x = torch.zeros(5,5)
y = torch.zeros(5,5)

x[0,0] = 1
y[0,0] = 0.2
x[1,1] = 0
y[1,1] = 0
print(x)
print(y)
print(MSE(x, y))
print(ada_weighted_mse_loss(x,y))
print(torch.nn.functional.binary_cross_entropy(x,y))
print(weighted_perfect_recovery_loss(x,y))


tensor([[1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
tensor([[0.2000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
0.025600001215934753
tensor(0.6400)
tensor(3.2000)
tensor(0.)
tensor(1.)
tensor(1.)
