In [None]:
# working with colab
# !pip install torchmetrics
# !git clone https://github.com/Nhrot22230/DeepSampler
#  !cd DeepSampler && make init

# import os
# import sys

# # Suponiendo que clonaste el repositorio en /content/DeepSampler
# project_root = os.path.join(os.getcwd(), "DeepSampler")
# if project_root not in sys.path:
#     sys.path.insert(0, project_root)

In [None]:
import os
import sys

project_root = os.getcwd()
while "src" not in os.listdir(project_root):
    project_root = os.path.dirname(project_root)
sys.path.append(project_root)

import torch
import numpy as np
import matplotlib.pyplot as plt
import librosa

from src.pipelines.data import musdb_pipeline
from src.pipelines.train import train_pipeline
from src.pipelines.infer import infer_pipeline
from src.pipelines.eval import eval_pipeline
from src.models import DeepSampler
from src.utils.data.dataset import MUSDBDataset
from src.utils.train.losses import MultiSourceLoss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Audio parameters
SR = 44100
NFFT = 2048
HOP = 512
CHUNK_DUR = 2
OVERLAP = 0

# Model parameters
N_SOURCES = 4
BASE_CHANNELS = 64
DEPTH = 4
DROP_RATE = 0.3
T_HEADS = 4
T_LAYERS = 4

# Dataset parameters
ISOLATED_BATCH_SIZE = 2
ISOLATED_MAX_SAMPLES = 400
ISOLATED_EPOCHS = 5

MIXED_BATCH_SIZE = 8
MIXED_MAX_SAMPLES = 4000
MIXED_EPOCHS = 100

# Inference parameters
N_ITER = 128

data_path = os.path.join(project_root, "data")
musdb_path = os.path.join(data_path, "musdb18hq")
train_path = os.path.join(data_path, "musdb18hq", "train")
test_path = os.path.join(data_path, "musdb18hq", "test")
output_path = os.path.join(data_path, "processed")
experiments_path = os.path.join(project_root, "experiments")
checkpoint_path = os.path.join(experiments_path, "checkpoints")
results_path = os.path.join(experiments_path, "results")
log_path = os.path.join(experiments_path, "logs")

train_files = os.listdir(train_path)
train_files.sort()

instruments = ["vocals", "drums", "bass", "other"]
calculated_shape = (
    MIXED_BATCH_SIZE,
    N_SOURCES,
    NFFT // 2 + 1,
    CHUNK_DUR * SR // HOP + 1,
)
calculated_shape

In [None]:
model = DeepSampler(
    out_ch=N_SOURCES,
    base_ch=BASE_CHANNELS,
    depth=DEPTH,
    dropout=DROP_RATE,
    t_heads=T_HEADS,
    t_layers=T_LAYERS,
)
model.to(device)
criterion = MultiSourceLoss(weights=[1, 1, 1, 1])

## Training

In [None]:
isolated_dataset = {inst: None for inst in instruments}
for i, inst in enumerate(instruments):
    isolated_dataset[inst] = musdb_pipeline(
        musdb_path=os.path.join(musdb_path, "train"),
        isolated=[inst],
        sample_rate=SR,
        n_fft=NFFT,
        hop_length=HOP,
        chunk_duration=CHUNK_DUR,
        overlap=OVERLAP,
        max_chunks=ISOLATED_MAX_SAMPLES,
    )

# combine isolated datasets
combined_data = []
for inst in instruments:
    dataset = isolated_dataset[inst]
    combined_data.extend(dataset.data)
    del dataset, isolated_dataset[inst]

combined_dataset = MUSDBDataset(data=combined_data, n_fft=NFFT, hop_length=HOP)
del combined_data

In [None]:
train_dataset = musdb_pipeline(
    musdb_path=train_path,
    sample_rate=SR,
    n_fft=NFFT,
    hop_length=HOP,
    chunk_duration=CHUNK_DUR,
    overlap=OVERLAP,
    max_chunks=MIXED_MAX_SAMPLES,
)

In [None]:
isolated_dataloader = torch.utils.data.DataLoader(
    combined_dataset,
    batch_size=MIXED_BATCH_SIZE,
    shuffle=True,
)

In [None]:
history = train_pipeline(
    model=model,
    criterion=criterion,
    optimizer=torch.optim.AdamW(model.parameters(), lr=3e-3),
    dataloader=isolated_dataloader,
    epochs=ISOLATED_EPOCHS,
    device=device,
)
del isolated_dataloader, combined_dataset

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=MIXED_BATCH_SIZE,
    shuffle=True,
)

In [None]:
history = train_pipeline(
    model=model,
    criterion=criterion,
    optimizer=torch.optim.AdamW(model.parameters()),
    dataloader=train_loader,
    epochs=MIXED_EPOCHS,
    checkpoint_name="dino",
    checkpoint_dir=checkpoint_path,
    checkpoint_every=MIXED_EPOCHS // 10,
    device=device,
)
del train_loader, train_dataset

In [None]:
# Supongamos que 'history' es un diccionario con las llaves:
# "epoch_loss", "learning_rate", "batch_losses"
# Por ejemplo:
# history = {
#     "epoch_loss": [0.9, 0.8, 0.7, ...],
#     "learning_rate": [0.001, 0.001, 0.0009, ...],
#     "batch_losses": [[1.0, 0.95, ...], [0.9, 0.85, ...], ...]
# }

fig, axs = plt.subplots(2, 1, figsize=(10, 8), constrained_layout=True)

# Plot para "Epoch Loss"
axs[0].plot(history["loss"], marker="o", color="blue", label="Epoch Loss")
axs[0].set_title("Loss por Época", fontsize=14)
axs[0].set_xlabel("Época", fontsize=12)
axs[0].set_ylabel("Loss", fontsize=12)
axs[0].legend()
axs[0].grid(True)

# Plot para "Learning Rate"
axs[1].plot(
    history["learning_rate"],
    marker="s",
    linestyle="--",
    color="green",
    label="Learning Rate",
)
axs[1].set_title("Tasa de Aprendizaje", fontsize=14)
axs[1].set_xlabel("Época", fontsize=12)
axs[1].set_ylabel("Learning Rate", fontsize=12)
axs[1].legend()
axs[1].grid(True)
plt.show()

## Testing

In [None]:
import torchaudio

epoch = 130
model_name = "DINOs"
checkpoint = torch.load(os.path.join(checkpoint_path, f"{model_name}_epoch_{epoch:03d}.pth"))
new_state_dict = {k.replace("module.", ""): v for k, v in checkpoint.items()}
model.load_state_dict(new_state_dict)

In [None]:
random_folder = np.random.choice(os.listdir(test_path))
audio_mixture = os.path.join(test_path, random_folder, "mixture.wav")

In [None]:
extracted_sources = infer_pipeline(
    model=model,
    mixture=audio_mixture,
    sample_rate=SR,
    chunk_seconds=CHUNK_DUR,
    overlap=OVERLAP,
    n_iter=N_ITER,
    n_fft=NFFT,
    hop_length=HOP,
    device=device,
)

In [None]:
# Asumiendo que instruments, test_path, random_folder, SR, extracted_sources,
# results_path, model_name y epoch están definidos previamente.
fig, axs = plt.subplots(len(instruments), 1, figsize=(20, 5 * len(instruments)), constrained_layout=True)
for i, inst in enumerate(instruments):
    # Cargar la forma de onda original desde el archivo.
    orig_file_path = os.path.join(test_path, random_folder, f"{inst}.wav")
    wav, _ = torchaudio.load(orig_file_path, normalize=True)
    wav = wav.mean(dim=0)
    wav = wav[0].cpu().numpy()

    # Obtener la forma de onda extraída.
    pred = extracted_sources[inst]
    # Si la forma de onda tiene más de un canal, se toma el primero.
    if pred.ndim > 1:
        pred = pred[0]
    pred_np = pred.cpu().numpy()

    # Crear ejes de tiempo en segundos para ambas señales.
    t_orig = np.linspace(0, len(wav) / SR, len(wav))
    t_pred = np.linspace(0, len(pred_np) / SR, len(pred_np))

    # Graficar ambas señales superpuestas para facilitar la comparación.
    axs[i].plot(t_orig, wav, label='Original', color="tab:blue", linewidth=1)
    axs[i].plot(t_pred, pred_np, label='Extraída', color="tab:orange", linewidth=1, alpha=0.8)
    axs[i].set_title(f"{inst.capitalize()} - Comparación de Waveforms", fontsize=20)
    axs[i].set_xlabel("Tiempo (s)", fontsize=16)
    axs[i].set_ylabel("Amplitud", fontsize=16)
    axs[i].grid(True)

    # Definir la región de zoom: 1 segundo centrado en el medio de la señal original.
    center = len(wav) // 2
    zoom_samples = SR  # 1 segundo de datos
    start = max(center - zoom_samples // 2, 0)
    end = start + zoom_samples

    del wav, pred, pred_np

# Guardar la figura completa como un único archivo PNG.
output_file = os.path.join(results_path, f"{model_name}_results_{epoch:03d}.png")
plt.savefig(output_file)
plt.close(fig)

In [None]:
from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio
from torchmetrics.audio.snr import SignalNoiseRatio, ScaleInvariantSignalNoiseRatio

sisdr = ScaleInvariantSignalDistortionRatio().to(device)
sdr = SignalDistortionRatio().to(device)
snr = SignalNoiseRatio().to(device)
sisnr = ScaleInvariantSignalNoiseRatio().to(device)

for inst in instruments:
    orig_file_path = os.path.join(test_path, random_folder, f"{inst}.wav")
    wav, _ = torchaudio.load(orig_file_path, normalize=True)

    # Convertir a mono y mover a device
    wav = wav.mean(dim=0).to(device)  # (1, samples)

    pred = extracted_sources[inst].to(device)

    # Si pred tiene más de un canal, quedarse con el primero
    if pred.ndim > 1:
        pred = pred[0].unsqueeze(0)  # Asegurar (1, samples)

    # Recortar a la misma longitud
    min_len = min(wav.shape[-1], pred.shape[-1])
    wav = wav[..., :min_len]
    pred = pred[..., :min_len]

    sdr_val = sisdr(pred, wav)
    sir_val = sdr(pred, wav)
    snr_val = snr(pred, wav)
    sisnr_val = sisnr(pred, wav)

    print(f"Metrics for {inst.capitalize()}:")
    print(f"SISDR: {sdr_val.item():.4f}")
    print(f"SDR: {sir_val.item():.4f}")
    print(f"SNR: {snr_val.item():.4f}")
    print(f"SISNR: {sisnr_val.item():.4f}")


In [None]:
for inst in instruments:
    # Save waveform
    file_path = os.path.join(results_path, f"{inst}.wav")
    print(f"Saving {inst} to {file_path}")

    waveform = extracted_sources[inst]
    # Ensure waveform is 2D: if it's 1D, add a channel dimension.
    if waveform.ndim == 1:
        waveform = waveform.unsqueeze(0)
    # Convert waveform to float32, which is supported by torchaudio.
    waveform = waveform.to(torch.float32)
    torchaudio.save(file_path, waveform.cpu(), SR)

In [None]:
eval_pipeline(
    model=model,
    dataset_path=test_path,
    sample_rate=SR,
    chunk_seconds=CHUNK_DUR,
    overlap=OVERLAP,
    n_fft=NFFT,
    hop_length=HOP,
    device=device,
)