In [None]:
# working with colab
# !pip install torchmetrics
# !git clone https://github.com/Nhrot22230/DeepSampler
#  !cd DeepSampler && make init

# import os
# import sys

# # Suponiendo que clonaste el repositorio en /content/DeepSampler
# project_root = os.path.join(os.getcwd(), "DeepSampler")
# if project_root not in sys.path:
#     sys.path.insert(0, project_root)

In [None]:
import os
import sys

project_root = os.getcwd()
while "src" not in os.listdir(project_root):
    project_root = os.path.dirname(project_root)
sys.path.append(project_root)

import torch
import numpy as np
import matplotlib.pyplot as plt
import librosa

from src.pipelines.data import musdb_pipeline
from src.pipelines.train import train_pipeline
from src.pipelines.infer import infer_pipeline
from src.pipelines.eval import eval_pipeline
from src.models import DeepSampler
from src.utils.data.dataset import MUSDBDataset
from src.utils.train.losses import MultiSourceLoss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Audio parameters
SR = 44100
NFFT = 512
HOP = 1024
CHUNK_DUR = 2
OVERLAP = 0

# Model parameters
N_SOURCES = 4
BASE_CHANNELS = 32
DEPTH = 4
DROP_RATE = 0.3
T_HEADS = 4
T_LAYERS = 2

# Dataset parameters
ISOLATED_BATCH_SIZE = 2
ISOLATED_MAX_SAMPLES = 300
ISOLATED_EPOCHS = 5

MIXED_BATCH_SIZE = 4
MIXED_MAX_SAMPLES = 1000
MIXED_EPOCHS = 100

# Inference parameters
N_ITER = 128

data_path = os.path.join(project_root, "data")
musdb_path = os.path.join(data_path, "musdb18hq")
train_path = os.path.join(data_path, "musdb18hq", "train")
test_path = os.path.join(data_path, "musdb18hq", "test")
output_path = os.path.join(data_path, "processed")
experiments_path = os.path.join(project_root, "experiments")
checkpoint_path = os.path.join(experiments_path, "checkpoints")
results_path = os.path.join(experiments_path, "results")
log_path = os.path.join(experiments_path, "logs")

train_files = os.listdir(train_path)
train_files.sort()

instruments = ["vocals", "drums", "bass", "other"]
calculated_shape = (
    MIXED_BATCH_SIZE,
    N_SOURCES,
    NFFT // 2 + 1,
    CHUNK_DUR * SR // HOP + 1,
)
calculated_shape

In [None]:
model = DeepSampler(
    in_ch=1,
    out_ch=N_SOURCES,
    base_ch=BASE_CHANNELS,
    depth=DEPTH,
    dropout=DROP_RATE,
    t_layers=T_LAYERS,
    t_heads=T_HEADS,
    t_dropout=DROP_RATE,
).to(device)
criterion = MultiSourceLoss(weights=[5,4,2,3]).to(device)

## Training

In [None]:
isolated_dataset = {inst: None for inst in instruments}
for i, inst in enumerate(instruments):
    isolated_dataset[inst] = musdb_pipeline(
        musdb_path=os.path.join(musdb_path, "train"),
        isolated=[inst],
        sample_rate=SR,
        n_fft=NFFT,
        hop_length=HOP,
        chunk_duration=CHUNK_DUR,
        overlap=OVERLAP,
        max_chunks=ISOLATED_MAX_SAMPLES,
    )

# combine isolated datasets
combined_data = []
for inst in instruments:
    dataset = isolated_dataset[inst]
    combined_data.extend(dataset.data)
    del dataset, isolated_dataset[inst]

combined_dataset = MUSDBDataset(data=combined_data, n_fft=NFFT, hop_length=HOP)
del combined_data

In [None]:
isolated_dataloader = torch.utils.data.DataLoader(
    combined_dataset,
    batch_size=MIXED_BATCH_SIZE,
    shuffle=True,
)

In [None]:
history = train_pipeline(
    model=model,
    criterion=criterion,
    optimizer=torch.optim.AdamW(model.parameters(), lr=3e-3),
    dataloader=isolated_dataloader,
    epochs=ISOLATED_EPOCHS,
    device=device,
)
del isolated_dataloader, combined_dataset, history

In [None]:
train_dataset = musdb_pipeline(
    musdb_path=train_path,
    sample_rate=SR,
    n_fft=NFFT,
    hop_length=HOP,
    chunk_duration=CHUNK_DUR,
    overlap=OVERLAP,
    max_chunks=MIXED_MAX_SAMPLES,
    save_dir=output_path,
)

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=MIXED_BATCH_SIZE,
    shuffle=True,
)

In [None]:
# sample = next(iter(train_loader))
# if 'instruments' not in globals():
#     instruments = [f"Instrument {i+1}" for i in range(4)]
#
# fig = plt.figure(constrained_layout=True, figsize=(15, 12))
# gs = fig.add_gridspec(3, 2)
#
# ax_mixed = fig.add_subplot(gs[0, :])
# mixed_db = librosa.amplitude_to_db(sample[0][0][0].numpy(), ref=np.max)
# img_mixed = ax_mixed.imshow(mixed_db, aspect="auto", origin="lower", cmap='magma')
# ax_mixed.set_title("Mixed Audio", fontsize=16, fontweight='bold', color='navy')
# fig.colorbar(img_mixed, ax=ax_mixed, fraction=0.046, pad=0.04)
#
# for i in range(4):
#     row = 1 if i < 2 else 2
#     col = i % 2
#     ax_source = fig.add_subplot(gs[row, col])
#
#     source_db = librosa.amplitude_to_db(sample[1][0][i].numpy(), ref=np.max)
#     img_source = ax_source.imshow(source_db, aspect="auto", origin="lower", cmap='magma')
#     ax_source.set_title(f"{instruments[i]} Source", fontsize=14, color='darkred')
#     fig.colorbar(img_source, ax=ax_source, fraction=0.046, pad=0.04)
#
# plt.suptitle("Audio Sources Visualization", fontsize=18, fontweight='bold', color='teal')
#
# save_path = os.path.join(results_path, "audio_sources.png")
# plt.savefig(save_path)
# plt.show()

In [None]:
try:
    history = train_pipeline(
        model=model,
        criterion=criterion,
        optimizer=torch.optim.AdamW(model.parameters(), lr=1e-3),
        dataloader=train_loader,
        epochs=MIXED_EPOCHS,
        checkpoint_name="dino",
        checkpoint_dir=checkpoint_path,
        checkpoint_every=MIXED_EPOCHS // 10,
        device=device,
    )
except RuntimeError as error:
    if "out of memory" in str(error):
        print("Error: No se pudo asignar memoria en la GPU. Liberando memoria...")
        torch.cuda.empty_cache()
    else:
        print("Se produjo un error:", error)
    raise

In [None]:
# Plot para "Epoch Loss"
plt.figure(figsize=(10, 6))
plt.plot(history["loss"], marker="o", color="blue", label="Epoch Loss")
plt.set_title("Loss por Época", fontsize=14)
plt.set_xlabel("Época", fontsize=12)
plt.set_ylabel("Loss", fontsize=12)
plt.legend()
plt.grid(True)

plt.suptitle("Entrenamiento del Modelo", fontsize=16, fontweight="bold")
plt.savefig(os.path.join(results_path, "training.png"))

## Testing

In [None]:
import torchaudio

epoch = 60
checkpoint = torch.load(
    os.path.join(checkpoint_path, f"dino_epoch_{epoch:03d}.pth")
)
new_state_dict = {k.replace("module.", ""): v for k, v in checkpoint.items()}
model.load_state_dict(new_state_dict)

In [None]:
eval_pipeline(
    model=model,
    dataset_path=test_path,
    sample_rate=SR,
    chunk_seconds=CHUNK_DUR,
    overlap=OVERLAP,
    n_fft=NFFT,
    hop_length=HOP,
    device=device,
)

In [None]:
random_folder = np.random.choice(os.listdir(test_path))
audio_mixture = os.path.join(test_path, random_folder, "mixture.wav")

extracted_sources = infer_pipeline(
    model=model,
    mixture=audio_mixture,
    sample_rate=SR,
    chunk_seconds=CHUNK_DUR,
    overlap=OVERLAP,
    n_iter=N_ITER,
    n_fft=NFFT,
    hop_length=HOP,
    device=device,
)

In [None]:
fig, axs = plt.subplots(len(instruments), 2, figsize=(20, 16), constrained_layout=True)

for i, inst in enumerate(instruments):
    # Load the original waveform from file.
    orig_file_path = os.path.join(test_path, random_folder, f"{inst}.wav")
    wav, _ = librosa.load(orig_file_path, sr=SR)

    # Plot the original waveform.
    axs[i, 0].plot(wav, color="tab:blue")
    axs[i, 0].set_title(f"{inst.capitalize()} (Original)", fontsize=16)
    axs[i, 0].set_xlabel("Samples", fontsize=14)
    axs[i, 0].set_ylabel("Amplitude", fontsize=14)
    axs[i, 0].grid(True)

    # Get the extracted waveform.
    pred = extracted_sources[inst]
    # If waveform has more than one channel, use the first channel for plotting.
    if pred.ndim > 1:
        pred = pred[0]
    pred_np = pred.cpu().numpy()

    # Plot the extracted waveform.
    axs[i, 1].plot(pred_np, color="tab:orange")
    axs[i, 1].set_title(f"{inst.capitalize()} (Extracted)", fontsize=16)
    axs[i, 1].set_xlabel("Samples", fontsize=14)
    axs[i, 1].set_ylabel("Amplitude", fontsize=14)
    axs[i, 1].grid(True)

    # Free memory for this loop iteration.
    del wav, pred, pred_np

# Save the complete figure as one large PNG file in RESULTS_PATH.
output_file = os.path.join(results_path, "dino_sampler", "combined_results.png")
plt.savefig(output_file, bbox_inches="tight", dpi=300)
plt.close(fig)

In [None]:
for inst in instruments:
    # Save waveform
    file_path = os.path.join(results_path, "dino_sampler", f"{inst}.wav")
    print(f"Saving {inst} to {file_path}")

    waveform = extracted_sources[inst]
    # Ensure waveform is 2D: if it's 1D, add a channel dimension.
    if waveform.ndim == 1:
        waveform = waveform.unsqueeze(0)
    # Convert waveform to float32, which is supported by torchaudio.
    waveform = waveform.to(torch.float32)
    torchaudio.save(file_path, waveform.cpu(), SR)