In [1]:
import os, sys
sys.path.append('..')

In [2]:
import tensorflow as tf
import tensorflow_probability as tfp

In [3]:
import sonnet

In [4]:
from filterflow.models import NNBernoulliDistribution, NNNormalDistribution
from filterflow.models.vrnn import make_filter, VRNNState
from data import create_pianoroll_dataset

from filterflow.resampling.criterion import NeffCriterion
from filterflow.resampling import RegularisedTransform, MultinomialResampler

In [5]:
import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [6]:
latent_size = 16
observation_size = 88
rnn_hidden_size = latent_size // 2
latent_encoder_layers = [16]
latent_encoded_size = 16

epsilon = 0.01
additional_variables_are_state = True
scaling = 0.9
convergence_threshold = 1e-3
max_iter = 500
neff = 0.5



batch_size = 1
n_particles = 16
dimension = latent_size

In [7]:
#@tf.function
def make_optimizer(initial_learning_rate=0.01, decay_steps=100, decay_rate=0.75, staircase=True):
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate,
        decay_steps=decay_steps,
        decay_rate=decay_rate,
        staircase=staircase)
    optimizer = tf.compat.v2.keras.optimizers.Adam(learning_rate=lr_schedule)
    return optimizer

In [8]:
#@tf.function
def input_fn(data_fp):
    inputs_tensor, targets_tensor, lens, mean = create_pianoroll_dataset(data_fp, split='train', batch_size=1)
    inputs_tensor = tf.expand_dims(inputs_tensor, 1)
    targets_tensor = tf.expand_dims(targets_tensor, 1)
    return (inputs_tensor, targets_tensor)

In [9]:
#@tf.function
def make_smc():
    resampling_criterion = NeffCriterion(tf.constant(neff), tf.constant(True))


    resampling_method = RegularisedTransform(epsilon,
                                       scaling=scaling,
                                       max_iter=max_iter,
                                       convergence_threshold=convergence_threshold,
                                       additional_variables_are_state=additional_variables_are_state)
    smc = make_filter(latent_size, observation_size, rnn_hidden_size, latent_encoder_layers,
            latent_encoded_size, resampling_method, resampling_criterion)
    return smc

In [10]:
#@tf.function
def init_state_fn(smc):
    normal_dist = tfp.distributions.Normal(0., 1.)

    initial_latent_state = tf.zeros([batch_size, n_particles, dimension])
    initial_latent_state = tf.cast(initial_latent_state, dtype=float)
    latent_encoded = smc._transition_model.latent_encoder(initial_latent_state)

    # initial rnn_state
    initial_rnn_state = [normal_dist.sample([batch_size, n_particles, rnn_hidden_size], seed=0)] * 2
    initial_rnn_state = tf.concat(initial_rnn_state, axis=-1)

    # rnn_out
    initial_rnn_out = tf.zeros([batch_size, n_particles, rnn_hidden_size])

    initial_weights = tf.ones((batch_size, n_particles), dtype=float) / tf.cast(n_particles, float)
    log_likelihoods = tf.zeros(batch_size, dtype=float)
    initial_state = VRNNState(particles=initial_latent_state,
                           log_weights=tf.math.log(initial_weights),
                           weights=initial_weights,
                           obs_likelihood=log_likelihoods,
                           log_likelihoods=log_likelihoods,
                           rnn_state=initial_rnn_state,
                           rnn_out=initial_rnn_out,
                           latent_encoded=latent_encoded)
    return initial_state

In [11]:

#@tf.function
def train_step(smc, init_state, T, observation_series, inputs_series, optimizer, seed=0, use_correction_term=False):
    
    trainable_variables = smc._transition_model.variables + smc._observation_model.variables
    
    
    with tf.GradientTape() as tape:
        tape.watch(trainable_variables)
        final_state = smc(init_state, observation_series, n_observations=T, inputs_series=inputs_series, return_final=True,
                          seed=seed)
        real_ll = tf.reduce_mean(final_state.log_likelihoods)
        obs_likelihood = tf.reduce_mean(final_state.obs_likelihood)
        ess = final_state.ess

        if use_correction_term:
            correction =  tf.reduce_mean(final_state.resampling_correction)
        else:
            correction = tf.constant(0.)

        loss = -(real_ll + correction)

    grads = tape.gradient(loss, trainable_variables)

    capped_gvs = [tf.clip_by_value(grad, -500., 500.) for grad in grads]
    optimizer.apply_gradients(zip(capped_gvs, trainable_variables))
    return loss, grads, ess

In [12]:
import tensorflow.compat.v1 as tf_compat

@tf.function
def train(smc, num_steps, init_state, T, observation_series, inputs_series, optimizer, seed=0, use_correction_term=False):
    # init containers
    loss_tensor_array = tf.TensorArray(dtype=tf.float32, size=num_steps, dynamic_size=False, element_shape=[])
    ess_tensor_array = tf.TensorArray(dtype=tf.float32, size=num_steps, dynamic_size=False, element_shape=[])
    
    # checkpoint
    ckpt = tf.train.Checkpoint(optimizer=optimizer, smc=smc)
    manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
        
    for step in tf.range(1, num_steps + 1):
        # train step
        loss, grads, ess = train_step(smc, init_state, T, observation_series, 
                                      inputs_series, optimizer, seed=0, use_correction_term=False)
        # record
        ess_tensor_array = ess_tensor_array.write(step - 1, ess[0])
        loss_tensor_array = loss_tensor_array.write(step - 1, loss)
        if step % 50 == 0:
            tf.py_function(manager.save, [], [tf.string])
        
    return loss_tensor_array.stack(), ess_tensor_array.stack()

In [13]:
smc = make_smc()
init_state = init_state_fn(smc)

data_dir = '/data/hylia/thornton/filterflow/data/piano_data'
file_name = 'jsb.pkl'
data_fp = os.path.join(data_dir, file_name)
inputs_series, observation_series = input_fn(data_fp)

# snt networks initiated on first call
t_samp = smc._transition_model.sample(init_state, inputs_series[0], seed=0)
obs_samp = smc._observation_model.sample(init_state, seed=0)

T = observation_series.shape.as_list()[0]
observation_series = tf.data.Dataset.from_tensor_slices(observation_series)
inputs_series = tf.data.Dataset.from_tensor_slices(inputs_series)

optimizer = make_optimizer()

In [14]:
ckpt = tf.train.Checkpoint(optimizer=optimizer, smc=smc)  
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
ckpt.restore(manager.latest_checkpoint)

smc.variables

(<tf.Variable 'conditional_bernoulli_distribution_fcnet/linear_0/b:0' shape=(88,) dtype=float32, numpy=
 array([ 0.02791241, -0.04618699,  0.08131835, -0.00614074,  0.12447699,
         0.11066904,  0.07472951,  0.10863438, -0.0312203 ,  0.1259065 ,
         0.12236042,  0.09303114,  0.13466662,  0.07671011, -0.05925879,
         0.09951378,  0.01375258,  0.09570523, -0.11664508, -0.00288441,
         0.13362534,  0.08650601,  0.05181648,  0.03424622,  0.0478554 ,
         0.10710999,  0.07780091, -0.01259601,  0.1479202 ,  0.12158426,
         0.08254941,  0.09259136,  0.09832896,  0.10499346,  0.0388561 ,
         0.11600867,  0.11638453,  0.00092365, -0.02155752,  0.12775402,
        -0.01566025, -0.09990635,  0.12005492,  0.10758364,  0.11856676,
         0.23092723,  0.15447454,  0.10761043, -0.09613475,  0.10213464,
        -0.04312116,  0.10486731,  0.13379356,  0.11235483,  0.12202025,
        -0.08131943,  0.15162776,  0.2213026 , -0.04668504,  0.07900737,
        -0.05303629,

In [23]:
final_state = smc(init_state, observation_series, n_observations=T, inputs_series=inputs_series, return_final=True,
                          seed=0)
real_ll = tf.reduce_mean(final_state.log_likelihoods)

In [24]:
real_ll

<tf.Tensor: shape=(), dtype=float32, numpy=-348.23755>

In [25]:
num_steps = tf.constant(100)
start = time.time()

loss_arr, ess_arr = train(smc, num_steps, init_state, T, observation_series, inputs_series, optimizer, seed=0, use_correction_term=False)
stop = time.time()
print(stop-start)



110.2130298614502


In [18]:
loss_arr.numpy()

array([356.37897, 355.29608, 354.28668, 353.34296, 352.46326, 351.64172,
       350.86884, 350.14185, 349.46582, 348.83185], dtype=float32)

In [21]:
(stop-start) / 100 * 50*10**3/60/60

1.955129537317488