In [1]:
dataset = "new_sc"

In [2]:
import ott
import jax
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.geometry import pointcloud, geometry, costs, graph
from ott.solvers.quadratic import gromov_wasserstein
from typing import Optional, Any
import numpy as np
import jax.numpy as jnp
from tqdm import tqdm
from ott.neural import datasets
from ott.neural.methods.flows import dynamics, otfm
from ott.neural.networks.layers import time_encoder
from ott.neural.networks.velocity_field import VelocityField
from ott.solvers import utils as solver_utils
from torch.utils.data import DataLoader
import jax.numpy as jnp
from typing import Literal, Optional
import scanpy as sc
import functools
import optax
from functools import partial
import functools
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Dict, Tuple
from tqdm import tqdm
import jax.tree_util as jtu

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np

import diffrax
from flax.training import train_state

from ott import utils
from ott.neural.methods.flows import dynamics
from ott.neural.networks import velocity_field
from ott.solvers import utils as solver_utils
from sklearn import preprocessing as pp
from moscot import datasets

LinTerm = Tuple[jnp.ndarray, jnp.ndarray]
QuadTerm = Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray],
                 Optional[jnp.ndarray]]
DataMatchFn = Union[Callable[[LinTerm], jnp.ndarray], Callable[[QuadTerm],
                                                               jnp.ndarray]]

# Fixed GENOT class

In [3]:
class GENOT:
  """Generative Entropic Neural Optimal Transport :cite:`klein_uscidda:23`.

  GENOT is a framework for learning neural optimal transport plans between
  two distributions. It allows for learning linear and quadratic
  (Fused) Gromov-Wasserstein couplings, in both the balanced and
  the unbalanced setting.

  Args:
    vf: Vector field parameterized by a neural network.
    flow: Flow between the latent and the target distributions.
    data_match_fn: Function to match samples from the source and the target
      distributions. Depending on the data passed in :meth:`__call__`, it has
      the following signature:

      - ``(src_lin, tgt_lin) -> matching`` - linear matching.
      - ``(src_quad, tgt_quad, src_lin, tgt_lin) -> matching`` -
        quadratic (fused) GW matching. In the pure GW setting, both ``src_lin``
        and ``tgt_lin`` will be set to :obj:`None`.

    source_dim: Dimensionality of the source distribution.
    target_dim: Dimensionality of the target distribution.
    condition_dim: Dimension of the conditions. If :obj:`None`, the underlying
      velocity field has no conditions.
    time_sampler: Time sampler with a ``(rng, n_samples) -> time`` signature.
    latent_noise_fn: Function to sample from the latent distribution in the
      target space with a ``(rng, shape) -> noise`` signature.
      If :obj:`None`, multivariate normal distribution is used.
    latent_match_fn: Function to match samples from the latent distribution
      and the samples from the conditional distribution with a
      ``(latent, samples) -> matching`` signature. If :obj:`None`, no matching
      is performed.
    n_samples_per_src: Number of samples drawn from the conditional distribution
      per one source sample.
    kwargs: Keyword arguments for
      :meth:`~ott.neural.networks.velocity_field.VelocityField.create_train_state`.
  """  # noqa: E501

  def __init__(
      self,
      vf: velocity_field.VelocityField,
      flow: dynamics.BaseFlow,
      data_match_fn: DataMatchFn,
      *,
      source_dim: int,
      target_dim: int,
      condition_dim: Optional[int] = None,
      time_sampler: Callable[[jax.Array, int],
                             jnp.ndarray] = solver_utils.uniform_sampler,
      latent_noise_fn: Optional[Callable[[jax.Array, Tuple[int, ...]],
                                         jnp.ndarray]] = None,
      latent_match_fn: Optional[Callable[[jnp.ndarray, jnp.ndarray],
                                         jnp.ndarray]] = None,
      n_samples_per_src: int = 1,
      **kwargs: Any,
  ):
    self.vf = vf
    self.flow = flow
    self.data_match_fn = data_match_fn
    self.time_sampler = time_sampler
    if latent_noise_fn is None:
      latent_noise_fn = functools.partial(_multivariate_normal, dim=target_dim)
    self.latent_noise_fn = latent_noise_fn
    self.latent_match_fn = latent_match_fn
    self.n_samples_per_src = n_samples_per_src

    self.vf_state = self.vf.create_train_state(
        input_dim=target_dim,
        condition_dim=source_dim + (condition_dim or 0),
        **kwargs
    )
    self.step_fn = self._get_step_fn()

  def _get_step_fn(self) -> Callable:

    @jax.jit
    def step_fn(
        rng: jax.Array,
        vf_state: train_state.TrainState,
        time: jnp.ndarray,
        source: jnp.ndarray,
        target: jnp.ndarray,
        latent: jnp.ndarray,
        source_conditions: Optional[jnp.ndarray],
    ):

      def loss_fn(
          params: jnp.ndarray, time: jnp.ndarray, source: jnp.ndarray,
          target: jnp.ndarray, latent: jnp.ndarray,
          source_conditions: Optional[jnp.ndarray], rng: jax.Array
      ) -> jnp.ndarray:
        rng_flow, rng_dropout = jax.random.split(rng, 2)
        x_t = self.flow.compute_xt(rng_flow, time, latent, target)
        if source_conditions is None:
          cond = source
        else:
          cond = jnp.concatenate([source, source_conditions], axis=-1)

        v_t = vf_state.apply_fn({"params": params},
                                time,
                                x_t,
                                cond,
                                rngs={"dropout": rng_dropout})
        u_t = self.flow.compute_ut(time, latent, target)

        return jnp.mean((v_t - u_t) ** 2)

      grad_fn = jax.value_and_grad(loss_fn)
      loss, grads = grad_fn(
          vf_state.params, time, source, target, latent, source_conditions, rng
      )

      return loss, vf_state.apply_gradients(grads=grads)

    return step_fn

  def __call__(
      self,
      loader: Iterable[Dict[str, np.ndarray]],
      n_iters: int,
      rng: Optional[jax.Array] = None
  ) -> Dict[str, List[float]]:
    """Train the GENOT model.

    Args:
      loader: Data loader returning a dictionary with possible keys
        `src_lin`, `tgt_lin`, `src_quad`, `tgt_quad`, `src_conditions`.
      n_iters: Number of iterations to train the model.
      rng: Random key for seeding.

    Returns:
      Training logs.
    """


    rng = utils.default_prng_key(rng)
    training_logs = {"loss": []}
    for batch in loader:
      rng = jax.random.split(rng, 5)
      rng, rng_resample, rng_noise, rng_time, rng_step_fn = rng

      batch = jtu.tree_map(jnp.asarray, batch)
      (src, src_cond, tgt), matching_data = prepare_data(batch)

      n = src.shape[0]
      time = self.time_sampler(rng_time, n * self.n_samples_per_src)
      latent = self.latent_noise_fn(rng_noise, (n, self.n_samples_per_src))

      tmat = self.data_match_fn(*matching_data)  # (n, m)
      src_ixs, tgt_ixs = solver_utils.sample_conditional(  # (n, k), (m, k)
          rng_resample,
          tmat,
          k=self.n_samples_per_src,
      )

      src, tgt = src[src_ixs], tgt[tgt_ixs]  # (n, k, ...),  # (m, k, ...)
      if src_cond is not None:
        src_cond = src_cond[src_ixs]

      if self.latent_match_fn is not None:
        src, src_cond, tgt = self._match_latent(rng, src, src_cond, latent, tgt)

      src = src.reshape(-1, *src.shape[2:])  # (n * k, ...)
      tgt = tgt.reshape(-1, *tgt.shape[2:])  # (m * k, ...)
      latent = latent.reshape(-1, *latent.shape[2:])
      if src_cond is not None:
        src_cond = src_cond.reshape(-1, *src_cond.shape[2:])

      loss, self.vf_state = self.step_fn(
          rng_step_fn, self.vf_state, time, src, tgt, latent, src_cond
      )

      training_logs["loss"].append(float(loss))
      if len(training_logs["loss"]) >= n_iters:
        break

    return training_logs

  def _match_latent(
      self, rng: jax.Array, src: jnp.ndarray, src_cond: Optional[jnp.ndarray],
      latent: jnp.ndarray, tgt: jnp.ndarray
  ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]:

    def resample(
        rng: jax.Array, src: jnp.ndarray, src_cond: Optional[jnp.ndarray],
        tgt: jnp.ndarray, latent: jnp.ndarray
    ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]:
      tmat = self.latent_match_fn(latent, tgt)  # (n, k)

      src_ixs, tgt_ixs = solver_utils.sample_joint(rng, tmat)  # (n,), (m,)
      src, tgt = src[src_ixs], tgt[tgt_ixs]
      if src_cond is not None:
        src_cond = src_cond[src_ixs]

      return src, src_cond, tgt

    cond_axis = None if src_cond is None else 1
    in_axes, out_axes = (0, 1, cond_axis, 1, 1), (1, cond_axis, 1)
    resample_fn = jax.jit(jax.vmap(resample, in_axes, out_axes))

    rngs = jax.random.split(rng, self.n_samples_per_src)
    return resample_fn(rngs, src, src_cond, tgt, latent)

  def transport(
      self,
      source: jnp.ndarray,
      condition: Optional[jnp.ndarray] = None,
      t0: float = 0.0,
      t1: float = 1.0,
      rng: Optional[jax.Array] = None,
      **kwargs: Any,
  ) -> jnp.ndarray:
    """Transport data with the learned plan.

    This function pushes forward the source distribution to its conditional
    distribution by solving the neural ODE.

    Args:
      source: Data to transport.
      condition: Condition of the input data.
      t0: Starting time of integration of neural ODE.
      t1: End time of integration of neural ODE.
      rng: Random generate used to sample from the latent distribution.
      kwargs: Keyword arguments for :func:`~diffrax.odesolve`.

    Returns:
      The push-forward defined by the learned transport plan.
    """

    def vf(t: jnp.ndarray, x: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray:
      params = self.vf_state.params
      return self.vf_state.apply_fn({"params": params}, t, x, cond, train=False)

    def solve_ode(x: jnp.ndarray, cond: jnp.ndarray) -> jnp.ndarray:
      ode_term = diffrax.ODETerm(vf)
      sol = diffrax.diffeqsolve(
          ode_term,
          t0=t0,
          t1=t1,
          y0=x,
          args=cond,
          **kwargs,
      )
      return sol.ys[0]

    kwargs.setdefault("dt0", None)
    kwargs.setdefault("solver", diffrax.Tsit5())
    kwargs.setdefault(
        "stepsize_controller", diffrax.PIDController(rtol=1e-5, atol=1e-5)
    )

    rng = utils.default_prng_key(rng)
    latent = self.latent_noise_fn(rng, (len(source),))

    if condition is not None:
      source = jnp.concatenate([source, condition], axis=-1)

    return jax.jit(jax.vmap(solve_ode))(latent, source)


def _multivariate_normal(
    rng: jax.Array,
    shape: Tuple[int, ...],
    dim: int,
    mean: float = 0.0,
    cov: float = 1.0
) -> jnp.ndarray:
  mean = jnp.full(dim, fill_value=mean)
  cov = jnp.diag(jnp.full(dim, fill_value=cov))
  return jax.random.multivariate_normal(rng, mean=mean, cov=cov, shape=shape)

def prepare_data(
    batch: Dict[str, jnp.ndarray]
) -> Tuple[Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray],
           Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray],
                 Optional[jnp.ndarray]]]:
  src_lin, src_quad = batch.get("src_lin"), batch.get("src_quad")
  tgt_lin, tgt_quad = batch.get("tgt_lin"), batch.get("tgt_quad")

  if src_quad is None and tgt_quad is None:  # lin
    src, tgt = src_lin, tgt_lin
    arrs = src_lin, tgt_lin, None, None
  elif src_lin is None and tgt_lin is None:  # quad
    src, tgt = src_quad, tgt_quad
    arrs = None, None, src_quad, tgt_quad
  elif all(
      arr is not None for arr in (src_lin, tgt_lin, src_quad, tgt_quad)
  ):  # fused quad
    src = jnp.concatenate([src_lin, src_quad], axis=1)
    tgt = jnp.concatenate([tgt_lin, tgt_quad], axis=1)
    arrs = src_lin, tgt_lin, src_quad, tgt_quad
  else:
    raise RuntimeError("Cannot infer OT problem type from data.")

  return (src, batch.get("src_condition"), tgt), arrs

# Different solve functions

In [4]:
@jax.jit
def solve_gw(epsilon: float, xx: jax.Array, yy: jax.Array, cost_fn : Any):
    ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=epsilon)
    geom_xx = pointcloud.PointCloud(
        x=xx, y=xx, cost_fn=cost_fn, scale_cost="mean"
    )
    geom_yy = pointcloud.PointCloud(
        x=yy, y=yy, cost_fn=cost_fn, scale_cost="mean"
    )
    geom_xy = None
    prob = quadratic_problem.QuadraticProblem(
        geom_xx, geom_yy, geom_xy,
    )
    return ot_solver(prob)

In [5]:
@jax.jit
def solve_fgw(epsilon: float, xx: jax.Array, yy: jax.Array, xy_x: jax.Array, xy_y: jax.Array, fused_penalty: float, cost_fn : str):
    ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=epsilon)
    geom_xx = pointcloud.PointCloud(
        x=xx, y=xx, cost_fn=cost_fn, scale_cost="mean"
    )
    geom_yy = pointcloud.PointCloud(
        x=yy, y=yy, cost_fn=cost_fn, scale_cost="mean"
    )
    geom_xy = pointcloud.PointCloud(
                    x=xy_x, y=xy_y, cost_fn=cost_fn, scale_cost="mean"
                )
    prob = quadratic_problem.QuadraticProblem(
        geom_xx, geom_yy, geom_xy, fused_penalty=fused_penalty,
    )
    return ot_solver(prob)

In [6]:
def get_nearest_neighbors(
    X: jnp.ndarray, Y: Optional[jnp.ndarray], k: int = 30  # type: ignore[name-defined]
) -> Tuple[jnp.ndarray, jnp.ndarray]:  # type: ignore[name-defined]
    concat = X if Y is None else jnp.concatenate((X, Y), axis=0) 
    pairwise_euclidean_distances = pointcloud.PointCloud(concat, concat).cost_matrix
    distances, indices = jax.lax.approx_min_k(
        pairwise_euclidean_distances, k=k, recall_target=0.95, aggregate_to_topk=True
    )
    connectivities = jnp.multiply(jnp.exp(-distances),  (distances>0))
    return connectivities/jnp.sum(connectivities), indices


def create_cost_matrix_quad(X: jnp.array, k_neighbors: int, **kwargs: Any) -> jnp.array:
    distances, indices = get_nearest_neighbors(X, None, k_neighbors)
    a = jnp.zeros((len(X), len(X)))
    adj_matrix = a.at[
        jnp.repeat(jnp.arange(len(X)), repeats=k_neighbors).flatten(), indices.flatten()
    ].set(distances.flatten())
    return graph.Graph.from_graph(adj_matrix, normalize=kwargs.pop("normalize", True), **kwargs).cost_matrix

@jax.jit
def solve_gw_geodesic(epsilon: float, xx: jax.Array, yy: jax.Array, k_neighbors=1024):
    ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=epsilon)
    cm_xx = create_cost_matrix_quad(xx, k_neighbors)
    cm_yy = create_cost_matrix_quad(yy, k_neighbors)
    geom_xx = geometry.Geometry(cost_matrix=cm_xx, epsilon=epsilon, scale_cost="mean")
    geom_yy = geometry.Geometry(cost_matrix=cm_yy, epsilon=epsilon, scale_cost="mean")
    geom_xy = None
    prob = quadratic_problem.QuadraticProblem(
        geom_xx, geom_yy, geom_xy,
    )
    
    return ot_solver(prob)


In [7]:
@jax.jit
def get_quad_initializer(push_forward: jnp.ndarray, target: jnp.ndarray):
    pc = pointcloud.PointCloud(push_forward, target)
    geom = geometry.Geometry(pc.cost_matrix, epsilon=1e-2)
    lp = linear_problem.LinearProblem(geom, a=jnp.ones(n) / n, b=jnp.ones(n) / n)
    return lp
    
def solve_gw_with_init(epsilon: float, xx: jax.Array, yy: jax.Array, cost_fn : Any, model: GENOT):
    push_forward = model.transport(xx)
    initializer = get_quad_initializer(push_forward, yy)
    ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=epsilon)
    geom_xx = pointcloud.PointCloud(
        x=xx, y=xx, cost_fn=cost_fn, scale_cost="mean"
    )
    geom_yy = pointcloud.PointCloud(
        x=yy, y=yy, cost_fn=cost_fn, scale_cost="mean"
    )
    prob = quadratic_problem.QuadraticProblem(
        geom_xx, geom_yy, 
    )
    return ot_solver(prob, init=initializer)

# Set hyperparameter

In [8]:
epsilon=0.01
n = 1024 # batch size
N_POINTS = 10
N_DRAWS = 100

# Generate the data

In [9]:
adata_atac = datasets.bone_marrow(rna=False)
adata_rna = datasets.bone_marrow(rna=True)
adata_atac.obsm["ATAC_lsi_l2_norm"] = pp.normalize(
    adata_atac.obsm["ATAC_lsi_red"], norm="l2"
)
x_quad = sc.pp.pca(adata_atac.obsm["ATAC_lsi_l2_norm"], n_comps=20)
y_quad = sc.pp.pca(adata_rna.obsm['GEX_X_pca'], n_comps=20)

x_lin_tmp = adata_atac.obsm['geneactivity_scvi']
y_lin_tmp = adata_rna.obsm['geneactivity_scvi']

xy = sc.pp.pca(np.concatenate((x_lin_tmp, y_lin_tmp), axis=0), n_comps=10)
x_lin = xy[:len(x_lin_tmp)]
y_lin = xy[len(x_lin_tmp):]

x_all = np.concatenate((x_lin, x_quad) , axis=1)
y_all = np.concatenate((y_lin, y_quad), axis=1)

DIM_FUSED = 10

# GW with graph cost

In [10]:
rng = np.random.default_rng(12345)

epsilon_graph = epsilon
vars_gw_graph = [None] * N_POINTS
for it in tqdm(range(N_POINTS)):
    minibatch_match = [None] * N_DRAWS
    x_fixed = rng.choice(x_all, size=(1,))
    for i in range(N_DRAWS):
        x = rng.choice(x_all, size=(n-1,))
        x = np.concatenate((x_fixed, x), axis=0)
        y = rng.choice(y_all, size=(n,))
        xx = x[:,DIM_FUSED:]
        yy = y[:,DIM_FUSED:]
        out = solve_gw_geodesic(epsilon_graph, xx, yy)
        likelihood = np.asarray(out.matrix[0]).astype('float64')
        minibatch_match[i] = rng.choice(yy, p=likelihood/np.sum(likelihood), shuffle=False)
    vars_gw_graph[it] = np.var(minibatch_match, axis=0)
        

  0%|                                                                                                            | 0/10 [00:00<?, ?it/s]2024-08-04 19:24:29.258312: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [02:10<00:00, 13.08s/it]


In [11]:
with open(f'{dataset}_graph_{epsilon_graph}.npy', 'wb') as f:
    np.save(f, vars_gw_graph)

In [12]:
rng = np.random.default_rng(12345)

epsilon_graph = 1e-4
vars_gw_graph = [None] * N_POINTS
for it in tqdm(range(N_POINTS)):
    minibatch_match = [None] * N_DRAWS
    x_fixed = rng.choice(x, size=(1,))
    for i in range(N_DRAWS):
        x = rng.choice(x, size=(n-1,))
        x = np.concatenate((x_fixed, x), axis=0)
        y = rng.choice(y, size=(n,))
        xx = x[:,DIM_FUSED:]
        yy = y[:,DIM_FUSED:]
        out = solve_gw_geodesic(epsilon_graph, xx, yy)
        likelihood = np.asarray(out.matrix[0]).astype('float64')
        minibatch_match[i] = rng.choice(yy, p=likelihood/np.sum(likelihood), shuffle=False)
    vars_gw_graph[it] = np.var(minibatch_match, axis=0)
        

100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [07:38<00:00, 45.80s/it]


In [13]:
with open(f'{dataset}_graph_{epsilon_graph}.npy', 'wb') as f:
    np.save(f, vars_gw_graph)

In [14]:
res_g1 = np.load(f"{dataset}_graph_0.01.npy")
res_g2 = np.load(f"{dataset}_graph_0.0001.npy")

In [15]:
res_g1_mean = np.mean(res_g1, axis=0)
res_g2_mean = np.mean(res_g2, axis=0)

In [16]:
res_g1_mean

array([61.97741  , 21.803251 , 11.612076 ,  8.697136 ,  5.3289843,
        2.9707897,  2.5559046,  1.9997063,  1.931324 ,  1.972929 ,
        1.5381061,  1.4914898,  1.4200451,  1.3044698,  1.1463658,
        1.3592931,  0.9774721,  0.8938856,  1.0035172,  0.8482791],
      dtype=float32)

In [17]:
res_g2_mean

array([17.576633  , 10.672904  ,  4.5450497 , 22.352678  ,  1.7365946 ,
        1.2827942 ,  1.5664084 ,  0.58604085,  0.66009676,  0.70645356,
        0.9827353 ,  0.54497397,  0.64035094,  0.5654494 ,  0.69637895,
        0.39773476,  0.57704294,  1.1055791 ,  0.2774175 ,  0.5372273 ],
      dtype=float32)

In [18]:
np.mean(res_g1_mean)

6.641622

In [19]:
np.mean(res_g2_mean)

3.4005268