# Questions
- which TFs make into a regulon and which do not? 

- TFBS database
- level of expression 
- additional criteria? 

In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata 
import seaborn as sns
from scipy.stats import zscore
import matplotlib.pyplot as plt
import collections
from natsort import natsorted

from scipy import stats
from scipy import sparse
from sklearn.decomposition import PCA
from umap import UMAP
from statsmodels.stats.multitest import multipletests

from matplotlib.colors import LinearSegmentedColormap

from scroutines.config_plots import *
from scroutines import powerplots # .config_plots import *
from scroutines import pnmf
from scroutines import basicu
from scroutines.gene_modules import GeneModules  


In [None]:
ddir = '/u/home/f/f7xiesnm/v1_multiome/juyoun/' 
f = ddir+'L23alltime_eReg_metadata_filtered.csv'
df_scenic = pd.read_csv(f, index_col=0)
df_scenic

In [None]:
output = ddir+'regulon_overview.csv'

df_reg = df_scenic.groupby(['TF', 'is_extended', 'Consensus_name']).first()[['Region_signature_name', 'Gene_signature_name']].sort_values('TF')
df_reg.to_csv(output)
df_reg

In [None]:
tf_list = df_reg.reset_index()['TF'].unique()

num_tf = len(tf_list)
num_reg = len(df_reg)
print(num_reg, num_tf)

# profile one regulon
- TF expression 
- region expression 
- gene expression


- get ABC pseudobulk profiles (sample, A/B/C, gene)
    - log mean
    - sum -> log -> mean

In [None]:
df_reg.loc['Meis2']

In [None]:
df_this_reg = df_scenic[df_scenic['Consensus_name']=='Meis2_+_+']
reg_genes = df_this_reg['Gene'].unique()
reg_regions = df_this_reg['Region'].unique()
reg_genes.shape, reg_regions.shape

# prep RNA data

In [None]:
scores_abc = pd.read_csv("/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/scores_l23abc.csv", 
                         index_col=0,
                        )
scores_abc['scores_c-a'] = scores_abc['scores_c'] - scores_abc['scores_a']
scores_abc

In [None]:
adata = anndata.read("/u/home/f/f7xiesnm/project-zipursky/v1-bb/v1/data/v1_multiome/superdupermegaRNA_hasraw_multiome_l23.h5ad")
adata

In [None]:
adata.X = adata.raw.X

In [None]:
adata.obs['scores_a'] = scores_abc.loc[adata.obs.index,'scores_a'].copy()
adata.obs['scores_b'] = scores_abc.loc[adata.obs.index,'scores_b'].copy()
adata.obs['scores_c'] = scores_abc.loc[adata.obs.index,'scores_c'].copy()
adata.obs['scores_c-a'] = scores_abc.loc[adata.obs.index,'scores_c-a'].copy()

In [None]:
sample_labels = adata.obs['Sample'].values
time_labels = [s[:-1].replace('DR', '') for s in sample_labels]

adata.obs['sample'] = sample_labels #
adata.obs['time']   = time_labels

uniq_samples = natsorted(np.unique(sample_labels))
nr_samples = [s for s in uniq_samples if "DR" not in s]
dr_samples = [s for s in uniq_samples if "DR" in s]

uniq_conds = np.array(natsorted(np.unique(adata.obs['cond'].values)))
print(uniq_conds)

In [None]:
nr_idx = np.array([0,1,2,4,6,8,10])
dr_idx = np.array([3,5,7,9])

nr_times = np.array([6,8,10,12,14,17,21])
dr_times = np.array(       [12,14,17,21])

In [None]:
# remove mitocondria genes
adata = adata[:,~adata.var.index.str.contains(r'^mt-')]
# remove sex genes
adata = adata[:,~adata.var.index.str.contains(r'^Xist$')]

# filter genes
cond = np.ravel((adata.X>0).sum(axis=0)) > 10 # expressed in more than 10 cells
adata = adata[:,cond].copy()

In [None]:
# counts
x = adata.X
cov = np.ravel(np.sum(x, axis=1))
genes = adata.var.index.values

# CP10k
xn = (sparse.diags(1/cov).dot(x))*1e4

# log2(CP10k+1)
xln = xn.copy()
xln.data = np.log2(xln.data+1)

adata.layers[    'norm'] = np.array(xn.todense())
adata.layers[ 'lognorm'] = np.array(xln.todense())

In [None]:
# adata.obs['sample'].unique()
import re

todo_conds = [
    'P12DR', 'P14DR', 'P17DR', 'P21DR',
    'P6', 'P8', 'P10', 'P12', 'P14', 'P17', 'P21', 
]
todo_samps = [
    'P12DRa', 'P12DRb',
    'P14DRa', 'P14DRb',
    'P17DRa', 'P17DRb',
    'P21DRa', 'P21DRb',
    'P6a', 'P6b', 'P6c', 
    'P8a', 'P8b', 'P8c', 
    'P10a', 'P10b', 
    'P12a', 'P12b', 'P12c', 
    'P14a', 'P14b',
    'P17a', 'P17b', 
    'P21a', 'P21b', 
]
todo_conds_t = np.array([int(re.sub(r'[a-zA-Z]', '', a)) for a in todo_conds])
todo_samps_t = np.array([int(re.sub(r'[a-zA-Z]', '', a)) for a in todo_samps])
print(todo_conds_t)
print(todo_samps_t)

def mean_over_samples(mmat_res_samp):
    """25 samples to 11 conditions
    """
    assert mmat_res_samp.shape[0] == 25
    
    mmat_res_samp_mean = np.zeros(mmat_res_samp.shape)[:11]
    mmat_res_samp_mean[0] = np.mean(mmat_res_samp[ :2], axis=0)
    mmat_res_samp_mean[1] = np.mean(mmat_res_samp[2:4], axis=0)
    mmat_res_samp_mean[2] = np.mean(mmat_res_samp[4:6], axis=0)
    mmat_res_samp_mean[3] = np.mean(mmat_res_samp[6:8], axis=0)

    mmat_res_samp_mean[4] = np.mean(mmat_res_samp[8:11], axis=0)
    mmat_res_samp_mean[5] = np.mean(mmat_res_samp[11:14], axis=0)
    mmat_res_samp_mean[6] = np.mean(mmat_res_samp[14:16], axis=0)
    mmat_res_samp_mean[7] = np.mean(mmat_res_samp[16:19], axis=0)
    mmat_res_samp_mean[8] = np.mean(mmat_res_samp[19:21], axis=0)
    mmat_res_samp_mean[9] = np.mean(mmat_res_samp[21:23], axis=0)
    mmat_res_samp_mean[10] = np.mean(mmat_res_samp[23:  ], axis=0)
    
    return mmat_res_samp_mean

In [None]:
%%time


offset = 1
n_type = 5
frac_archetypal_cells_viz = 0.2

# mat = adata.layers['norm'][...]
# gexp_l23baseline = np.log2(np.mean(mat, axis=0)*1e2+offset) # CP10k -> CPM

bigmat_nfd = np.zeros((len(todo_samps), n_type, adata.shape[1]))
bigmat_abc = np.zeros((len(todo_samps),      3, adata.shape[1]))

for i, samp in enumerate(todo_samps):
    print(samp)
    
    # get sub
    adatasub = adata[adata.obs['sample']==samp]
    n_cells = adatasub.shape[0]
    
    # get A vs C 
    ranks_ac = adatasub.obs['scores_c-a'].rank()
    ranks_b  = adatasub.obs['scores_b'].rank()
    
    # per type
    cells_type_nfd = pd.qcut(ranks_ac, n_type, labels=False)
    for j in range(n_type):
        mat_j = adatasub[cells_type_nfd==j].layers['norm'][...]
        mmat_j = np.log2(np.mean(mat_j, axis=0)*1e2+offset) # CP10k -> CPM
        bigmat_nfd[i,j] = mmat_j
    
    # A, B, C
    num_archetypal_cells_viz = int(n_cells*frac_archetypal_cells_viz)
    
    precond_a = ranks_ac <= num_archetypal_cells_viz
    precond_c = ranks_ac > adatasub.shape[0] - num_archetypal_cells_viz
    precond_b = ranks_b  > adatasub.shape[0] - num_archetypal_cells_viz
    
    cond_a = np.all([ precond_a, ~precond_b, ~precond_c], axis=0)
    cond_b = np.all([~precond_a,  precond_b, ~precond_c], axis=0)
    cond_c = np.all([~precond_a, ~precond_b,  precond_c], axis=0)
    
    for j, cond in enumerate([cond_a, cond_b, cond_c]):
        mat_j = adatasub[cond].layers['norm'][...]
        mmat_j = np.log2(np.mean(mat_j, axis=0)*1e2+offset)# -gexp_l23baseline # CP10k -> CPM
        bigmat_abc[i,j] = mmat_j


In [None]:
redmat_abc = mean_over_samples(bigmat_abc)

In [None]:
bigmat_abc.shape, redmat_abc.shape, genes.shape

# plot

In [None]:
query = ['Fos', 'Nr4a2', 'Egr1', 'Bdnf', 'Npas4', 'Nptx2',] #'Meis2', 'Foxp1', 'Rfx3',]
query_idx = basicu.get_index_from_array(genes, query)
print(np.setdiff1d(query, tf_list))
print(query_idx)


bigmat_abc[:,:,query_idx].shape

In [None]:
n = len(query)
fig, axs = plt.subplots(1, n, figsize=(3*n,1*3), sharex=True, sharey=False)
for i, (gidx, gname) in enumerate(zip(query_idx, query)):
    ax = axs[i]
    bigmat_ig = bigmat_abc[:,:,gidx]
    redmat_ig = redmat_abc[:,:,gidx]
    ax.set_title(gname)
    
    ax.plot(todo_samps_t[8:], bigmat_ig[8:,0], 'o', markersize=5, fillstyle='none', color='C0')
    ax.plot(todo_samps_t[8:], bigmat_ig[8:,1], 'o', markersize=5, fillstyle='none', color='C1')
    ax.plot(todo_samps_t[8:], bigmat_ig[8:,2], 'o', markersize=5, fillstyle='none', color='C2')
    
    ax.plot(todo_samps_t[:8], bigmat_ig[:8,0], 's', markersize=5, fillstyle='none', color='C0', alpha=0.5)
    ax.plot(todo_samps_t[:8], bigmat_ig[:8,1], 's', markersize=5, fillstyle='none', color='C1', alpha=0.5)
    ax.plot(todo_samps_t[:8], bigmat_ig[:8,2], 's', markersize=5, fillstyle='none', color='C2', alpha=0.5)
    
    ax.plot(todo_conds_t[4:], redmat_ig[4:,0], '-', color='C0')
    ax.plot(todo_conds_t[4:], redmat_ig[4:,1], '-', color='C1')
    ax.plot(todo_conds_t[4:], redmat_ig[4:,2], '-', color='C2')
    
    ax.plot(todo_conds_t[:4], redmat_ig[:4,0], '-', color='C0', alpha=0.5)
    ax.plot(todo_conds_t[:4], redmat_ig[:4,1], '-', color='C1', alpha=0.5)
    ax.plot(todo_conds_t[:4], redmat_ig[:4,2], '-', color='C2', alpha=0.5)

    ax.grid(False)
    ax.set_xticks([6,10,14,17,21])
    sns.despine(ax=ax)

axs[0].set_xlabel('Postnatal day (P)')
axs[0].set_ylabel('Gene expr.\nlog2(CPM+1)')
# output = os.path.join(outfigdir, 'gene_plot.pdf') 
# powerplots.savefig_autodate(fig, output)
plt.show()

# prep region set - count