In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
import tsms
from dataset import get_dataset
from model import create_vae
from localconfig import LocalConfig
from IPython.display import Audio

conf = LocalConfig()
conf.latent_dim = 256



In [2]:
_, _, test = get_dataset(conf)
vae = create_vae(conf)
vae.load_weights("checkpoints/VAE_0.0016.h5")

In [3]:
test_iterable = iter(test)

def reconstruct_audio(freq, mag):
    phase = tsms.core.generate_phase(
        freq, conf.sample_rate,
        conf.frame_size
    )
    return tsms.core.harmonic_synthesis(
        h_freq=freq, h_mag=mag, h_phase=phase,
        sample_rate=conf.sample_rate,
        frame_step=conf.frame_size
    )

In [22]:
batch = next(test_iterable)
h = batch["h"]
mask = batch["mask"]
note_number = batch["note_number"]
velocity= batch["velocity"]
instrument_id = batch["instrument_id"]
# choose any instrument, if preffered
# instrument_id = np.array([[5]])
pred, z_mean, z_log_variance = vae.predict([h, note_number, instrument_id, velocity])

pitch = 40 + tf.argmax(note_number[0])

In [23]:
f0_from_note = tsms.core.midi_to_hz(
    tf.cast(pitch, dtype=tf.float32)
)

harmonics = tsms.core.get_number_harmonics(
    f0_from_note, conf.sample_rate
)

h_freq_pred = pred[:, :1001, :harmonics, 0]
h_mag_pred = pred[:, :1001, :harmonics, 1] * 0.5

h_freq_gt = h[:, :1001, :harmonics, 0]
h_mag_gt = h[:, :1001, :harmonics, 1]

harmonic_indices = tf.range(1, harmonics + 1, dtype=tf.float32)

In [24]:
h_freq_centered = h_freq_pred * f0_from_note
h_freq = h_freq_centered + f0_from_note * harmonic_indices

h_freq_centered_gt = h_freq_gt * f0_from_note
h_freq_gt = h_freq_centered_gt + f0_from_note * harmonic_indices

In [25]:
audio_pred = reconstruct_audio(
    h_freq, h_mag_pred
)
audio_gt = reconstruct_audio(
    h_freq_gt, h_mag_gt
)

In [26]:
Audio(audio_pred, rate=conf.sample_rate)

In [27]:
Audio(audio_gt, rate=conf.sample_rate)