In [None]:
import os
import sys

project_root = os.getcwd()
while "src" not in os.listdir(project_root):
    project_root = os.path.dirname(project_root)
sys.path.append(project_root)

In [None]:
import torch
import librosa
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
import matplotlib.pyplot as plt
from src.pipelines import musdb_pipeline, train_pipeline, eval_pipeline, infer_pipeline
from src.models import DeepSampler, SCUNet, SimpleUNet
from src.utils.training import MultiSourceLoss, VGGFeatureLoss, MultiScaleLoss
import numpy as np

plt.rcParams["figure.figsize"] = [20, 6]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
nfft = 2048
hop_length = 512
window = torch.hann_window(nfft)
chunk_seconds = 2
overlap = 0
sr = 44100

data_root = os.path.join(project_root, "data")
musdb_root = os.path.join(data_root, "musdb18hq")

if not os.path.exists(data_root):
    raise FileNotFoundError(
        "No se encontró la carpeta data, por favor ejecute el script download_data.sh antes de ejecutar este script."
    )

In [None]:
train_dataset = musdb_pipeline(
    musdb_path=os.path.join(musdb_root, "train"),
    nfft=nfft,
    hop_length=hop_length,
    window=window,
    chunk_seconds=chunk_seconds,
    overlap=overlap,
    sample_rate=sr,
    max_samples=100,
)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=4, shuffle=True
)

In [None]:
test_dataset = musdb_pipeline(
    musdb_path=os.path.join(musdb_root, "test"),
    nfft=nfft,
    hop_length=hop_length,
    window=window,
    chunk_seconds=chunk_seconds,
    overlap=overlap,
    sample_rate=sr,
    max_samples=300,
)

test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False)

In [None]:
deep_sampler = DeepSampler()
optimizer = optim.Adam(deep_sampler.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
criterion = MultiSourceLoss(
    weights=[1, 1, 1, 1],
    distance="l1",
)

factor = 1
epochs = 2 * factor
p1_epochs = 1 * factor
deep_sampler.to(device)

In [None]:
history = train_pipeline(
    model=deep_sampler,
    dataloader=train_dataset,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    total_epochs=epochs,
    phase1_epochs=p1_epochs,
    device=device,
)

In [None]:
# plot history = {"epoch_loss": [], "learning_rate": [], "batch_losses": []}
plt.plot(history["epoch_loss"])
plt.title("Loss")
plt.show()

plt.plot(history["learning_rate"])
plt.title("Learning Rate")
plt.show()

plt.plot(np.array(history["batch_losses"]).flatten())
plt.title("Batch Losses")
plt.show()

In [None]:
test_folders = os.listdir(os.path.join(musdb_root, "test"))
random_folder = np.random.choice(test_folders)

audio_mixture = os.path.join(musdb_root, "test", random_folder, "mixture.wav")

In [None]:
sources = infer_pipeline(
    model=deep_sampler,
    mixture_path=audio_mixture,
    sample_rate=44100,
    chunk_seconds=chunk_seconds,
    overlap=overlap,
    n_fft=nfft,
    hop_length=hop_length,
    device=device,
)

In [None]:
inst = ["vocals", "drums", "bass", "other"]

In [None]:
for instrument in inst:
    file_path = os.path.join(musdb_root, "test", random_folder, f"{instrument}.wav")
    wav, _ = librosa.load(file_path, sr=44100)
    plt.figure(figsize=(20, 6))
    plt.plot(wav)
    plt.title(instrument)
    plt.show()

In [None]:
for i, s in enumerate(inst):
    plt.figure(figsize=(20, 6))
    plt.plot(sources[s])
    plt.title(s)
    plt.show()