In [1]:
import os, sys
sys.path.append('../..')

In [2]:
import tensorflow as tf
import tensorflow_probability as tfp

In [3]:
import sonnet

In [4]:
from filterflow.models import NNBernoulliDistribution, NNNormalDistribution
from filterflow.models.vrnn import make_filter, VRNNState
from data import create_pianoroll_dataset

from filterflow.resampling.criterion import NeffCriterion
from filterflow.resampling import RegularisedTransform, MultinomialResampler

In [5]:
import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [6]:
latent_size = 16
observation_size = 88
rnn_hidden_size = latent_size // 2
latent_encoder_layers = [16]
latent_encoded_size = 16

epsilon = 0.01
additional_variables_are_state = True
scaling = 0.9
convergence_threshold = 1e-3
max_iter = 500
neff = 0.5



batch_size = 1
n_particles = 16
dimension = latent_size

In [7]:
#@tf.function
def make_optimizer(initial_learning_rate=0.01, decay_steps=100, decay_rate=0.75, staircase=True):
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate,
        decay_steps=decay_steps,
        decay_rate=decay_rate,
        staircase=staircase)
    optimizer = tf.compat.v2.keras.optimizers.Adam(learning_rate=lr_schedule)
    return optimizer

In [8]:
#@tf.function
def input_fn(data_fp):
    inputs_tensor, targets_tensor, lens, mean = create_pianoroll_dataset(data_fp, split='train', batch_size=1)
    inputs_tensor = tf.expand_dims(inputs_tensor, 1)
    targets_tensor = tf.expand_dims(targets_tensor, 1)
    return (inputs_tensor, targets_tensor)

In [9]:
#@tf.function
def make_smc():
    resampling_criterion = NeffCriterion(tf.constant(neff), tf.constant(True))


    resampling_method = RegularisedTransform(epsilon,
                                       scaling=scaling,
                                       max_iter=max_iter,
                                       convergence_threshold=convergence_threshold,
                                       additional_variables_are_state=additional_variables_are_state)
    smc = make_filter(latent_size, observation_size, rnn_hidden_size, latent_encoder_layers,
            latent_encoded_size, resampling_method, resampling_criterion)
    return smc

In [10]:
#@tf.function
def init_state_fn(smc):
    normal_dist = tfp.distributions.Normal(0., 1.)

    initial_latent_state = tf.zeros([batch_size, n_particles, dimension])
    initial_latent_state = tf.cast(initial_latent_state, dtype=float)
    latent_encoded = smc._transition_model.latent_encoder(initial_latent_state)

    # initial rnn_state
    initial_rnn_state = [normal_dist.sample([batch_size, n_particles, rnn_hidden_size], seed=0)] * 2
    initial_rnn_state = tf.concat(initial_rnn_state, axis=-1)

    # rnn_out
    initial_rnn_out = tf.zeros([batch_size, n_particles, rnn_hidden_size])

    initial_weights = tf.ones((batch_size, n_particles), dtype=float) / tf.cast(n_particles, float)
    log_likelihoods = tf.zeros(batch_size, dtype=float)
    initial_state = VRNNState(particles=initial_latent_state,
                           log_weights=tf.math.log(initial_weights),
                           weights=initial_weights,
                           obs_likelihood=log_likelihoods,
                           log_likelihoods=log_likelihoods,
                           rnn_state=initial_rnn_state,
                           rnn_out=initial_rnn_out,
                           latent_encoded=latent_encoded)
    return initial_state

In [11]:

#@tf.function
def train_step(smc, init_state, T, observation_series, inputs_series, optimizer, seed=0, use_correction_term=False):
    
    trainable_variables = smc._transition_model.variables + smc._observation_model.variables
    
    
    with tf.GradientTape() as tape:
        tape.watch(trainable_variables)
        final_state = smc(init_state, observation_series, n_observations=T, inputs_series=inputs_series, return_final=True,
                          seed=seed)
        real_ll = tf.reduce_mean(final_state.log_likelihoods)
        obs_likelihood = tf.reduce_mean(final_state.obs_likelihood)
        ess = final_state.ess

        if use_correction_term:
            correction =  tf.reduce_mean(final_state.resampling_correction)
        else:
            correction = tf.constant(0.)

        loss = -(real_ll + correction)

    grads = tape.gradient(loss, trainable_variables)

    capped_gvs = [tf.clip_by_value(grad, -500., 500.) for grad in grads]
    optimizer.apply_gradients(zip(capped_gvs, trainable_variables))
    return loss, grads, ess

In [12]:
import tensorflow.compat.v1 as tf_compat

@tf.function
def train(smc, num_steps, init_state, T, observation_series, inputs_series, optimizer, seed=0, use_correction_term=False):
    # init containers
    loss_tensor_array = tf.TensorArray(dtype=tf.float32, size=num_steps, dynamic_size=False, element_shape=[])
    ess_tensor_array = tf.TensorArray(dtype=tf.float32, size=num_steps, dynamic_size=False, element_shape=[])
    
    # checkpoint
    ckpt = tf.train.Checkpoint(optimizer=optimizer, smc=smc)
    manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
        
    for step in tf.range(1, num_steps + 1):
        # train step
        loss, grads, ess = train_step(smc, init_state, T, observation_series, 
                                      inputs_series, optimizer, seed=0, use_correction_term=False)
        # record
        ess_tensor_array = ess_tensor_array.write(step - 1, ess[0])
        loss_tensor_array = loss_tensor_array.write(step - 1, loss)
        if step % 50 == 0:
            tf.py_function(manager.save, [], [tf.string])
        
    return loss_tensor_array.stack(), ess_tensor_array.stack()

In [13]:
smc = make_smc()
init_state = init_state_fn(smc)

data_dir = '/data/hylia/thornton/filterflow/data/piano_data'
file_name = 'jsb.pkl'
data_fp = os.path.join(data_dir, file_name)
inputs_series, observation_series = input_fn(data_fp)

# snt networks initiated on first call
t_samp = smc._transition_model.sample(init_state, inputs_series[0], seed=0)
obs_samp = smc._observation_model.sample(init_state, seed=0)

T = observation_series.shape.as_list()[0]
observation_series = tf.data.Dataset.from_tensor_slices(observation_series)
inputs_series = tf.data.Dataset.from_tensor_slices(inputs_series)

optimizer = make_optimizer()

In [14]:
ckpt = tf.train.Checkpoint(optimizer=optimizer, smc=smc)  
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
ckpt.restore(manager.latest_checkpoint)

smc.variables

(<tf.Variable 'conditional_bernoulli_distribution_fcnet/linear_0/b:0' shape=(88,) dtype=float32, numpy=
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.], dtype=float32)>,
 <tf.Variable 'conditional_bernoulli_distribution_fcnet/linear_0/w:0' shape=(24, 88) dtype=float32, numpy=
 array([[-0.2824334 , -0.33880448,  0.2363793 , ..., -0.18314494,
          0.1297835 ,  0.28344813],
        [ 0.30638048,  0.19645168, -0.12270907, ..., -0.00490296,
         -0.0914169 , -0.0524416 ],
        [ 0.10982814,  0.1727829 ,  0.04427597, ...,  0.2899695 ,
         -0.17524444, -0.15036178],
        ...,
        [-0.05506197, -0.03092252, -0.0818035 , ...,

In [15]:
final_state = smc(init_state, observation_series, n_observations=T, inputs_series=inputs_series, return_final=True,
                          seed=0)
real_ll = tf.reduce_mean(final_state.log_likelihoods)

In [16]:
real_ll

<tf.Tensor: shape=(), dtype=float32, numpy=-7921.6885>

In [17]:
num_steps = tf.constant(100)
start = time.time()

loss_arr, ess_arr = train(smc, num_steps, init_state, T, observation_series, inputs_series, optimizer, seed=0, use_correction_term=False)
stop = time.time()
print(stop-start)



155.58456087112427


In [18]:
loss_arr.numpy()

array([7921.6885 , 6678.5693 , 5287.273  , 3867.9644 , 2685.331  ,
       1979.1832 , 1678.0592 , 1601.3253 , 1607.1075 , 1577.1888 ,
       1527.7871 , 1483.6029 , 1444.8535 , 1406.782  , 1370.1868 ,
       1337.3705 , 1314.1073 , 1296.9287 , 1281.8344 , 1264.365  ,
       1246.8557 , 1227.2219 , 1208.6218 , 1191.224  , 1173.5375 ,
       1157.8962 , 1143.0536 , 1130.5546 , 1118.7002 , 1107.0151 ,
       1094.7749 , 1081.4761 , 1066.9366 , 1051.4716 , 1035.6051 ,
       1019.9667 , 1005.10913,  991.2947 ,  977.9688 ,  964.8777 ,
        951.4277 ,  937.2925 ,  922.7064 ,  907.4172 ,  891.6844 ,
        875.42444,  859.3975 ,  844.1927 ,  829.3064 ,  814.4013 ,
        799.3071 ,  784.04645,  768.562  ,  752.70715,  736.6278 ,
        721.0114 ,  705.5345 ,  689.8006 ,  673.26904,  655.9581 ,
        638.05615,  620.0008 ,  601.9278 ,  583.6063 ,  563.9363 ,
        543.6247 ,  523.1898 ,  503.35318,  484.45355,  466.34497,
        448.7949 ,  431.7517 ,  415.64554,  401.39746,  389.49

In [19]:
(stop-start) / 100 * 50*10**3/60/60

21.608966787656147