In [None]:
!pip install jaxopt
!pip install ott-jax
!pip install gpjax

## GP Jax preliminaries

In [None]:
import jax
from jax.scipy.sparse.linalg import gmres
from jax.lax import custom_linear_solve
from jax.numpy.linalg import slogdet
import jax.numpy as jnp

In [None]:
from jax.config import config
# config.update("jax_enable_x64", True)

In [None]:
import jax.random as jr
import gpjax as gpx
from functools import partial

In [None]:
import seaborn as sns
# sns.set_context("notebook")
sns.set_context("paper")

## GP Sinkhorn

In [None]:
import jax.numpy as jnp

pad_size = 2

def cloud_coordinates():
  img_size = 28 - 2*pad_size
  steps = jnp.linspace(-1, 1., num=img_size, endpoint=True)
  x, y = jnp.meshgrid(steps, steps)
  x = x.flatten()
  y = y.flatten()
  grid = jnp.stack([x, y]).T
  return grid

In [None]:
from typing import Any, Optional
from dataclasses import dataclass

from jaxopt.tree_util import tree_l2_norm

from ott.geometry import pointcloud
import ott.core.sinkhorn as sinkhorn

from jaxopt import LBFGS


def mu_cloud_embedding(cloud, mu, 
                       init_dual=(None, None),
                       **kwargs):
  init_dual_cloud, init_dual_mu = init_dual  # for warm start
  weights = cloud  # unpack distribution
  mu_cloud, mu_weight = mu  # unpack distribution

  sinkhorn_epsilon = kwargs.pop('sinkhorn_epsilon')
  cloud_coords = cloud_coordinates()

  mu_w = jax.nn.softmax(mu_weight) if mu_weight is not None else None  # ensure it is a probability distribution
  mu_c = mu_cloud - jnp.mean(mu_cloud, axis=0, keepdims=True)  # invariance by translation : recenter mu around its mean
  scale = 1.0
  mu_c = scale * jnp.tanh(mu_c)

  # common geometry for all images
  geom = pointcloud.PointCloud(cloud_coords, mu_c,
                               epsilon=sinkhorn_epsilon)

  def sinkhorn_single_cloud(cloud_weight,
                            init_dual_cloud, init_dual_mu):
    out = sinkhorn.sinkhorn(geom, cloud_weight, mu_w,
                            init_dual_a=init_dual_cloud,
                            init_dual_b=init_dual_mu,
                            **kwargs)
    return out

  parallel_sinkhorn = jax.vmap(sinkhorn_single_cloud,
                               in_axes=(0, 0, 0),
                               out_axes=0)
  
  outs = parallel_sinkhorn(weights, init_dual_cloud, init_dual_mu)
  init_dual = outs.f, outs.g  # for warm start
  return outs.g, init_dual

In [None]:
def mean_cloud_embedding(cloud, mu_params, init_dual, **kwargs):
  del init_dual  # unused
  coordinates, weights = cloud
  mu_cloud, _ = mu_params
  mean_cloud = jnp.sum(coordinates * weights[:,:,jnp.newaxis], axis=1, keepdims=True)
  pairwise_dist = jnp.sum((mean_cloud - mu_cloud[jnp.newaxis,:,:])**2, axis=-1)
  return pairwise_dist, None

In [None]:
def mu_uniform_ball(sample_train, key, mu_size, radius=0.5, with_weight=False):
  coords = cloud_coordinates()
  dim = coords.shape[-1]
  key_theta, key_r = jax.random.split(key)
  mu_cloud = jax.random.normal(key_theta, shape=(mu_size, dim))
  norms = jnp.sqrt(jnp.sum(mu_cloud**2, axis=1, keepdims=True))
  mu_cloud = mu_cloud / norms
  radii = jax.random.uniform(key_r, shape=(mu_size, 1))
  mu_cloud = mu_cloud * radius * radii
  centroids = jnp.sum(sample_train[:,:,jnp.newaxis] * coords[jnp.newaxis,:,:], axis=1)
  centroids_center = jnp.mean(centroids, axis=0, keepdims=True)
  mu_cloud = mu_cloud + centroids_center  # OT is invariant by translation
  mu_weight = None
  if with_weight:
    mu_weight = jnp.zeros(len(mu_cloud))  # before softmax
  return mu_cloud, mu_weight

## Dataset

In [None]:
import tensorflow as tf
tf.config.experimental.set_visible_devices([], 'GPU')

def img_to_cloud(image):
  img_size = 28
  sliced_cropped = slice(pad_size, img_size-pad_size, None)
  image   = image[sliced_cropped, sliced_cropped]
  weights = image.flatten()
  weights = weights / jnp.sum(weights)
  return weights


# 4,6 for mnist toy
# 5,7 for sandals, sneakers
# 0,5 for tee-shirt, sandals
def process_mnist(seed, ds_size, digits=[4, 6]):  
  train_mnist, (x_test, y_test) = tf.keras.datasets.mnist.load_data()
  # train_mnist, (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
  del train_mnist  # unused
  # select two classes
  target_0 = y_test == digits[0]
  target_1 = y_test == digits[1]
  sample_0 = x_test[target_0]
  sample_1 = x_test[target_1]
  # build subset
  sample = jnp.concatenate([sample_0, sample_1])
  target = jnp.concatenate([jnp.zeros(len(sample_0)), jnp.ones(len(sample_1))])
  target = target.reshape((-1, 1))
  # shuffle data
  key = jax.random.PRNGKey(seed)
  indices = jax.random.permutation(key, len(sample))
  sample = sample[indices]
  target = target[indices]
  # keep few points
  sample = sample[:ds_size]
  target = target[:ds_size]
  # make a cloud
  sample_cloud = jax.vmap(img_to_cloud, in_axes=0, out_axes=0)(sample)
  return sample_cloud, target, sample

## Training loop


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from pprint import PrettyPrinter
import numpy as onp

pp = PrettyPrinter(indent=4)

def plot_mu_points(mu_points, ax):
  mu_points = onp.array(mu_points)
  T = mu_points.shape[0]
  num_pts = mu_points.shape[1]
  points = list(map(str, range(1, num_pts+1)))*T
  timesteps = onp.array([(i//num_pts) for i in range(T * num_pts)])
  mu_points = mu_points - jnp.mean(mu_points, axis=1, keepdims=True)  # invariant by translation
  scale = 1.0
  mu_points = scale * jnp.tanh(mu_points)
  mu_points = mu_points.reshape((-1, 2))
  x_coord = mu_points[:,0]
  y_coord = mu_points[:,1]
  df = pd.DataFrame({'x':x_coord, 'y':y_coord, 'points':points, 'timesteps':timesteps})
  sns.lineplot(data=df, x='x', y='y', hue='points', markers='x', sort=False, marker='o', linestyle='-', ms=15., ax=ax)
  ax.set_title('$\mu$ coordinates')

def plot_loss(losses, ax):
  timesteps = onp.arange(len(losses))
  sns.lineplot(x=timesteps, y=losses, ax=ax)
  ax.set_xlabel('Timesteps')
  ax.set_ylabel('Loss')
  ax.set_title('Negative Log Marginal Likelihood')

def plot_mu_weights(mu_weights, ax):
  if mu_weights is None:
    return None
  mu_weights = onp.array(mu_weights)
  T = mu_weights.shape[0]
  num_pts = mu_weights.shape[1]
  points = list(map(str, range(1, num_pts+1)))*T
  timesteps = onp.array([(i//num_pts) for i in range(T * num_pts)])
  mu_weights = jax.nn.softmax(mu_weights.reshape((T, num_pts)), axis=-1)
  mu_weights = mu_weights.flatten()
  df = pd.DataFrame({'timesteps':timesteps, 'points':points, '$\mu$ weights':mu_weights})
  sns.lineplot(data=df, x='timesteps', y='$\mu$ weights', hue='points', markers='x', sort=False, marker='.', linestyle='-', ax=ax)
  ax.set_xlabel('Timesteps')
  ax.set_ylabel('$\mu$ weights')
  ax.set_title('$\mu$ weights')
  sns.move_legend(ax, "lower left")

In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error, explained_variance_score


def evaluate(cloud_embedding_fn, kernel_params, mu_params,
             posterior, likelihood, constrainer, unconstrainer,
             sample_train, sample_test, y_train, y_test, prefix=''):
  X_train, _ = cloud_embedding_fn(sample_train, mu_params)
  X_test, _ = cloud_embedding_fn(sample_test, mu_params)
  D = gpx.Dataset(X=X_train, y=y_train)
  posterior_fn = posterior(D, kernel_params)

  latent_dist = posterior_fn(X_test)
  predictive_dist = likelihood(latent_dist, kernel_params)
  predictive_mean = predictive_dist.mean()
  predictive_std = predictive_dist.stddev()

  try:
    predictive_mean = predictive_mean.flatten()
    evs = explained_variance_score(y_test.flatten(), predictive_mean)
    label_pred = (predictive_mean >= 0.5).astype(y_test.dtype)
    acc = jnp.mean(label_pred == y_test.flatten())
  except Exception as e:
    evs = float('nan')
    acc = float('nan')
  
  log_likelihood = posterior.marginal_log_likelihood(D, constrainer)(gpx.transform(kernel_params, unconstrainer))
  msg = f"[GPJAX] TrainSetSize={len(X_train)} {prefix}Acc={acc*100:.3f}% evs={evs:.5f} log-likelihood={log_likelihood:.3f}"
  print(msg)
  return evs, acc, log_likelihood

In [None]:
from tqdm import tqdm


def prettify_mu_str(mu):
  max_len = 12
  mu_selected = mu[0][:max_len]
  return "[" + ", ".join([f"({mu_p[0]:.3f},{mu_p[1]:.3f})" for mu_p in mu_selected]) + "]"


def learn(opt, opt_update, loss_fn,
          init_mu, cloud_embedding_fn,
          posterior, likelihood, constrainer, unconstrainer,
          sample_train, sample_test,
          y_train, y_test,
          verbose=False):
  ## init GP params
  parameter_state = gpx.initialise(posterior, key=None)
  constrained_kernel_params, _, _, _ = parameter_state.unpack()
  kernel_params = gpx.transform(constrained_kernel_params, unconstrainer)

  if verbose:
    print('Unconstrained params:', end='');
    pp.pprint(constrained_kernel_params)
    print('Constrained params:', end='');
    pp.pprint(kernel_params)

  ## init Mu Sinkhorn
  mu_params = init_mu(sample_train)
  print("μ: " + prettify_mu_str(mu_params))

  ## Parameters to be optimized by LBFGS
  params = {'kernel_params':kernel_params, 'mu_params':mu_params}

  ## precomputation for speed-up
  init_dual = cloud_embedding_fn(sample_train, mu_params)[1]

  opt_state = opt.init_state(params, init_dual=init_dual)
  mu_hist = [params['mu_params'][0]], [params['mu_params'][1]]
  losses = [float(loss_fn(params, init_dual)[0])]
  log_rate = 1
  pb = tqdm(range(opt.maxiter))
  for step in range(opt.maxiter):
    params, opt_state = opt_update(params, opt_state, init_dual)
    init_dual = opt_state.aux
    loss_val = opt_state.value
    mu_hist[0].append(params['mu_params'][0])
    mu_hist[1].append(params['mu_params'][1])
    losses.append(float(loss_val))
    if step % log_rate == 0:
      pb.update(log_rate)
      kernel_params = params['kernel_params']
      mu_params = params['mu_params']
      mu_str = prettify_mu_str(mu_params)
      kernel_params = gpx.transform(kernel_params, constrainer)
      train_metrics = evaluate(cloud_embedding_fn, kernel_params, mu_params,
           posterior, likelihood, constrainer, unconstrainer,
           sample_train, sample_train, y_train, y_train, prefix='Train')
      pb.set_postfix({"Objective": f"{loss_val: .2f}",
                      "TrainAcc" : train_metrics[1]*100,
                      # "Kernel"   : f"{params['kernel_params']}",
                      # "μ":mu_str
                      })
  pb.close()
  print('')

  kernel_params = params['kernel_params']
  mu_params = params['mu_params']
  kernel_params = gpx.transform(kernel_params, constrainer)
  pp.pprint(kernel_params)

  test_metrics = evaluate(cloud_embedding_fn, kernel_params, mu_params,
           posterior, likelihood, constrainer, unconstrainer,
           sample_train, sample_test, y_train, y_test, prefix='Test')

  return kernel_params, mu_params, (mu_hist, losses, test_metrics, train_metrics)

In [None]:
from jaxopt import OptaxSolver
import optax


def run_experiment(mu_sizes, seeds, sample_train, sample_test, y_train, y_test, *,
                   with_weight):
  ncols = 3
  f, axes = plt.subplots(nrows=len(mu_sizes), ncols=ncols)
  axes = onp.array(axes).reshape((len(mu_sizes), ncols))

  kernel = gpx.RBF()
  prior = gpx.Prior(kernel=kernel)
  likelihood = gpx.Bernoulli(num_datapoints=len(sample_train))
  # likelihood = gpx.Gaussian(num_datapoints=len(sample_train))
  posterior = prior * likelihood

  parameter_state = gpx.initialise(posterior, key=None)
  _, trainable, constrainer, unconstrainer = parameter_state.unpack()

  kwargs = dict(
      sinkhorn_epsilon               = 1e-1 ,
      lse_mode                       = True ,
      implicit_differentiation       = False,
      implicit_solver_ridge_kernel   = 1e-2 ,  # promote zero sum solutions
      implicit_solver_ridge_identity = 1e-2 ,  # regul for ill-posed problem
  )

  cloud_embedding_fn = partial(mu_cloud_embedding, **kwargs)

  def loss_fn(params, init_dual):
    kernel_params = params['kernel_params']
    mu_params = params['mu_params']
    X_train, init_dual = cloud_embedding_fn(sample_train, mu_params, init_dual)
    kernel_params = gpx.parameters.trainable_params(kernel_params, trainable)
    D = gpx.Dataset(X=X_train, y=y_train)
    nll = posterior.marginal_log_likelihood(D, constrainer, negative=True)
    return nll(kernel_params), init_dual

  opt = LBFGS(fun=loss_fn, maxiter=120, tol=1e-3, maxls=20, has_aux=True)
  # optax_opt = optax.adam(learning_rate=5e-2)
  # opt = OptaxSolver(opt=optax_opt, fun=loss_fn, maxiter=100, has_aux=True)

  @jax.jit
  def opt_update(params, opt_state, init_dual):
    return opt.update(params, opt_state, init_dual=init_dual)

  df_stats = []

  for mu_size, ax in zip(mu_sizes, axes):

    print( "##########################################")
    print(f"########### |μ|={mu_size:9d} ################")
    print( "##########################################")

    test_metrics_avg = []
    for i, seed in enumerate(seeds):

      key = jax.random.PRNGKey(seed)
      init_mu = partial(mu_uniform_ball, key=key,
                        mu_size=mu_size, with_weight=with_weight)

      kernel_params, mu_params, metrics = learn(opt, opt_update, loss_fn,
            init_mu, cloud_embedding_fn,
            posterior, likelihood, constrainer, unconstrainer,
            sample_train, sample_test,
            y_train, y_test)
      
      mu_hist, losses, test_metrics, train_metrics = metrics
      test_metrics_avg.append(test_metrics)

      if i+1 == len(seeds):
        plot_mu_points(mu_hist[0], ax[0])
        ax[0].axis('equal')
        plot_mu_weights(mu_hist[1], ax[1])
        plot_loss(losses, ax[2])
        
    test_metrics_avg = onp.array(test_metrics_avg)
    df = pd.DataFrame(data=test_metrics_avg, columns=['EVS', 'TestAcc', 'log_likelihood'])
    df['mu_size'] = mu_size
    df_stats.append(df)

  df_stats = pd.concat(df_stats, axis=0)
  return (kernel_params, mu_params), df_stats, metrics

In [None]:
from sklearn.model_selection import train_test_split

mu_sizes = [4]  # [1, 2, 3, 4, 5, 7, 10, 15, 20]
seeds = [113]  #, 11, 55, 79, 46, 98, 73, 22, 34, 76]
train_size = 200
test_size = 1000  # less stochastic.
ds_size = train_size + test_size
ds_seeds = [911]
tests_accs = []
for ds_seed in ds_seeds:
  sample_ds, target_ds, sample_naked = process_mnist(seed=ds_seed, ds_size=ds_size)
  sample_train, sample_test, y_train, y_test = train_test_split(sample_ds, target_ds, train_size=train_size, shuffle=True, random_state=89)
  plt.rcParams["figure.figsize"] = (24, 8*len(mu_sizes))
  (kernel_params, mu_params), df_stats, metrics = run_experiment(mu_sizes, seeds, sample_train, sample_test, y_train, y_test, with_weight=True)
  test_metric = metrics[1]
  test_acc = test_metric[1]
  tests_accs.append(test_acc)

In [None]:
df_stats.groupby('mu_size').mean()
# df_stats.to_csv('toy2d.csv')

In [None]:
plt.rcParams["figure.figsize"] = (16+8, 8*len(mu_sizes))
sns.set(font_scale=2)
ncols = 2
f, axes = plt.subplots(nrows=1, ncols=ncols)
axes = onp.array(axes).reshape((1, ncols))
mu_hist, losses, test_metrics, train_metrics = metrics
for ax in axes:
  ax[0].imshow(sample_train[2].reshape((24, 24)), extent=onp.array([-1, 1, -1, 1]), cmap='plasma')
  plot_mu_points(mu_hist[0], ax[0])
  ax[0].axis('equal')
  plot_mu_weights(mu_hist[1], ax[1])

In [None]:
def save_mu(mu_hist, losses):
  mu_points, mu_weights = mu_hist
  mu_points = onp.array(mu_points)
  mu_weights = onp.array(mu_weights)

  T = mu_weights.shape[0]
  num_pts = mu_weights.shape[1]

  points = list(map(str, range(1, num_pts+1)))*T
  timesteps = onp.array([(i//num_pts) for i in range(T * num_pts)])

  mu_weights = jax.nn.softmax(mu_weights, axis=-1)
  mu_weights = mu_weights.flatten()
  
  # invariant by translation
  mu_points = mu_points - jnp.mean(mu_points, axis=1, keepdims=True)  
  scale = 1.
  mu_points = scale * jnp.tanh(mu_points)
  mu_points = mu_points.reshape((-1, 2))
  x_coord = mu_points[:,0]
  y_coord = mu_points[:,1]

  losses = onp.array([[v]*num_pts for v in losses]).flatten()

  df = pd.DataFrame({'timesteps':timesteps, 'points':points,
                     '$\mu$ weights':mu_weights,
                     'x':x_coord, 'y':y_coord,
                     'losses':losses})
  df.to_csv('toy_metric.csv')


def save_metrics(metrics):
  mu_hist, losses, test_metrics, train_metrics = metrics
  save_mu(mu_hist, losses)
  test_metrics = onp.array([test_metrics])
  train_metrics = onp.array([train_metrics])
  train_test_metrics = onp.concatenate([train_metrics, test_metrics])
  df_metrics = pd.DataFrame(data=train_test_metrics, columns=['evs', 'rmse', 'mae', 'log_likelihood'])
  df_metrics['name'] = ['train', 'test']
  df_metrics.to_csv('toy_score.csv')

In [None]:
kwargs = dict(
      sinkhorn_epsilon               = 1e-1 ,
      lse_mode                       = True ,
      implicit_differentiation       = False,
      implicit_solver_ridge_kernel   = 1e-2 ,  # promote zero sum solutions
      implicit_solver_ridge_identity = 1e-2 ,  # regul for ill-posed problem
)
cloud_embedding_fn = partial(mu_cloud_embedding, **kwargs)
kernel = gpx.RBF()
prior = gpx.Prior(kernel=kernel)
likelihood = gpx.Bernoulli(num_datapoints=len(sample_train))
# likelihood = gpx.Gaussian(num_datapoints=len(sample_train))
posterior = prior * likelihood
parameter_state = gpx.initialise(posterior, key=None)
_, trainable, constrainer, unconstrainer = parameter_state.unpack()
evaluate(cloud_embedding_fn, kernel_params, mu_params,
           posterior, likelihood, constrainer, unconstrainer,
           sample_train, sample_test, y_train, y_test)

In [None]:
mu_params