In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import matplotlib.pyplot as plt
import numpy as np


In [None]:
def generate_synthetic_wave(batch_size=8, length=16000, sample_rate=16000):
    t = torch.linspace(0, 1, steps=length)
    waves = []
    for _ in range(batch_size):
        freq = torch.randint(100, 1000, (1,)).item()
        wave = torch.sin(2 * np.pi * freq * t)
        noise = 0.05 * torch.randn_like(wave)
        waves.append((wave + noise).unsqueeze(0))  # shape: [1, length]
    return torch.stack(waves)  # [batch_size, 1, length]


In [None]:
class TinyAudioAutoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(1, 16, 9, stride=2, padding=4),
            nn.ReLU(),
            nn.Conv1d(16, 32, 9, stride=2, padding=4),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(32, 16, 9, stride=2, padding=4, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(16, 1, 9, stride=2, padding=4, output_padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)


In [None]:
class WaveformL2Loss(nn.Module):
    def forward(self, x, y):
        return F.mse_loss(x, y)
class FFTLoss(nn.Module):
    def forward(self, x, y):
        x_fft = torch.fft.fft(x)
        y_fft = torch.fft.fft(y)
        return torch.mean(torch.abs(x_fft - y_fft))
class MelSpecL2Loss(nn.Module):
    def __init__(self, sample_rate=16000, n_fft=1024, hop_length=256, n_mels=64):
        super().__init__()
        self.mel = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels,
        )

    def forward(self, x, y):
        m1 = self.mel(x.squeeze(1))
        m2 = self.mel(y.squeeze(1))
        return F.mse_loss(m1, m2)


In [None]:
def sqrtm_newton_schulz(A, num_iter=10):
    B, N, _ = A.shape
    normA = A.norm(dim=(1, 2), keepdim=True)
    Y = A / normA
    I = torch.eye(N, device=A.device).unsqueeze(0).expand(B, -1, -1)
    Z = I.clone()

    for _ in range(num_iter):
        T = 0.5 * (3.0 * I - Z @ Y)
        Y = Y @ T
        Z = T @ Z
    return Y * torch.sqrt(normA)

class MelFIDLoss(nn.Module):
    def __init__(self, sample_rate=16000):
        super().__init__()
        self.mel = torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate, n_mels=64)

    def forward(self, x, y):
        m1 = self.mel(x.squeeze(1))  # [B, mel, T]
        m2 = self.mel(y.squeeze(1))
        f1 = m1.transpose(1, 2)  # [B, T, mel]
        f2 = m2.transpose(1, 2)
        mu1 = f1.mean(dim=1)
        mu2 = f2.mean(dim=1)
        c1 = (f1 - mu1.unsqueeze(1)).transpose(1, 2) @ (f1 - mu1.unsqueeze(1)) / (f1.shape[1] - 1)
        c2 = (f2 - mu2.unsqueeze(1)).transpose(1, 2) @ (f2 - mu2.unsqueeze(1)) / (f2.shape[1] - 1)
        sqrt_cov = sqrtm_newton_schulz(c1 @ c2)
        trace_term = torch.diagonal(c1 + c2 - 2 * sqrt_cov, dim1=1, dim2=2).sum(dim=1)
        mean_term = (mu1 - mu2).pow(2).sum(dim=1)
        return (mean_term + trace_term).mean()


In [None]:
def train(model, loss_fn, epochs=10, device='cuda'):
    model.to(device)
    loss_fn.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(epochs):
        model.train()
        x = generate_synthetic_wave(batch_size=16).to(device)
        y = model(x)
        loss = loss_fn(y, x)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch+1}: Loss = {loss.item():.4f}")


In [None]:
model = TinyAudioAutoencoder()
losses = {
    "Waveform L2": WaveformL2Loss(),
    "FFT Loss": FFTLoss(),
    "Mel L2": MelSpecL2Loss(),
    "Mel FID": MelFIDLoss(),
}

for name, loss_fn in losses.items():
    print(f"\nðŸ§ª Training with {name}")
    train(model, loss_fn, epochs=5)
