In [1]:

import torch

class NEWESTACB3dloss2024(torch.nn.Module):
    def __init__(self, b, m, n):

        """
        # Number of dp is related to timestep size. 1,000 = 3dp, 10,000 = 4dp, 100,000 = 5dp, 1,000,000 = 6dp etc
        
        """
        super().__init__()   
        self.mse_loss = torch.nn.MSELoss(reduction='mean')
        self.zero_weighting = 1
        self.nonzero_weighting = 1


        # create a batches tensor
        batches = torch.arange(b).repeat_interleave(m*n).view(-1,1)

        # Create indices tensor
        indices = torch.stack( torch.meshgrid(torch.arange(m), torch.arange(n), indexing='ij'), dim=-1).view(-1, 2) 
        indices = indices.repeat(b, 1)

        self.batches = batches  
        self.indices = indices

    def transform_to_3d_coordinates(self, input_tensor):
        # Reshape the input tensor and concatenate with indices
        output_tensor = torch.cat((self.batches, self.indices, input_tensor.reshape(-1, 1)), dim=1)

        return output_tensor


    def forward(self, reconstructed_image, target_image):
        reconstructed_image = self.transform_to_3d_coordinates(reconstructed_image)
        print("Reconstructed image: ", reconstructed_image)
        target_image = self.transform_to_3d_coordinates(target_image)
        print("\nTarget image: ", target_image)

        # target_image is a tensor of shape (m, 4)
        # Example:
        # target_image = np.array([[1, 2, 3, 4],
        #                         [5, 6, 7, 0],
        #                         [8, 9, 10, 11]])

        # Identify 0 values in the 4th column of the second dimension

        zero_mask = (target_image[:, 3] == 0)
        nonzero_mask = ~zero_mask

        values_zero = target_image[zero_mask]
        values_nonzero = target_image[nonzero_mask]

        corresponding_values_zero = reconstructed_image[zero_mask]
        corresponding_values_nonzero = reconstructed_image[nonzero_mask]

        zero_loss = self.mse_loss(corresponding_values_zero, values_zero)
        nonzero_loss = self.mse_loss(corresponding_values_nonzero, values_nonzero)

        if torch.isnan(zero_loss):
            zero_loss = 0
        if torch.isnan(nonzero_loss):
            nonzero_loss = 0

        weighted_mse_loss = (self.zero_weighting * zero_loss) + (self.nonzero_weighting * nonzero_loss)

        return weighted_mse_loss


In [5]:
# genrate torch ensor of shap 1, 1, 2, 2 full of zeros
tensor = torch.zeros(1, 1, 2, 2)
print(tensor)

tensor[0, 0, 0, 0] = 0.1
tensor[0, 0, 0, 1] = 0.2
tensor[0, 0, 1, 0] = 0.3

print(tensor)




tensor([[[[0., 0.],
          [0., 0.]]]])
tensor([[[[0.1000, 0.2000],
          [0.3000, 0.0000]]]])


In [6]:
# runm twnsor through loss
loss_fn = NEWESTACB3dloss2024(1, 2, 2)

loss = loss_fn(tensor, tensor)


Reconstructed image:  tensor([[0.0000, 0.0000, 0.0000, 0.1000],
        [0.0000, 0.0000, 1.0000, 0.2000],
        [0.0000, 1.0000, 0.0000, 0.3000],
        [0.0000, 1.0000, 1.0000, 0.0000]])

Target image:  tensor([[0.0000, 0.0000, 0.0000, 0.1000],
        [0.0000, 0.0000, 1.0000, 0.2000],
        [0.0000, 1.0000, 0.0000, 0.3000],
        [0.0000, 1.0000, 1.0000, 0.0000]])
