In [1]:
import torchvision.datasets as datasets
import numpy as np
import os
from functools import partial
from jax import jit
import jax
import jax.numpy as jnp
import torchvision
import jax.random as random
import matplotlib.pyplot as plt
import ott

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from typing import Any
from jax.numpy import ndarray
from ott.geometry.costs import TICost

def sigma(z) :
    return jnp.arcsinh(z)


#maybe regti should herit from ticost

@jax.tree_util.register_pytree_node_class
class RegTICost(TICost) :
    def __init__(self, cost = lambda x : 0.5*jnp.sum(x**2),reg_name ="l1",gamma=0.1 ) :
        super().__init__()
        self.reg_name = reg_name
        self.cost = cost
        if self.reg_name not in [None,'l1','stvs' ]:
            raise Exception("Norm name must be one of l1, stvs, None")
        self.gamma = gamma
    def reg(self,z) :
        if self.reg_name == "l1" :
            return self.gamma * jnp.sum(abs(z))
        elif self.reg_name == "stvs" :
            return self.gamma**2 * jnp.sum((sigma(z) + 0.5 -0.5 *jnp.exp(-2*sigma(z))))
        else :
            return 0 #prox of 0 function is identity
    def h(self, z) :
        return self.cost(z) + self.reg(z)

    def prox_reg(self,z) :
        if self.reg_name == "l1" :
            return jnp.where((1-self.gamma/abs(z))>0,(1-self.gamma/abs(z)),0) * z
        elif self.reg_name == "stvs" :
            return jnp.where((1-self.gamma**2/(abs(z)**2))>0,(1-self.gamma**2/(abs(z)**2)),0) * z
        else :
            return z
    def h_legendre(self, z: jnp.ndarray) -> float:
        q = self.prox_reg(jax.lax.stop_gradient(z))
        return jnp.sum(q * z) - self.h(q) #JUSTIFIER !!!!!!



##Experiments on synthetic data : analysis of the feature sparsity

In [3]:
n=  1000
s= 5
def sample_mu(n ,d, key) :
    return jnp.asarray(jax.random.uniform(key, (n,d))) #returns the whole sampled data
def compute_transport(x,s) : #takes a matrix as input
    return jnp.concatenate(( jnp.exp(x[:,:s]), x[:,s:]),-1)

def sample_data(d):
    #Sample xi :
    rng = jax.random.PRNGKey(0)
    rng1, rng2 = jax.random.split(rng)
    x = sample_mu(n,d ,rng1)
    #Sample yi
    x_tilde = sample_mu(n,d ,rng2)
    y = compute_transport(x_tilde,s)
    return x ,y
#Define normalized MSE and support error
def normalized_mse(y, y_pred) :
    return (1/ jnp.prod(np.array(y.shape))) * jnp.sum((y-y_pred)**2)
def support_error(y, y_pred,x) :
    delta = y_pred -x
    return jnp.sum(delta[:,s:]**2)/ (jnp.sum(delta**2) * y.shape[0]) #à vérifier !!!!

result ={}
from tqdm import tqdm
for gamma in  tqdm(jnp.linspace(1e-2,1e1, 10)) : #peut etre trop de points
    for d in [8,20,100]:
        for reg_name in [None,'l1', "stvs"]:
            x ,y = sample_data(d)
            cost = RegTICost(reg_name=reg_name, gamma=gamma)
            epsilon=0.1
            geom = ott.geometry.pointcloud.PointCloud(x, y, cost_fn=cost,scale_cost = "mean")
            problem = ott.problems.linear.linear_problem.LinearProblem(geom)
            solver = ott.solvers.linear.sinkhorn.Sinkhorn(threshold=epsilon)
            out = solver(problem)
            y_pred = out.to_dual_potentials().transport(x)
            nmse = normalized_mse(y , y_pred)
            s_error = support_error(y, y_pred, x)
            if reg_name is None :
              reg_name = "none"
            result[str(gamma)+"|"+str(d)+"|"+reg_name] = (nmse, s_error)
gammas = []
ds = []
reg_names = []
nmses = []
s_errors = []

# Parse result dictionary
for k, v in result.items():
    gamma, d, reg_name = k.split('|')
    nmse, s_error = v
    gammas.append(float(gamma))
    ds.append(int(d))
    reg_names.append(reg_name)
    nmses.append(nmse)
    s_errors.append(s_error)
import pandas as pd
# Create DataFrame
df = pd.DataFrame({
    'Gamma': gammas,
    'D': ds,
    'Reg_Name': reg_names,
    'NMSE': nmses,
    'S_Error': s_errors
})
df.to_csv("df1.csv")

2024-04-24 15:02:50.627614: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.4.99). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
100%|██████████| 10/10 [00:43<00:00,  4.37s/it]


#x dependent sparsity patterns

In [4]:
n = 100
s= 2
def F(x,s) :
    n1 = jnp.sum(x[:s]**2)
    n2 = jnp.sum(x[s:2*s]**2)
    if n1> n2 :
        return jnp.concatenate(( jnp.exp(x[:,:s]), x[:,s:]),-1)

    else :
        return  jnp.concatenate(( x[:,:s], jnp.exp(x[:,s:2*s]), x[:,2*s:]),-1)
def sample_data(d):
    #Sample xi :
    rng = jax.random.PRNGKey(0)
    rng1, rng2 = jax.random.split(rng)
    x = sample_mu(n,d ,rng1)
    #Sample yi
    x_tilde = sample_mu(n,d ,rng2)
    y = F(x_tilde,s)
    return x ,y
result = {}
for d in tqdm(jnp.linspace(1e1, 1e3, 10)):
    d = int(d)
    for reg_name in ['l1', "stvs", None]:
        for gamma in  jnp.linspace(1e-2,1e1, 10) : #We'll keep the value of lowest NMSE
            x ,y = sample_data(d)
            cost = RegTICost(reg_name=reg_name, gamma=gamma)
            epsilon=0.1
            geom = ott.geometry.pointcloud.PointCloud(x, y, cost_fn=cost,scale_cost = "mean")
            problem = ott.problems.linear.linear_problem.LinearProblem(geom)
            solver = ott.solvers.linear.sinkhorn.Sinkhorn(threshold=epsilon)
            out = solver(problem)
            y_pred = out.to_dual_potentials().transport(x)
            nmse = normalized_mse(y , y_pred)
            if reg_name is None :
              reg_name = "none"
            if str(d)+"|"+reg_name not in result.keys() or result[str(d)+"|"+reg_name]>nmse:
                result[str(d)+"|"+reg_name] = nmse
            if reg_name == "none" :
              reg_name = None
#process the data and plot
ds = []
reg_names = []
nmses = []

# Parse result dictionary
for k, v in result.items():
    d, reg_name = k.split('|')
    nmse = v
    ds.append(int(d))
    reg_names.append(reg_name)
    nmses.append(nmse)
import pandas as pd
# Create DataFrame
df = pd.DataFrame({
    'D': ds,
    'Reg_Name': reg_names,
    'NMSE': nmses,
})
df.to_csv("df2.csv")

100%|██████████| 10/10 [02:00<00:00, 12.08s/it]


##Aplication to single cell genomics

the data can be downloaded here : https://cellxgene.cziscience.com/datasets

In [5]:
import scanpy
path_file = "Massively multiplex chemical transcriptomics at single-cell resolution - K562.h5ad"
data= scanpy.read_h5ad(path_file)
#Remove cells with less than 200 expressed genes
df = data.obs[data.obs['num_genes_expressed']>200]
#Remove genes expressed in less than 20 cells
import numpy as np
ann_data = data[data.obs.index.isin(df.index)]
column_sums = ann_data.X.sum(axis=0)
column_sums_dense = np.asarray(column_sums).flatten()
mask = column_sums_dense > 20
filtered_ann_data = ann_data[:, mask]
#Select the 5 drugs Belinostat, Dacinostat, Givinostat, Hesperadin, and Quisinostat, and the control
medocs = ["Belinostat (PXD101)",
"Hesperadin",
"Givinostat (ITF2357)",
"Dacinostat (LAQ824)",
"Quisinostat (JNJ-26481585) 2HCl", "Vehicle"]
scanpy.pp.normalize_total(filtered_ann_data)
scanpy.pp.log1p(filtered_ann_data)
scanpy.pp.highly_variable_genes(filtered_ann_data, n_top_genes=5000)
#scanpy.tl.rank_genes_groups(filtered_ann_data, groupby='product_name', method='t-test')#
#perturbation_marker_names = { k : filtered_ann_data.uns['rank_genes_groups']['names'][k][:50] for k in medocs}
#Whole set
cell_names = {k : filtered_ann_data.obs[k == df["product_name"]].index for k in medocs}
cell_data_full = { k : filtered_ann_data[filtered_ann_data.obs.index.isin(cell_names[k])] for k in medocs}
#cell_marker_index = {k : [list(cell_data_full[k].to_df().columns).index(item) for item in perturbation_marker_names[k]] for k in medocs}
cell_data_full = { k : np.asarray(cell_data_full[k].X.todense()) for k in medocs}
#cell_data_marker = {k : cell_data_full[k][:, cell_marker_index[k]] for k in medocs}
#Top 5k
top5k_filtered_ann_data = filtered_ann_data[:, filtered_ann_data.var['highly_variable']]
cell_names = {k : top5k_filtered_ann_data.obs[k == df["product_name"]].index for k in medocs}
cell_data_5k = { k : top5k_filtered_ann_data[top5k_filtered_ann_data.obs.index.isin(cell_names[k])] for k in medocs}
cell_data_5k = { k : np.asarray(cell_data_5k[k].X.todense()) for k in medocs}

  view_to_actual(adata)


In [6]:
from sklearn.model_selection import KFold
from sklearn.decomposition import PCA
from ott.tools.sinkhorn_divergence import sinkhorn_divergence

def experiment(medoc, cell_data,reg_name, pca_bool, gamma):
    X = cell_data['Vehicle']
    y = cell_data[medoc]
    Xfolds = KFold(n_splits=10, shuffle=True, random_state=42).split(X)
    yfolds=  KFold(n_splits=10, shuffle=True, random_state=42).split(y)
    result = []
    cost = RegTICost(reg_name=reg_name,gamma = gamma)
    for fold, ((train_index_X, test_index_X), (train_index_y, test_index_y)) in enumerate(zip(Xfolds, yfolds)):
        X_train, X_test = X[train_index_X], X[test_index_X]
        y_train, y_test = y[train_index_y], y[test_index_y]
        if pca_bool :
            pca = PCA(n_components=50).fit(X_train)
            X_train,y_train = pca.transform(X_train), pca.transform(y_train)
            X_test,y_test = pca.transform(X_test), pca.transform(y_test)
        epsilon=0.1
        geom = ott.geometry.pointcloud.PointCloud(X_train, y_train, cost_fn=cost,scale_cost = "mean")
        problem = ott.problems.linear.linear_problem.LinearProblem(geom)
        solver = ott.solvers.linear.sinkhorn.Sinkhorn(threshold=epsilon)
        out = solver(problem)
        y_test_pred = out.to_dual_potentials().transport(X_test)
        ##Metrics
        #Sinkhorn divergence
        result.append(sinkhorn_divergence(ott.geometry.pointcloud.PointCloud, y_test_pred, y_test).divergence)
        #Dans l'article ils parlent de deux autres metriques, le R2 et le ranked biased overlap, le probleme est qu'ils le font entre des "marqueurs de perturbation" et ce qu'on prédit, et je ne comprends pas ce qui est attendu. Surtout que il se trouve que les marqueurs de perturbations n'appartiennent pas au sous ensemble des 5k genes, donc ça voudrait dire qu'on ne peut pas le calculer pour un des sous ensemble : j'y comprends que dalle
    return result

In [7]:
result = {}
medocs = ["Belinostat (PXD101)",
"Hesperadin",
"Givinostat (ITF2357)",
"Dacinostat (LAQ824)",
"Quisinostat (JNJ-26481585) 2HCl"]
from tqdm import tqdm
for gamma in tqdm([2**(-i) for i in range(6)]):
    for medoc in medocs :
        for reg_name in ["l1", "stvs" ,None] :
            for data_name  in ["cell_data_5k", "cell_data_full"] :
                for pca_bool in [True, False] :
                    if data_name == "cell_data_5k" :
                        data = cell_data_5k
                    else :
                        data = cell_data_full
                    result[medoc+"|"+reg_name+"|"+data_name+"|"+ str(pca_bool)+"|"+str(gamma)] = experiment(medoc, cell_data_5k, reg_name, pca_bool=pca_bool, gamma=gamma)

#process data, make the plots(mega chiant)

  0%|          | 0/6 [00:00<?, ?it/s]2024-04-24 15:07:05.113537: W external/xla/xla/service/hlo_rematerialization.cc:2946] Can't reduce memory use below -23.25GiB (-24960156968 bytes) by rematerialization; only reduced to 33.32GiB (35781620000 bytes), down from 33.32GiB (35781620000 bytes) originally
2024-04-24 15:07:15.137051: W external/tsl/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 33.32GiB (rounded to 35781620224)requested by op 
2024-04-24 15:07:15.137156: W external/tsl/tsl/framework/bfc_allocator.cc:494] *___________________________________________________________________________________________________
E0424 15:07:15.137189   30588 pjrt_stream_executor_client.cc:2809] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 35781620000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   68.85MiB
              constant allocation:         0B
      

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 35781620000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   68.85MiB
              constant allocation:         0B
        maybe_live_out allocation:   33.32GiB
     preallocated temp allocation:         0B
                 total allocation:   33.39GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 33.32GiB
		Operator: op_name="jit(<lambda>)/jit(main)/sub" source_file="/tmp/ipykernel_30588/2357495081.py" source_line=23
		XLA Label: fusion
		Shape: f32[3017,593,5000]
		==========================

	Buffer 2:
		Size: 57.54MiB
		Entry Parameter Subshape: f32[3017,5000]
		==========================

	Buffer 3:
		Size: 11.31MiB
		Entry Parameter Subshape: f32[593,5000]
		==========================



In [9]:
#process the data and plot
medocs = []
regs = []
datas =  []
pcas = []
gammas = []
values = []
# Parse result dictionary
for k, v in result.items():
    medoc,reg_name,data, pca, gamma = k.split('|')
    value = v
    medocs.append(medoc)
    regs.append(reg_name)
    datas.append(data)
    pcas.append(pca)
    gammas.append(gamma)
    values.append(value)
import pandas as pd
# Create DataFrame
df = pd.DataFrame({
    'medocs': medocs,
    'regs': regs,
    'datas': datas,
    'pcas': pcas,
    'gammas': gammas,
    'values': values,
})
df
df.to_csv("df3.csv")
df

Unnamed: 0,medocs,regs,datas,pcas,gammas,values
0,Belinostat (PXD101),l1,cell_data_5k,True,1,"[13.824425, 12.557515, 11.976587, 12.064277, 1..."
