In [22]:
import torch
from torch import nn
from hcpdiff.loss import EDMLoss, SSIMLoss, GWLoss  # Replace 'your_module_name' with the actual module name where MinSNRLoss is defined

class CombinedLoss(nn.Module):
    def __init__(self, edm_weight=1.0, ssim_weight=0.1, gw_weight=0.1, gamma=1.0,**kwargs):
        super(CombinedLoss, self).__init__()

        self.edm_loss = EDMLoss(gamma=gamma, **kwargs)
        
        self.ssim_loss = SSIMLoss( **kwargs)
        self.gw_loss = GWLoss(**kwargs)

        self.edm_weight = edm_weight
        self.ssim_weight = ssim_weight
        self.gw_weight = gw_weight

    def forward(self, input: torch.Tensor, target: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: # input: latent target: latent sigma [b,64,64,64]
        edm_loss = self.edm_loss(input, target, sigma)
        ssim_loss = self.ssim_loss(input, target)
        gw_loss = self.gw_loss(input, target)

        # Combine losses with weights
        combined_loss = (
            self.edm_weight * edm_loss +  # [20, 64, 64, 64]
            self.ssim_weight * ssim_loss + 
            self.gw_weight * gw_loss
        )

        return combined_loss

import torch

In [23]:
edm_loss = EDMLoss(gamma=0.5)

ssim_loss = SSIMLoss(reduction=None)
gw_loss = GWLoss(reduction=None)
inputs, target, sigma = torch.Tensor(2,64,64,64), torch.Tensor(2,64,64,64), torch.Tensor(2,1,1,1)

edm_loss = edm_loss(inputs, target, sigma)
ssim_loss = ssim_loss(inputs, target)
gw_loss = gw_loss(inputs, target)


In [24]:
gw_loss.shape

torch.Size([2, 64, 64, 64])

In [28]:
from pytorch_msssim import SSIM, MS_SSIM
from torch.nn.modules.loss import _Loss
import torch
ssim = SSIM(size_average=False, channel=64) # channel -> 64
ssim(inputs, target)

tensor([nan, nan])