In [None]:
import numpy as np
import scanpy as sc
import pickle
from scib_metrics.benchmark import Benchmarker
import scib

In [None]:
import CINEMAOT as cnm

In [None]:
adata = sc.read_h5ad('/gpfs/gibbs/pi/zhao/tl688/cpsc_finalproject/genept_data/GenePT/cinemaot_data/Integrated_subset.h5ad')
adata_raw = sc.AnnData(adata.raw.X, obs = adata.obs, var = adata.raw.var)

In [None]:
pert_cond = 'IFNg' # modify it for different perturbation cases.

In [None]:
# adata_raw = adata_raw[:, adata.var_names]
sc.pp.highly_variable_genes(adata_raw, n_top_genes=500) # users can modify the number of genes here
adata_raw = adata_raw[:, adata_raw.var.highly_variable]

adata_ = adata_raw[adata_raw.obs['perturbation'].isin(['No stimulation', pert_cond])]

with open("/gpfs/gibbs/pi/zhao/tl688/GenePT/data_embedding/GPT_3_5_gene_embeddings.pickle", "rb") as fp:
    GPT_3_5_gene_embeddings = pickle.load(fp)
gene_names= list(adata_.var.index)
count_missing = 0
EMBED_DIM = 1536 # embedding dim from GPT-3.5
lookup_embed = np.zeros(shape=(len(gene_names),EMBED_DIM))
for i, gene in enumerate(gene_names):
    if gene in GPT_3_5_gene_embeddings:
        lookup_embed[i,:] = GPT_3_5_gene_embeddings[gene].flatten()
    else:
        count_missing+=1
# lookup_embed = np.random.rand(lookup_embed.shape[0], lookup_embed.shape[1])
# genePT_w_emebed = np.dot(adata_.X,lookup_embed)/len(gene_names)
genePT_w_emebed = adata_.X @ lookup_embed/len(gene_names)
print(f"Unable to match {count_missing} out of {len(gene_names)} genes in the GenePT-w embedding")


In [None]:
adata_.obsm['X_pca'] = genePT_w_emebed # replace the PCs using gpt 3.5 embeddings


In [None]:
cf, ot, de = cnm.cinemaot.cinemaot_unweighted(adata_,obs_label='perturbation', ref_label=pert_cond, expr_label='No stimulation',mode='parametric',thres=0.5,smoothness=1e-5,eps=1e-3,preweight_label='cell_type0528')

adata_.obsm['cf'] = cf.copy()
adata_.obsm['cf'][adata_.obs['perturbation']==pert_cond,:] = np.matmul(ot/np.sum(ot,axis=1)[:,None],cf[adata_.obs['perturbation']=='No stimulation',:])
sc.pp.neighbors(adata_,use_rep='cf')

sc.tl.umap(adata_,random_state=1)
sc.pl.umap(adata_,color=['perturbation','cell_type0528'],wspace=0.5, save = f'cinemaot_pbmc_cf_{pert_cond}_genept.pdf', palette='tab20c')

In [None]:
results = scib.metrics.metrics(
    adata_,
    adata_int=adata_,
    batch_key="perturbation",
    label_key="cell_type0528",
    embed="cf",
    isolated_labels_asw_=True,
    silhouette_=True,
    hvg_score_=False,
    graph_conn_=True,
    pcr_=True,
    isolated_labels_f1_=False,
    trajectory_=False,
    nmi_=True,  # use the clustering, bias to the best matching
    ari_=True,  # use the clustering, bias to the best matching
    cell_cycle_=False,
    kBET_=False,  # kBET return nan sometimes, need to examine
    ilisi_=True,
    clisi_=True,
)

In [None]:
results