### **DISCLAIMER**: To run this notebook, you need to install `jax` and `ott`.

They were not included in the standard environment, because their use is limited to this notebook.
If you manage your Python environment using Conda (and you have CUDA drivers already installed), you can install the required packages with:

    conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
    conda install -c conda-forge ott-jax

In [1]:
%env XLA_PYTHON_CLIENT_MEM_FRACTION=.20

env: XLA_PYTHON_CLIENT_MEM_FRACTION=.20


In [2]:
import jax
import jax.numpy as jnp
import ott
import pandas as pd

# Compute Wasserstein distance

In [3]:
def make_geometry(t0_points, t1_points):
    """ Set up inital/final cloud points living in space endowed with squared Eucliden distance
    """
    point_cloud = ott.geometry.pointcloud.PointCloud(t0_points, t1_points, ott.geometry.costs.SqEuclidean())
    return point_cloud

In [4]:
def compute_ot(t0_points, t1_points):
    """ Solve OT problem
    """
    point_cloud = make_geometry(t0_points, t1_points)
    sinkhorn = ott.solvers.linear.sinkhorn.Sinkhorn()(ott.problems.linear.linear_problem.LinearProblem(point_cloud))
    return sinkhorn

In [5]:
def transport(ot, init_points):
    return ot.to_dual_potentials().transport(init_points)

In [6]:
def compute_wasserstein_2(preds, true):
    ot = compute_ot(preds, true)
    return jnp.sqrt(ot.transport_cost_at_geom(make_geometry(preds, true))).item()

In [7]:
test_true = jnp.load("../reproducibility/statephate/data/statephate_embs_final_test.npy")

##### Perform multiple evaluations

In [10]:
def load_repetitions(prefix, reps_num=10):
    return [jnp.load(f"{prefix}_{i}.npy") for i in range(reps_num)]

In [11]:
def compute_robust_wasserstein_2(reps, true):
    w2_scores = []

    for rep in reps:
        w2_scores.append(compute_wasserstein_2(rep[-1], true))

    return pd.DataFrame(pd.DataFrame(w2_scores).apply(lambda x: {"mean": x.mean(), "std": x.std()}).to_dict())

In [12]:
print("SBalign W2")
compute_robust_wasserstein_2(load_repetitions("results/sb-align_test_trajs"), test_true)

SBalign W2


Unnamed: 0,0
mean,11.113329
std,0.017947


In [13]:
print("Baseline W2")
compute_robust_wasserstein_2(load_repetitions("results/baseline_test_trajs"), test_true)

Baseline W2


Unnamed: 0,0
mean,12.496513
std,0.037506


In [14]:
print("Baseline with SBalign W2")
compute_robust_wasserstein_2(load_repetitions("results/baseline_augmented_test_trajs"), test_true)

Baseline with SBalign W2


Unnamed: 0,0
mean,10.541547
std,0.079424
