In [None]:
import os
import glob
import numpy as np
import librosa
from magenta.models.nsynth import utils
from magenta.models.nsynth.wavenet import fastgen

In [None]:
model_path = '/mnt/data/Birdman/models/wavenet-ckpt/model.ckpt-200000'
sample_size = 32000
batch_size = 50

In [None]:
def rolling_window(a, window, overlap, copy = False):
    sh = (a.size - window + 1, window)
    st = a.strides * 2
    view = np.lib.stride_tricks.as_strided(a, strides = st, shape = sh)[0::overlap]
    if copy:
        return view.copy()
    else:
        return view
    
def batch_array(a, batch_len):
    idx = list(range(0, len(a), batch_len))
    if idx[-1] != len(a):
        idx += [len(a)]
    for start, end in zip(idx[:-1], idx[1:]):
        yield a[start:end]
        
def get_name_from_path(path):
    name = os.path.splitext(os.path.basename(path))[0]
    return name

In [None]:
recordings_paths = glob.glob('/mnt/data/Birdman/full/*.wav')
completed_names = [get_name_from_path(path) for path in glob.glob('*.npy')]
completed_names

In [None]:
d = {}
encodings = []

for path in recordings_paths:
    name = get_name_from_path(path)
    if name in completed_names:
        print('Skipping', name)
    else:
        print('Processing', name)
        y, sr = librosa.load(path, sr=None)
        samples = rolling_window(y, sample_size, sample_size // 2)
        for sample in batch_array(samples, batch_size):
            encoding = fastgen.encode(wav_data=sample, checkpoint_path=model_path, sample_length=sample_size)
            encodings.append(encoding)
        full_enc = np.vstack(encodings)
        d[name] = full_enc
        np.save(name, full_enc)

In [None]:
d