In [None]:
import sys
sys.path.append("../..")

In [2]:
import math
import numpy as np
import bqplot.pyplot as bplt
from ipywidgets import HBox

import tensorflow as tf
from tensorflow.keras import layers
import tensorflow_probability as tfp

from filterflow.resampling.differentiable.loss.regularized import SinkhornLoss
from filterflow.resampling.differentiable.regularized_transport.utils import cost
from filterflow.resampling.differentiable.biased import RegularisedTransform
from filterflow.resampling.standard.systematic import SystematicResampler
from filterflow.base import State


In [3]:
class OptimizedPointCloud(tf.Module):
    def __init__(self, n_particles, dimension, name="OptimizedPointCloud"):
        super(OptimizedPointCloud, self).__init__(name=name)
        self._n_particles = n_particles
        self._dimension = dimension
                
        self.cost_and_weight_repr = tf.Variable(tf.random.normal([2,]), name='mix_weights')
        self._contribution_weight = tf.Variable(tf.random.normal([self._n_particles, self._n_particles]), name='contribution_weights')
    
    @tf.function
    def _mix_cost_and_weight(self, cost, log_weights):
        weighted_cost = self.cost_and_weight_repr[0] * cost
        weighted_log_weights = self.cost_and_weight_repr[1] * log_weights
        
        temp = weighted_cost + tf.expand_dims(weighted_log_weights, 1)
        return 2. * tf.math.sigmoid(temp) - 1.
    
    @tf.function
    def _make_contribution(self, transformed_input):
        temp = tf.matmul(self._contribution_weight, transformed_input)
        return tf.nn.softmax(temp)
        
    @tf.function
    def _input_transform(self, log_w, x):
        return self._mix_cost_and_weight(cost(x, x), log_w)
        
    @tf.function
    def _normalize(self, x):
        mean = tf.reduce_mean(x, 1, keepdims=True)
        x_ = x - mean
        std = tf.math.reduce_std(x_, 1, keepdims=True)
        return mean, std
        
    def __call__(self, log_w, x):
        mean, std = self._normalize(x)
        
        transformed_input = self._input_transform(log_w, x - mean)
        
        contribution_weights = self._make_contribution(transformed_input)
        float_n_particles = tf.cast(self._n_particles, float)
        z =  float_n_particles * tf.linalg.matmul(contribution_weights, x - mean)
        
        matrix_for_reg = contribution_weights / float_n_particles
        
        reg_lines = tf.reduce_mean(tf.abs(tf.reduce_sum(matrix_for_reg, 1) - tf.cast(self._n_particles, float) * tf.math.exp(log_w)))
        
        return z + mean, reg_lines
        

In [4]:
B = 50
N = 25
D = 2

epsilon = tf.constant(0.25)
loss = SinkhornLoss(epsilon, True)

In [5]:
regularized_transform = RegularisedTransform(epsilon)
optimized_point_cloud = OptimizedPointCloud(N, D)
systematic_resampler = SystematicResampler()


In [6]:
scatter_fig = bplt.figure(animation_duration=0)
scatter_fig.layout.height = '500px'
scatter_fig.layout.width = '500px'
learnt_scatter = bplt.scatter([], [], size=[], colors = ['blue'])
regularized_scatter = bplt.scatter([], [], size=[], colors = ['green'])
initial_scatter = bplt.scatter([], [], size=[], colors = ['red'])
bplt.set_lim(-2., 2., 'y')
_ = bplt.set_lim(-2., 2., 'x')




In [7]:
def fill_na(tensor):
    mask = tf.math.is_finite(tensor)
    return tf.where(mask, tensor, tf.zeros_like(tensor))

In [8]:
n_iter = 10000

In [9]:
losses_fig = bplt.figure(animation_duration=0)
losses_fig.layout.height = '500px'
losses_fig.layout.width = '500px'
systematic_loss_plot = bplt.plot([], [], colors = ['red'])
learnt_loss_plot = bplt.plot([], [], colors = ['blue'])
bplt.set_lim(0., 3., 'y')
# _ = bplt.set_lim(0., n_iter, 'x')




LinearScale(max=3.0, min=0.0)

In [10]:
HBox([scatter_fig, losses_fig])

HBox(children=(Figure(axes=[Axis(scale=LinearScale(max=2.0, min=-2.0)), Axis(orientation='vertical', scale=Lin…

In [None]:
tf.random.set_seed(666)
patience = 25
has_waited = 0
lr_decay = 0.9
current_loss = 1e5
n_descent_per_batch = 1

lr = tf.Variable(1e-2)

optimizer = tf.keras.optimizers.Adam(lr=lr)

trainable_variables = optimized_point_cloud.trainable_variables

uniform_log_weights = tf.zeros([B, N]) - math.log(N)
uniform_weights = tf.zeros([B, N]) + 1/N

systematic_losses = []
learnt_losses = []

for i in range(n_iter):
    random_x = tf.random.uniform([B, N, D], -1., 1.)
    random_w = tf.random.uniform([B, N], 0., 1.) ** 2
    random_w /= tf.reduce_sum(random_w, 1, keepdims=True)
    random_log_w = tf.math.log(random_w)
    for _ in range(n_descent_per_batch):
        with tf.GradientTape() as tape:
            tape.watch(trainable_variables)
            generated_particules, reg = optimized_point_cloud(random_log_w, random_x)
            batch_loss = loss(random_log_w, random_w, random_x, uniform_log_weights, uniform_weights, generated_particules)
            loss_value = tf.reduce_mean(batch_loss + reg)
            gradients = tape.gradient(loss_value, trainable_variables)
        optimizer.apply_gradients([(fill_na(grad), var) for grad, var in zip(gradients, trainable_variables)])
    
        
    if loss_value.numpy().sum() < current_loss:
        has_waited = 0
        current_loss = loss_value.numpy().sum()
    else:
        has_waited += 1
        
    if has_waited >= patience:
        current_loss = 1e5
        lr.assign(lr * lr_decay)
        has_waited = 0
        patience = min( patience + 10, 200)
        
    state = State(random_x, random_log_w, random_w, tf.zeros([B,]))
    flags = tf.ones([B,], dtype=bool)
    systematic_state = systematic_resampler.apply(state, flags)
    systematic_loss = tf.reduce_mean(loss(random_log_w, random_w, random_x, uniform_log_weights, uniform_weights, systematic_state.particles))
    
    systematic_loss_plot.x = np.arange(i)
    learnt_loss_plot.x = np.arange(i)
    
    learnt_losses.append(tf.reduce_mean(batch_loss).numpy().sum())
    systematic_losses.append(systematic_loss.numpy().sum())
    
    
    systematic_loss_plot.y = systematic_losses
    learnt_loss_plot.y = learnt_losses
    
    if i % 50 == 0:

        regularized_state = regularized_transform.apply(state, flags)
        regularized_loss = tf.reduce_mean(loss(random_log_w, random_w, random_x, uniform_log_weights, uniform_weights, regularized_state.particles))

        
        learnt_scatter.x = generated_particules[0, :, 0]
        learnt_scatter.y = generated_particules[0, :, 1]
        
        initial_scatter.x = random_x[0, :, 0]
        initial_scatter.y = random_x[0, :, 1]
        initial_scatter.size = random_w[0] * 100
        
        regularized_scatter.x = regularized_state.particles[0, :, 0]
        regularized_scatter.y = regularized_state.particles[0, :, 1]
        
        print(f'Step {i} loss: {tf.reduce_mean(batch_loss)}, systematic comparison:{systematic_loss}, regularized comparison: {regularized_loss}, lr: {lr.numpy().sum()}')


Step 0 loss: 10.261614799499512, systematic comparison:0.000904199609067291, regularized comparison: -0.0265206266194582, lr: 0.009999999776482582
Step 50 loss: 0.11571108549833298, systematic comparison:0.004766217898577452, regularized comparison: -0.030225159600377083, lr: 0.009999999776482582
Step 100 loss: 0.027278423309326172, systematic comparison:0.003684626892209053, regularized comparison: -0.02957852929830551, lr: 0.009999999776482582
Step 150 loss: 0.030860750004649162, systematic comparison:0.004141117446124554, regularized comparison: -0.02832578495144844, lr: 0.008999999612569809
Step 200 loss: 0.014565377496182919, systematic comparison:-0.000919794081710279, regularized comparison: -0.02932380884885788, lr: 0.008099999278783798
Step 250 loss: 0.03676600381731987, systematic comparison:0.005799758248031139, regularized comparison: -0.024094391614198685, lr: 0.008099999278783798
Step 300 loss: 0.025295810773968697, systematic comparison:9.612459507479798e-06, regularized

In [None]:
random_x = tf.random.uniform([B, N, D], -4., 1.)
random_w = tf.random.uniform([B, N], 0., 1.) ** 3

random_w /= tf.reduce_sum(random_w, 1, keepdims=True)
random_log_w = tf.math.log(random_w)

In [None]:
1 / tf.reduce_sum(random_w ** 2, 1)

In [None]:
new_particles = optimized_point_cloud(random_log_w, random_x)[0]

In [None]:
new_particles