In [1]:
import torchvision.datasets as datasets
import numpy as np 
import os
from functools import partial
from jax import jit
import jax 
import jax.numpy as jnp
import torchvision
import jax.random as random
import matplotlib.pyplot as plt
import ott
import rbo

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
from typing import Any
from jax.numpy import ndarray
from ott.geometry.costs import TICost

def sigma(z) :
    jnp.arcsinh(z)


#maybe regti should herit from ticost

@jax.tree_util.register_pytree_node_class
class RegTICost(TICost) : 
    def __init__(self, cost = lambda x : 0.5*jnp.sum(x**2),reg_name ="l1",gamma=0.1 ) :
        super().__init__()
        self.reg_name = reg_name
        self.cost = cost
        if self.reg_name not in ['l1','stvs' ,None]:
            raise Exception("Norm name must be one of l1, stvs, None")
        self.gamma = gamma
    def reg(self,z) :
        if self.reg_name == "l1" :
            return self.gamma * jnp.sum(abs(z))
        elif self.reg_name == "stvs" : 
            return self.gamma * jnp.sum(self.gamma ** 2 * (sigma(z) + 0.5 -0.5 *jnp.exp(-2*sigma(z))))
        else :
            return 0
    def h(self, z) :
        return self.cost(z) + self.reg(z)

    def prox_reg(self,z) : 
        if self.reg_name == "l1" :
            return jnp.where((1-self.gamma/abs(z))>0,(1-self.gamma/abs(z)),0) * z
        elif self.reg_name == "stvs" :
            return jnp.where((1-self.gamma**2/(abs(z)**2))>0,(1-self.gamma**2/(abs(z)**2)),0) * z
        else :
            return 0
    def h_legendre(self, z: jnp.ndarray) -> float: 
        q = self.prox_reg(jax.lax.stop_gradient(z))
        return jnp.sum(q * z) - self.h(q) #JUSTIFIER !!!!!!



##Experiments on synthetic data : analysis of the feature sparsity

In [None]:
n=  1000
s= 5
def sample_mu(n ,d) :
    return jnp.random.random(d,n) #returns the whole sampled data
def compute_transport(x,s) : #takes a matrix as input
    x[:,:s] = jnp.exp(x[:,:s])
    return x
def sample_data(d):
    #Sample xi : 
    rng = ...
    rng1, rng2 = rng.split...
    x = sample_mu(n,d ,rng1)
    #Sample yi
    x_tilde = sample_mu(n,d ,rng2)
    y = compute_transport(x_tilde,s)
    return x ,y
#Define normalized MSE and support error
def normalized_mse(y, y_pred) :
    return (1/ jnp.prod(y.shape)) * jnp.sum((y-y_pred)**2)
def support_error(y, y_pred,x) :
    delta = y_pred -x 
    return jnp.sum(delta[:,s:]**2)/ (jnp.sum(delta**2) * y.shape[0]) #à vérifier !!!!

result ={}
for gamma in  jnp.linspace(1e-2,1e1, 10) : #peut etre trop de points
    for d in [8,20,100]:
        for reg_name in ['l1', "stvs", None]:
            x ,y = sample_data(d)
            cost = RegTICost(reg_name=reg_name, gamma=gamma)
            epsilon=0.1
            geom = ott.geometry.pointcloud.PointCloud(x, y, cost_fn=cost,scale_cost = "mean")
            problem = ott.problems.linear.linear_problem.LinearProblem(geom)
            solver = ott.solvers.linear.sinkhorn.Sinkhorn(threshold=epsilon)
            out = solver(problem)
            y_pred = out.to_dual_potentials().transport(x)
            nmse = normalized_mse(y , y_pred)
            s_error = support_error(y, y_pred, x)
            result[str(gamma)+"|"+str(d)+"|"+reg_name] = (nmse, s_error)

#process data, make the plots

#x dependent sparsity patterns

In [None]:
n = 100
s= 2
def F(x,s) : #x is a vector
    n1 = jnp.sum(x[:s]**2)
    n2 = jnp.sum(x[s:2*s]**2)
    if n1> n2 : 
        x[:s] = jnp.exp(x[:s])
        return x
    else :
        x[s:2 *s] = jnp.exp(x[s:2 *s])
        return x
    
for d in jnp.linspace(1e1, 1e3, 10):
    for reg_name in ['l1', "stvs", None]:
        for gamma in  jnp.linspace(1e-2,1e1, 10) : #We'll keep the value of lowest NMSE
            x ,y = sample_data(d)
            cost = RegTICost(reg_name=reg_name, gamma=gamma)
            epsilon=0.1
            geom = ott.geometry.pointcloud.PointCloud(x, y, cost_fn=cost,scale_cost = "mean")
            problem = ott.problems.linear.linear_problem.LinearProblem(geom)
            solver = ott.solvers.linear.sinkhorn.Sinkhorn(threshold=epsilon)
            out = solver(problem)
            y_pred = out.to_dual_potentials().transport(x)
            nmse = normalized_mse(y , y_pred)
            if str(d)+"|"+reg_name not in result.keys() or result[str(d)+"|"+reg_name]>nmse:
                result[str(d)+"|"+reg_name] = nmse
#process the data and plot


##Aplication to single cell genomics

In [2]:
import scanpy
path_file = "/home/tordjx/Downloads/Massively multiplex chemical transcriptomics at single-cell resolution - K562.h5ad"
data= scanpy.read_h5ad(path_file)
#Remove cells with less than 200 expressed genes
df = data.obs[data.obs['num_genes_expressed']>200]
#Remove genes expressed in less than 20 cells
import numpy as np
ann_data = data[data.obs.index.isin(df.index)]
column_sums = ann_data.X.sum(axis=0)
column_sums_dense = np.asarray(column_sums).flatten()
mask = column_sums_dense > 20
filtered_ann_data = ann_data[:, mask]
#Select the 5 drugs Belinostat, Dacinostat, Givinostat, Hesperadin, and Quisinostat, and the control
medocs = ["Belinostat (PXD101)",
"Hesperadin",
"Givinostat (ITF2357)",
"Dacinostat (LAQ824)",
"Quisinostat (JNJ-26481585) 2HCl", "Vehicle"]
scanpy.pp.normalize_total(filtered_ann_data)
scanpy.pp.log1p(filtered_ann_data)
scanpy.pp.highly_variable_genes(filtered_ann_data, n_top_genes=5000)
#scanpy.tl.rank_genes_groups(filtered_ann_data, groupby='product_name', method='t-test')#
#perturbation_marker_names = { k : filtered_ann_data.uns['rank_genes_groups']['names'][k][:50] for k in medocs}
#Whole set 
cell_names = {k : filtered_ann_data.obs[k == df["product_name"]].index for k in medocs}
cell_data_full = { k : filtered_ann_data[filtered_ann_data.obs.index.isin(cell_names[k])] for k in medocs}
#cell_marker_index = {k : [list(cell_data_full[k].to_df().columns).index(item) for item in perturbation_marker_names[k]] for k in medocs}
cell_data_full = { k : np.asarray(cell_data_full[k].X.todense()) for k in medocs}
#cell_data_marker = {k : cell_data_full[k][:, cell_marker_index[k]] for k in medocs}
#Top 5k
top5k_filtered_ann_data = filtered_ann_data[:, filtered_ann_data.var['highly_variable']]
cell_names = {k : top5k_filtered_ann_data.obs[k == df["product_name"]].index for k in medocs}
cell_data_5k = { k : top5k_filtered_ann_data[top5k_filtered_ann_data.obs.index.isin(cell_names[k])] for k in medocs}
cell_data_5k = { k : np.asarray(cell_data_5k[k].X.todense()) for k in medocs}

  view_to_actual(adata)
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_indices]
  self.stats[group_name, "scores"] = scores[global_indices]
  self.stats[group_name, "pvals"] = pvals[global_indices]
  self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices]
  self.stats[group_name, "logfoldchanges"] = np.log2(
  self.stats[group_name, "names"] = self.var_names[global_ind

In [19]:
from sklearn.model_selection import KFold
from sklearn.decomposition import PCA
from ott.tools.sinkhorn_divergence import sinkhorn_divergence

def experiment(medoc, cell_data,reg_name, pca_bool = False, gamma):
    X = cell_data['Vehicle']
    y = cell_data[medoc]
    Xfolds = KFold(n_splits=10, shuffle=True, random_state=42).split(X)
    yfolds=  KFold(n_splits=10, shuffle=True, random_state=42).split(y)
    result = []
    cost = RegTICost(reg_name=reg_name,gamma = gamma)
    for fold, ((train_index_X, test_index_X), (train_index_y, test_index_y)) in enumerate(zip(Xfolds, yfolds)):
        X_train, X_test = X[train_index_X], X[test_index_X]
        y_train, y_test = y[train_index_y], y[test_index_y]
        if pca_bool :
            pca = PCA(n_components=50).fit(X_train)
            X_train,y_train = pca.transform(X_train), pca.transform(y_train)
            X_test,y_test = pca.transform(X_test), pca.transform(y_test)
            epsilon=0.1
            geom = ott.geometry.pointcloud.PointCloud(X_train, y_train, cost_fn=cost,scale_cost = "mean")
            problem = ott.problems.linear.linear_problem.LinearProblem(geom)
            solver = ott.solvers.linear.sinkhorn.Sinkhorn(threshold=epsilon)
            out = solver(problem)
            y_test_pred = out.to_dual_potentials().transport(X_test)
            ##Metrics 
            #Sinkhorn divergence
            result.append(sinkhorn_divergence(ott.geometry.pointcloud.PointCloud, y_test_pred, y_test).divergence)
            #Dans l'article ils parlent de deux autres metriques, le R2 et le ranked biased overlap, le probleme est qu'ils le font entre des "marqueurs de perturbation" et ce qu'on prédit, et je ne comprends pas ce qui est attendu. Surtout que il se trouve que les marqueurs de perturbations n'appartiennent pas au sous ensemble des 5k genes, donc ça voudrait dire qu'on ne peut pas le calculer pour un des sous ensemble : j'y comprends que dalle
    return result

In [20]:
result = {}
medocs = ["Belinostat (PXD101)",
"Hesperadin",
"Givinostat (ITF2357)",
"Dacinostat (LAQ824)",
"Quisinostat (JNJ-26481585) 2HCl"]
for gamma in [2**(-i) for i in range(6)]:
    for medoc in medocs : 
        for reg_name in ["l1", "stvs" ,None] :
            for data_name  in ["cell_data_5k", "cell_data_full"] :
                for pca_bool in [True, False] :
                    if data_name == "cell_data_5k" :
                        data = cell_data_5k
                    else :
                        data = cell_data_full
                    result[medocs+"|"+reg_name+"|"+data_name+"|"+ str(pca_bool)+"|"+str(gamma)] = experiment(medoc, cell_data_5k, reg_name, pca_bool=pca_bool)

#process data, make the plots(mega chiant)

[Array(16.46344, dtype=float32),
 Array(15.376888, dtype=float32),
 Array(14.863703, dtype=float32),
 Array(15.006366, dtype=float32),
 Array(15.618315, dtype=float32),
 Array(15.302305, dtype=float32),
 Array(15.599243, dtype=float32),
 Array(15.5090885, dtype=float32),
 Array(15.460455, dtype=float32),
 Array(16.38553, dtype=float32)]

False

In [206]:
X_p.shape

(1024, 2)