In [1]:
from sklearn import datasets
import scanpy as sc
import optax
import matplotlib.pyplot as plt
from sklearn import preprocessing as pp
import numpy as np

import ott
import sklearn
import matplotlib.pyplot as plt
from ott.geometry import geometry, pointcloud
import jax
from typing import Mapping, Any, Optional, Union, Callable, Tuple
from types import MappingProxyType
import jax.numpy as jnp
from functools import partial
from ott.solvers.linear import sinkhorn
from ott.problems.linear import linear_problem
from entot.models.model import OTFlowMatching
from entot.nets.nets import MLP_vector_field, MLP_bridge, MLP_marginal,MLP_fused_vector_field
import sklearn.preprocessing as pp
import scanpy as sc
from ott.solvers.linear import sinkhorn, acceleration
from sklearn import preprocessing as pp
from ott.geometry.pointcloud import PointCloud
from ott.tools.sinkhorn_divergence import sinkhorn_divergence

- the fwd and bwd functions take an extra `perturbed` argument, which     indicates which primals actually need a gradient. You can use this     to skip computing the gradient for any unperturbed value. (You can     also safely just ignore this if you wish.)
- `None` was previously passed to indicate a symbolic zero gradient for     all objects that weren't inexact arrays, but all inexact arrays     always had an array-valued gradient. Now, `None` may also be passed     to indicate that an inexact array has a symbolic zero gradient.
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
adata = sc.read("../../data/adata_pancreas_2019.h5ad")

In [3]:
adata = adata[adata.obs["celltype"]!="Multipotent"].copy()
adata.obs["lineage"] = adata.obs.apply(lambda x: "A" if x["celltype"] in ["Acinar", "Tip"] else "ED", axis=1).astype("category")

In [4]:
sc.pp.pca(adata, n_comps=30)
adata.obsm["X_pca_scaled"] = pp.StandardScaler().fit_transform(adata.obsm["X_pca"])

In [5]:
from typing import Tuple, Callable, Union, List, Optional
import scipy.sparse as sp
import jax.numpy as jnp
import jax
import pandas as pd

def get_nearest_neighbors(
    input_batch: jnp.ndarray, target: jnp.ndarray, k: int = 30
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Get the k nearest neighbors of the input batch in the target."""
    if target.shape[0] < k:
        raise ValueError(f"k is {k}, but must be smaller or equal than {target.shape[0]}.")
    pairwise_euclidean_distances = jnp.sqrt(jnp.sum((input_batch - target) ** 2, axis=-1))
    negative_distances, indices = jax.lax.top_k(-1 * pairwise_euclidean_distances, k=k)
    return -1 * negative_distances, indices

def project_transport_matrix(  
        predicted_tgt_cells: jnp.ndarray,
        tgt_cells: jnp.ndarray,
        batch_size: int = 1024,
        k: int = 1,
    ) -> sp.csr_matrix:
        """Project Neural OT map onto cells."""

        get_knn_fn = jax.vmap(get_nearest_neighbors, in_axes=(0, None, None))
        row_indices: Union[jnp.ndarray, List[jnp.ndarray]] = []
        column_indices: Union[jnp.ndarray, List[jnp.ndarray]] = []
        distances_list: Union[jnp.ndarray, List[jnp.ndarray]] = []
        for index in range(0, len(predicted_tgt_cells), batch_size):
            _, indices = get_knn_fn(predicted_tgt_cells[index : index + batch_size], tgt_cells, k)
            column_indices.append(indices.flatten())
            row_indices.append(
                jnp.repeat(jnp.arange(index, index + min(batch_size, len(predicted_tgt_cells) - index)), min(k, len(tgt_cells)))
            )
        ri = jnp.concatenate(row_indices)
        ci = jnp.concatenate(column_indices)
        mat = np.zeros((len(ri), len(tgt_cells)))
        mat[ri,ci] = 1.0
        return mat

def aggregate_transport_matrix(adata_source, adata_target, tmat, aggregation_key = "celltype", forward = True):
    df_source = adata_source.obs[aggregation_key]
    df_target = adata_target.obs[aggregation_key]

    annotations_source = adata_source.obs[aggregation_key].cat.categories
    annotations_target = adata_target.obs[aggregation_key].cat.categories

    tm = pd.DataFrame(
        np.zeros((len(annotations_source), len(annotations_target))),
        index=annotations_source,
        columns=annotations_target,
    )
    
    for annotation_src in annotations_source:
        for annotation_tgt in annotations_target:
            tm.loc[annotation_src, annotation_tgt] = tmat[
                np.ix_((df_source == annotation_src).squeeze(), (df_target == annotation_tgt).squeeze())
            ].sum()
    return tm#.div(tm.sum(axis=int(forward)), axis=int(not forward))

In [6]:
sc.pp.pca(adata, n_comps=30)
adata.obsm["X_pca_scaled"] = pp.StandardScaler().fit_transform(adata.obsm["X_pca"])

In [7]:
n_cells_source=len(adata[adata.obs["day"]=="14.5"])
n_cells_target=len(adata[adata.obs["day"]=="15.5"])

n_samples_train_source = int(n_cells_source * 0.6)
n_samples_test_source = n_cells_source - n_samples_train_source

n_samples_train_target = int(n_cells_target * 0.6)
n_samples_test_target = n_cells_target - n_samples_train_target

inds_source_train = np.asarray(jax.random.choice(jax.random.PRNGKey(0), n_cells_source, (n_samples_train_source,), replace=False))
inds_source_test = list(set(list(range(n_samples_train_source))) - set(inds_source_train))

inds_target_train = np.asarray(list(jax.random.choice(jax.random.PRNGKey(1), n_cells_target, (n_samples_train_target,), replace=False)))
inds_target_test = list(set(list(range(n_samples_train_target))) - set(inds_target_train))

adata_source_train = adata[adata.obs["day"]=="14.5"][inds_source_train,:]
adata_source_test = adata[adata.obs["day"]=="14.5"][inds_source_test,:]

adata_target_train = adata[adata.obs["day"]=="15.5"][inds_target_train,:]
adata_target_test = adata[adata.obs["day"]=="15.5"][inds_target_test,:]

source_train = adata_source_train.obsm["X_pca_scaled"]
source_test = adata_source_test.obsm["X_pca_scaled"]
target_train = adata_target_train.obsm["X_pca_scaled"]
target_test = adata_target_test.obsm["X_pca_scaled"]

2023-09-18 11:09:05.593117: W external/xla/xla/service/gpu/nvptx_compiler.cc:698] The NVIDIA driver's CUDA version is 12.1 which is older than the ptxas CUDA version (12.2.128). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [8]:
adata_test = sc.concat(
    [adata_source_test, adata_target_test],
    join="outer",
    label="day",
    keys=["14.5", "15.5"],
)

adata_train = sc.concat(
    [adata_source_train, adata_target_train],
    join="outer",
    label="day",
    keys=[14.5, 15.5],
)

In [9]:
sinkhorn_divs = [None] * 3
scores = [None] * 3

for i in range(3):
    ot_solver=ott.solvers.linear.sinkhorn.Sinkhorn()
    neural_net = MLP_vector_field(target_train.shape[1], latent_embed_dim = 256, num_layers=8, n_frequencies=128)
    bridge_net = MLP_bridge(target_train.shape[1], 10)
    
    
    mlp_eta = MLP_marginal(256, 5)
    mlp_xi = MLP_marginal(256, 5)
    
    
    otfm = OTFlowMatching(neural_net, scale_cost="mean", bridge_net=bridge_net, ot_solver=ot_solver, epsilon=1e-2, mlp_eta=mlp_eta, mlp_xi=mlp_xi, tau_a=0.99, tau_b=0.99, input_dim=30, output_dim=30, iterations=10_000, k_noise_per_x=1, seed=i)
    otfm(source_train, target_train, 1024, 1024)
    gex_predicted = otfm.transport(source_test, seed=0)[0][0,...]
    sinkhorn_divs[i] = sinkhorn_divergence(PointCloud, gex_predicted, target_test, epsilon=1e-3).divergence
    

    tm = project_transport_matrix(gex_predicted, target_test)
    agg_tm = aggregate_transport_matrix(adata_source_test, adata_target_test, tm, aggregation_key="lineage")
    scores[i] = (agg_tm.iloc[0,0]+agg_tm.iloc[1,1])/agg_tm.sum().sum()

    full_test = np.concatenate((source_test, target_test), axis=0)
    adata_test.obs[f"left_rescaling_{i}"] = otfm.state_eta.apply_fn({"params": otfm.state_eta.params}, x=full_test)
    adata_test.obs[f"right_rescaling_{i}"] = otfm.state_eta.apply_fn({"params": otfm.state_xi.params}, x=full_test)

    full_train = np.concatenate((source_train, target_train), axis=0)
    adata_train.obs[f"left_rescaling_{i}"] = otfm.state_eta.apply_fn({"params": otfm.state_eta.params}, x=full_train)
    adata_train.obs[f"right_rescaling_{i}"] = otfm.state_eta.apply_fn({"params": otfm.state_xi.params}, x=full_train)

100%|██████████| 10000/10000 [07:32<00:00, 22.12it/s]
100%|██████████| 10000/10000 [07:32<00:00, 22.12it/s]
100%|██████████| 10000/10000 [07:31<00:00, 22.14it/s]


In [10]:
adata_train.obs.to_csv("pancreas_unbalanced_99_genot_train.csv")
adata_test.obs.to_csv("pancreas_unbalanced_99_genot_test.csv")

In [11]:
print(np.mean(scores), np.var(scores), np.mean(sinkhorn_divs), np.var(sinkhorn_divs))

0.7137886037407567 0.0006542549940865413 19.631609 0.09984531
