In [5]:
import sys
sys.path.append("../..")

In [6]:
import math
import numpy as np
import bqplot.pyplot as bplt
from ipywidgets import HBox

import tensorflow as tf
from tensorflow.keras import layers
import tensorflow_probability as tfp

from filterflow.resampling.differentiable.loss.regularized import SinkhornLoss
from filterflow.resampling.differentiable.loss.regularized import SinkhornLoss
from filterflow.resampling.differentiable.regularized_transport.utils import cost
from filterflow.resampling.differentiable.biased import RegularisedTransform
from filterflow.resampling.standard.systematic import SystematicResampler
from filterflow.base import State


In [7]:
class OptimizedPointCloud(tf.Module):
    def __init__(self, n_particles, dimension, name="OptimizedPointCloud"):
        super(OptimizedPointCloud, self).__init__(name=name)
        self._n_particles = n_particles
        self._dimension = dimension
                
        self.cost_and_weight_repr = tf.Variable(tf.random.uniform([2,], -0.5, 0.5), name='mix_weights')
        self._contribution_weight_1 = tf.Variable(tf.random.uniform([self._n_particles, self._n_particles//2], -1./n_particles, 1./n_particles), name='contribution_weights_1')
        self._contribution_weight_2 = tf.Variable(tf.random.uniform([self._n_particles//2, self._n_particles], -1./n_particles, 1./n_particles), name='contribution_weights_2')
    
    @tf.function
    def _mix_cost_and_weight(self, cost, log_weights):
        float_n_particles = tf.cast(self._n_particles, float)
        weighted_cost = self.cost_and_weight_repr[0] * cost
        weighted_log_weights = self.cost_and_weight_repr[1] * (log_weights - tf.math.log(float_n_particles))
        
        temp = weighted_cost + tf.expand_dims(weighted_log_weights, 1)
        return 2. * tf.math.sigmoid(temp) -1.
    
    @tf.function
    def _make_contribution(self, transformed_input):
        temp = tf.matmul(transformed_input, self._contribution_weight_1)
        temp = 2. * tf.math.sigmoid(temp) - 1.
        temp = tf.matmul(temp, self._contribution_weight_2)
        return temp
        
    @tf.function
    def _input_transform(self, log_w, x):
        return self._mix_cost_and_weight(cost(x, x), log_w)
        
    @tf.function
    def _normalize(self, x):
        mean = tf.reduce_mean(x, 1, keepdims=True)
        x_ = x - mean
        std = tf.math.reduce_std(x_, 1, keepdims=True)
        return mean, std
        
    def __call__(self, log_w, x):
        mean, std = self._normalize(x)
        
        transformed_input = self._input_transform(log_w, (x - mean)/std)

        float_n_particles = tf.cast(self._n_particles, float)

        log_contribution_weights = self._make_contribution(transformed_input)
        log_contribution_weights = log_contribution_weights - tf.reduce_logsumexp(log_contribution_weights, 2, keepdims=True)
        log_contribution_weights = tf.math.log(float_n_particles) + log_contribution_weights + tf.expand_dims(log_w, 1)
        contribution_weights = tf.exp(log_contribution_weights)
#         contribution_weights = float_n_particles * contribution_weights
        z =  tf.linalg.matmul(contribution_weights, (x - mean)/std)
        
        reg_lines = tf.reduce_sum(tf.abs(tf.reduce_sum(contribution_weights, 1) - float_n_particles* tf.math.exp(log_w)))
        reg_cols = tf.reduce_sum(tf.abs(tf.reduce_sum(contribution_weights, 2) - 1.))
        return std*(z + mean), reg_lines + reg_cols
        

In [8]:
n_iter = 10000

In [9]:
B = 10
N = 25
D = 2

epsilon = tf.constant(0.25)
loss = SinkhornLoss(epsilon, True)

In [10]:
regularized_transform = RegularisedTransform(epsilon)
optimized_point_cloud = OptimizedPointCloud(N, D)
systematic_resampler = SystematicResampler()


In [11]:
scatter_fig = bplt.figure(animation_duration=0)
scatter_fig.layout.height = '500px'
scatter_fig.layout.width = '500px'
learnt_scatter = bplt.scatter([], [], size=[], colors = ['blue'])
regularized_scatter = bplt.scatter([], [], size=[], colors = ['green'], alpha=0.75)
initial_scatter = bplt.scatter([], [], size=[], colors = ['red'])
bplt.set_lim(-5., 5., 'y')
_ = bplt.set_lim(-5., 5., 'x')




In [12]:
def fill_na(tensor):
    mask = tf.math.is_finite(tensor)
    return tf.where(mask, tensor, tf.zeros_like(tensor))

In [13]:
losses_fig = bplt.figure(animation_duration=0)
losses_fig.layout.height = '500px'
losses_fig.layout.width = '500px'
systematic_loss_plot = bplt.plot([], [], colors = ['red'])
learnt_loss_plot = bplt.plot([], [], colors = ['blue'])
bplt.set_lim(0., 3., 'y')
# _ = bplt.set_lim(0., n_iter, 'x')

LinearScale(max=3.0, min=0.0)

In [14]:
HBox([scatter_fig, losses_fig])

HBox(children=(Figure(axes=[Axis(scale=LinearScale(max=5.0, min=-5.0)), Axis(orientation='vertical', scale=Lin…

In [None]:
tf.random.set_seed(666)
patience = 25
has_waited = 0
lr_decay = 0.9
current_loss = 1e5
n_descent_per_batch = 1

lr = tf.Variable(1e-2)

optimizer = tf.keras.optimizers.Adam(lr=lr)

trainable_variables = optimized_point_cloud.trainable_variables

uniform_log_weights = tf.zeros([B, N]) - math.log(N)
uniform_weights = tf.zeros([B, N]) + 1/N

systematic_losses = []
learnt_losses = []

for i in range(n_iter):
    random_x = tf.random.normal([B, N, D], -1., 1.)
    random_w = tf.random.uniform([B, N], 0., 1.) ** 2
    random_w /= tf.reduce_sum(random_w, 1, keepdims=True)
    random_log_w = tf.math.log(random_w)
    for _ in range(n_descent_per_batch):
        with tf.GradientTape() as tape:
            tape.watch(trainable_variables)
            generated_particules, reg = optimized_point_cloud(random_log_w, random_x)
            batch_loss = loss(random_log_w, random_w, random_x, uniform_log_weights, uniform_weights, generated_particules)
            loss_value = tf.reduce_mean(batch_loss + reg)
            gradients = tape.gradient(loss_value, trainable_variables)
        optimizer.apply_gradients([(fill_na(grad), var) for grad, var in zip(gradients, trainable_variables)])
    
        
    if loss_value.numpy().sum() < current_loss:
        has_waited = 0
        current_loss = loss_value.numpy().sum()
    else:
        has_waited += 1
        
    if has_waited >= patience:
        current_loss = 1e5
        lr.assign(lr * lr_decay)
        has_waited = 0
        patience = min( patience + 10, 200)
        
    state = State(random_x, random_log_w, random_w, tf.zeros([B,]))
    flags = tf.ones([B,], dtype=bool)
    systematic_state = systematic_resampler.apply(state, flags)
    systematic_loss = tf.reduce_mean(loss(random_log_w, random_w, random_x, uniform_log_weights, uniform_weights, systematic_state.particles))
    
    systematic_loss_plot.x = np.arange(i)
    learnt_loss_plot.x = np.arange(i)
    
    learnt_losses.append(tf.reduce_mean(batch_loss).numpy().sum())
    systematic_losses.append(systematic_loss.numpy().sum())
    
    
    systematic_loss_plot.y = systematic_losses
    learnt_loss_plot.y = learnt_losses
    
    if i % 50 == 0:

        regularized_state = regularized_transform.apply(state, flags)
        regularized_loss = tf.reduce_mean(loss(random_log_w, random_w, random_x, uniform_log_weights, uniform_weights, regularized_state.particles))

        
        learnt_scatter.x = generated_particules[0, :, 0]
        learnt_scatter.y = generated_particules[0, :, 1]
        
        initial_scatter.x = random_x[0, :, 0]
        initial_scatter.y = random_x[0, :, 1]
        initial_scatter.size = random_w[0] * 100
        
        regularized_scatter.x = regularized_state.particles[0, :, 0]
        regularized_scatter.y = regularized_state.particles[0, :, 1]
        
        print(f'Step {i} loss: {tf.reduce_mean(batch_loss)}, systematic comparison:{systematic_loss}, regularized comparison: {regularized_loss}, lr: {lr.numpy().sum()}, reg: {reg.numpy().sum()}')


Step 0 loss: 0.7448053956031799, systematic comparison:0.03441043943166733, regularized comparison: 0.029747450724244118, lr: 0.009999999776482582, reg: 0.3794863820075989
Step 50 loss: 0.8223794102668762, systematic comparison:0.03613913804292679, regularized comparison: 0.0281593706458807, lr: 0.009999999776482582, reg: 0.019738195464015007
Step 100 loss: 0.7418352365493774, systematic comparison:0.0219185259193182, regularized comparison: 0.028830811381340027, lr: 0.008999999612569809, reg: 0.0032449164427816868
Step 150 loss: 0.7918053865432739, systematic comparison:0.03258129954338074, regularized comparison: 0.06700434535741806, lr: 0.008099999278783798, reg: 0.0007443295326083899
Step 200 loss: 0.7546194791793823, systematic comparison:0.03899209946393967, regularized comparison: 0.02625749073922634, lr: 0.008099999278783798, reg: 0.0003857287229038775
Step 250 loss: 0.9041797518730164, systematic comparison:0.0411408506333828, regularized comparison: 0.0335797555744648, lr: 0.