In [42]:
import os
import glob
import torch
import torchaudio
import numpy as np
from torch.utils.data import Dataset

SR = 22050
N_MELS = 128
N_FFT = 1024
HOP = 512
SEG_DUR = 5.0
SEG_SAMPLES = int(SR * SEG_DUR)

class FMADataset(Dataset):
    def __init__(self, root):
        # recursively find all mp3s
        self.files = glob.glob(os.path.join(root, "**/*.mp3"), recursive=True)
        print("Found", len(self.files), "files")

        self.mel = torchaudio.transforms.MelSpectrogram(
            sample_rate=SR,
            n_fft=N_FFT,
            hop_length=HOP,
            n_mels=N_MELS
        )

    def __len__(self):
        return len(self.files)

    def safe_load(self, path):
        """Load audio with torchaudio, resample, return numpy array"""
        try:
            wav, file_sr = torchaudio.load(path)
            # convert to mono
            if wav.shape[0] > 1:
                wav = wav.mean(dim=0, keepdim=True)
            # resample if needed
            if file_sr != SR:
                wav = torchaudio.functional.resample(wav, file_sr, SR)
            return wav.squeeze(0), True
        except Exception as e:
            print("Skipping bad file:", path, e)
            return None, False

    def __getitem__(self, idx):
        path = self.files[idx]
        wav, ok = self.safe_load(path)
        if not ok:
            # try the next file if this one failed
            return self.__getitem__((idx+1) % len(self.files))

        # pad or crop to fixed length
        if len(wav) < SEG_SAMPLES:
            wav = torch.nn.functional.pad(wav, (0, SEG_SAMPLES - len(wav)))
        else:
            start = np.random.randint(0, len(wav) - SEG_SAMPLES + 1)
            wav = wav[start:start+SEG_SAMPLES]

        mel = self.mel(wav.unsqueeze(0))  # [1, n_mels, T]
        logmel = torch.log(mel + 1e-6)
        return logmel, path

In [43]:
from torch.utils.data import DataLoader

dataset = FMADataset("data")
loader = DataLoader(dataset, batch_size=16, shuffle=True)

# test a batch
batch, paths = next(iter(loader))
print(batch.shape)   # [B, 1, 128, T]
print(paths[0])

Found 7994 files
torch.Size([16, 1, 128, 216])
data/060/060753.mp3


In [44]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvAutoencoder(nn.Module):
    def __init__(self, latent_dim=128):
        super().__init__()
        # Encoder
        self.enc = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.ReLU(),
        )
        self.fc_enc = nn.Linear(256*8*14, latent_dim)   # depends on input size
        self.fc_dec = nn.Linear(latent_dim, 256*8*14)

        # Decoder
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1),
        )

    def forward(self, x):
        B, C, H, W = x.shape
        z = self.enc(x)
        z_flat = z.view(B, -1)
        latent = self.fc_enc(z_flat)

        out = self.fc_dec(latent)
        out = out.view(B, 256, 8, 14)
        out = self.dec(out)

        # --- fix size mismatch ---
        out = F.interpolate(out, size=(H, W), mode="bilinear", align_corners=False)
        return out, latent

In [45]:
from torch.utils.data import DataLoader
import torch.optim as optim

# hyperparams
LATENT_DIM = 128
EPOCHS = 10
LR = 1e-3

device = "cuda" if torch.cuda.is_available() else "cpu"
model = ConvAutoencoder(latent_dim=LATENT_DIM).to(device)
opt = optim.Adam(model.parameters(), lr=LR)
criterion = nn.L1Loss()

# dataset + loader
dataset = FMADataset("data")
loader = DataLoader(dataset, batch_size=16, shuffle=True)

for epoch in range(EPOCHS):
    total_loss = 0
    for batch, _ in loader:
        batch = batch.to(device)
        out, _ = model(batch)
        loss = criterion(out, batch)

        opt.zero_grad()
        loss.backward()
        opt.step()

        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {total_loss/len(loader):.4f}")

Found 7994 files


[src/libmpg123/layer3.c:INT123_do_layer3():1804] error: dequantization failed!
[src/libmpg123/layer3.c:INT123_do_layer3():1804] error: dequantization failed!
[src/libmpg123/layer3.c:INT123_do_layer3():1776] error: part2_3_length (3360) too large for available bit count (3240)
[src/libmpg123/layer3.c:INT123_do_layer3():1776] error: part2_3_length (3328) too large for available bit count (3240)
[src/libmpg123/layer3.c:INT123_do_layer3():1844] error: dequantization failed!


Epoch 1/10 - Loss: 1.6870


KeyboardInterrupt: 