In [82]:
import torch


def spectral_loss(
    target_audio,
    estimate_audio,
    fft_sizes=(2048, 1024, 512, 256, 128, 64),
    loss_type="L1",
    overlap=0.5,
    mag_weight=1.0,
    delta_time_weight=0.0,
    delta_freq_weight=0.0,
    cumsum_freq_weight=0.0,
    logmag_weight=0.0,
    loudness_weight=0.0,
    weights=None,
):
    """Multi-scale spectral loss adapted from https://github.com/magenta/ddsp/blob/main/ddsp/losses.py

    Args:
        target_audio (torch.tensor): audio target
        estimate_audio (torch.tensor): audio estimate
        fft_sizes (tuple, optional): fft sizes for multi-scale spectrogram comparison. Defaults to (2048, 1024, 512, 256, 128, 64).
        loss_type (str, optional): Can be "L1", "L2" or "COSINE". Defaults to "L1".
        overlap (float, optional): Overlap for spectrogram computation. Defaults to 0.5.
        mag_weight (float, optional): _description_. Defaults to 1.0.
        delta_time_weight (float, optional): _description_. Defaults to 0.0.
        delta_freq_weight (float, optional): _description_. Defaults to 0.0.
        cumsum_freq_weight (float, optional): _description_. Defaults to 0.0.
        logmag_weight (float, optional): _description_. Defaults to 0.0.
        loudness_weight (float, optional): _description_. Defaults to 0.0.
        weights (_type_, optional): _description_. Defaults to None.

    Returns:
        torch.tensor: Loss tensor
    """
    loss = 0.0
    batch_size, _, n_samples = estimate_audio.shape
        
    # reshape audio signal for stft function
    target_audio = target_audio.reshape((batch_size, n_samples))
    estimate_audio = estimate_audio.reshape((batch_size, n_samples))

    for fft_size in fft_sizes:
        hop_length = int((1 - overlap) * fft_size)
        target_mag = torch.abs(
            torch.stft(target_audio, fft_size, hop_length, return_complex=True)
        )
        estimate_mag = torch.abs(
            torch.stft(estimate_audio, fft_size, hop_length, return_complex=True)
        )

        if mag_weight > 0:
            loss += mag_weight * mean_difference(
                target_mag, estimate_mag, loss_type, weights
            )

        if delta_time_weight > 0:
            target = torch.diff(target_mag, dim=2)
            value = torch.diff(estimate_mag, dim=2)
            loss += delta_time_weight * mean_difference(
                target, value, loss_type, weights
            )

        if delta_freq_weight > 0:
            target = torch.diff(target_mag, dim=1)
            value = torch.diff(estimate_mag, dim=1)
            loss += delta_freq_weight * mean_difference(
                target, value, loss_type, weights
            )

        if cumsum_freq_weight > 0:
            target = torch.cumsum(target_mag, dim=1)
            value = torch.cumsum(estimate_mag, dim=1)
            loss += cumsum_freq_weight * mean_difference(
                target, value, loss_type, weights
            )

        if logmag_weight > 0:
            target = torch.log(target_mag + 1e-6)
            value = torch.log(estimate_mag + 1e-6)
            loss += logmag_weight * mean_difference(target, value, loss_type, weights)
    try:
        if loudness_weight > 0:
            target = compute_loudness(target, n_fft=2048)
            value = compute_loudness(estimate_audio, n_fft=2048)
            loss += loudness_weight * mean_difference(target, value, loss_type, weights)
    except: 
        pass

    return loss


def mean_difference(target, value, loss_type, weights):
    if loss_type == "L1":
        loss = torch.abs(target - value)
    elif loss_type == "L2":
        loss = (target - value).pow(2)
    elif loss_type == "COSINE":
        target_norm = target.norm(dim=-1, keepdim=True)
        value_norm = value.norm(dim=-1, keepdim=True)
        similarity = torch.sum(target * value, dim=-1) / (target_norm * value_norm)
        loss = 1 - similarity
    else:
        raise ValueError("Invalid loss_type. Use 'L1', 'L2', or 'COSINE'.")

    if weights is not None:
        loss = loss * weights

    return loss.mean()


def compute_loudness(audio, n_fft):
    raise NotImplementedError


In [86]:
target = torch.randn(1, 1, 160000)  # Example target audio
estimate_audio = torch.randn(1, 1, 160000)  # Example audio
loss = spectral_loss(
    target_audio=target,
    estimate_audio=estimate_audio,
    loss_type="L2",
    mag_weight=1.0,
    delta_freq_weight=1.0,
    delta_time_weight=1.0,
    cumsum_freq_weight=1.0,
    logmag_weight=1.0,
)
print(loss)

tensor(590028.6250)
