In [47]:
import ott
import jax
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.geometry import pointcloud, geometry, costs, graph
from ott.solvers.quadratic import gromov_wasserstein
from typing import Optional, Any
import numpy as np
import jax.numpy as jnp
from tqdm import tqdm
import moscot
from sklearn import preprocessing as pp
from moscot import datasets
from ott.neural import datasets
from ott.neural.methods.flows import dynamics, otfm
from ott.neural.networks.layers import time_encoder
from ott.neural.networks.velocity_field import VelocityField
from ott.solvers import utils as solver_utils
from torch.utils.data import DataLoader
import jax.numpy as jnp
from typing import Literal, Optional
import scanpy as sc
import functools
import optax
from functools import partial
import functools
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np

import diffrax
from flax.training import train_state

from ott import utils
from ott.neural.methods.flows import dynamics

from ott.neural.networks import velocity_field
from ott.solvers import utils as solver_utils


In [18]:
adata_atac = datasets.bone_marrow(rna=False)
adata_rna = datasets.bone_marrow(rna=True)

In [21]:
adata_atac.obsm["ATAC_lsi_l2_norm"] = pp.normalize(
    adata_atac.obsm["ATAC_lsi_red"], norm="l2"
)

In [29]:
x_quad = adata_atac.obsm["ATAC_lsi_l2_norm"]
x_lin = adata_atac.obsm['geneactivity_scvi']
y_quad = adata_rna.obsm['GEX_X_pca']
y_lin = adata_rna.obsm['geneactivity_scvi']

In [31]:
@jax.jit
def solve_gw(epsilon: float, xx: jax.Array, yy: jax.Array, cost_fn : Any):
    ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=epsilon)
    geom_xx = pointcloud.PointCloud(
        x=xx, y=xx, cost_fn=cost_fn, scale_cost="mean"
    )
    geom_yy = pointcloud.PointCloud(
        x=yy, y=yy, cost_fn=cost_fn, scale_cost="mean"
    )
    geom_xy = None
    prob = quadratic_problem.QuadraticProblem(
        geom_xx, geom_yy, geom_xy,
    )
    return ot_solver(prob)

In [32]:
@jax.jit
def solve_fgw(epsilon: float, xx: jax.Array, yy: jax.Array, xy_x: jax.Array, xy_y: jax.Array, fused_penalty: float, cost_fn : str):
    ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=epsilon)
    geom_xx = pointcloud.PointCloud(
        x=xx, y=xx, cost_fn=cost_fn, scale_cost="mean"
    )
    geom_yy = pointcloud.PointCloud(
        x=yy, y=yy, cost_fn=cost_fn, scale_cost="mean"
    )
    geom_xy = pointcloud.PointCloud(
                    x=xy_x, y=xy_y, cost_fn=cost_fn, scale_cost="mean"
                )
    prob = quadratic_problem.QuadraticProblem(
        geom_xx, geom_yy, geom_xy, fused_penalty=fused_penalty,
    )
    return ot_solver(prob)

In [33]:
epsilon=0.01
n = m = 1024
n_cells = adata_rna.n_obs
N_POINTS = 10
N_DRAWS = 100

In [34]:
x_all = np.concatenate((x_lin, x_quad), axis=1)
y_all = np.concatenate((y_lin, y_quad), axis=1)

# GW coupling

In [36]:
rng = np.random.default_rng(12345)

vars_gw = [None] * N_POINTS
for it in tqdm(range(N_POINTS)):
    minibatch_match = [None] * N_DRAWS
    x_fixed = rng.choice(x_quad, size=(1,))
    for i in range(N_DRAWS):
        xx = rng.choice(x_quad, size=(n-1,))
        yy = rng.choice(y_quad, size=(n,))
        xx = np.concatenate((x_fixed, xx), axis=0)
        out = solve_gw(epsilon, xx, yy, costs.SqEuclidean())
        minibatch_match[i] = yy[out.matrix[0].argmax()]
    vars_gw[it] = np.var(minibatch_match, axis=0)
        

  0%|          | 0/10 [00:00<?, ?it/s]2024-08-02 11:20:11.562140: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.3 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
100%|██████████| 10/10 [01:40<00:00, 10.02s/it]


In [37]:
np.mean(vars_gw, axis=0)

array([ 5.341936  ,  9.443396  ,  3.4130402 , 18.488602  , 11.778585  ,
        2.2195802 ,  1.6063015 ,  1.7096084 ,  2.4169955 ,  1.0565075 ,
        1.3675957 ,  1.4125022 ,  1.6875887 ,  0.915408  ,  1.4444823 ,
        1.6825899 ,  1.7936014 ,  1.5318329 ,  1.9978215 ,  1.189678  ,
        1.148832  ,  0.8153895 ,  0.8552225 ,  0.5410267 ,  0.63023376,
        0.63284415,  0.6548734 ,  0.6377155 ,  0.569869  ,  0.46078387,
        0.36792287,  0.49002132,  0.50971776,  0.4240129 ,  0.48965827,
        0.35645932,  0.44659686,  0.39116758,  0.47028342,  0.49672455,
        0.3743592 ,  0.4045298 ,  0.4306882 ,  0.30464956,  0.41981164,
        0.38552368,  0.3943028 ,  0.3405323 ,  0.38330466,  0.36002928],
      dtype=float32)

# FGW coupling

In [38]:
fused_dim = adata_rna.obsm['geneactivity_scvi'].shape[1]

In [41]:
fused_penalty = 5.0

vars_fgw = [None] * N_POINTS
for it in tqdm(range(N_POINTS)):
    minibatch_match = [None] * N_DRAWS
    x_fixed = rng.choice(x_all, size=(1,))
    for i in range(N_DRAWS):
        x = rng.choice(x_all, size=(n-1,))
        y = rng.choice(y_all, size=(n,))
        x = np.concatenate((x_fixed, x), axis=0)
        xy_x = x[:, :fused_dim]
        xx = x[:, fused_dim:]
        xy_y = y[:, :fused_dim]
        yy = y[:, fused_dim:]
        out = solve_fgw(epsilon, xx, yy, xy_x, xy_y, fused_penalty, costs.SqEuclidean())
        minibatch_match[i] = yy[out.matrix[0].argmax()]
    vars_fgw[it] = np.var(minibatch_match, axis=0)
        

100%|██████████| 10/10 [11:22<00:00, 68.23s/it]


In [42]:
np.mean(vars_fgw, axis=0)

array([1.6058352 , 0.71844906, 1.5548128 , 1.7208456 , 1.6152325 ,
       1.444094  , 0.9143399 , 0.8446681 , 1.0240452 , 0.7637906 ,
       0.6131214 , 0.6254627 , 0.6617712 , 0.6505443 , 0.84522295,
       0.5436702 , 0.6355413 , 0.56149304, 0.40460578, 0.49956036,
       0.5260143 , 0.43565196, 0.49143115, 0.4675246 , 0.42419046,
       0.45044756, 0.53047365, 0.47891778, 0.49142194, 0.45686287,
       0.34254217, 0.445892  , 0.43057093, 0.3878977 , 0.56677836,
       0.41601315, 0.48215023, 0.40180627, 0.30989015, 0.51922053,
       0.3438385 , 0.37571892, 0.4225463 , 0.4530788 , 0.38479957,
       0.31896955, 0.30284244, 0.30423164, 0.47299033, 0.35403764],
      dtype=float32)

# Outer coupling

In [43]:
rng = np.random.default_rng(12345)

vars_outer = [None] * N_POINTS
for it in tqdm(range(N_POINTS)):
    minibatch_match = [None] * N_DRAWS
    x_fixed = rng.choice(x_quad, size=(1,))
    for i in range(N_DRAWS):
        xx = rng.choice(x_quad, size=(n-1,))
        yy = rng.choice(y_quad, size=(n,))
        xx = np.concatenate((x_fixed, xx), axis=0)
        minibatch_match[i] = yy[rng.choice(n)]
    vars_outer[it] = np.var(minibatch_match, axis=0)
        

100%|██████████| 10/10 [00:00<00:00, 17.52it/s]


In [44]:
np.mean(vars_outer, axis=0)

array([63.85145   , 20.996494  , 10.8110895 ,  9.172075  ,  4.8902535 ,
        2.973329  ,  2.418523  ,  1.9865059 ,  1.9095405 ,  1.7250545 ,
        1.4696648 ,  1.589028  ,  1.4415276 ,  1.4233559 ,  1.2567345 ,
        1.1282915 ,  0.98204535,  0.8703457 ,  0.9029452 ,  0.8260309 ,
        0.8104348 ,  0.7899305 ,  0.72356737,  0.704216  ,  0.6943593 ,
        0.6676404 ,  0.6392656 ,  0.6633253 ,  0.6335087 ,  0.65308446,
        0.6496994 ,  0.6102631 ,  0.56338954,  0.5937041 ,  0.6269791 ,
        0.5373762 ,  0.57377195,  0.5638007 ,  0.5281928 ,  0.58500785,
        0.5747962 ,  0.61202157,  0.5183285 ,  0.62495965,  0.57612765,
        0.5513538 ,  0.66942555,  0.5211011 ,  0.5400032 ,  0.52233875],
      dtype=float32)

# GW coupling with graph cost

In [48]:
def get_nearest_neighbors(
    X: jnp.ndarray, Y: Optional[jnp.ndarray], k: int = 30  # type: ignore[name-defined]
) -> Tuple[jnp.ndarray, jnp.ndarray]:  # type: ignore[name-defined]
    concat = X if Y is None else jnp.concatenate((X, Y), axis=0) 
    pairwise_euclidean_distances = pointcloud.PointCloud(concat, concat).cost_matrix
    distances, indices = jax.lax.approx_min_k(
        pairwise_euclidean_distances, k=k, recall_target=0.95, aggregate_to_topk=True
    )
    connectivities = jnp.multiply(jnp.exp(-distances),  (distances>0))
    return connectivities/jnp.sum(connectivities), indices


def create_cost_matrix_quad(X: jnp.array, k_neighbors: int, **kwargs: Any) -> jnp.array:
    distances, indices = get_nearest_neighbors(X, None, k_neighbors)
    a = jnp.zeros((len(X), len(X)))
    adj_matrix = a.at[
        jnp.repeat(jnp.arange(len(X)), repeats=k_neighbors).flatten(), indices.flatten()
    ].set(distances.flatten())
    return graph.Graph.from_graph(adj_matrix, normalize=kwargs.pop("normalize", True), **kwargs).cost_matrix

@jax.jit
def solve_gw_geodesic(epsilon: float, xx: jax.Array, yy: jax.Array, k_neighbors=1024):
    ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=epsilon)
    cm_xx = create_cost_matrix_quad(xx, k_neighbors)
    cm_yy = create_cost_matrix_quad(yy, k_neighbors)
    geom_xx = geometry.Geometry(cost_matrix=cm_xx, epsilon=epsilon, scale_cost="mean")
    geom_yy = geometry.Geometry(cost_matrix=cm_yy, epsilon=epsilon, scale_cost="mean")
    geom_xy = None
    prob = quadratic_problem.QuadraticProblem(
        geom_xx, geom_yy, geom_xy,
    )
    
    return ot_solver(prob)


In [51]:
rng = np.random.default_rng(12345)

vars_gw_with_cost = [None] * N_POINTS
for it in tqdm(range(N_POINTS)):
    minibatch_match = [None] * N_DRAWS
    x_fixed = rng.choice(x_quad, size=(1,))
    for i in range(N_DRAWS):
        xx = rng.choice(x_quad, size=(n-1,))
        yy = rng.choice(y_quad, size=(n,))
        xx = np.concatenate((x_fixed, xx), axis=0)
        out = solve_gw_geodesic(1e-4, xx, yy)
        minibatch_match[i] = yy[out.matrix[0].argmax()]
    vars_gw_with_cost[it] = np.var(minibatch_match, axis=0)
        

100%|██████████| 10/10 [21:44<00:00, 130.48s/it]


In [50]:
np.mean(vars_gw_with_cost, axis=0)

array([26.144226  ,  5.2994776 ,  4.110175  ,  3.3605618 ,  2.1154704 ,
        2.2987833 ,  1.8305962 ,  1.4375892 ,  1.5837132 ,  1.3724697 ,
        1.1490113 ,  1.4541018 ,  1.2324321 ,  0.6696581 ,  1.1929839 ,
        1.0228605 ,  1.3315439 ,  0.579856  ,  1.5750699 ,  0.9631475 ,
        0.94529456,  0.6166785 ,  0.8490292 ,  0.5294205 ,  0.62870467,
        0.55993813,  0.7111932 ,  0.4478217 ,  0.6927124 ,  0.4607266 ,
        0.49126992,  0.5105599 ,  0.34721106,  0.33905137,  0.6510705 ,
        0.31404868,  0.4170688 ,  0.34568197,  0.34320244,  0.3746317 ,
        0.31040493,  0.43551064,  0.44723076,  0.2847697 ,  0.51012313,
        0.26826063,  0.33211073,  0.29835093,  0.3791681 ,  0.3026885 ],
      dtype=float32)

In [52]:
np.mean(vars_gw_with_cost, axis=0)

array([30.287369  , 13.795336  ,  3.25704   ,  8.520558  ,  3.118068  ,
        2.4210818 ,  2.2824252 ,  1.4459902 ,  1.5132645 ,  1.7648203 ,
        1.337975  ,  1.1572365 ,  0.95345414,  1.0153188 ,  0.85283965,
        0.9153244 ,  0.81427366,  0.67998827,  0.72692597,  0.73937255,
        0.7375054 ,  0.81443584,  0.6848441 ,  0.6463245 ,  0.67465246,
        0.51106584,  0.6099758 ,  0.50120467,  0.59036005,  0.4672802 ,
        0.38954112,  0.46826306,  0.47254533,  0.46390995,  0.46100408,
        0.41751584,  0.37316605,  0.372652  ,  0.44961944,  0.44628996,
        0.40203506,  0.40338048,  0.40815306,  0.44192058,  0.39213556,
        0.3654736 ,  0.3348021 ,  0.40522486,  0.36673665,  0.35686824],
      dtype=float32)

In [55]:
np.median(vars_gw_with_cost)

0.49824405

In [56]:
np.median(vars_outer)

0.73889935

In [57]:
vars_gw_with_cost

[array([56.36818   ,  0.31370762,  4.848273  ,  8.252392  ,  0.44863963,
         2.7913184 ,  4.34932   ,  1.5811455 ,  1.1619692 ,  2.8873563 ,
         0.84533864,  1.3445387 ,  0.29838789,  0.96899104,  0.6282118 ,
         0.8479105 ,  0.38007632,  0.3448022 ,  0.9715004 ,  0.4719854 ,
         0.3180483 ,  0.80603766,  0.5998819 ,  0.49886167,  0.74294114,
         0.21643727,  0.62291235,  0.3186474 ,  0.29721567,  0.38680282,
         0.436132  ,  0.35899487,  0.26891625,  0.29621464,  0.30375597,
         0.2549983 ,  0.3000758 ,  0.28349128,  0.34830874,  0.39077517,
         0.2113358 ,  0.34109733,  0.28902444,  0.24463736,  0.2960405 ,
         0.3071557 ,  0.23506777,  0.2744166 ,  0.21660623,  0.24227029],
       dtype=float32),
 array([26.71894   , 33.838123  ,  3.1722476 ,  9.732677  ,  2.870113  ,
         2.7596786 ,  1.6742207 ,  1.4702929 ,  2.8194113 ,  1.5336356 ,
         3.7468905 ,  1.1574912 ,  0.9973824 ,  2.0079722 ,  1.0125808 ,
         1.0160599 ,  0.929

In [58]:
import seaborn as sns

In [59]:
np.median(vars_gw_with_cost, axis=0)

array([34.305176  ,  3.5097694 ,  3.6789331 ,  9.337931  ,  3.3062851 ,
        2.143671  ,  1.8235841 ,  1.4599905 ,  1.3220536 ,  1.6588954 ,
        1.0521667 ,  1.2148612 ,  0.7936456 ,  0.95127463,  0.7075889 ,
        0.9319852 ,  0.6431408 ,  0.6159438 ,  0.7700067 ,  0.65434134,
        0.6550793 ,  0.82781446,  0.59966195,  0.5872264 ,  0.7679812 ,
        0.47623324,  0.5941071 ,  0.4398327 ,  0.46886104,  0.39886877,
        0.41193587,  0.4148548 ,  0.4483583 ,  0.43510646,  0.4017622 ,
        0.39528024,  0.37533396,  0.37243003,  0.40887046,  0.42422244,
        0.4001484 ,  0.41025156,  0.39641637,  0.38122386,  0.3307941 ,
        0.33384362,  0.31292623,  0.38175353,  0.36926997,  0.3568952 ],
      dtype=float32)

In [60]:
np.median(vars_outer, axis=0)

array([63.328766  , 21.162588  , 10.765812  ,  9.44256   ,  4.937848  ,
        2.8993225 ,  2.287239  ,  1.9870181 ,  1.9609971 ,  1.673708  ,
        1.426938  ,  1.5883273 ,  1.3660821 ,  1.4497459 ,  1.2387083 ,
        1.2089312 ,  0.99777865,  0.8463193 ,  0.8859196 ,  0.8473345 ,
        0.7960615 ,  0.7768669 ,  0.7335659 ,  0.70843804,  0.6631929 ,
        0.6274885 ,  0.6457504 ,  0.67355883,  0.58294845,  0.62390095,
        0.7044693 ,  0.62885535,  0.5419507 ,  0.54576534,  0.6245214 ,
        0.5611466 ,  0.56516886,  0.52852005,  0.5235366 ,  0.5400578 ,
        0.5140086 ,  0.54182863,  0.5070927 ,  0.56044716,  0.6100757 ,
        0.48293784,  0.6195171 ,  0.52509487,  0.5004153 ,  0.52955616],
      dtype=float32)

In [61]:
np.median(vars_gw, axis=0)

array([ 0.8820608 ,  9.290772  ,  4.1896243 , 14.393162  , 12.516366  ,
        1.5015056 ,  1.2207592 ,  0.9511055 ,  1.6950252 ,  0.623165  ,
        0.77666414,  0.84694695,  1.8101006 ,  0.8568433 ,  0.96417236,
        1.5446256 ,  1.1234238 ,  1.3443781 ,  0.899596  ,  0.92511   ,
        0.7018045 ,  0.66213584,  0.7191014 ,  0.421175  ,  0.43247208,
        0.47894132,  0.48229694,  0.66355455,  0.48039237,  0.44007665,
        0.36243406,  0.37708473,  0.4890284 ,  0.3879767 ,  0.45313394,
        0.32599285,  0.4414451 ,  0.38105157,  0.4406883 ,  0.45656043,
        0.3951171 ,  0.41522592,  0.4067306 ,  0.29068375,  0.39024383,
        0.35033363,  0.37701407,  0.32212588,  0.36680806,  0.378597  ],
      dtype=float32)

In [62]:
np.median(vars_fgw, axis=0)

array([0.5221194 , 0.41073254, 0.6073319 , 0.43629837, 1.031556  ,
       1.0521399 , 0.84350705, 0.42736018, 1.1203581 , 0.32484668,
       0.43899712, 0.5201279 , 0.5546199 , 0.67939895, 0.734656  ,
       0.37453765, 0.6764116 , 0.41707307, 0.32706374, 0.5439963 ,
       0.47210515, 0.43207037, 0.3628828 , 0.37915373, 0.31847644,
       0.3506788 , 0.46502542, 0.44106048, 0.4119529 , 0.41103598,
       0.29024774, 0.37666082, 0.39988178, 0.35549098, 0.53449893,
       0.3768517 , 0.47108155, 0.2848975 , 0.28934675, 0.49999887,
       0.25623393, 0.3727044 , 0.4218763 , 0.39999914, 0.372264  ,
       0.28622118, 0.2898785 , 0.2647023 , 0.4104831 , 0.30325148],
      dtype=float32)

In [63]:
np.mean(vars_gw_with_cost, axis=0)

array([30.287369  , 13.795336  ,  3.25704   ,  8.520558  ,  3.118068  ,
        2.4210818 ,  2.2824252 ,  1.4459902 ,  1.5132645 ,  1.7648203 ,
        1.337975  ,  1.1572365 ,  0.95345414,  1.0153188 ,  0.85283965,
        0.9153244 ,  0.81427366,  0.67998827,  0.72692597,  0.73937255,
        0.7375054 ,  0.81443584,  0.6848441 ,  0.6463245 ,  0.67465246,
        0.51106584,  0.6099758 ,  0.50120467,  0.59036005,  0.4672802 ,
        0.38954112,  0.46826306,  0.47254533,  0.46390995,  0.46100408,
        0.41751584,  0.37316605,  0.372652  ,  0.44961944,  0.44628996,
        0.40203506,  0.40338048,  0.40815306,  0.44192058,  0.39213556,
        0.3654736 ,  0.3348021 ,  0.40522486,  0.36673665,  0.35686824],
      dtype=float32)

In [64]:
np.mean(vars_outer, axis=0)

array([63.85145   , 20.996494  , 10.8110895 ,  9.172075  ,  4.8902535 ,
        2.973329  ,  2.418523  ,  1.9865059 ,  1.9095405 ,  1.7250545 ,
        1.4696648 ,  1.589028  ,  1.4415276 ,  1.4233559 ,  1.2567345 ,
        1.1282915 ,  0.98204535,  0.8703457 ,  0.9029452 ,  0.8260309 ,
        0.8104348 ,  0.7899305 ,  0.72356737,  0.704216  ,  0.6943593 ,
        0.6676404 ,  0.6392656 ,  0.6633253 ,  0.6335087 ,  0.65308446,
        0.6496994 ,  0.6102631 ,  0.56338954,  0.5937041 ,  0.6269791 ,
        0.5373762 ,  0.57377195,  0.5638007 ,  0.5281928 ,  0.58500785,
        0.5747962 ,  0.61202157,  0.5183285 ,  0.62495965,  0.57612765,
        0.5513538 ,  0.66942555,  0.5211011 ,  0.5400032 ,  0.52233875],
      dtype=float32)

In [65]:
np.mean(vars_gw, axis=0)

array([ 5.341936  ,  9.443396  ,  3.4130402 , 18.488602  , 11.778585  ,
        2.2195802 ,  1.6063015 ,  1.7096084 ,  2.4169955 ,  1.0565075 ,
        1.3675957 ,  1.4125022 ,  1.6875887 ,  0.915408  ,  1.4444823 ,
        1.6825899 ,  1.7936014 ,  1.5318329 ,  1.9978215 ,  1.189678  ,
        1.148832  ,  0.8153895 ,  0.8552225 ,  0.5410267 ,  0.63023376,
        0.63284415,  0.6548734 ,  0.6377155 ,  0.569869  ,  0.46078387,
        0.36792287,  0.49002132,  0.50971776,  0.4240129 ,  0.48965827,
        0.35645932,  0.44659686,  0.39116758,  0.47028342,  0.49672455,
        0.3743592 ,  0.4045298 ,  0.4306882 ,  0.30464956,  0.41981164,
        0.38552368,  0.3943028 ,  0.3405323 ,  0.38330466,  0.36002928],
      dtype=float32)