In [1]:
import torch
import torch.nn.functional as F
import torchmetrics

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
y_pred = torch.randn(4, 3, 128, 128)
y_true = torch.randn(4, 3, 128, 128)

In [3]:
def KL_DivLoss(y_pred, y_true):
    kl_loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
    log_input = F.log_softmax(y_pred, dim=1)
    log_target = F.log_softmax(y_true, dim=1)
    output = kl_loss(log_input, log_target)
    return output

def RMSELoss(y_pred, y_true):
    mse_loss = torch.nn.MSELoss(reduction="mean")
    output = torch.sqrt(mse_loss(y_true, y_pred))
    return output

def MAELoss(y_pred, y_true):
    mae_loss = torch.nn.L1Loss(reduction="mean")
    output = torch.sqrt(mae_loss(y_true, y_pred))
    return output  

def PSNR(y_pred, y_true):
    psnr = torchmetrics.PeakSignalNoiseRatio()
    output = psnr(y_pred, y_true)
    return output   

def SSIM(y_pred, y_true):
    ssim = torchmetrics.StructuralSimilarityIndexMeasure()
    output = ssim(y_pred, y_true)
    return output

def FID(y_pred, y_true):
    from torchmetrics.image.fid import FrechetInceptionDistance
    
    fid = FrechetInceptionDistance(feature=64, normalize=True)
    fid.update(y_true, real=True)
    fid.update(y_pred, real=False)
    output = fid.compute()
    return output

In [4]:
fn_list = [
    ("kl_div", KL_DivLoss), 
    ("rmse", RMSELoss), 
    ("mae", MAELoss),
    ("psnr", PSNR),
    ("ssim", SSIM),
    ("fid", FID)
]

loss_dict = {}
for fn_name, fn in fn_list:
    loss_dict[fn_name] = fn(y_pred, y_true)

Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to /homes/zr523/.cache/torch/hub/checkpoints/weights-inception-2015-12-05-6726825d.pth
100%|██████████████████████████████████████| 91.2M/91.2M [00:01<00:00, 54.6MB/s]


In [5]:
loss_dict

{'kl_div': tensor(8479.2217),
 'rmse': tensor(1.4155),
 'mae': tensor(1.0630),
 'psnr': tensor(16.1908),
 'ssim': tensor(0.0155),
 'fid': tensor(0.0801)}