In [None]:
import pickle
import os
import numpy as np
import torch.nn.functional as F 
import torch
import pandas as pd
import anndata as ad
from scETM import scETM
from multiprocessing import Pool

os.getcwd()

In [None]:
working_dir = 'glioma'

In [None]:

gl_csvs = ['IDH-MUT.csv']
gl_adatas = []
for fpath in gl_csvs:
    df = pd.read_csv(fpath, index_col=0)
    adata = ad.AnnData(X=df.iloc[:, 2:], obs=df.iloc[:, :2])
    gl_adatas.append(adata)
gl = ad.concat(gl_adatas, label="batch_indices")

adata = gl
model = scETM(adata.n_vars, adata.obs.batch_indices.nunique())
model.load_state_dict(torch.load('model'))
model.get_all_embeddings_and_nll(adata)

delta, alpha, rho = map(pd.DataFrame, [adata.obsm['delta'], adata.uns['alpha'], adata.varm['rho']])
delta.index = adata.obs_names
rho.index = adata.var_names
delta.shape, alpha.shape, rho.shape

In [None]:
adata

In [None]:
print('Get top 30 genes per topic (for enrichment analysis)')
beta = rho @ alpha.T  # (gene, topic)
top_words = pd.DataFrame(adata.var_names.values[np.argsort(beta.values, axis=0)[:-31:-1]])  # (n_top, topic)
top_words.to_csv(os.path.join(working_dir, 'beta_top30genes.csv'))

In [None]:
print('Saving unnormliazed topic mixture delta')
delta.to_csv(os.path.join(working_dir, 'delta.csv'))

print('Saving metadata')
## create meta csv (condition, individual_id, cell_type)
adata.obs.to_csv(os.path.join(working_dir, 'meta.csv'))

print('Saving normalized topic mixture theta')
theta = torch.tensor(delta.values).softmax(dim=-1).detach().cpu().numpy()
theta = pd.DataFrame(theta, index=adata.obs_names)
theta.to_csv(os.path.join(working_dir, 'theta.csv'))

In [None]:
# by default, keep all topics
print('Sampling theta')
delta_sample = delta.sample(10000)
topic_kept = delta_sample.columns[delta_sample.sum(0) >= 1500]  # (topics)
meta_sample = adata.obs.loc[delta_sample.index]
delta_sample.to_csv(os.path.join(working_dir, 'delta_sampled.csv'))
meta_sample.to_csv(os.path.join(working_dir, 'meta_sampled.csv'))

delta_kept = delta[topic_kept]  # (cells, topics)

In [None]:
print("Pathway enrichment analysis")
from pathdip import pathDIP_Http 

n_topics = delta.shape[1]
component = "Literature curated (core) pathway memberships"
sources = "ACSN2,BioCarta,EHMN,HumanCyc,INOH,IPAVS,KEGG,NetPath,OntoCancro,Panther_Pathway,PharmGKB,PID,RB-Pathways,REACTOME,stke,systems-biology.org,SignaLink2.0,SIGNOR2.0,SMPDB,Spike,UniProt_Pathways,WikiPathways"
o = pathDIP_Http()
pathway_df=[]
for i in range(n_topics):
    IDs = ', '.join(top_words[i])
    o.searchOnGenesymbols(IDs, component, sources)
    result = o.getPathwayAnalysis().split('\n')[1:]
    for line in result:
        p = line.split('\t')[:-1]
        p.append(i)
        if len(p) == 1:
            continue
        pathway_df.append(p)
pathway_df = pd.DataFrame(pathway_df, columns = ['pathway_source','pathway_name','p_val','q_val_BH','q_val_Bonf','topic'])  # (pathways, features)

pathway_df['q_val_BH'] = pathway_df['q_val_BH'].astype(float)
pathway_df = pathway_df[pathway_df['q_val_BH'] < 0.05]
pathway_df.to_csv(os.path.join(working_dir, 'pathways.csv'))

In [None]:
print('Starting permutation test for cell types')
def simulate_mean_diff_once(data, rng: np.random.Generator):
    half = len(data) // 2
    ind = np.arange(len(data))
    rng.shuffle(ind)
    md = data[ind[:half]].mean(0) - data[ind[half:half * 2]].mean(0)
    return md

def simulate_mean_diff(data, repeats, seed):
    rng = np.random.default_rng(seed)
    mds = []
    for _ in range(repeats):
        mds.append(simulate_mean_diff_once(data, rng))
    return mds

types = adata.obs.cell_types.unique()
mds = []

reps = 10000
n_jobs = 10
# WARNING: Multithreading does not work in notebook, please run multithread_perm.py
with Pool(n_jobs) as p:
    l = [p.apply_async(simulate_mean_diff, (delta_kept.values, reps // n_jobs, seed)) for seed in range(n_jobs)]
    l = [e.get() for e in l]
    mds_simulated = np.concatenate(l, axis=0)
for t in types:
    test = delta_kept[adata.obs.cell_types == t]  # (cells_test, topics)
    ctrl = delta_kept[adata.obs.cell_types != t]  # (cells_ctrl, topics)
    md = test.mean(0) - ctrl.mean(0)  # (topics)
    mds.append(md)
mds = np.array(mds)  # (cell_types, topics)
mds_simulated = np.array(mds_simulated)

pvals = (mds_simulated.T[None, ...] > mds[..., None]).sum(-1) + 1 / (reps + 1)  # (cell_types, topics, *repeats*)
pval_df = pd.DataFrame(pvals, index=types, columns=topic_kept)  # (cell_types, topics)
pval_df = pval_df * 100 * len(types)
pval_df.to_csv(os.path.join(working_dir, 'perm_p_onesided_celltype.csv'))

mds = pd.DataFrame(mds, index=types, columns=topic_kept)
mds.to_csv(os.path.join(working_dir, 'perm_mean_celltype.csv'))

In [None]:
print('Selecting interesting topics to plot beta')
from collections import OrderedDict
# INPUT REQUIRED HERE ("interesting" topics can be cell type / condition DE topics, or topics with interesting pathways)
user_selected_topics = [15,18,19,35,49]
print('Get top 10 genes per topic (for figure)')
top_words = pd.DataFrame(adata.var_names.values[np.argsort(beta[user_selected_topics].values, axis=0)[:-11:-1]])  # (n_top, topic_selected) 
gene_list = list(OrderedDict.fromkeys(top_words.values.T.flatten()))

beta_top = beta.loc[gene_list, user_selected_topics]  # (genes_selected, topic_selected)
beta_top = beta_top / beta_top.abs().max(0)
beta_top.to_csv(os.path.join(working_dir, 'beta_top10genes_selected_topics.csv'))

In [None]:
delta_sum = delta_kept.abs().sum(axis=0)

topk = delta_sum.sort_values()[-15:].index

pathway_new = pathway_df.loc[pathway_df.topic.apply(lambda x: x in topk)]
pathway_new['neg_log_q_BH'] = -np.log10(pathway_new['q_val_BH'].astype(float).values)
pathway_new['topic'] = pathway_new['topic'].astype(str)
pathway_new["name"] = pathway_new['topic'].str.cat(pathway_new['pathway_name'].astype(str), sep=':').str.slice(0, 50)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

mypalette = dict()
is_grey = False
for topic in pathway_new.topic.unique():
    if not is_grey:
        mypalette[topic] = 'black'
        is_grey=True
    else:
        mypalette[topic] = 'grey'
        is_grey=False
        
fig,ax=plt.subplots(figsize=(4,20),dpi=500)
ax = sns.barplot(y="name", x="neg_log_q_BH", 
                 data=pathway_new,ax=ax,hue='topic',palette=mypalette,
                 dodge=False)
plt.xlabel('Negative log10 q-value (BH)')
plt.ylabel('')
plt.legend([],[], frameon=False)
plt.savefig(os.path.join(working_dir, "pathway_15topics.pdf"), bbox_inches = "tight")