In [1]:
import torch
import piq

from torch import nn

from src.metrics.basic import *
from src.metrics import ssim_loss
from src.metrics import fid

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
batch_size = 16
channels = 3
height = 128
width = 128

real_images = torch.rand(batch_size, channels, height, width).to(device)
generated_images = torch.rand(batch_size, channels, height, width).to(device)
predictions = torch.sigmoid(torch.rand(batch_size, 1)).to(device)
targets = torch.ones_like(predictions).to(device)

In [4]:
mse_custom = mse_loss(real_images, generated_images)
mse_torch = nn.MSELoss()(real_images, generated_images)
print(f"MSE Custom: {mse_custom}, Torch: {mse_torch}")


MSE Custom: 0.16663974523544312, Torch: 0.16663974523544312


In [5]:
bce_custom = bce_loss(predictions, targets)
bce_torch = nn.BCELoss(reduction="mean")(predictions, targets)
print(f"BCE Custom: {bce_custom}, Torch: {bce_torch}")

BCE Custom: 7.381808757781982, Torch: 0.4613630473613739


In [6]:
mu = torch.rand(batch_size, 128)
log_var = torch.rand(batch_size, 128)
kld_custom = kld_loss(mu, log_var)

print(f"KLD Custom: {kld_custom.item()}")

KLD Custom: 34.64309310913086


In [7]:
ssim_custom = ssim_loss.SSIMLoss()(real_images, generated_images)
ssim_piq = piq.SSIMLoss()(real_images, generated_images)

print(f"SSIM Custom: {ssim_custom.item():.6f}, PIQ: {ssim_piq.item():.6f}")

SSIM Custom: 0.964708, PIQ: 0.994827


In [8]:
real_images = torch.rand(batch_size, channels, height, width).to(device)
generated_images = torch.rand(batch_size, channels, height, width).to(device)

fid_score_custom = fid.calculate_fid(real_images, generated_images, device, impl='custom') # Takes 20 seconds
fid_score_torch = fid.calculate_fid(real_images, generated_images, device, impl='torchmetrics') # Takes 1 second
print(f"FID Custom: {fid_score_custom.item():.6f}, TorchMetrics: {fid_score_torch}")



FID Custom: 6.376643, TorchMetrics: 20.622800827026367
