In [None]:
import os
import sys

project_root = os.getcwd()
while "src" not in os.listdir(project_root):
    project_root = os.path.dirname(project_root)
sys.path.append(project_root)

In [None]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
import matplotlib.pyplot as plt
from src.pipelines import musdb_pipeline, train_pipeline, eval_pipeline, infer_pipeline
from src.models import DeepSampler, SCUNet, SimpleUNet
from src.utils.training import MultiSourceLoss, VGGFeatureLoss, MultiScaleLoss


plt.rcParams["figure.figsize"] = [20, 6]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
nfft = 2048
hop_length = 512
window = torch.hann_window(nfft)
chunk_seconds = 2
overlap = 0
sr = 44100

data_root = os.path.join(project_root, "data")
musdb_root = os.path.join(data_root, "musdb18hq")

if not os.path.exists(data_root):
    raise FileNotFoundError(
        "No se encontró la carpeta data, por favor ejecute el script download_data.sh antes de ejecutar este script."
    )

In [None]:
train_dataset = musdb_pipeline(
    musdb_path=os.path.join(musdb_root, "train"),
    nfft=nfft,
    hop_length=hop_length,
    window=window,
    chunk_seconds=chunk_seconds,
    overlap=overlap,
    sample_rate=sr,
)

test_dataset = musdb_pipeline(
    musdb_path=os.path.join(musdb_root, "test"),
    nfft=nfft,
    hop_length=hop_length,
    window=window,
    chunk_seconds=chunk_seconds,
    overlap=overlap,
    sample_rate=sr,
)

In [None]:
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=16, shuffle=True
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=16, shuffle=False
)

In [None]:
deep_sampler = DeepSampler(
    input_channels=1,
    output_channels=4,
    base_channels=32,
    depth=5,
    dropout=0.1,
    transformer_heads=4,
    transformer_layers=4,
)
optimizer = optim.Adam(deep_sampler.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
criterion = MultiSourceLoss(
    weights=[1, 1, 1, 1],
    distance="l1",
)

epochs = 50
p1_epochs = 20

In [None]:
train_pipeline(
    model=deep_sampler,
    dataloader=train_dataloader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    total_epochs=epochs,
    phase1_epochs=p1_epochs,
    device=device,
)

In [None]:
eval_pipeline(
    model=deep_sampler,
    dataloader=test_dataloader,
    device=device,
)