In [None]:
import os
import sys

project_root = os.getcwd()
while "src" not in os.listdir(project_root):
    project_root = os.path.dirname(project_root)
sys.path.append(project_root)

In [None]:
import torch
from torch import optim
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
from torch.utils.data import DataLoader
from src.pipelines import train_pipeline, infer_pipeline
from src.models import SimpleUNet, SCUNet, DeepSampler
from src.utils.data.dataset import MUSDB18Dataset
from src.utils.training import (
    MultiSourceLoss,
    MultiScaleLoss,
)
import numpy as np
import matplotlib.pyplot as plt
import librosa

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

experiments_path = os.path.join(project_root, "experiments")
data_root = os.path.join(project_root, "data")
checkpoint_path = os.path.join(experiments_path, "checkpoints")
scunet_path = os.path.join(checkpoint_path, "scunet.pth")
results_path = os.path.join(experiments_path, "results")
musdb_path = os.path.join(project_root, "data", "musdb18hq", "test")
musdb_files = os.listdir(musdb_path)
musdb_files.sort()

In [None]:
deep_sampler = DeepSampler()
x = torch.randn(1, 1, 1025, 173)  # (batch, channels, height, width)
output = deep_sampler(x)  # Salida: (1, 4, 1025, 173)
output.shape

In [None]:
train_dataset = MUSDB18Dataset(os.path.join(data_root, "processed", "train"))
train_loader = DataLoader(
    train_dataset, batch_size=8, shuffle=True, num_workers=4, pin_memory=True
)
mixture, _ = train_dataset.__getitem__(0)
print("Sample input shape:", mixture.shape)

In [None]:
deep_sampler.to(device)
weights = [3.0, 1.0, 1.0, 1.0]
weights = [w / sum(weights) for w in weights]

criterion = MultiSourceLoss(
    weights=weights,
    distance="l1",
)

# Configuración del optimizador.
# Se incrementó ligeramente el weight_decay para ayudar a prevenir overfitting.
optimizer = optim.Adam(deep_sampler.parameters(), lr=1e-3, weight_decay=1e-5)

# Configuración del scheduler.
# Opción 1: StepLR, que reduce la tasa de aprendizaje cada 10 épocas.
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

# Opción 2 (alternativa): ReduceLROnPlateau, que reduce la tasa de aprendizaje
# cuando la pérdida de validación se estanca.
# scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)

# Parámetros de entrenamiento
total_epochs = (
    50  # Se recomienda aumentar el número de épocas para asegurar la convergencia.
)
phase1_epochs = 10

In [None]:
trained_model, history = train_pipeline(
    model=deep_sampler,
    dataloader=train_loader,
    device=device,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    total_epochs=total_epochs,
    phase1_epochs=phase1_epochs,
)

In [None]:
# plot history losses
# history = {"epoch_loss": [], "learning_rate": []}
plt.plot(history["epoch_loss"])
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training loss")
plt.show()

In [None]:
seconds = 30
N_FFT = 1024
HOP_LENGTH = 512

random_sample = np.random.choice(musdb_files)
sample_path = os.path.join(musdb_path, random_sample)

In [None]:
mixture_path = os.path.join(sample_path, "mixture.wav")
vocals_path = os.path.join(sample_path, "vocals.wav")
bass_path = os.path.join(sample_path, "bass.wav")
drums_path = os.path.join(sample_path, "drums.wav")
other_path = os.path.join(sample_path, "other.wav")

mixture, sr = librosa.load(mixture_path, sr=None)
vocals, _ = librosa.load(vocals_path, sr=None)
bass, _ = librosa.load(bass_path, sr=None)
drums, _ = librosa.load(drums_path, sr=None)
other, _ = librosa.load(other_path, sr=None)

mixture = mixture[: int(seconds * sr)]
vocals = vocals[: int(seconds * sr)]
bass = bass[: int(seconds * sr)]
drums = drums[: int(seconds * sr)]
other = other[: int(seconds * sr)]

mixture_stft = librosa.stft(mixture, n_fft=N_FFT, hop_length=HOP_LENGTH)
vocals_stft = librosa.stft(vocals, n_fft=N_FFT, hop_length=HOP_LENGTH)
bass_stft = librosa.stft(bass, n_fft=N_FFT, hop_length=HOP_LENGTH)
drums_stft = librosa.stft(drums, n_fft=N_FFT, hop_length=HOP_LENGTH)
other_stft = librosa.stft(other, n_fft=N_FFT, hop_length=HOP_LENGTH)

In [None]:
infer_pipeline(
    model=trained_model,
    mixture_path=mixture_path,
    output_path=os.path.join(results_path, f"deep_sampler"),
    device=device,
)

In [None]:
res_vocals_path = os.path.join(results_path, "deep_sampler", "vocals.wav")
res_bass_path = os.path.join(results_path, "deep_sampler", "bass.wav")
res_drums_path = os.path.join(results_path, "deep_sampler", "drums.wav")
res_other_path = os.path.join(results_path, "deep_sampler", "other.wav")

res_vocals, sr = librosa.load(res_vocals_path, sr=None)
res_bass, _ = librosa.load(res_bass_path, sr=None)
res_drums, _ = librosa.load(res_drums_path, sr=None)
res_other, _ = librosa.load(res_other_path, sr=None)

res_vocals = res_vocals[: int(seconds * sr)]
res_bass = res_bass[: int(seconds * sr)]
res_drums = res_drums[: int(seconds * sr)]
res_other = res_other[: int(seconds * sr)]

res_vocals_stft = librosa.stft(res_vocals, n_fft=N_FFT, hop_length=HOP_LENGTH)
res_bass_stft = librosa.stft(res_bass, n_fft=N_FFT, hop_length=HOP_LENGTH)
res_drums_stft = librosa.stft(res_drums, n_fft=N_FFT, hop_length=HOP_LENGTH)
res_other_stft = librosa.stft(res_other, n_fft=N_FFT, hop_length=HOP_LENGTH)

In [None]:
plt.figure(figsize=(20, 10))
plt.suptitle(f"Expectativa: {os.path.basename(sample_path)}")
plt.subplot(2, 3, 1)
plt.title("Mixture")
plt.imshow(np.log1p(np.abs(mixture_stft)), aspect="auto", origin="lower")
plt.subplot(2, 3, 2)
plt.title("Vocals")
plt.imshow(np.log1p(np.abs(vocals_stft)), aspect="auto", origin="lower")
plt.subplot(2, 3, 3)
plt.title("Bass")
plt.imshow(np.log1p(np.abs(bass_stft)), aspect="auto", origin="lower")
plt.subplot(2, 3, 4)
plt.title("Drums")
plt.imshow(np.log1p(np.abs(drums_stft)), aspect="auto", origin="lower")
plt.subplot(2, 3, 5)
plt.title("Other")
plt.imshow(np.log1p(np.abs(other_stft)), aspect="auto", origin="lower")

In [None]:
plt.figure(figsize=(20, 10))
plt.suptitle("Realidad")
plt.subplot(2, 3, 1)
plt.title("Mixture")
plt.imshow(
    np.log1p(np.abs(res_bass_stft + res_drums_stft + res_other_stft + res_vocals_stft)),
    aspect="auto",
    origin="lower",
)
plt.subplot(2, 3, 2)
plt.title("Vocals")
plt.imshow(np.log1p(np.abs(res_vocals_stft)), aspect="auto", origin="lower")
plt.subplot(2, 3, 3)
plt.title("Bass")
plt.imshow(np.log1p(np.abs(res_bass_stft)), aspect="auto", origin="lower")
plt.subplot(2, 3, 4)
plt.title("Drums")
plt.imshow(np.log1p(np.abs(res_drums_stft)), aspect="auto", origin="lower")
plt.subplot(2, 3, 5)
plt.title("Other")
plt.imshow(np.log1p(np.abs(res_other_stft)), aspect="auto", origin="lower")