In [None]:
!pip install jaxopt
!pip install gpjax
!pip install ott-jax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting jaxopt
  Downloading jaxopt-0.5-py3-none-any.whl (128 kB)
[K     |████████████████████████████████| 128 kB 4.8 MB/s 
Installing collected packages: jaxopt
Successfully installed jaxopt-0.5
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gpjax
  Downloading GPJax-0.4.13.tar.gz (24 kB)
Collecting optax
  Downloading optax-0.1.3-py3-none-any.whl (145 kB)
[K     |████████████████████████████████| 145 kB 7.1 MB/s 
[?25hCollecting chex
  Downloading chex-0.1.5-py3-none-any.whl (85 kB)
[K     |████████████████████████████████| 85 kB 4.9 MB/s 
[?25hCollecting distrax>=0.1.2
  Downloading distrax-0.1.2-py3-none-any.whl (272 kB)
[K     |████████████████████████████████| 272 kB 45.1 MB/s 
Collecting ml-collections==0.1.0
  Downloading ml_collections-0.1.0-py3-none-any.whl (88 kB)
[K     |████████████████████████████████|

## GP Jax preliminaries

In [None]:
import jax
from jax.scipy.sparse.linalg import gmres
from jax.lax import custom_linear_solve
from jax.numpy.linalg import slogdet
import jax.numpy as jnp

In [None]:
from jax.config import config
# config.update("jax_enable_x64", True)

In [None]:
import jax.random as jr
import gpjax as gpx
from functools import partial

In [None]:
import seaborn as sns
# sns.set_context("notebook")
sns.set_context("paper")

## GP Sinkhorn

In [None]:
import jax.numpy as jnp

In [None]:
# Found by running NoteBook LearnSinkhornParameters.ipynb

mu_cloud_57 = jnp.array([[ 0.21957754,  0.35253927],
              [ 0.09953722,  0.43077431],
              [ 0.9886206 ,  0.11487304],
              [-0.83145712,  0.1669271 ],
              [ 0.27657292,  0.12463248],
              [-0.06175206,  0.21355569],
              [-0.04373814,  0.23610753],
              [ 0.30061063,  0.15764027],
              [ 0.10935807,  0.41247011],
              [-0.01664592,  0.29173418],
              [ 0.87559493, -0.05147166],
              [-0.08324035,  0.21878795]], dtype=jnp.float32)
mu_weight_57 = jnp.array([ 9.62563782e-03,  8.53329137e-03, -1.21150131e-01,
               4.27755472e-01, -3.74874932e-04, -2.53500209e-02,
               1.69277536e-03, -3.16603754e-02, -1.62934858e-02,
              -2.89706357e-02, -1.59076041e-01, -6.47316129e-02])

mu_cloud_56 = jnp.array([[-0.18145709,  0.20396382],
              [-0.17958481,  0.20788199],
              [-1.44540799, -1.63469062],
              [ 2.0360876 , -0.12869805],
              [ 1.57829722, -0.18106633],
              [-0.26955691,  0.09504965],
              [ 0.34187511, -0.54820849],
              [ 0.30926302, -0.48505812],
              [ 0.04961068,  1.89436538],
              [ 0.05049443,  1.9277267 ],
              [-1.39822343, -1.20600227],
              [ 0.04183291,  1.54766735]])
mu_weight_56 = jnp.array([ 1.11278815,  1.0128134 , -0.35208267, -0.2118818 ,
              -0.13187659, -0.0698886 ,  0.27349365,  0.07536408,
              -0.57247179, -0.55270456, -0.04717465, -0.53637863])

mu_cloud_46 = jnp.array([[ 1.3449012 ,  2.13853616],
              [ 0.45771468, -0.16584545],
              [-1.39573938, -1.01686017],
              [-0.19325787,  1.67778046],
              [-0.48856439, -2.10435212],
              [ 0.5307469 , -0.81295445]], dtype=jnp.float32)
mu_weight_46 = jnp.array([-0.16987539,  0.04750267,  0.46313042, -0.18759383,
               0.05566591, -0.20882978], dtype=jnp.float32)

mu_cloud_46_5 = jnp.array([[ 0.22206083,  1.71624839],
              [-0.43609571, -0.50998092],
              [-0.19480746,  0.93104592],
              [-0.38481522, -2.075897  ],
              [-0.12979581, -0.1755457 ]], dtype=jnp.float32)
mu_weight_46_5 = jnp.array([-0.13618056,  0.35912072, -0.59738579,  0.60334502,
              -0.2288994 ], dtype=jnp.float32)

mu_cloud_46_4 = jnp.array([[-0.19373875, -1.8789795 ],
              [-0.14753972,  1.90018948],
              [ 0.16684833, -0.23762417],
              [ 0.11818121, -0.75299912]], dtype=jnp.float32)
mu_weight_46_4 = jnp.array([ 0.99958116,  0.82718838,  1.90731148, -3.73408101], dtype=jnp.float32)

mu = mu_cloud_46_4, mu_weight_46_4

In [None]:
def cloud_coordinates(image_size, pad_size):
  img_size = image_size - 2*pad_size
  steps = jnp.linspace(-1, 1., num=img_size, endpoint=True)
  x, y = jnp.meshgrid(steps, steps)
  x = x.flatten()
  y = y.flatten()
  grid = jnp.stack([x, y]).T
  grid = jnp.array(grid, dtype=jnp.float32)
  return grid

In [None]:
from typing import Any, Optional
from dataclasses import dataclass
from jaxopt.tree_util import tree_l2_norm
from jaxopt import LBFGS

In [None]:
def mean_cloud_embedding(cloud, mu_params, init_dual, **kwargs):
  del init_dual  # unused
  coordinates, weights = cloud
  mu_cloud, _ = mu_params
  mean_cloud = jnp.sum(coordinates * weights[:,:,jnp.newaxis], axis=1, keepdims=True)
  pairwise_dist = jnp.sum((mean_cloud - mu_cloud[jnp.newaxis,:,:])**2, axis=-1)
  return pairwise_dist, None

In [None]:
from ott.geometry import pointcloud
import ott.core.sinkhorn as sinkhorn


def mu_cloud_embedding(cloud, mu,
                       image_size=28,
                       pad_size=0,
                       **kwargs):
  weights = cloud  # unpack distribution
  mu_cloud, mu_weight = mu  # unpack distribution

  sinkhorn_epsilon = kwargs.pop('sinkhorn_epsilon')
  cloud_coords = cloud_coordinates(image_size=image_size, pad_size=pad_size)

  mu_w = jax.nn.softmax(mu_weight) if mu_weight is not None else None  # ensure it is a probability distribution
  mu_c = mu_cloud - jnp.mean(mu_cloud, axis=0, keepdims=True)  # invariance by translation : recenter mu around its mean
  scale = 1.0
  mu_c = scale * jnp.tanh(mu_c)

  # common geometry for all images
  geom = pointcloud.PointCloud(cloud_coords, mu_c,
                               epsilon=sinkhorn_epsilon)

  def sinkhorn_single_cloud(cloud_weight):
    out = sinkhorn.sinkhorn(geom, cloud_weight, mu_w,
                            **kwargs)
    return out

  parallel_sinkhorn = jax.vmap(sinkhorn_single_cloud,
                               in_axes=0,
                               out_axes=0)
  
  outs = parallel_sinkhorn(weights)
  init_dual = outs.f, outs.g  # for warm start
  return outs.g, init_dual
  
kwargs = dict(
    sinkhorn_epsilon               = 1e-1 ,
    lse_mode                       = True ,
    implicit_differentiation       = False,
    implicit_solver_ridge_kernel   = 1e-2 ,  # promote zero sum solutions
    implicit_solver_ridge_identity = 1e-2 ,  # regul for ill-posed problem
)

cloud_embedding_fn = partial(mu_cloud_embedding, **kwargs)

## Dataset

In [None]:
import tensorflow as tf
from jax.image import resize
tf.config.experimental.set_visible_devices([], 'GPU')

def img_to_cloud(image, image_size, pad_size):
  image = resize(image, (image_size, image_size), "cubic")
  sliced_cropped = slice(pad_size, image_size-pad_size, None)
  image   = image[sliced_cropped, sliced_cropped]
  weights = image.flatten()
  weights = weights / jnp.sum(weights)
  weights = jnp.array(weights, dtype=jnp.float32)
  return weights


# 4,6 for mnist toy
# 5,7 for sandals, sneakers
# 0,5 for tee-shirt, sandals
def process_mnist(seed, ds_size, image_size=28, pad_size=2, digits=[4, 6]):  
  train_mnist, (x_test, y_test) = tf.keras.datasets.mnist.load_data()
  # train_mnist, (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
  del train_mnist  # unused
  # select two classes
  target_0 = y_test == digits[0]
  target_1 = y_test == digits[1]
  sample_0 = x_test[target_0]
  sample_1 = x_test[target_1]
  # build subset
  sample = jnp.concatenate([sample_0, sample_1])
  target = jnp.concatenate([jnp.zeros(len(sample_0)), jnp.ones(len(sample_1))])
  target = target.reshape((-1, 1))
  # shuffle data
  key = jax.random.PRNGKey(seed)
  indices = jax.random.permutation(key, len(sample))
  sample = sample[indices]
  target = target[indices]
  # keep few points
  sample = sample[:ds_size]
  target = target[:ds_size]
  # make a cloud
  img_to_cloud_fun = partial(img_to_cloud, image_size=image_size, pad_size=pad_size)
  sample_cloud = jax.vmap(img_to_cloud_fun, in_axes=0, out_axes=0)(sample)
  return sample_cloud, target, sample

## Training loop


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from pprint import PrettyPrinter
import numpy as onp

pp = PrettyPrinter(indent=4)

def plot_loss(losses, ax):
  timesteps = onp.arange(len(losses))
  sns.lineplot(x=timesteps, y=losses, ax=ax)
  ax.set_xlabel('Timesteps')
  ax.set_ylabel('Loss')
  ax.set_title('Negative Log Marginal Likelihood')

In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error, explained_variance_score


def evaluate(kernel_params,
             posterior, likelihood, constrainer, unconstrainer,
             sample_train, sample_test, y_train, y_test, prefix=''):
  X_train = sample_train
  X_test = sample_test
  D = gpx.Dataset(X=X_train, y=y_train)
  posterior_fn = posterior(D, kernel_params)

  latent_dist = posterior_fn(X_test)
  predictive_dist = likelihood(latent_dist, kernel_params)
  predictive_mean = predictive_dist.mean()
  predictive_std = predictive_dist.stddev()

  try:
    predictive_mean = predictive_mean.flatten()
    evs = explained_variance_score(y_test.flatten(), predictive_mean)
    label_pred = (predictive_mean >= 0.5).astype(y_test.dtype)
    acc = jnp.mean(label_pred == y_test.flatten())
  except Exception as e:
    evs = float('nan')
    acc = float('nan')
  
  log_likelihood = posterior.marginal_log_likelihood(D, constrainer)(gpx.transform(kernel_params, unconstrainer))
  msg = f"[GPJAX] TrainSetSize={len(X_train)} {prefix}Acc={acc*100:.3f}% evs={evs:.5f} log-likelihood={log_likelihood:.3f}"
  print(msg)
  return evs, acc, log_likelihood

In [None]:
from tqdm import tqdm


def learn(opt, opt_update, loss_fn,
          posterior, likelihood, constrainer, unconstrainer,
          sample_train, sample_test,
          y_train, y_test,
          verbose=False):
  ## init GP params
  parameter_state = gpx.initialise(posterior, key=None)
  constrained_kernel_params, _, _, _ = parameter_state.unpack()
  kernel_params = gpx.transform(constrained_kernel_params, unconstrainer)

  if verbose:
    print('Unconstrained params:', end='');
    pp.pprint(constrained_kernel_params)
    print('Constrained params:', end='');
    pp.pprint(kernel_params)
  
  ## Parameters to be optimized by LBFGS
  params = {'kernel_params':kernel_params}

  opt_state = opt.init_state(params)
  losses = [float(loss_fn(params))]
  log_rate = 1
  pb = tqdm(range(opt.maxiter))
  for step in range(opt.maxiter):
    params, opt_state = opt_update(params, opt_state)
    loss_val = opt_state.value
    losses.append(float(loss_val))
    if step % log_rate == 0:
      pb.update(log_rate)
      kernel_params = params['kernel_params']
      kernel_params = gpx.transform(kernel_params, constrainer)
      train_metrics = evaluate(kernel_params,
           posterior, likelihood, constrainer, unconstrainer,
           sample_train, sample_train, y_train, y_train, prefix='Train')
      pb.set_postfix({"Objective": f"{loss_val: .2f}",
                      "TrainAcc" : train_metrics[1]*100
                      })
  pb.close()
  print('')

  kernel_params = params['kernel_params']
  kernel_params = gpx.transform(kernel_params, constrainer)
  pp.pprint(kernel_params)

  test_metrics = evaluate(kernel_params,
           posterior, likelihood, constrainer, unconstrainer,
           sample_train, sample_test, y_train, y_test, prefix='Test')

  return kernel_params, (losses, test_metrics, train_metrics)

In [None]:
from typing import Dict
from jaxtyping import Array, Float
from dataclasses import dataclass


@dataclass(repr=False)
class MMD_Mnist(gpx.kernels.Kernel):
  image_size: int = 28
  pad_size: int = 0
  name: Optional[str] = "Maximum Mean Discrepancy"

  def __post_init__(self, ):
    self.ndims = 1 if not self.active_dims else len(self.active_dims)
    lengthscale = 2.
    self.coords = cloud_coordinates(image_size=self.image_size, pad_size=self.pad_size)

    # TODO: pre-compute.
    coords = self.coords / 1.0  # shape (size_cloud, 2)
    left_coords = coords[:,jnp.newaxis,:]  # shape (size_cloud, 1, 2)
    right_coords = coords[jnp.newaxis,:,:]  # shape (1, size_cloud, 2)
    squared_distance = jnp.sum((left_coords - right_coords) ** 2, axis=-1)
    K = jnp.exp(-0.5 * squared_distance)  # shape (size_cloud, size_cloud)
    self.K = K

  def __call__(self, x, y, params):
    K = self.K

    xx = x[:,jnp.newaxis] * x[jnp.newaxis,:]
    yy = y[:,jnp.newaxis] * y[jnp.newaxis,:]
    xy = x[:,jnp.newaxis] * y[jnp.newaxis,:]
    weights = xx + yy - 2*xy  # shape (size_cloud)
    weights = weights * (len(x)*len(y))**0.5
    mmd_dst = jnp.sum(weights * K)  # shape (,)

    similarity = jnp.exp(-params["smoothness "] * mmd_dst**2)
    similarity = params["variance"] * similarity

    return jnp.squeeze(similarity)

  def _initialise_params(self, key: jnp.DeviceArray) -> Dict:
    return {
        "lengthscale": jnp.array([1.0]),
        "smoothness ": jnp.array([1.0]),
        "variance": jnp.array([1.0]),
    }

In [None]:
import time

def run_benchmark_MMD(n, m):
  image_size = round(m**0.5)
  pad_size = 0
  sample_ds, _, _ = process_mnist(seed=42, ds_size=n, image_size=image_size, pad_size=pad_size)

  kernel = MMD_Mnist(image_size=image_size, pad_size=pad_size)
  kernel_params = kernel._initialise_params(42)

  def run_ker(x, y):
    return kernel(x, y, kernel_params)
  
  run_ker_left = jax.vmap(run_ker, in_axes=(0, None))
  run_ker_right = jax.vmap(run_ker_left, in_axes=(None, 0))
  jitted_run_ker = jax.jit(run_ker_right)

  # jit for speed.
  _ = jitted_run_ker(sample_ds, sample_ds).block_until_ready()

  tic = time.perf_counter()
  jitted_run_ker(sample_ds, sample_ds).block_until_ready()
  toc = time.perf_counter()

  print(f"{toc - tic:0.6f} seconds")

In [None]:
run_benchmark_MMD(50, 100)

0.001422 seconds


In [None]:
run_benchmark_MMD(100, 100)

0.004920 seconds


In [None]:
run_benchmark_MMD(100, 400)

0.055241 seconds


In [None]:
run_benchmark_MMD(400, 400)

0.682734 seconds


In [None]:
run_benchmark_MMD(400, 625)

1.681392 seconds


In [None]:
run_benchmark_MMD(1000, 625)

10.833395 seconds


In [None]:
run_benchmark_MMD(1000, 1000)

14.206577 seconds


In [None]:
def run_benchmark_Sinkhorn(n, m):
  image_size = round(m**0.5)
  pad_size = 0
  sample_ds, _, _ = process_mnist(seed=42, ds_size=n, image_size=image_size, pad_size=pad_size)

  kernel = gpx.kernels.RBF()
  kernel_params = kernel._initialise_params(42)
  kernel_params = {'lengthscale': jnp.array([0.3], dtype=jnp.float32), 'variance': jnp.array([1.], dtype=jnp.float32)}

  cloud_embedding_fn_closure = partial(cloud_embedding_fn, image_size=image_size, pad_size=pad_size)

  def run_ker(embed_x, embed_y):
    return kernel(embed_x, embed_y, kernel_params)

  def bench_ker(x, y):
    embed_x, _ = cloud_embedding_fn_closure(cloud=x, mu=mu)
    embed_y, _ = cloud_embedding_fn_closure(cloud=y, mu=mu)
    run_ker_left = jax.vmap(run_ker, in_axes=(0, None))
    run_ker_right = jax.vmap(run_ker_left, in_axes=(None, 0))
    return run_ker_right(embed_x, embed_y)

  sample_ds = onp.random.uniform(size=sample_ds.shape) + 0.1
  sample_ds = sample_ds / onp.sum(sample_ds, axis=-1, keepdims=True)
  sample_ds = jnp.array(sample_ds, dtype=jnp.float32)

  # jit for speed.
  jitted_run_ker = jax.jit(bench_ker)

  useless = jitted_run_ker(sample_ds, sample_ds).block_until_ready()

  tic = time.perf_counter()
  jitted_run_ker(sample_ds, sample_ds).block_until_ready()
  toc = time.perf_counter()

  print(f"{toc - tic:0.6f} seconds")

run_benchmark_Sinkhorn(50, 100)

0.026048 seconds


In [None]:
run_benchmark_Sinkhorn(50, 100)

0.026524 seconds


In [None]:
run_benchmark_Sinkhorn(100, 100)

0.033414 seconds


In [None]:
run_benchmark_Sinkhorn(100, 400)

0.024946 seconds


In [None]:
run_benchmark_Sinkhorn(400, 400)

0.071065 seconds


In [None]:
run_benchmark_Sinkhorn(400, 625)

0.097655 seconds


In [None]:
run_benchmark_Sinkhorn(1000, 625)

0.115044 seconds


In [None]:
run_benchmark_Sinkhorn(1000, 1000)

0.156624 seconds


In [None]:
from jaxopt import OptaxSolver
import optax


def run_experiment(seeds, sample_train, sample_test, y_train, y_test):
  ncols = 3
  f, ax = plt.subplots(nrows=1, ncols=1)

  kernel = MMD_Mnist()  # gpx.RBF()
  prior = gpx.Prior(kernel=kernel)
  likelihood = gpx.Bernoulli(num_datapoints=len(sample_train))
  # likelihood = gpx.Gaussian(num_datapoints=len(sample_train))
  posterior = prior * likelihood

  parameter_state = gpx.initialise(posterior, key=None)
  _, trainable, constrainer, unconstrainer = parameter_state.unpack()

  kwargs = dict(
      sinkhorn_epsilon               = 1e-1 ,
      lse_mode                       = True ,
      implicit_differentiation       = False,
      implicit_solver_ridge_kernel   = 1e-2 ,  # promote zero sum solutions
      implicit_solver_ridge_identity = 1e-2 ,  # regul for ill-posed problem
  )

  def loss_fn(params):
    kernel_params = params['kernel_params']
    X_train = sample_train
    kernel_params = gpx.parameters.trainable_params(kernel_params, trainable)
    D = gpx.Dataset(X=X_train, y=y_train)
    nll = posterior.marginal_log_likelihood(D, constrainer, negative=True)
    return nll(kernel_params)

  opt = LBFGS(fun=loss_fn, maxiter=120, tol=1e-3, maxls=20, has_aux=False)
  # optax_opt = optax.adam(learning_rate=5e-2)
  # opt = OptaxSolver(opt=optax_opt, fun=loss_fn, maxiter=100, has_aux=True)

  @jax.jit
  def opt_update(params, opt_state):
    return opt.update(params, opt_state)

  test_metrics_avg = []
  for i, seed in enumerate(seeds):

    key = jax.random.PRNGKey(seed)

    kernel_params, metrics = learn(opt, opt_update, loss_fn,
          posterior, likelihood, constrainer, unconstrainer,
          sample_train, sample_test,
          y_train, y_test)
    
    losses, test_metrics, train_metrics = metrics
    test_metrics_avg.append(test_metrics)

    if i+1 == len(seeds):
      ax.axis('equal')
      plot_loss(losses, ax)
  
  return (kernel_params,), metrics

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
from sklearn.model_selection import train_test_split
seeds = [997]  #, 11, 55, 79, 46, 98, 73, 22, 34, 76]
plt.rcParams["figure.figsize"] = (24, 8)
train_size = 200
test_size = 1000  # less stochastic.
ds_size = train_size + test_size
ds_seeds = [615, 31, 987, 156, 987, 29, 68, 648, 21, 94, 49, 165, 1, 64561, 471, 32, 986, 7, 38, 968, 14, 65, 78, 9, 33]
tests_accs = []
for ds_seed in ds_seeds:
  sample_ds, target_ds, sample_naked = process_mnist(seed=ds_seed, ds_size=ds_size)
  # sample_ds, _ = cloud_embedding_fn(sample_ds, mu)
  sample_train, sample_test, y_train, y_test = train_test_split(sample_ds, target_ds, train_size=train_size, shuffle=True, random_state=89)
  (kernel_params,), metrics = run_experiment(seeds, sample_train, sample_test, y_train, y_test)
  test_metric = metrics[1]
  test_acc = test_metric[1]
  tests_accs.append(test_acc)

In [None]:
def save_metrics(metrics):
  losses, test_metrics, train_metrics = metrics
  test_metrics = onp.array([test_metrics])
  train_metrics = onp.array([train_metrics])
  train_test_metrics = onp.concatenate([train_metrics, test_metrics])
  df_metrics = pd.DataFrame(data=train_test_metrics, columns=['evs', 'rmse', 'mae', 'log_likelihood'])
  df_metrics['name'] = ['train', 'test']
  df_metrics.to_csv('toy_score.csv')

In [None]:
kwargs = dict(
      sinkhorn_epsilon               = 1e-1 ,
      lse_mode                       = True ,
      implicit_differentiation       = False,
      implicit_solver_ridge_kernel   = 1e-2 ,  # promote zero sum solutions
      implicit_solver_ridge_identity = 1e-2 ,  # regul for ill-posed problem
)
kernel = MMD_Mnist()  # gpx.RBF()
prior = gpx.Prior(kernel=kernel)
likelihood = gpx.Bernoulli(num_datapoints=len(sample_train))
# likelihood = gpx.Gaussian(num_datapoints=len(sample_train))
posterior = prior * likelihood
parameter_state = gpx.initialise(posterior, key=None)
_, trainable, constrainer, unconstrainer = parameter_state.unpack()
evaluate(kernel_params,
         posterior, likelihood, constrainer, unconstrainer,
         sample_train, sample_test, y_train, y_test)

In [None]:
tests_accs

In [None]:
onp.mean(sorted(tests_accs))

In [None]:
onp.std(sorted(tests_accs))

In [None]:
plt.imshow(process_mnist(seed=ds_seed, ds_size=ds_size)[0][0].reshape((24,24)))

In [None]:
mu