In [None]:
import scanpy as sc
import scarches as sca
import numpy as np
from scarches.plotting.terms_scores import plot_abs_bfs_key
import pandas as pd

In [None]:
sc.set_figure_params(figsize=(6, 6))
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sb

font = {'family' : 'Arial',
        'size'   : 14}

matplotlib.rc('font', **font)
matplotlib.rc('ytick', labelsize=14)
matplotlib.rc('xtick', labelsize=14)
matplotlib.rcParams["figure.dpi"] = 200

### Pbmc Kang

In [None]:
adata = sc.read('kang_pbmc_integrated.h5ad')

In [None]:
adata.obs['condition_merged'] = adata.obs['condition'].tolist()
adata.obs['condition_merged'][adata.obs['condition_merged'].astype(str)=='nan'] = 'control'
adata.strings_to_categoricals()

In [None]:
intr_cvae = sca.models.EXPIMAP.load('q_intr_cvae_nolog_alpha_kl_0_5_0_1_sd_2020', adata)

In [None]:
directions = intr_cvae.latent_directions(method="sum")

In [None]:
directions = adata.uns['directions']

In [None]:
adata.obsm['X_cvae'] *= directions[intr_cvae.model.decoder.nonzero_terms()]

In [None]:
scores_cond = intr_cvae.latent_enrich('condition_merged', comparison="control", directions_key=directions, adata=adata, n_sample=7000)

In [None]:
scores_ct = intr_cvae.latent_enrich('cell_type_joint', directions_key=directions, n_sample=7000, adata=adata)

In [None]:
scores_ct = adata.uns['bf_scores']

In [None]:
adata.uns['active_terms'] = adata.uns['directions'][intr_cvae.model.decoder.nonzero_terms()]

In [None]:
adata_ct = adata[adata.obs.cell_type_joint == 'CD14+ Monocytes']

In [None]:
scores_c_ct = intr_cvae.latent_enrich('condition_merged', comparison="control", directions_key=directions, adata=adata_ct, n_sample=10000)

In [None]:
scores_c_ct = adata.uns['bf_scores']

In [None]:
print(adata.uns.keys())

In [None]:
print('STIMULATED' in adata.obs.columns)  # Should be True if it’s in adata.obs
print('STIMULATED' in adata.uns)          # Should be True if it’s in adata.uns


In [None]:
axs = sca.plotting.plot_abs_bfs(adata, terms=adata.uns['terms'], yt_step=1, scale_y=2.45)

In [None]:
adata_ct.uns['active_terms'] = adata.uns['terms'][intr_cvae.model.decoder.nonzero_terms()]

In [None]:
def filter_set_scores(scores, adata, filter_v=2.31):
    for k in scores:
        print(k)
        mask = np.abs(scores[k]['bf']) > filter_v
        s = sum(mask)
        if s > 0:
            sort = np.argsort(np.abs(scores[k]['bf'])[mask])[::-1]
            enriched_terms = adata.uns['terms'][mask][sort]
            
            print(enriched_terms)
            print(scores[k]['bf'][mask][sort])
            
            for term in enriched_terms:
                adata.obs[term] = adata.obsm['X_cvae'][:, adata.uns['active_terms'] == term]

In [None]:
filter_set_scores(scores_c_ct, adata_ct, filter_v=1)

In [None]:
#check_terms = ['SIGNALING_BY_GPCR', 'CLASS_A1_RHODOPSIN_LIKE_RECEPT', 'IMMUNE_SYSTEM',
#               'RNA_POL_I_RNA_POL_III_AND_MITO', 'METABOLISM_OF_CARBOHYDRATES',
#               'CYTOKINE_SIGNALING_IN_IMMUNE_S', 'APOPTOTIC_EXECUTION_PHASE',
#               'METABOLISM_OF_NUCLEOTIDES', 'BIOLOGICAL_OXIDATIONS',
#               'INTERFERON_GAMMA_SIGNALING']

In [None]:
check_terms = ['INTERFERON_SIGNALING', 'INTERFERON_ALPHA_BETA_SIGNALIN',
 'GPCR_DOWNSTREAM_SIGNALING', 'IMMUNE_SYSTEM', 'SIGNALING_BY_GPCR',
 'METABOLISM_OF_CARBOHYDRATES', 'CYTOKINE_SIGNALING_IN_IMMUNE_S',
 'PLATELET_ACTIVATION_SIGNALING_', 'METABOLISM_OF_AMINO_ACIDS_AND_',
 'METABOLISM_OF_NUCLEOTIDES']

In [None]:
intr_cvae.term_genes('IMMUNE_SYSTEM')

In [None]:
idxs = [adata_ct.uns['active_terms'].tolist().index(t) for t in check_terms]

In [None]:
idxs

In [None]:
adata_pl = sc.AnnData(X=adata.obsm['X_cvae'][:, idxs])

In [None]:
adata_pl.var_names = check_terms

In [None]:
adata_pl.obs['ct_cond'] = 'stub'

In [None]:
for i in range(adata.n_obs):
    adata_pl.obs['ct_cond'][i] = adata.obs['cell_type_joint'][i] + '_' + adata.obs['condition_merged'][i]

In [None]:
sc.pl.stacked_violin(adata_pl, var_names=check_terms, groupby='ct_cond', swap_axes=True)

In [None]:
query_ct = adata_pl[adata.obs.study == 'Kang'].obs.ct_cond.unique().tolist()

In [None]:
adata_pl_q = adata_pl[adata_pl.obs.ct_cond.isin(query_ct)]

In [None]:
sc.pl.stacked_violin(adata_pl_q, var_names=check_terms[:5], groupby='ct_cond', swap_axes=True)

In [None]:
sc.pl.stacked_violin(adata_pl_q, var_names=['INTERFERON_SIGNALING', 'GPCR_DOWNSTREAM_SIGNALING', 'SIGNALING_BY_GPCR', 'METABOLISM_OF_CARBOHYDRATES'], groupby='ct_cond', swap_axes=True)

In [None]:
query_ct = adata.obs.cell_type_joint[adata.obs.batch_join == 'Kang (query)'].unique()

In [None]:
scores_ct_q = {k: v for k, v in scores_ct.items() if k in query_ct}

In [None]:
for ct in query_ct.categories:
    print(ct)
print(adata.uns.keys())

In [None]:
# adata_pl.obs['ct_cond'] = [
#     adata.obs['cell_type_joint'][i] + '_' + adata.obs['condition_merged'][i] 
#     for i in range(adata.n_obs)
# ]

In [None]:
print(adata.obs.columns)
print(adata.obs['study'].unique())
adata.obs['study'] = adata.obs['study'].astype(str)  # Ensure it's treated as a string
adata.obs['study'] = adata.obs['study'].astype('category')


In [None]:
for ct in query_ct.categories:
    print(ct)
    scores = scores_ct[ct]['bf']
    sort = np.argsort(np.abs(scores))[::-1]
    top_10_terms = adata.uns['directions'][sort][:10]
    idxs = [adata.uns['active_terms'].tolist().index(t) for t in top_10_terms]
    
    adata_pl = sc.AnnData(X=adata.obsm['X_cvae'][:, idxs])
    adata_pl.var_names = top_10_terms.tolist()
    adata_pl.obs['ct_cond'] = 'stub'
    for i in range(adata.n_obs):
        adata_pl.obs['ct_cond'][i] = adata.obs['cell_type_joint'][i] + '_' + adata.obs['condition_merged'][i]
    query_typs = adata_pl[adata.obs.study == 'Kang'].obs.ct_cond.unique().tolist()
    adata_pl_q = adata_pl[adata_pl.obs.ct_cond.isin(query_typs)]
    sc.pl.stacked_violin(adata_pl_q, var_names=top_10_terms, groupby='ct_cond', swap_axes=True)

In [None]:
for ct in query_ct.categories:
    print("Current category:", ct)
    if ct not in scores_ct:
        print(f"Warning: '{ct}' not found in scores_ct. Skipping.")
        continue

    scores = scores_ct[ct]['bf']
    print("Scores:", scores)
    
    sort = np.argsort(np.abs(scores))[::-1]

    top_10_terms = adata.uns['directions'][sort][:10]
    print("Top 10 terms:", top_10_terms)
    
    idxs = [adata.uns['active_terms'].tolist().index(t) for t in top_10_terms]
    print("Indices in `active_terms`:", idxs)
    
    # Proceed as before
    adata_pl = sc.AnnData(X=adata.obsm['X_cvae'][:, idxs])
    adata_pl.var_names = top_10_terms.tolist()
    adata_pl.obs['ct_cond'] = 'stub'
    for i in range(adata.n_obs):
        adata_pl.obs['ct_cond'][i] = adata.obs['cell_type_joint'][i] + '_' + adata.obs['condition_merged'][i]
    
    query_typs = adata.obs.query("study == 'Kang'")['ct_cond'].unique().tolist()
    print("Query types:", query_typs)
    print("Query types:", query_typs)
    
    adata_pl_q = adata_pl[adata_pl.obs.ct_cond.isin(query_typs)]
    print("Filtered adata_pl_q observations:", adata_pl_q.obs.shape)

    # Check if `stacked_violin` runs successfully
    sc.pl.stacked_violin(adata_pl_q, var_names=top_10_terms, groupby='ct_cond', swap_axes=True)


In [None]:

for ct in query_ct.categories:
    print(ct)
    scores = scores_ct[ct]['bf']
    sort = np.argsort(np.abs(scores))[::-1]
    top_10_terms = adata.uns['directions'][sort][:10]
    idxs = [adata.uns['active_terms'].tolist().index(t) for t in top_10_terms]
    
    adata_pl = sc.AnnData(X=adata.obsm['X_cvae'][:, idxs])
    adata_pl.var_names = top_10_terms.tolist()
    adata_pl.obs['ct_cond'] = 'stub'
    for i in range(adata.n_obs):
        adata_pl.obs['ct_cond'][i] = adata.obs['cell_type_joint'][i] + '_' + adata.obs['condition_merged'][i]
        
    query_typs = adata_pl.obs['ct_cond'].unique().tolist()
    adata_pl_q = adata.var_names_make_unique()
    # adata_pl_q = adata_pl[adata_pl.obs()]
    # adata_pl = adata_pl.unique()
    sc.pl.stacked_violin(adata_pl_q, var_names=top_10_terms, groupby='ct_cond', swap_axes=True)

In [None]:
adata.obs['condition_merged'] = adata.obs['condition_merged'].astype(str)
adata.obs['condition_merged'][adata.obs['condition_merged']=='stimulated'] = 'IFN-beta'

In [None]:
for ct in query_ct:
    print(ct)
    adata_ct = adata[adata.obs.cell_type_joint == ct]
    scores_c_ct = intr_cvae.latent_enrich('condition_merged', comparison="control", directions=directions, adata=adata_ct, n_perm=50000, exact=True)
    
    scores = scores_c_ct['IFN-beta']['bf']
    sort = np.argsort(np.abs(scores))[::-1]
    top_10_terms = adata.uns['full_terms'][sort][:10]
    idxs = [adata.uns['active_terms'].tolist().index(t) for t in top_10_terms]
    
    adata_pl = sc.AnnData(X=adata.obsm['X_cvae'][:, idxs])
    terms = [t[:45] for t in top_10_terms]
    adata_pl.var_names = terms
    adata_pl.obs['ct_cond'] = 'stub'
    for i in range(adata.n_obs):
        adata_pl.obs['ct_cond'][i] = adata.obs['cell_type_joint'][i] + '_' + adata.obs['condition_merged'][i]
    query_typs = adata_pl[adata.obs.study == 'Kang'].obs.ct_cond.unique().tolist()
    adata_pl_q = adata_pl[adata_pl.obs.ct_cond.isin(query_typs)]
    sc.pl.stacked_violin(adata_pl_q, var_names=terms, groupby='ct_cond', swap_axes=True)