In [None]:
import yaml
from pathlib import Path
from torch.utils.data import DataLoader
from asteroid.data import MUSDB18Dataset

#####################
##### ARGS ##########
#####################
DATA_DIR = Path("musdb_data")
with open("cfg.yaml", "r") as f:
    CFG = yaml.load(f, Loader=yaml.FullLoader)
    
SEGMENT_SIZE = CFG["segment_size"]
RANDOM_TRACK_MIX = CFG["random_track_mix"]
TARGETS = CFG["targets"]
N_SRC = len(TARGETS)

#####################
##### HYPER-PARAMETERS
#####################
SAMPLE_RATE = CFG["sample_rate"]
SIZE = None if CFG["size"] == -1 else CFG["size"]
LR = CFG["learning_rate"]
N_EPOCHS = CFG["n_epochs"]
BATCH_SIZE = CFG["batch_size"]

N_BLOCKS = CFG["n_blocks"]
N_REPEATS = CFG["n_repeats"]
BN_CHAN = CFG["bn_chan"]
HID_CHAN = CFG["hid_chan"]
SKIP_CHAN = CFG["skip_chan"]
CONV_KERNEL_SIZE = CFG["conv_kernel_size"]
KERNEL_SIZE = CFG["kernel_size"]
N_FILTERS = CFG["n_filters"]
STRIDE = CFG["stride"]

################
##### DATA #####
################
train_dataset = MUSDB18Dataset(
    root=DATA_DIR.__str__(),
    targets=TARGETS,
    suffix=".mp4",
    split="train",
    subset=None,
    segment=SEGMENT_SIZE,
    samples_per_track=1,
    random_segments=True,
    random_track_mix=RANDOM_TRACK_MIX,
    sample_rate=SAMPLE_RATE,
    size=SIZE
)
train_loader = DataLoader(train_dataset, batch_size=1)
print(">>> Training Dataloader ready")

test_dataset = MUSDB18Dataset(
    root=DATA_DIR.__str__(),
    targets=TARGETS,
    suffix=".mp4",
    split="test",
    subset=None,
    segment=SEGMENT_SIZE,
    samples_per_track=1,
    random_segments=True,
    random_track_mix=RANDOM_TRACK_MIX,
    sample_rate=SAMPLE_RATE,
    size=SIZE
)
test_loader = DataLoader(test_dataset, batch_size=1)

In [None]:
list_mixes = []
list_sources = []

In [None]:
for mix, sources in train_loader:
    list_mixes.append(mix)
    list_sources.append(sources)

In [None]:
mix.shape

In [None]:
sources.shape

In [None]:
from scipy.io import wavfile
for j, (mix, sources) in enumerate(zip(list_mixes, list_sources)):
    f = Path(f"./SON_{j}")
    f.mkdir()
    wavfile.write(str(f/"mix.wav"), SAMPLE_RATE, mix.detach().numpy())
    for i, s in enumerate(sources):
        path = str(f/f"{i}.wav")
        wavfile.write(path, SAMPLE_RATE, s.detach().numpy())