# CellOracle summarise perturbation scores

In [None]:
import re
import logging as log
from pathlib import Path
import yaml

import numpy as np
import scipy as sp
import pandas as pd
import scanpy as sc
import celloracle as co
from celloracle.applications import Oracle_development_module

import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, Markdown

In [None]:
logger = log.getLogger()

In [None]:
log.info(f"CellOracle version: {co.__version__}")

In [None]:
%matplotlib inline

## Params

In [None]:
## input
perturbation_scores = {
    Path(p).parents[1].stem: p for p in [
        "/path/to/perturbation_scores_1.csv", 
        "/path/to/perturbation_scores_2.csv",
    ]
}

## 1) Load

In [None]:
log.info("load perturbation scores")

In [None]:
ps_df_all = pd.DataFrame()

for ps_name, ps_path in perturbation_scores.items():
    ps_df = pd.read_csv(ps_path, index_col=0).rename(columns={"PS": f"{ps_name}"})
    ps_df_all = ps_df_all.merge(ps_df, how='outer', left_index=True, right_index=True)

ps_df_all = ps_df_all.T

## 2) Visualise perturbation scores

### data frame

In [None]:
ps_df_all.sort_values("Osteoblast", ascending=False)[:15]

### line plot

In [None]:
try:
    plt_df = ps_df_all.apply(lambda x: (x-x.mean())/ x.std(), axis=1).loc[
        ['KLF4_OE', 'SP7_OE', 'DLX3_OE'],
        ["Suture Mes2", "Suture Mes1", "CrnOsteoPro1", "CrnOsteoPro4", "Osteoblast"]
    ].rename_axis("condition").reset_index().melt(var_name="cell type", value_name="PS", id_vars="condition")
    
    with plt.rc_context({"figure.figsize":(5,2)}):
        # sns.lineplot(data=plt_df, y="PS", x="cell type", hue="condition")
        g = sns.FacetGrid(data=plt_df, col='condition', col_wrap=1, height=1.5, aspect=5/1.5)
        g.map(sns.lineplot, 'cell type', 'PS')
    
    plt.xticks(rotation=45)
except:
    log.exception("could not plot line plot")

In [None]:
pstd = ps_df_all.std(axis=1)
pmean = ps_df_all.mean(axis=1)
pmax = ps_df_all.max(axis=1)

def entropy(x):
    x = x - x.min() + 1e-12
    p = x/x.sum()
    elem = p * np.log(x)
    return -sum(elem)

pent = ps_df_all.apply(entropy, axis=1)

### heatmap

#### overexpression

In [None]:
plt_df = ps_df_all.loc[ps_df_all.index.str.endswith("_OE"),:]

In [None]:
top_n = 5
top_TFs = []

for c in plt_df:
    top_TFs.extend(plt_df.sort_values(c, ascending=False)[:top_n].index.tolist())
    
top_TFs = list(set(top_TFs))

", ".join(top_TFs)

In [None]:
plt_df = ps_df_all.loc[top_TFs,:].apply(lambda x: (x - x.mean()) / x.std(), axis = 0)
plt_df.index = [c.split('_')[0] for c in plt_df.index]

from scipy.spatial import distance
from scipy.cluster import hierarchy

col_linkage = hierarchy.linkage(
    distance.pdist(plt_df.T), 
    method = 'ward',
)

# optimal leaf ordering for rows
col_linkage = hierarchy.optimal_leaf_ordering(
    col_linkage,
    distance.pdist(plt_df.T),
    
)

sns.clustermap(
    plt_df, 
    z_score=1, 
    col_linkage = col_linkage,
    cmap = 'magma',
    linewidths = 0.005,
    linecolor = 'white',
    figsize = (5,8),
)

#### knockout

In [None]:
plt_df = ps_df_all.loc[ps_df_all.index.str.endswith("_KO"),:]

In [None]:
top_n = 5
top_TFs = []

for c in plt_df:
    top_TFs.extend(plt_df.sort_values(c, ascending=False)[:top_n].index.tolist())
    
top_TFs = list(set(top_TFs))

", ".join(top_TFs)

In [None]:
plt_df = ps_df_all.loc[top_TFs,:].apply(lambda x: (x - x.mean()) / x.std(), axis = 0)
plt_df.index = [c.split('_')[0] for c in plt_df.index]

from scipy.spatial import distance
from scipy.cluster import hierarchy

col_linkage = hierarchy.linkage(
    distance.pdist(plt_df.T), 
    method = 'ward',
)

# optimal leaf ordering for rows
col_linkage = hierarchy.optimal_leaf_ordering(
    col_linkage,
    distance.pdist(plt_df.T),
    
)

sns.clustermap(
    plt_df, 
    z_score=1, 
    col_linkage = col_linkage,
    cmap = 'magma',
    linewidths = 0.005,
    linecolor = 'white',
    figsize = (5,8),
)