In [5]:
from moscot.problems.cross_modality import TranslationProblem
import numpy as np
import sys
import jax.numpy as jnp
import scanpy as sc
from typing import Any, Tuple
from sklearn.linear_model import LinearRegression
from sklearn.neighbors import NearestNeighbors
import scipy
from sklearn import preprocessing as pp
import os
from ott.geometry import costs, geometry, graph, pointcloud
import jax
import pandas as pd


In [2]:
def get_nearest_neighbors(
        X: jnp.ndarray, Y: jnp.ndarray, k: int = 30
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:  
        concat = jnp.concatenate((X, Y), axis=0)
        pairwise_euclidean_distances = pointcloud.PointCloud(concat, concat).cost_matrix
        distances, indices = jax.lax.approx_min_k(
            pairwise_euclidean_distances, k=k, recall_target=0.95, aggregate_to_topk=True
        )
        return distances, indices

def create_cost_matrix(X: jnp.array, Y: jnp.array, k_neighbors: int, **kwargs: Any) -> jnp.array:
    distances, indices = get_nearest_neighbors(X, Y, k_neighbors)
    a = jnp.zeros((len(X) + len(Y), len(X) + len(Y)))
    adj_matrix = a.at[
        jnp.repeat(jnp.arange(len(X) + len(Y)), repeats=k_neighbors).flatten(), indices.flatten()
    ].set(distances.flatten())
    return graph.Graph.from_graph(adj_matrix[:len(X), len(X):], normalize=kwargs.pop("normalize", True), **kwargs).cost_matrix


In [3]:
adata_atac = sc.read("../../data/bone_marrow_atac.h5ad")
adata_rna = sc.read("../../data/bone_marrow_rna.h5ad")

In [6]:
adata_atac.obsm["ATAC_lsi_l2_norm"] = pp.normalize(
    adata_atac.obsm["ATAC_lsi_red"], norm="l2"
)

In [10]:
ftp = TranslationProblem(adata_src=adata_atac, adata_tgt=adata_rna)
ftp = ftp.prepare(
    src_attr="ATAC_lsi_l2_norm", tgt_attr="GEX_X_pca", joint_attr="geneactivity_scvi"
)
ftp = ftp.solve(epsilon=0.5e-2, alpha=0.7)
translated_fused = ftp.translate(source="src", target="tgt", forward=True)

  if not is_categorical_dtype(df_full[k]):


[34mINFO    [0m Solving `[1;36m1[0m` problems                                                                                      
[34mINFO    [0m Solving problem OTProblem[1m[[0m[33mstage[0m=[32m'prepared'[0m, [33mshape[0m=[1m([0m[1;36m6224[0m, [1;36m6224[0m[1m)[0m[1m][0m.                                          


In [11]:
import scipy
def foscttm(
    x: np.ndarray,
    y: np.ndarray,
) -> float:
    d = scipy.spatial.distance_matrix(x, y)
    foscttm_x = (d < np.expand_dims(np.diag(d), axis=1)).mean(axis=1)
    foscttm_y = (d < np.expand_dims(np.diag(d), axis=0)).mean(axis=0)
    fracs = []
    for i in range(len(foscttm_x)):
        fracs.append((foscttm_x[i] + foscttm_y[i]) / 2)
    return np.mean(fracs).round(4)
    
foscttm(adata_rna.obsm["GEX_X_pca"], translated_fused)

0.4342

In [None]:
tp = TranslationProblem(adata_src, adata_tgt)
tp = tp.prepare(src_attr = "source_train", tgt_attr = "target_train")
if cost == "graph":
    cm = create_cost_matrix(source, source, k_neighbors=len(source)+1)
    df_x = pd.DataFrame(cm, index=adata_src.obs_names, columns=adata_src.obs_names)
    tp["src", "tgt"].set_x(df_x, tag="cost_matrix")
    cm = create_cost_matrix(target, target, k_neighbors=len(target)+1)
    df_y = pd.DataFrame(cm, index=adata_tgt.obs_names, columns=adata_tgt.obs_names)
    tp["src", "tgt"].set_y(df_y, tag="cost_matrix")
tp = tp.solve(epsilon=epsilon, scale_cost="mean")