In [8]:
import os
import sys

project_root = os.getcwd()
while "src" not in os.listdir(project_root):
    project_root = os.path.dirname(project_root)
sys.path.append(project_root)

In [9]:
import os
import torch
import torchaudio
import numpy as np
from typing import List, Dict, Any, Tuple, Callable
import matplotlib.pyplot as plt
import random
import librosa

# Global configuration for plotting
plt.rcParams["figure.figsize"] = [20, 6]

In [10]:
data_root = os.path.join(project_root, "data", "processed", "train")

if not os.path.exists(data_root):
    raise FileNotFoundError(
        "No se encontró la carpeta data, por favor ejecute el script download_data.sh antes de ejecutar este script."
    )

In [11]:
from src.utils.audio import AudioSample
from src.utils.data.dataset import AudioDataset
from src.utils.data.transforms import TransformChain, FreqSplit, STFTTransform

samples: List[AudioSample] = []
for file in os.listdir(data_root):
    npz_file = os.path.join(data_root, file)
    data = np.load(npz_file, allow_pickle=True)
    for key in data.files:
        samples.append(AudioSample(data[key].item()))
    break

In [12]:
tc = TransformChain()
convert_to_tensor: Callable[[AudioSample], Dict[str, torch.Tensor]] = (
    lambda x: x.to_tensor()
)
tc.add_transform(convert_to_tensor)
tc.add_transform(STFTTransform())

dataset = AudioDataset(samples, tc)

In [13]:
from src.utils.training.loss import sdr_loss, si_sdr_loss, snr_loss
from src.utils.training.trainer import Trainer
from src.models import SimpleUNet

model = SimpleUNet(input_channels=1, output_channels=4, depth=1)

In [None]:
trainer = Trainer(
    target_keys=["drums", "bass", "vocals", "other"],
    model=model,
    loss_fn=si_sdr_loss,
    optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
)

trainer.train(dataset, epochs=50, batch_size=16)

Epoch 1/10 - Loss: 29.8182
Epoch 2/10 - Loss: 13.3872
Epoch 3/10 - Loss: 5.7077
Epoch 4/10 - Loss: 3.9504
Epoch 5/10 - Loss: 2.8950
Epoch 6/10 - Loss: 2.1826
Epoch 7/10 - Loss: 1.6946
Epoch 8/10 - Loss: 1.3638
Epoch 9/10 - Loss: 1.0584
Epoch 10/10 - Loss: 0.7978
