In [None]:
import os
import sys

project_root = os.getcwd()
while "src" not in os.listdir(project_root):
    project_root = os.path.dirname(project_root)
sys.path.append(project_root)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from IPython.display import Audio
from src.utils.audio import load_audio
from src.pipelines.data import process_audio_folder
from src.pipelines import musdb_pipeline, train_pipeline
from src.models import DeepSampler
from src.utils.training.loss import MultiSourceLoss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
SR = 44100
NFFT = 2048
HOP = 512
data_path = os.path.join(project_root, "data")
musdb_path = os.path.join(data_path, "musdb18hq", "train")
output_path = os.path.join(data_path, "processed")

musdb_files = os.listdir(musdb_path)
musdb_files.sort()

In [None]:
random_sample = np.random.choice(musdb_files)
mixture = load_audio(os.path.join(musdb_path, random_sample, "mixture.wav"))
chunks = process_audio_folder(
    audio_folder=os.path.join(musdb_path, random_sample),
    sample_rate=SR,
    chunk_duration=2,
    overlap=0,
)

In [None]:
Audio(mixture, rate=SR)

In [None]:
Audio(chunks[3]['mixture'], rate=SR)

In [None]:
train_dataset = musdb_pipeline(
    musdb_path=musdb_path,
    sample_rate=SR,
    n_fft=NFFT,
    hop_length=HOP,
    chunk_duration=2,
    overlap=0,
    save_dir=output_path,
    max_chunks=400,
)

In [None]:
model = DeepSampler()
model.to(device)
criterion = MultiSourceLoss(weights=[1, 1, 1, 1])
optimizer = torch.optim.Adam(model.parameters())

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=4,
)

In [None]:
# example_item[INPUT][BATCH][CHANNEL] = MIXTURE SPECTROGRAM
# example_tiem[OUTPUT][BATCH][CHANNEL] = [VOCALS, DRUMS, BASS, OTHER] SPECTROGRAM
example_item = next(iter(train_loader))
mixture_spectrogram = example_item[0][0][0]  # Mixture spectrogram
source_spectrograms = example_item[1][0]  # Separated sources: [Vocals, Drums, Bass, Other]
source_names = ['Vocals', 'Drums', 'Bass', 'Other']

# Create subplots
fig, axes = plt.subplots(1, 5, figsize=(20, 5))
axes[0].imshow(mixture_spectrogram, aspect='auto', origin='lower')
axes[0].set_title('Mixture')
axes[0].set_xlabel('Time')
axes[0].set_ylabel('Frequency')
axes[0].set_xticks([])
axes[0].set_yticks([])

for i in range(4):
    axes[i + 1].imshow(source_spectrograms[i], aspect='auto', origin='lower')
    axes[i + 1].set_title(source_names[i])
    axes[i + 1].set_xlabel('Time')
    axes[i + 1].set_yticks([])
    axes[i + 1].set_xticks([])
plt.tight_layout()
plt.show()

In [None]:
train_pipeline(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    dataloader=train_loader,
    epochs=2,
    phase1_epochs=1,
    device=device,
)