In [None]:
import os
import sys

project_root = os.getcwd()
while "src" not in os.listdir(project_root):
    project_root = os.path.dirname(project_root)
sys.path.append(project_root)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import librosa

from IPython.display import Audio
from src.utils.audio import load_audio
from src.pipelines.data import process_audio_folder
from src.pipelines import musdb_pipeline, train_pipeline, infer_pipeline
from src.models import DeepSampler
from src.utils.training.loss import MultiSourceLoss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
SR = 44100
NFFT = 2048
HOP = 512
CHUNK_DUR = 2
OVERLAP = 0
BATCH_SIZE = 12

data_path = os.path.join(project_root, "data")
musdb_path = os.path.join(data_path, "musdb18hq")
train_path = os.path.join(data_path, "musdb18hq", "train")
test_path = os.path.join(data_path, "musdb18hq", "test")
output_path = os.path.join(data_path, "processed")

train_files = os.listdir(train_path)
train_files.sort()

In [None]:
random_sample = np.random.choice(train_files)
mixture = load_audio(os.path.join(train_path, random_sample, "mixture.wav"))
chunks = process_audio_folder(
    audio_folder=os.path.join(train_path, random_sample),
    sample_rate=SR,
    chunk_duration=CHUNK_DUR,
    overlap=OVERLAP,
)

In [None]:
Audio(mixture, rate=SR)

In [None]:
rand_chunk = np.random.choice(chunks)
Audio(rand_chunk["mixture"], rate=SR)

In [None]:
train_dataset = musdb_pipeline(
    musdb_path=train_path,
    sample_rate=SR,
    n_fft=NFFT,
    hop_length=HOP,
    chunk_duration=CHUNK_DUR,
    overlap=OVERLAP,
    save_dir=output_path,
    max_chunks=500,
)

In [None]:
model = DeepSampler()
model.to(device)
criterion = MultiSourceLoss(weights=[1, 1, 1, 1])
optimizer = torch.optim.Adam(model.parameters())

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

In [None]:
# example_item[INPUT][BATCH][CHANNEL] = MIXTURE SPECTROGRAM
# example_tiem[OUTPUT][BATCH][CHANNEL] = [VOCALS, DRUMS, BASS, OTHER] SPECTROGRAM
example_item = next(iter(train_loader))
mixture_spectrogram = example_item[0][0][0]  # Mixture spectrogram
source_spectrograms = example_item[1][
    0
]  # Separated sources: [Vocals, Drums, Bass, Other]
source_names = ["Vocals", "Drums", "Bass", "Other"]

# Create subplots
fig, axes = plt.subplots(1, 5, figsize=(20, 5))
axes[0].imshow(mixture_spectrogram, aspect="auto", origin="lower")
axes[0].set_title("Mixture")
axes[0].set_xlabel("Time")
axes[0].set_ylabel("Frequency")
axes[0].set_xticks([])
axes[0].set_yticks([])

for i in range(4):
    axes[i + 1].imshow(source_spectrograms[i], aspect="auto", origin="lower")
    axes[i + 1].set_title(source_names[i])
    axes[i + 1].set_xlabel("Time")
    axes[i + 1].set_yticks([])
    axes[i + 1].set_xticks([])
plt.tight_layout()
plt.show()

In [None]:
history = train_pipeline(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    dataloader=train_loader,
    epochs=3,
    device=device,
)

In [None]:
# Supongamos que 'history' es un diccionario con las llaves:
# "epoch_loss", "learning_rate", "batch_losses"
# Por ejemplo:
# history = {
#     "epoch_loss": [0.9, 0.8, 0.7, ...],
#     "learning_rate": [0.001, 0.001, 0.0009, ...],
#     "batch_losses": [[1.0, 0.95, ...], [0.9, 0.85, ...], ...]
# }

# Crear una figura con 3 subplots verticales
fig, axs = plt.subplots(3, 1, figsize=(10, 8), constrained_layout=True)

# Plot para "Epoch Loss"
axs[0].plot(history["epoch_loss"], marker="o", color="blue", label="Epoch Loss")
axs[0].set_title("Loss por Época", fontsize=14)
axs[0].set_xlabel("Época", fontsize=12)
axs[0].set_ylabel("Loss", fontsize=12)
axs[0].legend()
axs[0].grid(True)

# Plot para "Learning Rate"
axs[1].plot(
    history["learning_rate"],
    marker="s",
    linestyle="--",
    color="green",
    label="Learning Rate",
)
axs[1].set_title("Tasa de Aprendizaje", fontsize=14)
axs[1].set_xlabel("Época", fontsize=12)
axs[1].set_ylabel("Learning Rate", fontsize=12)
axs[1].legend()
axs[1].grid(True)

# Plot para "Batch Losses"
# Aplanar los valores de batch losses en caso de que sea una lista de listas
batch_losses_flat = np.array(history["batch_losses"]).flatten()
axs[2].plot(batch_losses_flat, marker="^", color="red", label="Batch Losses")
axs[2].set_title("Pérdida por Batch", fontsize=14)
axs[2].set_xlabel("Batch", fontsize=12)
axs[2].set_ylabel("Loss", fontsize=12)
axs[2].legend()
axs[2].grid(True)

plt.show()

In [None]:
random_folder = np.random.choice(os.listdir(test_path))
audio_mixture = os.path.join(musdb_path, "test", random_folder, "mixture.wav")

In [None]:
extracted_sources = infer_pipeline(
    model=model,
    mixture_path=audio_mixture,
    sample_rate=SR,
    chunk_seconds=CHUNK_DUR,
    overlap=OVERLAP,
    n_fft=NFFT,
    hop_length=HOP,
    device=device,
)

In [None]:
inst = ["vocals", "drums", "bass", "other"]
fig, axs = plt.subplots(len(inst), 2, figsize=(20, 16))

for i, instrument in enumerate(inst):
    file_path = os.path.join(musdb_path, "test", random_folder, f"{instrument}.wav")
    wav, _ = librosa.load(file_path, sr=SR)
    axs[i, 0].plot(wav, color="tab:blue")
    axs[i, 0].set_title(f"{instrument.capitalize()} (Original)", fontsize=16)
    axs[i, 0].set_xlabel("Samples", fontsize=14)
    axs[i, 0].set_ylabel("Amplitude", fontsize=14)
    axs[i, 0].grid(True)

    axs[i, 1].plot(extracted_sources[instrument].cpu().numpy(), color="tab:orange")
    axs[i, 1].set_title(f"{instrument.capitalize()} (Extraído)", fontsize=16)
    axs[i, 1].set_xlabel("Samples", fontsize=14)
    axs[i, 1].set_ylabel("Amplitude", fontsize=14)
    axs[i, 1].grid(True)

plt.tight_layout()
plt.show()