In [1]:
import yaml
import numpy as np
import os
import soundfile as sf
from scipy.signal import stft
from tqdm import tqdm
import torch
# If you have a GPU, put the data on the GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
directory = "data/babyslakh_16k"
listdir = os.listdir(directory)
listdir.sort()
format = '.wav'
savedir = 'data/baby/tensor'
sample_rate = 16000
nperseg = 256
max_size = 250

for i, tr in enumerate(tqdm(listdir)):
    tr_dicts_2 = {}
    tr_path = os.path.join(directory, tr)
    with open(os.path.join(tr_path, "metadata.yaml")) as meta:
        metadata = yaml.load(meta, Loader=yaml.Loader)
    # add Mix to the dictionary
    f, t, Zxx = stft(sf.read(os.path.join(directory,tr, 'mix'+format))[0], fs=sample_rate, nperseg=nperseg)
    tr_dicts_2.update({'mix': torch.tensor(Zxx.copy(), dtype=torch.complex128, device=device)})
    file_inst = []
    for stem in metadata['stems'].keys():
        file_inst.append([stem + format, metadata['stems'][stem]['inst_class']])
    file_inst = np.array(file_inst)
    file_inst = file_inst[[file_inst[:,0][i] in os.listdir(tr_path + '/stems') for i in range(len(file_inst))]]
    # combine all stems from the same instrument and track using soundfile
    for inst in np.unique(file_inst[:,1]):
        inst_files = file_inst[file_inst[:,1] == inst][:,0]
        inst_files = [os.path.join(tr_path, 'stems', inst_file) for inst_file in inst_files]
        inst_data = np.array([sf.read(inst_file)[0] for inst_file in inst_files])
        inst_data = np.sum(inst_data, axis=0)
        # compute the STFT of the combined data
        if inst in ['mix','Bass', 'Guitar', 'Drums', 'Piano']:
            f, t, Zxx = stft(inst_data, fs=sample_rate, nperseg=nperseg)
            tr_dicts_2.update({inst: torch.tensor(Zxx.copy(), dtype=torch.complex128, device=device)})
    # save the dictionary as a .pt file
    torch.save(tr_dicts_2, os.path.join(savedir, tr + '.pt'))
    if i > max_size:
        break

100%|██████████| 20/20 [00:15<00:00,  1.33it/s]


In [7]:
tr_dicts_3 = torch.load('data/baby/tensor/Track00002.pt')

tr_dicts_3['mix']

tensor([[ 0.0000e+00+0.0000e+00j,  0.0000e+00+0.0000e+00j,
          0.0000e+00+0.0000e+00j,  ...,
         -3.7538e-05+0.0000e+00j, -3.6747e-05+0.0000e+00j,
         -4.0455e-06+0.0000e+00j],
        [ 0.0000e+00+0.0000e+00j,  0.0000e+00+0.0000e+00j,
          0.0000e+00+0.0000e+00j,  ...,
          1.7981e-05+1.5456e-06j,  2.1901e-05+2.8957e-06j,
         -1.3105e-06+3.5642e-06j],
        [ 0.0000e+00+0.0000e+00j,  0.0000e+00+0.0000e+00j,
          0.0000e+00+0.0000e+00j,  ...,
          8.7858e-07-1.7095e-06j, -1.4241e-06+1.0943e-06j,
          2.4475e-06+1.9497e-06j],
        ...,
        [ 0.0000e+00+0.0000e+00j,  0.0000e+00+0.0000e+00j,
          0.0000e+00+0.0000e+00j,  ...,
         -3.1833e-07+6.4117e-07j,  2.5399e-07-1.4708e-07j,
         -2.5272e-07+4.8756e-08j],
        [ 0.0000e+00+0.0000e+00j,  0.0000e+00+0.0000e+00j,
          0.0000e+00+0.0000e+00j,  ...,
          2.4930e-07-1.1578e-07j, -6.0487e-07-1.5230e-07j,
          2.1178e-08+2.5657e-07j],
        [ 0.0000e+00+0