In [None]:
import os
import sys

project_root = os.getcwd()
while "src" not in os.listdir(project_root):
    project_root = os.path.dirname(project_root)
sys.path.append(project_root)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import librosa

from IPython.display import Audio
from src.utils.audio import load_audio
from src.pipelines.data import process_audio_folder
from src.pipelines import musdb_pipeline, train_pipeline, infer_pipeline
from src.models import DeepSampler
from src.utils.training.loss import MultiSourceLoss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
SR = 44100
NFFT = 2048
HOP = 512
CHUNK_DUR = 2
OVERLAP = 0

data_path = os.path.join(project_root, "data")
musdb_path = os.path.join(data_path, "musdb18hq")
train_path = os.path.join(data_path, "musdb18hq", "train")
test_path = os.path.join(data_path, "musdb18hq", "test")
output_path = os.path.join(data_path, "processed")

train_files = os.listdir(train_path)
train_files.sort()

In [None]:
random_sample = np.random.choice(train_files)
mixture = load_audio(os.path.join(train_path, random_sample, "mixture.wav"))
chunks = process_audio_folder(
    audio_folder=os.path.join(train_path, random_sample),
    sample_rate=SR,
    chunk_duration=CHUNK_DUR,
    overlap=OVERLAP,
)

In [None]:
Audio(mixture, rate=SR)

In [None]:
Audio(chunks[3]["mixture"], rate=SR)

In [None]:
train_dataset = musdb_pipeline(
    musdb_path=train_path,
    sample_rate=SR,
    n_fft=NFFT,
    hop_length=HOP,
    chunk_duration=CHUNK_DUR,
    overlap=OVERLAP,
    save_dir=output_path,
    max_chunks=5000,
)

In [None]:
model = DeepSampler()
model.to(device)
criterion = MultiSourceLoss(weights=[1, 1, 1, 1])
optimizer = torch.optim.Adam(model.parameters())

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
)

In [None]:
# example_item[INPUT][BATCH][CHANNEL] = MIXTURE SPECTROGRAM
# example_tiem[OUTPUT][BATCH][CHANNEL] = [VOCALS, DRUMS, BASS, OTHER] SPECTROGRAM
example_item = next(iter(train_loader))
mixture_spectrogram = example_item[0][0][0]  # Mixture spectrogram
source_spectrograms = example_item[1][0]  # Separated sources: [Vocals, Drums, Bass, Other]
source_names = ["Vocals", "Drums", "Bass", "Other"]

# Create subplots
fig, axes = plt.subplots(1, 5, figsize=(20, 5))
axes[0].imshow(mixture_spectrogram, aspect="auto", origin="lower")
axes[0].set_title("Mixture")
axes[0].set_xlabel("Time")
axes[0].set_ylabel("Frequency")
axes[0].set_xticks([])
axes[0].set_yticks([])

for i in range(4):
    axes[i + 1].imshow(source_spectrograms[i], aspect="auto", origin="lower")
    axes[i + 1].set_title(source_names[i])
    axes[i + 1].set_xlabel("Time")
    axes[i + 1].set_yticks([])
    axes[i + 1].set_xticks([])
plt.tight_layout()
plt.show()

In [None]:
history = train_pipeline(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    dataloader=train_loader,
    epochs=1,
    device=device,
)

In [None]:
# plot history = {"epoch_loss": [], "learning_rate": [], "batch_losses": []}
# plt.plot(history["epoch_loss"])
# plt.title("Loss")
# plt.show()
#
# plt.plot(history["learning_rate"])
# plt.title("Learning Rate")
# plt.show()
#
# plt.plot(np.array(history["batch_losses"]).flatten())
# plt.title("Batch Losses")
# plt.show()

In [None]:
random_folder = np.random.choice(os.listdir(test_path))
audio_mixture = os.path.join(musdb_path, "test", random_folder, "mixture.wav")

In [None]:
sources = infer_pipeline(
    model=model,
    mixture_path=audio_mixture,
    sample_rate=SR,
    chunk_seconds=CHUNK_DUR,
    overlap=OVERLAP,
    n_fft=NFFT,
    hop_length=HOP,
    device=device,
)

In [None]:
inst = ["vocals", "drums", "bass", "other"]

for instrument in inst:
    file_path = os.path.join(musdb_path, "test", random_folder, f"{instrument}.wav")
    wav, _ = librosa.load(file_path, sr=SR)
    plt.figure(figsize=(20, 6))
    plt.plot(wav)
    plt.title(instrument)
    plt.show()

In [None]:
for i, s in enumerate(inst):
    plt.figure(figsize=(20, 6))
    plt.plot(sources[s])
    plt.title(s)
    plt.show()