In [1]:
# Author: A. Wenteler

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import anndata as ad 
import pickle as pkl 
import scanpy as sc 
from tqdm import tqdm

from scipy.sparse import csr_matrix

In [3]:
# Load data
sc_data_raw = ad.read_h5ad('../../data/norman_2019_raw.h5ad')

In [4]:
def preprocess_adata(adata, min_gene_counts=None, min_cell_counts=None, no_highly_var=2000):
    """
    Input is an adata object has a condition column with either "ctrl" for negative controls or GENE_SYMBOL for perturbed cells
    """
    
    adata = adata.copy()

    #filter genes 
    if min_gene_counts is not None:
        sc.pp.filter_genes(adata, min_counts=min_gene_counts)

    #filter cells
    if min_cell_counts is not None:
        sc.pp.filter_cells(adata, min_counts=min_cell_counts)

    #filter only single gene perturbations and controls
    conditions_to_keep = list()
    for cond in list(adata.obs['guide_ids']):
        if "," not in cond:
            conditions_to_keep.append(cond)
    adata = adata[adata.obs['guide_ids'].isin(conditions_to_keep), :]

    #apply preprocessing transformation
    sc.pp.normalize_total(adata, inplace=True)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, n_top_genes=no_highly_var)
    highly_variable_genes = adata.var_names[adata.var['highly_variable']]
    adata = adata[:, highly_variable_genes]

    return adata  

In [5]:
adata_pp = preprocess_adata(sc_data_raw, min_gene_counts=5, min_cell_counts=None, no_highly_var=2000)

  view_to_actual(adata)


In [6]:
# Load differentially expressed genes 
diff_genes = pkl.load(open('../../data/splits/perturb/norman_1/de_test/deg_pert_dict.pkl', 'rb'))

In [8]:
expr_matrix = adata_pp.X.todense()
perts = adata_pp.obs['guide_ids'].tolist()
expr_matrix = pd.DataFrame(expr_matrix, columns=adata_pp.var.gene_symbols)
expr_matrix['perturbations'] = perts
adata_obs = adata_pp.obs

In [9]:
expr_matrix_ctrl = expr_matrix.loc[expr_matrix['perturbations'] == '']
expr_matrix_pert = expr_matrix.loc[expr_matrix['perturbations'] != '']

In [10]:
expr_matrix_ctrl.shape

(11855, 2001)

In [None]:
df = pd.DataFrame(columns=['perturbations', 'gene', 'expression'])

In [None]:
new_rows = []

for pert, de_genes in diff_genes.items():
    for gene in de_genes: 
        gene_symbol = adata_pp.var.gene_symbols[adata_pp.var.index == gene].iloc[0]
        expression_values = expr_matrix_pert[expr_matrix_pert['perturbations'] == pert][gene_symbol].tolist()

        new_rows.append([pert, gene_symbol, expression_values])

top20_pert_df = pd.DataFrame(new_rows, columns=['perturbations', 'gene', 'expression'])

In [None]:
# Precompute the gene index to symbol mapping outside the loop
gene_symbol_mapping = adata_pp.var['gene_symbols'].to_dict()

# Loop over the perturbations and differentially expressed genes
for pert, de_genes in tqdm(diff_genes.items()):
    all_genes = set(adata_pp.var.index)  # Set of all genes
    remaining_genes = all_genes - set(de_genes)  # Compute the remaining genes

    # Get the expression matrix for the current perturbation
    expression_values_pert = expr_matrix_pert[expr_matrix_pert['perturbations'] == pert]
    temp_rows = []

    for gene in remaining_genes:
        gene_symbol = gene_symbol_mapping[gene] 

        expression_values = expression_values_pert[gene_symbol].tolist()
        temp_rows.append([pert, gene_symbol, expression_values])

    new_rows.extend(temp_rows)
    
nontop20_pert_df = pd.DataFrame(new_rows, columns=['perturbations', 'gene', 'expression'])

In [None]:
ikzf3_top20 = top20_pert_df[top20_pert_df['perturbations'] == 'IKZF3']
ikzf3_nontop20 = nontop20_pert_df[nontop20_pert_df['perturbations'] == 'IKZF3']

glb1l2_top20 = top20_pert_df[top20_pert_df['perturbations'] == 'GLB1L2']
glb1l2_nontop20 = nontop20_pert_df[nontop20_pert_df['perturbations'] == 'GLB1L2']

set_top20 = top20_pert_df[top20_pert_df['perturbations'] == 'SET']
set_nontop20 = nontop20_pert_df[nontop20_pert_df['perturbations'] == 'SET']

In [None]:
ikzf3_top20_expr = np.vstack(ikzf3_top20['expression'].values).mean(axis=0)
ikzf3_nontop20_expr = np.vstack(ikzf3_nontop20['expression'].values).mean(axis=0)

glb1l2_top20_expr = np.vstack(glb1l2_top20['expression'].values).mean(axis=0)
glb1l2_nontop20_expr = np.vstack(glb1l2_nontop20['expression'].values).mean(axis=0)

set_top20_expr = np.vstack(set_top20['expression'].values).mean(axis=0)
set_nontop20_expr = np.vstack(set_nontop20['expression'].values).mean(axis=0)

In [None]:
pert_comp = {
    "Gene": ["IKZF3", "IKZF3", "GLB1L2", "GLB1L2", "SET", "SET"],
    "Group": ["Top 20 DEGs", "Tail genes", "Top 20 DEGs", "Tail genes", "Top 20 DEGs", "Tail genes"],
    "Expression": [ikzf3_top20_expr, ikzf3_nontop20_expr, glb1l2_top20_expr, glb1l2_nontop20_expr, set_top20_expr, set_nontop20_expr]
}

In [None]:
pert_comp_df = pd.DataFrame(pert_comp)
pert_comp_df

In [None]:
expression_data = pert_comp_df.explode('Expression')
expression_data_tail = expression_data[expression_data['Group'] == 'Tail genes']
expression_data_top20 = expression_data[expression_data['Group'] == 'Top 20 DEGs']

In [None]:
avg_expression_data_tail = expression_data_tail.groupby("Gene")["Expression"].mean()

In [None]:
# set dpi = 300 
plt.figure(dpi=300)
sns.violinplot(x='Gene', y='Expression', hue='Group', data=expression_data_top20)
plt.axhline(y=0.1606, color='C1', linestyle='--', label='Mean expression of tail genes')
plt.xlabel('Perturbation')

handles, labels = plt.gca().get_legend_handles_labels()
unique_labels = dict(zip(labels, handles))
plt.legend(unique_labels.values(), unique_labels.keys(), loc='upper right')
plt.savefig('paper_figs/top20_vs_tail_genes.pdf')

In [None]:
top20_pert_expl = top20_pert_df.explode('expression')
nontop20_pert_expl = nontop20_pert_df.explode('expression')
top20_pert_expl

In [None]:
# calculate the average and standard deviation of the expression across all perturbations and genes for both top20 and non top 20
top20_avg = top20_pert_expl['expression'].mean()
top20_std = top20_pert_expl['expression'].std()
nontop20_avg = nontop20_pert_expl['expression'].mean()
nontop20_std = nontop20_pert_expl['expression'].std()
# find out the minimum of nontop20
nontop20_min = nontop20_pert_expl['expression'].min()
print(f"Top 20 average expression: {top20_avg}, Top 20 standard deviation: {top20_std}")
print(f"Non top 20 average expression: {nontop20_avg}, Non top 20 standard deviation: {nontop20_std}")
print(f"Non top 20 minimum expression: {nontop20_min}")