In [1]:
import jax
import os

In [2]:
from sklearn import datasets
import scanpy as sc
import optax
import matplotlib.pyplot as plt
from sklearn import preprocessing as pp
import numpy as np
import scipy

import ott
import sklearn
import matplotlib.pyplot as plt
from ott.geometry import geometry, pointcloud
import jax
from typing import Mapping, Any, Optional, Union, Callable, Tuple
from types import MappingProxyType
import jax.numpy as jnp
from functools import partial
from ott.solvers.linear import sinkhorn
from ott.problems.linear import linear_problem
from entot.models.model import OTFlowMatching
from entot.nets.nets import MLP_vector_field, MLP_bridge, MLP_marginal,MLP_fused_vector_field
import sklearn.preprocessing as pp
import scanpy as sc
from ott.solvers.linear import sinkhorn, acceleration
from sklearn import preprocessing as pp


from ott.geometry.pointcloud import PointCloud
from ott.tools.sinkhorn_divergence import sinkhorn_divergence

  from .autonotebook import tqdm as notebook_tqdm
- the fwd and bwd functions take an extra `perturbed` argument, which     indicates which primals actually need a gradient. You can use this     to skip computing the gradient for any unperturbed value. (You can     also safely just ignore this if you wish.)
- `None` was previously passed to indicate a symbolic zero gradient for     all objects that weren't inexact arrays, but all inexact arrays     always had an array-valued gradient. Now, `None` may also be passed     to indicate that an inexact array has a symbolic zero gradient.


In [3]:
def foscttm(
    x: np.ndarray,
    y: np.ndarray,
) -> float:
    d = scipy.spatial.distance_matrix(x, y)
    foscttm_x = (d < np.expand_dims(np.diag(d), axis=1)).mean(axis=1)
    foscttm_y = (d < np.expand_dims(np.diag(d), axis=0)).mean(axis=0)
    fracs = []
    for i in range(len(foscttm_x)):
        fracs.append((foscttm_x[i] + foscttm_y[i]) / 2)
    return np.mean(fracs).round(4)

In [4]:
adata_atac = sc.read("../../data/bone_marrow_atac.h5ad")
adata_rna = sc.read("../../data/bone_marrow_rna.h5ad")

In [5]:
adata_source=adata_atac.copy()
adata_target=adata_rna.copy()

n_cells_source=len(adata_atac)

n_samples_train = int(n_cells_source * 0.6)
n_samples_test = n_cells_source - n_cells_source

inds_train = np.asarray(jax.random.choice(jax.random.PRNGKey(0), n_cells_source, (n_samples_train,), replace=False))
inds_test = list(set(list(range(n_cells_source))) - set(np.asarray(inds_train)))

fused = np.concatenate((adata_atac.obsm["geneactivity_scvi"], adata_rna.obsm["geneactivity_scvi"]), axis=0)
fused = sc.pp.pca(fused, n_comps=25)

source_fused = fused[:len(adata_source), :]
target_fused = fused[len(adata_target):, :]

source_q = pp.normalize(
    adata_source.obsm["ATAC_lsi_red"], norm="l2"
) 
target_q = adata_target.obsm["GEX_X_pca"]

source_train_q = source_q[inds_train, :]
source_test_q = source_q[inds_test, :]
target_train_q = target_q[inds_train, :]
target_test_q = target_q[inds_test, :]
source_train_fused = source_fused[inds_train, :]
source_test_fused = source_fused[inds_test, :]
target_train_fused = target_fused[inds_train, :]
target_test_fused = target_fused[inds_test, :]

source_train = np.concatenate((source_train_fused, source_train_q), axis=1)
source_test = np.concatenate((source_test_fused, source_test_q), axis=1)
target_train = np.concatenate((target_train_fused, target_train_q), axis=1)
target_test = np.concatenate((target_test_fused, target_test_q), axis=1)


In [6]:
batch_sizes = [1024, 512, 256, 128, 64]
tot_samples_seen = 5_000 * 1024
foscttms_one_sample = [None] * len(batch_sizes)
foscttms_cond_mean = [None] * len(batch_sizes)
sinkhorn_divs_one_sample = [None] * len(batch_sizes)
sinkhorn_divs_cond_mean = [None] * len(batch_sizes)

In [7]:
neural_net = MLP_vector_field(target_train.shape[1], latent_embed_dim = 256, num_layers=8, n_frequencies=128)
bridge_net = MLP_bridge(target_train.shape[1], 10)
linear_ot_solver = sinkhorn.Sinkhorn(
                momentum=acceleration.Momentum(value=1., start=25)
        )
solver = ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein(epsilon=0.01, linear_ot_solver=linear_ot_solver)


In [8]:
for j, bs in enumerate(batch_sizes):
    num_iter = min(tot_samples_seen // bs, 100_000)
    otfm = OTFlowMatching(neural_net, bridge_net, epsilon=None, scale_cost="mean", input_dim=source_train.shape[1], output_dim=target_train.shape[1], iterations=num_iter, ot_solver=solver, k_noise_per_x=1, fused_penalty = 1.0, split_dim=fused.shape[1])
    otfm(source_train, target_train, bs, bs)
    res_test = [None] * 5
    for i in range(5):
        res_test[i] = otfm.transport(source_test, seed=i)[0][0,...]
    cond_mean_test = jnp.mean(jnp.asarray(res_test), axis=0)
    one_sample_test = res_test[0]
    foscttms_one_sample[j] = foscttm(one_sample_test, target_test)
    foscttms_cond_mean[j] = foscttm(cond_mean_test, target_test)
    sinkhorn_divs_one_sample[j] = float(sinkhorn_divergence(PointCloud, one_sample_test, target_test, epsilon=1e-2).divergence)
    sinkhorn_divs_cond_mean[j] = float(sinkhorn_divergence(PointCloud, cond_mean_test, target_test, epsilon=1e-2).divergence)
    


100%|██████████| 5000/5000 [06:19<00:00, 13.18it/s]
100%|██████████| 10000/10000 [11:25<00:00, 14.59it/s]
100%|██████████| 20000/20000 [20:48<00:00, 16.02it/s] 
100%|██████████| 40000/40000 [49:29<00:00, 13.47it/s] 
100%|██████████| 80000/80000 [2:28:44<00:00,  8.96it/s]  


In [9]:
np.save("foscttms_cond_mean_2", np.asarray(foscttms_cond_mean))
np.save("foscttms_one_sample_2", np.asarray(foscttms_one_sample))
np.save("sinkhorn_divs_one_sample_2", np.asarray(sinkhorn_divs_one_sample))
np.save("sinkhorn_divs_cond_mean_2", np.asarray(sinkhorn_divs_cond_mean))

In [10]:
sinkhorn_divs_one_sample

[88.82067108154297,
 88.9491958618164,
 88.72113037109375,
 88.02635955810547,
 88.17122650146484]

In [12]:
foscttms_one_sample

[0.1133, 0.1068, 0.1045, 0.1042, 0.1138]

In [13]:
foscttms_cond_mean

[0.0754, 0.0685, 0.0652, 0.0608, 0.0642]