In [1]:
import os, sys
sys.path.append("../..")

import attr
import math
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import tqdm

In [2]:
tf.config.set_visible_devices([], 'GPU') 

In [3]:
from filterflow.base import State
from filterflow.resampling.standard import SystematicResampler, StratifiedResampler, MultinomialResampler
from filterflow.resampling.base import NoResampling
from filterflow.resampling.differentiable import RegularisedTransform, CorrectedRegularizedTransform
from filterflow.resampling.differentiable.optimized import OptimizedPointCloud
from filterflow.resampling.differentiable.optimizer.sgd import SGD
from filterflow.resampling.differentiable.ricatti.solver import PetkovSolver
from filterflow.resampling.differentiable.loss.sliced_wasserstein.swd import SlicedWassersteinDistance
from filterflow.resampling.differentiable.loss.regularized import SinkhornLoss
from filterflow.resampling.differentiable.loss.sliced_wasserstein.utils import sqeuclidean

This notebook aims to demonstrate the differentiability issue encountered at resampling time. To do this we will compare functionals of the point cloud whilst changing a parameter.

In [4]:
tf.random.set_seed(42)

B = 1
N = 50
D = 10

x = tf.random.normal([B, N, D], 0., 1.)
y = tf.zeros(D)

weights = tf.random.uniform([B, N], 0., 1.)
weights = weights / tf.reduce_sum(weights, axis=1, keepdims=True)
log_weights = tf.zeros([B, N]) - math.log(N)



In [5]:
@tf.function
def log_likelihood(state, observation, resampler, seed=666):
    tf.random.set_seed(seed)
    rv = tfp.distributions.MultivariateNormalDiag(tf.zeros(D), tf.ones(D))
    flags = tf.constant([True])
    log_prob = rv.log_prob(observation-state.particles)
    log_weights = log_prob - tf.reduce_logsumexp(log_prob, 1, keepdims=True)
    state = attr.evolve(state, log_weights=log_weights, weights=tf.math.exp(log_weights))
    state = resampler.apply(state, flags)
    log_prob = rv.log_prob(observation-state.particles) + state.log_weights
    return tf.reduce_logsumexp(log_prob)
    

In [6]:
linspace = np.linspace(-0.5, 0.5, 150).astype(np.float32)

In [7]:
# do not decorate this. seed is being set
def get_data(linspace, resampler, x, y):
    res = []
    grads = []
    for z_val in tqdm.tqdm(linspace):
        z = z_val + tf.zeros(D)
        tf.random.set_seed(666)
        with tf.GradientTape() as tape:
            tape.watch(z)
            state = State(x + z, log_weights, tf.math.exp(log_weights), tf.constant([0.]))
            ll = log_likelihood(state, y, resampler)
        ll_grad = tape.gradient(ll, z)
        res.append(ll.numpy().sum())
        grads.append(ll_grad.numpy().sum())
    return res, grads
        

In [8]:
systematic = SystematicResampler()
multinomial = MultinomialResampler()
stratified = StratifiedResampler()
no_resampling = NoResampling()

epsilon = tf.constant(0.1)
scaling = tf.constant(0.5)
convergence_threshold = tf.constant(1e-3)
max_iter = tf.constant(500)

regularized = RegularisedTransform(epsilon, scaling, max_iter, convergence_threshold)

step_size = tf.constant(0.25)
horizon = tf.constant(5.)
threshold = tf.constant(1e-2)

solver = PetkovSolver(n_iter=tf.constant(30))
corrected_no_grad = CorrectedRegularizedTransform(epsilon, scaling, max_iter, convergence_threshold, ricatti_solver=solver, propagate_correction_gradient=False)
corrected = CorrectedRegularizedTransform(epsilon, scaling, max_iter, convergence_threshold, ricatti_solver=solver, propagate_correction_gradient=True)

sinkhorn_loss = SinkhornLoss(epsilon, symmetric=True, scaling=scaling, max_iter=tf.constant(100), convergence_threshold=convergence_threshold)
sinkhorn_optimizer = SGD(sinkhorn_loss, 50., 50, 0.95)
sinkhorn_optimized_cloud = OptimizedPointCloud(sinkhorn_optimizer, regularized)

sliced_loss = SlicedWassersteinDistance(10, sqeuclidean)
sliced_optimizer = SGD(sliced_loss, 1., 20, 0.75)
sliced_optimized_cloud = OptimizedPointCloud(sliced_optimizer, regularized)

In [None]:
# no_resampling_data, no_resampling_grad = get_data(linspace, no_resampling, x, y)
# systematic_data, systematic_grad = get_data(linspace, systematic, x, y)
# multinomial_data, multinomial_grad = get_data(linspace, multinomial, x, y)
# stratified_data, stratified_grad = get_data(linspace, stratified, x, y)

# regularized_data, regularized_grad = get_data(linspace, regularized, x, y)
# corrected_no_grad_data, corrected_no_grad_grad = get_data(linspace, corrected_no_grad, x, y)
# corrected_data, corrected_grad = get_data(linspace, corrected, x, y)
# sliced_optimized_data, sliced_optimized_grad = get_data(linspace, sliced_optimized_cloud, x, y)
sinkhorn_optimized_data, sinkhorn_optimized_grad = get_data(linspace, sinkhorn_optimized_cloud, x, y)

  0%|                                                                                          | 0/150 [00:00<?, ?it/s]

Instructions for updating:
Do not pass `graph_parents`.  They will  no longer be used.
loss [0.329190254]
0.0161165837
loss [0.225150913]
0.00572164729
loss [0.212497786]
0.00145895965
loss [0.211328596]
0.00125881936
loss [0.210395068]
0.00171431154
loss [0.209821343]
0.00139287626
loss [0.209500343]
0.000550209545
loss [0.209342316]
0.000309508061
loss [0.209246486]
0.000342474552
loss [0.209162921]
0.000414597569
loss [0.209064573]
0.000552246231
loss [0.208922938]
0.000793768966
loss [0.208697632]
0.00115333148
loss [0.208362475]
0.00145966606
loss [0.208000809]
0.00135326479
loss [0.207761198]
0.000925853499
loss [0.207650214]
0.000553482852
loss [0.207602859]
0.000324870634
loss [0.207581252]
0.000195009285
loss [0.207570061]
0.000124006649
loss [0.207563341]
8.71028751e-05
loss [0.207558811]
6.24395907e-05
loss [0.207554981]
4.97037545e-05
loss [0.207551703]
4.07951884e-05
loss [0.207548559]
3.73837538e-05
loss [0.207545757]
3.66657041e-05
loss [0.207543075]
3.59762926e-05
loss 

  1%|▌                                                                                 | 1/150 [00:07<18:36,  7.50s/it]

loss [0.326616228]
0.0159585867
loss [0.222948432]
0.00593814533
loss [0.212132573]
0.00153439678
loss [0.210941717]
0.00158034312
loss [0.209898725]
0.0019737347
loss [0.209324956]
0.000388602493
loss [0.209225804]
0.000385869294
loss [0.209170327]
0.000388691202
loss [0.209124044]
0.000356021337
loss [0.209083065]
0.000305550173
loss [0.209044352]
0.000253140926
loss [0.209003359]
0.000218093395
loss [0.208950222]
0.000381959137
loss [0.208859757]
0.000793555519
loss [0.208636343]
0.00190612208
loss [0.207938403]
0.00324318279
loss [0.206851]
0.00230120495
loss [0.206365272]
0.00130687188
loss [0.206194401]
0.000763300806
loss [0.206130296]
0.000461468473
loss [0.206106246]
0.000288170762
loss [0.206096396]
0.000185397454
loss [0.206093609]
0.000122600235
loss [0.206092775]
8.31531361e-05
loss [0.206093505]
5.7737343e-05
loss [0.206093788]
4.58396971e-05
loss [0.206094533]
4.3834094e-05
loss [0.206094906]
4.20091674e-05
loss [0.206095368]
4.03451268e-05
loss [0.206095487]
3.8828468e-

  1%|█                                                                                 | 2/150 [00:11<15:33,  6.31s/it]

loss [0.323320031]
0.0156765636
loss [0.203639567]
0.00535042
loss [0.189563513]
0.00251387479
loss [0.187036961]
0.00136436895
loss [0.186484069]
0.000426042825
loss [0.186361477]
0.000351348892
loss [0.186287493]
0.000275040045
loss [0.186231077]
0.0002106959
loss [0.186170369]
0.000345557695
loss [0.186073482]
0.000767429359
loss [0.185809731]
0.00208260212
loss [0.184852749]
0.00330899656
loss [0.183668539]
0.00173340272
loss [0.183366865]
0.00186791271
loss [0.183016986]
0.00322722644
loss [0.181934416]
0.00209486857
loss [0.181387872]
0.00156252657
loss [0.181218565]
0.00147142354
loss [0.181130692]
0.00211894698
loss [0.181043252]
0.00157474168
loss [0.181025207]
0.000996947289
loss [0.181016907]
0.000649705529
loss [0.181013525]
0.000496985391
loss [0.181017756]
0.000367961824
loss [0.181024864]
0.000274132937
loss [0.18103157]
0.000207652018
loss [0.181037053]
0.000160236377
loss [0.181041181]
0.000125966966
loss [0.181044579]
0.000100774545
loss [0.181047082]
8.19452107e-05
l

  2%|█▋                                                                                | 3/150 [00:14<13:23,  5.46s/it]

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(15, 5), sharex=True, sharey=True)
axes[0].plot(linspace, no_resampling_data, label='no resampling', linestyle='--', color='k')
axes[1].plot(linspace, no_resampling_data, label='no resampling', linestyle='--', color='k')
axes[0].step(linspace, systematic_data, label='systematic', alpha=0.75)
axes[0].step(linspace, multinomial_data, label='multinomial', alpha=0.75)
axes[0].step(linspace, stratified_data, label='stratified', alpha=0.75)
# axes[1].plot(linspace, regularized_data, label='regularized')
axes[1].plot(linspace, corrected_data, label='corrected')
axes[1].plot(linspace, corrected_no_grad_data, label='corrected_no_grad')
axes[1].plot(linspace, optimized_data, label='optimized_data')
_ = axes[0].legend(), axes[1].legend()
fig.savefig(os.path.join('./charts/', 'differentiability_illustration_likelihood.png'))

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(15, 5), sharex=True, sharey=False)
axes[0].step(linspace, no_resampling_grad, label='no resampling', linestyle='--', color='k')
axes[1].step(linspace, no_resampling_grad, label='no resampling', linestyle='--', color='k')
axes[0].step(linspace, systematic_grad, label='systematic', alpha=0.75)
axes[0].step(linspace, multinomial_grad, label='multinomial', alpha=0.75)
axes[0].step(linspace, stratified_grad, label='stratified', alpha=0.75)
# axes[1].plot(linspace, regularized_grad, label='regularized')
axes[1].plot(linspace, corrected_grad, label='corrected')
axes[1].plot(linspace, corrected_no_grad_grad, label='corrected_no_grad')
axes[1].plot(linspace, optimized_grad, label='optimized_data')
_ = axes[0].legend(), axes[1].legend()
fig.savefig(os.path.join('./charts/', 'differentiability_illustration_gradient.png'))