In [None]:
import librosa
import numpy as np
import tensorflow as tf
from magenta.models.nsynth import utils
from magenta.models.nsynth.wavenet import fastgen
from magenta.models.nsynth.wavenet.h512_bo16 import Config

In [None]:
model_path = '/mnt/data/Birdman/models/wavenet-ckpt/model.ckpt-200000'
sample_size = 32000
batch_size = 50

In [None]:
def rolling_window(a, window, overlap, copy = False):
    sh = (a.size - window + 1, window)
    st = a.strides * 2
    view = np.lib.stride_tricks.as_strided(a, strides = st, shape = sh)[0::overlap]
    if copy:
        return view.copy()
    else:
        return view
    
def batch_array(a, batch_len):
    idx = list(range(0, len(a), batch_len))
    if idx[-1] != len(a):
        idx += [len(a)]
    for start, end in zip(idx[:-1], idx[1:]):
        yield a[start:end]
        
def load_nsynth(batch_size=1, sample_length=64000):
    """Load the NSynth autoencoder network.
    Args:
    batch_size: Batch size number of observations to process. [1]
    sample_length: Number of samples in the input audio. [64000]
    Returns:
    graph: The network as a dict with input placeholder in {"X"}
    """
    config = Config()
    with tf.device("/gpu:0"):
        x = tf.placeholder(tf.float32, shape=[batch_size, sample_length])
        graph = config.build({"wav": x}, is_training=False)
        graph.update({"X": x})
    return graph

def encode(wav_data, checkpoint_path, sample_length=64000):
    if wav_data.ndim == 1:
        wav_data = np.expand_dims(wav_data, 0)
        batch_size = 1
    elif wav_data.ndim == 2:
        batch_size = wav_data.shape[0]

  # Load up the model for encoding and find the encoding of "wav_data"
    session_config = tf.ConfigProto(allow_soft_placement=True)
    session_config.gpu_options.allow_growth = True
    with tf.Graph().as_default(), tf.Session(config=session_config) as sess:
        hop_length = Config().ae_hop_length
        wav_data, sample_length = utils.trim_for_encoding(wav_data, sample_length,
                                                          hop_length)
        print(sample_length)
        net = load_nsynth(batch_size=batch_size, sample_length=sample_length)
        saver = tf.train.Saver()
        saver.restore(sess, checkpoint_path)
        encodings = sess.run(net["encoding"], feed_dict={net["X"]: wav_data})
    return encodings

In [None]:
session_config = tf.ConfigProto(allow_soft_placement=True)
session_config.gpu_options.allow_growth = True
sess = tf.Session(config=session_config)
net = load_nsynth(batch_size=50, sample_length=31744)
saver = tf.train.Saver()
saver.restore(sess, model_path)

In [None]:
encodings = []

y, sr = librosa.load('/mnt/data/Birdman/samples/recordings/STHELENA-02_20140605_200000_1.wav', sr=None)
multiple_sampsize = len(y) // 32000 * 32000
y = y[:multiple_sampsize]
samples = rolling_window(y, sample_size, sample_size // 2)
for sample in batch_array(samples, batch_size):
    sample, sample_length = utils.trim_for_encoding(sample, sample_size, Config().ae_hop_length)
    encoding = sess.run(net["encoding"], feed_dict={net["X"]: sample[:31744]})
    encodings.append(encoding)

In [None]:
encodings

In [None]:
samples.shape

In [None]:
samples_subset = samples[:50,:]
samples_subset.shape

In [None]:
encoding = fastgen.encode(wav_data=samples_subset, checkpoint_path=model_path, sample_length=sample_size)

In [None]:
Config().ae_hop_length

In [None]:
y, sr = librosa.load('/mnt/data/Birdman/samples/recordings/STHELENA-02_20140605_200000_1.wav', sr=None)
samples = rolling_window(y, sample_size, sample_size // 2)

In [None]:
wav_data, sample_length = utils.trim_for_encoding(samples[:50], sample_size, Config().ae_hop_length)

In [None]:
wav_data.shape

In [None]:
int(len(y) / 32000) * 32000

In [None]:
len(y) // 32000 * 32000