In [1]:
import os
import sys

project_root = os.getcwd()
while "src" not in os.listdir(project_root):
    project_root = os.path.dirname(project_root)
sys.path.append(project_root)

In [None]:
import os
import torch
import numpy as np
from typing import List, Dict, Callable
import matplotlib.pyplot as plt


# Global configuration for plotting
plt.rcParams["figure.figsize"] = [20, 6]

In [3]:
data_root = os.path.join(project_root, "data", "processed", "train")

if not os.path.exists(data_root):
    raise FileNotFoundError(
        "No se encontró la carpeta data, por favor ejecute el script download_data.sh antes de ejecutar este script."
    )

In [None]:
from src.utils.audio import AudioSample
from src.utils.data.dataset import AudioDataset
from src.utils.data.transforms import TransformChain, FreqSplit, STFTTransform

tc = TransformChain()
convert_to_tensor: Callable[[AudioSample], Dict[str, torch.Tensor]] = (
    lambda x: x.to_tensor()
)
tc.add_transform(convert_to_tensor)
tc.add_transform(STFTTransform())

samples: List[AudioSample] = []
for file in os.listdir(data_root):
    npz_file = os.path.join(data_root, file)
    data = np.load(npz_file, allow_pickle=True)
    for key in data.files:
        samples.append(AudioSample(data[key].item()))
    break

In [None]:
from sklearn.model_selection import train_test_split

train_samples, val_samples = train_test_split(samples, test_size=0.3)

train_dataset = AudioDataset(train_samples, tc)
val_dataset = AudioDataset(val_samples, tc)

In [6]:
from src.utils.training.loss import sdr_loss, si_sdr_loss, snr_loss
from src.utils.training.trainer import Trainer
from src.models import SimpleUNet

model = SimpleUNet(input_channels=1, output_channels=4, depth=1)

In [7]:
trainer = Trainer(
    target_keys=["drums", "bass", "vocals", "other"],
    model=model,
    loss_fn=si_sdr_loss,
    optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
)

trainer.train(train_dataset, epochs=20, batch_size=4)

2025-02-10 02:48:20,223 - INFO - Starting training for 50 epochs.
2025-02-10 02:48:29,119 - INFO - Epoch 1/50 - Loss: 36.7422      


Epoch 1/50 - Loss: 36.7422


2025-02-10 02:48:37,621 - INFO - Epoch 2/50 - Loss: 25.3931      


Epoch 2/50 - Loss: 25.3931


2025-02-10 02:48:46,182 - INFO - Epoch 3/50 - Loss: 17.4122      


Epoch 3/50 - Loss: 17.4122


2025-02-10 02:48:54,704 - INFO - Epoch 4/50 - Loss: 9.6595       


Epoch 4/50 - Loss: 9.6595


2025-02-10 02:49:03,311 - INFO - Epoch 5/50 - Loss: 5.5685       


Epoch 5/50 - Loss: 5.5685


2025-02-10 02:49:11,940 - INFO - Epoch 6/50 - Loss: 4.2400       


Epoch 6/50 - Loss: 4.2400


2025-02-10 02:49:20,627 - INFO - Epoch 7/50 - Loss: 3.5173       


Epoch 7/50 - Loss: 3.5173


2025-02-10 02:49:29,266 - INFO - Epoch 8/50 - Loss: 3.1239       


Epoch 8/50 - Loss: 3.1239


2025-02-10 02:49:37,827 - INFO - Epoch 9/50 - Loss: 2.9304       


Epoch 9/50 - Loss: 2.9304


2025-02-10 02:49:46,357 - INFO - Epoch 10/50 - Loss: 2.8041      


Epoch 10/50 - Loss: 2.8041


2025-02-10 02:49:54,882 - INFO - Epoch 11/50 - Loss: 2.5099      


Epoch 11/50 - Loss: 2.5099


2025-02-10 02:50:03,406 - INFO - Epoch 12/50 - Loss: 2.3962      


Epoch 12/50 - Loss: 2.3962


2025-02-10 02:50:11,880 - INFO - Epoch 13/50 - Loss: 2.2786      


Epoch 13/50 - Loss: 2.2786


2025-02-10 02:50:20,434 - INFO - Epoch 14/50 - Loss: 2.1590      


Epoch 14/50 - Loss: 2.1590


2025-02-10 02:50:28,959 - INFO - Epoch 15/50 - Loss: 2.0419      


Epoch 15/50 - Loss: 2.0419


2025-02-10 02:50:37,491 - INFO - Epoch 16/50 - Loss: 1.9757      


Epoch 16/50 - Loss: 1.9757


2025-02-10 02:50:46,045 - INFO - Epoch 17/50 - Loss: 1.8615      


Epoch 17/50 - Loss: 1.8615


2025-02-10 02:50:54,561 - INFO - Epoch 18/50 - Loss: 2.0045      


Epoch 18/50 - Loss: 2.0045


2025-02-10 02:51:03,048 - INFO - Epoch 19/50 - Loss: 1.8889      


Epoch 19/50 - Loss: 1.8889


2025-02-10 02:51:11,581 - INFO - Epoch 20/50 - Loss: 1.7218      


Epoch 20/50 - Loss: 1.7218


2025-02-10 02:51:20,110 - INFO - Epoch 21/50 - Loss: 1.6773      


Epoch 21/50 - Loss: 1.6773


2025-02-10 02:51:28,584 - INFO - Epoch 22/50 - Loss: 1.6583      


Epoch 22/50 - Loss: 1.6583


2025-02-10 02:51:37,125 - INFO - Epoch 23/50 - Loss: 1.5430      


Epoch 23/50 - Loss: 1.5430


2025-02-10 02:51:45,658 - INFO - Epoch 24/50 - Loss: 1.6886      


Epoch 24/50 - Loss: 1.6886


2025-02-10 02:51:54,208 - INFO - Epoch 25/50 - Loss: 1.4410      


Epoch 25/50 - Loss: 1.4410


2025-02-10 02:52:02,753 - INFO - Epoch 26/50 - Loss: 1.4030      


Epoch 26/50 - Loss: 1.4030


2025-02-10 02:52:11,255 - INFO - Epoch 27/50 - Loss: 1.4254      


Epoch 27/50 - Loss: 1.4254


2025-02-10 02:52:19,804 - INFO - Epoch 28/50 - Loss: 1.3431      


Epoch 28/50 - Loss: 1.3431


2025-02-10 02:52:28,437 - INFO - Epoch 29/50 - Loss: 1.3079      


Epoch 29/50 - Loss: 1.3079


2025-02-10 02:52:37,077 - INFO - Epoch 30/50 - Loss: 1.2268      


Epoch 30/50 - Loss: 1.2268


2025-02-10 02:52:45,645 - INFO - Epoch 31/50 - Loss: 1.4447      


Epoch 31/50 - Loss: 1.4447


2025-02-10 02:52:54,237 - INFO - Epoch 32/50 - Loss: 1.2872      


Epoch 32/50 - Loss: 1.2872


2025-02-10 02:53:02,792 - INFO - Epoch 33/50 - Loss: 1.2596      


Epoch 33/50 - Loss: 1.2596


2025-02-10 02:53:11,337 - INFO - Epoch 34/50 - Loss: 1.0743      


Epoch 34/50 - Loss: 1.0743


2025-02-10 02:53:19,926 - INFO - Epoch 35/50 - Loss: 1.2588      


Epoch 35/50 - Loss: 1.2588


2025-02-10 02:53:28,544 - INFO - Epoch 36/50 - Loss: 1.2217      


Epoch 36/50 - Loss: 1.2217


2025-02-10 02:53:37,157 - INFO - Epoch 37/50 - Loss: 1.1670      


Epoch 37/50 - Loss: 1.1670


2025-02-10 02:53:45,770 - INFO - Epoch 38/50 - Loss: 1.1253      


Epoch 38/50 - Loss: 1.1253


2025-02-10 02:53:54,379 - INFO - Epoch 39/50 - Loss: 1.0642      


Epoch 39/50 - Loss: 1.0642


2025-02-10 02:54:02,968 - INFO - Epoch 40/50 - Loss: 1.0225      


Epoch 40/50 - Loss: 1.0225


2025-02-10 02:54:11,508 - INFO - Epoch 41/50 - Loss: 1.0175      


Epoch 41/50 - Loss: 1.0175


2025-02-10 02:54:20,055 - INFO - Epoch 42/50 - Loss: 1.0358      


Epoch 42/50 - Loss: 1.0358


2025-02-10 02:54:28,662 - INFO - Epoch 43/50 - Loss: 0.9819      


Epoch 43/50 - Loss: 0.9819


2025-02-10 02:54:37,247 - INFO - Epoch 44/50 - Loss: 0.9792      


Epoch 44/50 - Loss: 0.9792


2025-02-10 02:54:45,794 - INFO - Epoch 45/50 - Loss: 0.9610      


Epoch 45/50 - Loss: 0.9610


2025-02-10 02:54:54,372 - INFO - Epoch 46/50 - Loss: 0.9288      


Epoch 46/50 - Loss: 0.9288


2025-02-10 02:55:02,905 - INFO - Epoch 47/50 - Loss: 0.8947      


Epoch 47/50 - Loss: 0.8947


2025-02-10 02:55:11,484 - INFO - Epoch 48/50 - Loss: 0.8570      


Epoch 48/50 - Loss: 0.8570


2025-02-10 02:55:20,046 - INFO - Epoch 49/50 - Loss: 0.8704      


Epoch 49/50 - Loss: 0.8704


2025-02-10 02:55:28,406 - INFO - Epoch 50/50 - Loss: 0.8297      


Epoch 50/50 - Loss: 0.8297


In [None]:
trainer.validate(val_dataset, batch_size=4)

In [None]:
model_path = os.path.join(project_root, "experiments/results", "simple_unet.pth")

# Save the model
torch.save(model.state_dict(), model_path)

RuntimeError: Parent directory /home/nhrot/Programming/Python/DeepLearning/DeepSampler/models does not exist.