In [1]:
import os
import sys

project_root = os.getcwd()
while "src" not in os.listdir(project_root):
    project_root = os.path.dirname(project_root)
sys.path.append(project_root)

In [2]:
import os
import torch
import numpy as np
from typing import List, Dict, Callable
import matplotlib.pyplot as plt

# Global configuration for plotting
plt.rcParams["figure.figsize"] = [20, 6]

In [3]:
data_root = os.path.join(project_root, "data", "processed", "train")

if not os.path.exists(data_root):
    raise FileNotFoundError(
        "No se encontró la carpeta data, por favor ejecute el script download_data.sh antes de ejecutar este script."
    )

In [4]:
from src.utils.audio import AudioSample
from src.utils.data.dataset import AudioDataset
from src.utils.data.transforms import TransformChain, STFTTransform

tc = TransformChain()
convert_to_tensor: Callable[[AudioSample], Dict[str, torch.Tensor]] = (
    lambda x: x.to_tensor()
)
tc.add_transform(convert_to_tensor)
tc.add_transform(STFTTransform())

samples: List[AudioSample] = []
for file in os.listdir(data_root):
    npz_file = os.path.join(data_root, file)
    data = np.load(npz_file, allow_pickle=True)
    for key in data.files:
        samples.append(AudioSample(data[key].item()))
    break

In [5]:
from sklearn.model_selection import train_test_split

train_samples, val_samples = train_test_split(samples, test_size=0.3)

train_dataset = AudioDataset(train_samples, tc)
val_dataset = AudioDataset(val_samples, tc)

In [6]:
from src.utils.training.loss import si_sdr_loss
from src.utils.training.trainer import Trainer
from src.models import SimpleUNet

model = SimpleUNet(input_channels=1, output_channels=4, depth=6)

In [7]:
trainer = Trainer(
    target_keys=["drums", "bass", "vocals", "other"],
    model=model,
    loss_fn=si_sdr_loss,
    optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
)

trainer.train(train_dataset, epochs=20, batch_size=1)

2025-02-12 02:38:30,003 - INFO - Starting training for 20 epochs.
                                                                 

Epoch 1/20 - Loss: 30.1201


                                                                 

Epoch 2/20 - Loss: 18.5486


                                                                 

Epoch 3/20 - Loss: 8.9731


                                                                 

Epoch 4/20 - Loss: 4.9817


                                                                 

Epoch 5/20 - Loss: 3.4406


                                                                 

Epoch 6/20 - Loss: 2.6078


                                                                 

Epoch 7/20 - Loss: 1.9613


                                                                 

Epoch 8/20 - Loss: 1.5070


                                                                 

Epoch 9/20 - Loss: 1.2953


                                                                 

Epoch 10/20 - Loss: 1.1137


                                                                 

Epoch 11/20 - Loss: 0.8518


                                                                 

Epoch 12/20 - Loss: 0.7952


                                                                 

Epoch 13/20 - Loss: 0.5955


                                                                 

Epoch 14/20 - Loss: 0.4768


                                                                 

Epoch 15/20 - Loss: 0.4731


                                                                 

Epoch 16/20 - Loss: 0.1971


                                                                 

Epoch 17/20 - Loss: 0.1802


                                                                 

Epoch 18/20 - Loss: 0.1153


                                                                 

Epoch 19/20 - Loss: 0.1150


                                                                 

Epoch 20/20 - Loss: 0.0175




In [8]:
trainer.validate(val_dataset, batch_size=1)

2025-02-12 02:45:18,901 - INFO - Validation Loss: -0.6647          


Validation Loss: -0.6647


In [11]:
# Check if the model is saved in the checkpoint folder then create a new checkpoint
saves_path = os.path.join(project_root, "experiments", "checkpoints")

if not os.path.exists(saves_path):
    os.makedirs(saves_path)

random_8digit_number = np.random.randint(10000000, 99999999)

model_name = "simple_unet" + str(random_8digit_number)

torch.save(model.state_dict(), os.path.join(saves_path, model_name + ".pth"))