In [1]:
import os
import sys

project_root = os.getcwd()
while "src" not in os.listdir(project_root):
    project_root = os.path.dirname(project_root)
sys.path.append(project_root)

In [None]:
import os
import torch
import numpy as np
from typing import List, Dict, Callable
import matplotlib.pyplot as plt

# Global configuration for plotting
plt.rcParams["figure.figsize"] = [20, 6]

In [3]:
data_root = os.path.join(project_root, "data", "processed", "train")

if not os.path.exists(data_root):
    raise FileNotFoundError(
        "No se encontró la carpeta data, por favor ejecute el script download_data.sh antes de ejecutar este script."
    )

In [None]:
from src.utils.audio import AudioSample
from src.utils.data.dataset import AudioDataset
from src.utils.data.transforms import TransformChain, STFTTransform

tc = TransformChain()
convert_to_tensor: Callable[[AudioSample], Dict[str, torch.Tensor]] = (
    lambda x: x.to_tensor()
)
tc.add_transform(convert_to_tensor)
tc.add_transform(STFTTransform())

samples: List[AudioSample] = []
for file in os.listdir(data_root):
    npz_file = os.path.join(data_root, file)
    data = np.load(npz_file, allow_pickle=True)
    for key in data.files:
        samples.append(AudioSample(data[key].item()))
    break

In [5]:
from sklearn.model_selection import train_test_split

train_samples, val_samples = train_test_split(samples, test_size=0.3)

train_dataset = AudioDataset(train_samples, tc)
val_dataset = AudioDataset(val_samples, tc)

In [6]:
from src.utils.training.loss import sdr_loss, si_sdr_loss, snr_loss
from src.utils.training.trainer import Trainer
from src.models import SimpleUNet

model = SimpleUNet(input_channels=1, output_channels=4, depth=1)

In [7]:
trainer = Trainer(
    target_keys=["drums", "bass", "vocals", "other"],
    model=model,
    loss_fn=si_sdr_loss,
    optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
)

trainer.train(train_dataset, epochs=15, batch_size=4)

2025-02-10 02:59:40,593 - INFO - Starting training for 15 epochs.
                                                                 

Epoch 1/15 - Loss: 40.6267


                                                                 

Epoch 2/15 - Loss: 30.7089


                                                                 

Epoch 3/15 - Loss: 22.7713


                                                                 

Epoch 4/15 - Loss: 17.6424


                                                                 

Epoch 5/15 - Loss: 13.0182


                                                                 

Epoch 6/15 - Loss: 9.3262


                                                                 

Epoch 7/15 - Loss: 7.1134


                                                                 

Epoch 8/15 - Loss: 5.9228


                                                                 

Epoch 9/15 - Loss: 5.1030


                                                                 

Epoch 10/15 - Loss: 4.3355


                                                                 

Epoch 11/15 - Loss: 3.7703


                                                                 

Epoch 12/15 - Loss: 3.3382


                                                                 

Epoch 13/15 - Loss: 3.1382


                                                                 

Epoch 14/15 - Loss: 2.8993


                                                                 

Epoch 15/15 - Loss: 2.8944




In [8]:
trainer.validate(val_dataset, batch_size=4)

2025-02-10 03:01:10,886 - INFO - Validation Loss: 2.0215         


Validation Loss: 2.0215


In [9]:
model_path = os.path.join(project_root, "experiments/results", "simple_unet.pth")

# Save the model
torch.save(model.state_dict(), model_path)