In [1]:
import os
import sys

project_root = os.getcwd()
while "src" not in os.listdir(project_root):
    project_root = os.path.dirname(project_root)
sys.path.append(project_root)

In [2]:
import torch
from src.utils.training import (
    MultiSourceLoss,
    MultiScaleLoss,
    VGGFeatureLoss,
)

In [3]:
channel_weights = [3.0, 1.0, 1.0, 1.0]

# Unbatched spectrograms (shape: [C, H, W])
output_spec = torch.randn(4, 1025, 173, requires_grad=True)
target_spec = torch.randn(4, 1025, 173)

# Batched spectrograms (shape: [B, C, H, W])
batched_output_spec = torch.randn(4, 4, 1025, 173, requires_grad=True)
batched_target_spec = torch.randn(4, 4, 1025, 173)

In [4]:
l1_loss = MultiSourceLoss(
    weights=channel_weights,
    distance="l1",
)
loss = l1_loss(output_spec, target_spec)
loss.backward()
print("L1 loss:", loss.item())

l2_loss = MultiSourceLoss(
    weights=channel_weights,
    distance="l2",
)
loss = l2_loss(output_spec, target_spec)
loss.backward()
print("L1 loss:", loss.item())

multi_scale_loss = MultiScaleLoss(
    weights=channel_weights,
)
loss = multi_scale_loss(output_spec, target_spec)
loss.backward()
print("Multi-scale loss:", loss.item())

composite_loss = VGGFeatureLoss(
    weights=channel_weights,
)
loss = composite_loss(output_spec, target_spec)
loss.backward()
print("Composite loss:", loss.item())

L1 loss: 1.1282662153244019
L1 loss: 1.9990620613098145
Multi-scale loss: 0.6574965715408325
Composite loss: 35.571136474609375


In [5]:
l1_loss = MultiSourceLoss(weights=channel_weights)
loss = l1_loss(batched_output_spec, batched_target_spec)
loss.backward()
print("L1 loss:", loss.item())

l2_loss = MultiSourceLoss(
    weights=channel_weights,
    distance="l2",
)
loss = l2_loss(batched_output_spec, batched_target_spec)
loss.backward()
print("Spectral loss:", loss.item())

multi_scale_loss = MultiScaleLoss(
    weights=channel_weights,
)
loss = multi_scale_loss(batched_output_spec, batched_target_spec)
loss.backward()
print("Multi-scale loss:", loss.item())

composite_loss = VGGFeatureLoss(
    weights=channel_weights,
)
loss = composite_loss(batched_output_spec, batched_target_spec)
loss.backward()
print("Composite loss:", loss.item())

L1 loss: 1.128129005432129
Spectral loss: 1.9971257448196411
Multi-scale loss: 0.6578590273857117
Composite loss: 35.648277282714844
