# Post-process after running Scenic plus

In [None]:
import os
import sys
import pickle
import dill
import tempfile
import logging as log
import warnings
from pathlib import Path

import numpy as np
import scanpy as sc
import pandas
import pyranges
import pybiomart as pbm
import networkx as nx
import json

from scenicplus.scenicplus_class import create_SCENICPLUS_object
from scenicplus.wrappers.run_scenicplus import run_scenicplus
from scenicplus.preprocessing.filtering import apply_std_filtering_to_eRegulons
from scenicplus.eregulon_enrichment import score_eRegulons
from scenicplus.dimensionality_reduction import run_eRegulons_tsne, run_eRegulons_umap
from scenicplus.cistromes import TF_cistrome_correlation, generate_pseudobulks
from scenicplus.RSS import *
from scenicplus.networks import create_nx_tables, create_nx_graph, plot_networkx, export_to_cytoscape
from pycisTopic.diff_features import find_highly_variable_features

from IPython.display import display, Markdown

In [None]:
logger = log.getLogger()
warnings.simplefilter(action = 'ignore')

In [None]:
sc.settings.set_figure_params(dpi=150, frameon=False, figsize=(10, 10), facecolor='white')

## Parameters

In [None]:
thresholds = {
    'rho': [-0.4, 0.4],
    'n_targets': 0,
}

cell_type_col = "GEX_celltype"
work_dir = Path("/path/to/work_dir")
n_cores = 8

tmp_dir = tempfile.mkdtemp()

In [None]:
scenicplus_path = work_dir / "scenicplus"

In [None]:
display(Markdown(f"""
**parameters:**
- **thresholds:**
  - **rho:** *{thresholds['rho']}*
  - **n_targets:** *{thresholds['n_targets']}*
- **scenic plus object cell type column:** *{cell_type_col}*
- **working directory for output files:** *{work_dir.resolve()}*
"""))

## 1) Load

In [None]:
log.info("Loading scenic object...")

with open(scenicplus_path / "scplus_obj.pkl", "rb") as f:
    scplus_obj = dill.load(f)

update annotation (uncomment if required)

In [None]:
# annot_update = pd.read_csv("updated_cell_metadata.csv", index_col=0)["updated_annotation"]

# scplus_obj.metadata_cell["updated_annotation"] = scplus_obj.metadata_cell.index.map(annot_update)

# scplus_obj.metadata_cell[cell_type_col] = [
#     row["updated_annotation"] if not pd.isna(row["updated_annotation"]) else row[cell_type_col] 
#     for _, row in scplus_obj.metadata_cell.iterrows()
# ]

subsetting (uncomment if required)

In [None]:
sel_cells = None
# sel_cells = list(set(scplus_obj.cell_names[scplus_obj.metadata_cell[cell_type_col].isin(new_annot['updated_annotation'].unique())]))

# scplus_obj.subset(
#     cells = sel_cells,
#     regions = None,
#     genes = None,
#     return_copy = False,
# )

In [None]:
cell_type_vc = scplus_obj.metadata_cell[cell_type_col].value_counts()
cell_type_vc

## 2) Standard filtering

In [None]:
log.info("Apply standard filtering to eRegulons...")
apply_std_filtering_to_eRegulons(scplus_obj)

## 3) Re-calculate AUCell scores

In [None]:
log.info("Recalculate AUCell after filtering...")

log.info("...load gene rankings")
with open(scenicplus_path / "region_ranking.pkl", "rb") as f:
    region_ranking = dill.load(f)

with open(scenicplus_path / "gene_ranking.pkl", "rb") as f:
    gene_ranking = dill.load(f)

In [None]:
if sel_cells:
    gene_ranking.subset(cells = sel_cells)
    region_ranking.subset(cells = sel_cells)

In [None]:
log.info("...score regions")
score_eRegulons(
    scplus_obj,
    ranking = region_ranking,
    eRegulon_signatures_key = 'eRegulon_signatures_filtered',
    key_added = 'eRegulon_AUC_filtered',
    enrichment_type= 'region',
    auc_threshold = 0.05,
    normalize = False,
    n_cpu = n_cores
)

log.info("...score genes")
score_eRegulons(
    scplus_obj,
    gene_ranking,
    eRegulon_signatures_key = 'eRegulon_signatures_filtered',
    key_added = 'eRegulon_AUC_filtered',
    enrichment_type = 'gene',
    auc_threshold = 0.05,
    normalize= False,
    n_cpu = n_cores
)

### embedding

In [None]:
log.info("Calculate embeddings...")

log.info("...UMAP")
run_eRegulons_umap(
    scplus_obj = scplus_obj,
    auc_key = 'eRegulon_AUC_filtered',
    reduction_name = 'eRegulons_UMAP', 
)

log.info("...tSNE")
run_eRegulons_tsne(
    scplus_obj = scplus_obj,
    auc_key = 'eRegulon_AUC_filtered',
    reduction_name = 'eRegulons_tSNE',
)

In [None]:
import seaborn as sns
from scenicplus.dimensionality_reduction import plot_metadata_given_ax

fig, axs = plt.subplots(ncols=2, figsize = (16, 8))
plot_metadata_given_ax(
    scplus_obj=scplus_obj,
    ax = axs[0],
    reduction_name = 'eRegulons_UMAP',
    variable = cell_type_col,
)
plot_metadata_given_ax(
    scplus_obj=scplus_obj,
    ax = axs[1],
    reduction_name = 'eRegulons_tSNE',
    variable = cell_type_col,
)
fig.tight_layout()
sns.despine(ax = axs[0]) #remove top and right edge of axis border
sns.despine(ax = axs[1]) #remove top and right edge of axis border
plt.show()

## 4) Find high QC regulons

In [None]:
log.info("Determine high QC regulons...")

log.info("...generate pseudobulks")
generate_pseudobulks(
    scplus_obj = scplus_obj,
    variable = cell_type_col,
    auc_key = 'eRegulon_AUC_filtered',
    signature_key = 'Gene_based',
    nr_cells = min(cell_type_vc.min(), 10),
    nr_pseudobulks = 100,
)

generate_pseudobulks(
    scplus_obj = scplus_obj,
    variable = cell_type_col,
    auc_key = 'eRegulon_AUC_filtered',
    signature_key = 'Region_based',
    nr_cells = min(cell_type_vc.min(), 10),
    nr_pseudobulks = 100,
)

In [None]:
log.info("...cistrome correlation")
TF_cistrome_correlation(
    scplus_obj,
    use_pseudobulk = True,
    variable = cell_type_col,
    auc_key = 'eRegulon_AUC_filtered',
    signature_key = 'Gene_based',
    out_key = 'filtered_gene_based'
)

TF_cistrome_correlation(
    scplus_obj,
    use_pseudobulk = True,
    variable = cell_type_col,
    auc_key = 'eRegulon_AUC_filtered',
    signature_key = 'Region_based',
    out_key = 'filtered_region_based',
)

### plot corr

In [None]:
import seaborn as sns
import numpy as np

n_targets = [int(x.split('(')[1].replace('r)', '')) for x in scplus_obj.uns['TF_cistrome_correlation']['filtered_region_based']['Cistrome']]
rho = scplus_obj.uns['TF_cistrome_correlation']['filtered_region_based']['Rho'].to_list()
adj_pval = scplus_obj.uns['TF_cistrome_correlation']['filtered_region_based']['Adjusted_p-value'].to_list()

fig, ax = plt.subplots(figsize = (10, 5))
sc = ax.scatter(rho, n_targets, c = -np.log10(adj_pval), s = 5)
ax.set_xlabel('Correlation coefficient')
ax.set_ylabel('nr. target regions')
#ax.hlines(y = thresholds['n_targets'], xmin = min(rho), xmax = max(rho), color = 'black', ls = 'dashed', lw = 1)
ax.vlines(x = thresholds['rho'], ymin = 0, ymax = max(n_targets), color = 'black', ls = 'dashed', lw = 1)
ax.text(x = thresholds['rho'][0], y = max(n_targets), s = str(thresholds['rho'][0]))
ax.text(x = thresholds['rho'][1], y = max(n_targets), s = str(thresholds['rho'][1]))
sns.despine(ax = ax)
fig.colorbar(sc, label = '-log10(adjusted_pvalue)', ax = ax)
plt.show()

### select

In [None]:
log.info("...select")
selected_cistromes = scplus_obj.uns['TF_cistrome_correlation']['filtered_region_based'].loc[
    np.logical_or(
        scplus_obj.uns['TF_cistrome_correlation']['filtered_region_based']['Rho'] > thresholds['rho'][1],
        scplus_obj.uns['TF_cistrome_correlation']['filtered_region_based']['Rho'] < thresholds['rho'][0]
    )
]['Cistrome'].to_list()

In [None]:
selected_eRegulons = [x.split('_(')[0] for x in selected_cistromes]
selected_eRegulons_gene_sig = [
    x for x in scplus_obj.uns['eRegulon_signatures_filtered']['Gene_based'].keys()
    if x.split('_(')[0] in selected_eRegulons
]
selected_eRegulons_region_sig = [
    x for x in scplus_obj.uns['eRegulon_signatures_filtered']['Region_based'].keys()
    if x.split('_(')[0] in selected_eRegulons
]

In [None]:
log.info("...store in scenicplus object")
#save the results in the scenicplus object
scplus_obj.uns['selected_eRegulon'] = {'Gene_based': selected_eRegulons_gene_sig, 'Region_based': selected_eRegulons_region_sig}
log.info(f'selected: {len(selected_eRegulons_gene_sig)} eRegulons')

### plot dotplot-heatmap

In [None]:
dgem = scplus_obj.to_df('EXP')
dgem = dgem.T / dgem.T.sum(0) * 10**6
dgem = np.log1p(dgem).T

In [None]:
regs_not_in_gex_genes = []
gex_genes = set(dgem.columns)
for c in scplus_obj.uns['eRegulon_AUC_filtered']['Region_based']:
    if c.split('_')[0] not in gex_genes:
        regs_not_in_gex_genes.append(c)

scplus_obj.uns['eRegulon_AUC_filtered']['Region_based'] = scplus_obj.uns['eRegulon_AUC_filtered']['Region_based'].drop(columns=regs_not_in_gex_genes)

In [None]:
regs_not_in_gex_genes = []
gex_genes = set(dgem.columns)
for c in scplus_obj.uns['eRegulon_AUC_filtered']['Gene_based']:
    if c.split('_')[0] not in gex_genes:
        regs_not_in_gex_genes.append(c)

scplus_obj.uns['eRegulon_AUC_filtered']['Gene_based'] = scplus_obj.uns['eRegulon_AUC_filtered']['Gene_based'].drop(columns=regs_not_in_gex_genes)

In [None]:
import plotnine
from plotnine import ggplot, geom_point, aes, scale_fill_distiller, theme_bw, geom_tile, theme, element_text, element_blank, labs, theme_minimal
from plotnine.facets import facet_grid
from scenicplus.plotting.dotplot import generate_dotplot_df

def mod_heatmap_dotplot(
    scplus_obj: SCENICPLUS,
    size_matrix: pd.DataFrame,
    dot_color_matrix: pd.DataFrame,
    color_matrix: pd.DataFrame,
    scale_size_matrix: bool = True,
    scale_color_matrix: bool = True,
    group_variable: str = None,
    subset_eRegulons: list = None,
    sort_by: str = 'color_val',
    index_order: list = None,
    save: str = None,
    figsize: tuple = (5, 8),
    split_repressor_activator: bool = True,
    plot_regulons = ["activator", "repressor"],
    orientation: str = 'vertical'
):
    """
    Function to generate dotplot dataframe from cistrome AUC enrichment

    Parameters
    ----------
    scplus_obj: `class::SCENICPLUS`
        A :class:`SCENICPLUS` object.
    size_matrix: pd.DataFrame
        A pd.DataFrame containing values to plot using size scale.
    dot_color_matrix: pd.DataFrame
        A pd.DataFrame containing values to plot using dot color scale.
    color_matrix
        A pd.DataFrame containing values to plot using color scale.
    scale_size_matrix: bool
        Scale size matrix between 0 and 1 along index.
    scale_color_matrix: bool
        Scale color matrix between 0 and 1 along index.
    group_variable: str:
        Variable by which to group cell barcodes by (needed if the index of size or color matrix are cells.)
    subset_eRegulons: List
        List of eRegulons to plot.
    sort_by: str
        Sort by color_val or size_val.
    index_order: list
        Order of index to plot.
    figsize: tuple
        size of the figure (x, y).
    split_repressor_activator: bool
        Wether to split the plot on repressors/activators.
    orientation: str
        Plot in horizontal or vertical orientation
    """
    plotting_df = generate_dotplot_df(
        scplus_obj = scplus_obj,
        size_matrix = size_matrix,
        color_matrix = color_matrix,
        scale_size_matrix = scale_size_matrix,
        scale_color_matrix = scale_color_matrix,
        group_variable = group_variable,
        subset_eRegulons = subset_eRegulons)
    
    add_df = generate_dotplot_df(
        scplus_obj = scplus_obj,
        size_matrix = dot_color_matrix,
        color_matrix = color_matrix,
        scale_size_matrix = scale_size_matrix,
        scale_color_matrix = scale_color_matrix,
        group_variable = group_variable,
        subset_eRegulons = subset_eRegulons)
    
    plotting_df = plotting_df.merge(
        add_df[["size_val"]].rename(columns={"size_val": "dot_color_val"}), 
        how="left", 
        left_index=True, 
        right_index=True)

    if index_order is not None:
        if len(set(index_order) & set(plotting_df['index'])) != len(set(plotting_df['index'])):
            Warning('not all indices are provided in index_order, order will not be changed!')
        else:
            plotting_df['index'] = pd.Categorical(plotting_df['index'], categories = index_order)
            
    if len(plot_regulons) < 2:
        split_repressor_activator = False
    
    #sort values
    tmp = plotting_df[['index', 'eRegulon_name', sort_by]
        ].pivot_table(index = 'index', columns = 'eRegulon_name'
        ).fillna(0)['color_val']
    if index_order is not None:
        tmp = tmp.loc[index_order]
    idx_max = tmp.idxmax(axis = 0)
    order = pd.concat([idx_max[idx_max == x] for x in tmp.index.tolist() if len(plotting_df[plotting_df == x]) > 0]).index.tolist()
    plotting_df['eRegulon_name'] = pd.Categorical(plotting_df['eRegulon_name'], categories = order)
    plotting_df['repressor_activator'] = ['activator' if '+' in n.split('_')[1] and 'extended' not in n or '+' in n.split('_')[2] and 'extended' in n  else 'repressor' for n in plotting_df['eRegulon_name']]
    plotting_df = plotting_df[plotting_df["repressor_activator"].isin(plot_regulons)]
    
    plotnine.options.figure_size = figsize
    
    if split_repressor_activator:
        if orientation == 'vertical':
            plot = (
                ggplot(plotting_df, aes('index', 'eRegulon_name'))
                + facet_grid(
                    'repressor_activator ~ .', 
                    scales = "free", 
                    space = {'x': [1], 'y': [sum(plotting_df['repressor_activator'] == 'activator'), sum(plotting_df['repressor_activator'] == 'repressor')]})
                + geom_tile(mapping = aes(fill = 'color_val'))
                + scale_fill_distiller(type = 'div', palette = 'RdYlBu', limits = [-0.1, 1.1])
                + geom_point(
                        mapping = aes(size = 'size_val', alpha = 'dot_color_val'),
                        colour = "black")
                + theme(axis_text_x=element_text(rotation=60, hjust=1))
                + theme(axis_title_x = element_blank(), axis_title_y = element_blank()))
        elif orientation == 'horizontal':
            plot = (
                ggplot(plotting_df, aes('eRegulon_name', 'index'))
                + facet_grid(
                    '. ~ repressor_activator', 
                    scales = "free", 
                    space = {'y': [1], 'x': [sum(plotting_df['repressor_activator'] == 'activator'), sum(plotting_df['repressor_activator'] == 'repressor')]})
                + geom_tile(mapping = aes(fill = 'color_val'))
                + scale_fill_distiller(type = 'div', palette = 'RdYlBu', limits = [-0.1, 1.1])
                + geom_point(
                        mapping = aes(size = 'size_val', alpha = 'dot_color_val'),
                        colour = "black")
                + theme(axis_text_x=element_text(rotation=60, hjust=1))
                + theme(axis_title_x = element_blank(), axis_title_y = element_blank()))
    else:
        if orientation == 'vertical':
            plot = (
                ggplot(plotting_df, aes('index', 'eRegulon_name'))
                + geom_tile(mapping = aes(fill = 'color_val'))
                + scale_fill_distiller(type = 'div', palette = 'RdYlBu', limits = [-0.1, 1.1])
                + geom_point(
                        mapping = aes(size = 'size_val', alpha = 'dot_color_val'),
                        colour = "black")
                + theme(axis_text_x=element_text(rotation=60, hjust=1))
                + theme(axis_title_x = element_blank(), axis_title_y = element_blank()))
        elif orientation == 'horizontal':
            plot = (
                ggplot(plotting_df, aes('eRegulon_name', 'index'))
                + geom_tile(mapping = aes(fill = 'color_val'))
                + scale_fill_distiller(type = 'div', palette = 'RdYlBu', limits = [-0.1, 1.1])
                + geom_point(
                        mapping = aes(size = 'size_val', alpha = 'dot_color_val'),
                        colour = "black")
                + theme(axis_text_x=element_text(rotation=60, hjust=1))
                + theme(axis_title_x = element_blank(), axis_title_y = element_blank()))
    if save is not None:
        plot.save(save)
    else:
        return plot

In [None]:
from scenicplus.plotting.dotplot import heatmap_dotplot
from plotnine import labs, theme_minimal, scale_y_discrete

plotnine.options.dpi = 100

try:
    p = mod_heatmap_dotplot(
        scplus_obj = scplus_obj,
        size_matrix = scplus_obj.uns['eRegulon_AUC_filtered']['Region_based'], #specify what to plot as dot sizes, target region enrichment in this case
        dot_color_matrix = scplus_obj.uns['eRegulon_AUC_filtered']['Gene_based'],
        color_matrix = dgem, #specify  what to plot as colors, TF expression in this case
        scale_size_matrix = True,
        scale_color_matrix = True,
        group_variable = cell_type_col,
        subset_eRegulons = scplus_obj.uns['selected_eRegulon']['Gene_based'],
        # index_order = [],
        plot_regulons = ["activator"],
        figsize = (5, 20),
        orientation = 'vertical'
    ) 

    p = (
        p
        + labs(fill = 'TF expression', size = 'Region accessibility', alpha = 'TF activity')
        + theme_minimal()
        + theme(
            axis_text_x = element_text(rotation=60, hjust=1, face="bold",colour = "black"),
            axis_text_y = element_text(face="bold",colour = "black"),
        )
        + theme(axis_title_x = element_blank(), axis_title_y = element_blank())
        + scale_y_discrete(labels = lambda l: [(v.split("_")[0] if "__" not in v else v) for v in l])
    )
    
    display(p)
    # p.save()
except Exception:
    log.exception("error when plotting dotplot-heatmap")

## 5) Calculate Regulon Specificity Scores

In [None]:
log.info("Compute RSS...")

regulon_specificity_scores(
    scplus_obj,
    variable = cell_type_col,
    auc_key = 'eRegulon_AUC_filtered',
    signature_keys = ['Region_based'],
    selected_regulons = [x for x in scplus_obj.uns['selected_eRegulon']['Region_based'] if '-' not in x],
    out_key_suffix = '_filtered'
)

### plot RSS scores

In [None]:
try:
    plot_rss(scplus_obj, cell_type_col + '_filtered', num_columns=3, top_n=5, figsize = (10, 20), fontsize=10)
except Exception:
    log.exception("error when plotting RSS scores")

## 6) Plot correlation heatmap

In [None]:
from scenicplus.plotting.correlation_plot import *

correlation_heatmap(
    scplus_obj,
    auc_key = 'eRegulon_AUC_filtered',
    signature_keys = ['Gene_based'],
    selected_regulons = scplus_obj.uns['selected_eRegulon']['Gene_based'],
    fcluster_threshold = 0.1,
    figsize = (5, 5),
    fontsize = 12
)

## 7) Plot coverage plot

In [None]:
gtf_file = "/lustre/scratch126/cellgen/team205/jp30/bone_atlas/data/download/gencode.v43.basic.annotation.gtf"

plot_region = ''

plot_gene = "RUNX2"; region_padding = 0.5e5

# genes_violin_plot = [plot_gene]
col_dct = None

In [None]:
import matplotlib.pyplot as plt
import os
from scenicplus.utils import get_interaction_pr
import pyranges as pr
from scenicplus.plotting.coverageplot import *

In [None]:
import logging
import sys

class LoggingContext:
    def __init__(self, logger, level=None, handler=None, close=True):
        self.logger = logger
        self.level = level
        self.handler = handler
        self.close = close

    def __enter__(self):
        if self.level is not None:
            self.old_level = self.logger.level
            self.logger.setLevel(self.level)
        if self.handler:
            self.logger.addHandler(self.handler)

    def __exit__(self, et, ev, tb):
        if self.level is not None:
            self.logger.setLevel(self.old_level)
        if self.handler:
            self.logger.removeHandler(self.handler)
        if self.handler and self.close:
            self.handler.close()

In [None]:
# cell type selection
cts = scplus_obj.metadata_cell[cell_type_col].unique().tolist()  # or subset to specific cell types

### load track info

load bigwig files

In [None]:
# bigwig_dir = str(work_dir / 'scATAC' / 'consensus_peak_calling' / 'pseudobulk_bw_files')
# bw_dict = {x.replace('.bw', ''): os.path.join(bigwig_dir, x) for x in os.listdir(bigwig_dir) if '.bw' in x}

# with open(work_dir / 'scATAC' / 'consensus_peak_calling' / 'cellt_id_mapping.pkl', "rb") as f:
#     obs_df = pickle.load(f)
#     cellt_map = obs_df[["celltype", "celltype_id"]].drop_duplicates().set_index("celltype_id").to_dict()["celltype"]
#     bw_dict = {cellt_map[k]: v for k, v in bw_dict.items()}

In [None]:
# bw_dict = {k:v for k,v in bw_dict.items() if pyBigWig.open(v).header()}

# bw_dict = {k:v for k,v in bw_dict.items() if k in cts}

load interactions from scenicplus object

In [None]:
# pr_interact = get_interaction_pr(scplus_obj, 'hsapiens', 'hg38', inplace = False, subset_for_eRegulons_regions = True, eRegulons_key = 'eRegulons')

load gene annotations from gtf

In [None]:
# pr_gtf = pr.read_gtf(gtf_file)

### plot covereage plot

In [None]:
# gene_info = pr_gtf.df.query(f"gene_name == '{plot_gene}'")

# if not plot_region:
#     chrom = gene_info.Chromosome.tolist()[0]

#     y_min = int(gene_info.Start.min() - region_padding)
#     y_max = int(gene_info.End.max() + region_padding)

#     plot_region = f"{chrom}:{y_min}-{y_max}"

#     print(plot_region)

add annotation of Regulon binding sites below coverage plots

In [None]:
# reg_sub = scplus_obj.uns["eRegulon_metadata_filtered"].query(f"Region.str.startswith('{plot_region.split(':')[0]}')")

# add_bed = pr.PyRanges(reg_sub.Region.str.extract("(?P<Chromosome>chr[0-9]+):(?P<Start>[0-9]+)-(?P<End>[0-9]+)", expand=True).assign(Name=",".join(cts)))

filter shown interactions to selected gene TSS with other regions

In [None]:
# strand = gene_info.Strand.tolist()[0]
# print(f"strand: {strand}")

In [None]:
# if strand == "+":
#     reg_interact = pr_interact.intersect(pr.PyRanges(pd.Series(f"{chrom}:{gene_info.Start.min()}-{gene_info.Start.max()}").str.extract("(?P<Chromosome>chr[0-9]+):(?P<Start>[0-9]+)-(?P<End>[0-9]+)", expand=True)))
# elif strand == "-":
#     reg_interact = pr_interact.intersect(pr.PyRanges(pd.Series(f"{chrom}:{gene_info.End.min()}-{gene_info.End.max()}").str.extract("(?P<Chromosome>chr[0-9]+):(?P<Start>[0-9]+)-(?P<End>[0-9]+)", expand=True)))

make plot

In [None]:
# # %%debug --breakpoint /nfs/team205/jp30/scenicplus/src/scenicplus/plotting/coverageplot.py:257
# # NOTE: violonplots buggy, axis not well aligned and axis limits chosen in a strange way

# with LoggingContext(logger, level=logging.WARNING):
#     fig = coverage_plot(
#         SCENICPLUS_obj = scplus_obj,
#         bw_dict = bw_dict,
#         region = plot_region,
#         figsize = (7.5,5),
#         pr_gtf = pr_gtf,
#         color_dict = col_dct,
#         plot_order = cts,
#         pr_interact = reg_interact,
#         # genes_violin_plot = [plot_gene],
#         # sort_vln_plots = False,
#         meta_data_key = cell_type_col,
#         pr_consensus_bed = add_bed,
#         arc_rad = 0.1,
#         gene_label_offset=5,
#         bw_ymax = 5,
#         fontsize_dict = {'bigwig_label': 8, 'gene_label': 6, 'violinplots_xlabel': 0, 'title': 10, 'bigwig_tick_label': 2, 'violinplots_ylabel': 0},
#         height_ratios_dict = {'bigwig_violin': 0.5, 'genes': 0.1, 'arcs': 0.4, 'custom_ax': 0.0}
#     )

#     plt.tight_layout()
    
#     plt.savefig(f'coverage_plot_{plot_gene}.pdf', transparent=True)
#     plt.savefig(f'coverage_plot_{plot_gene}.png', transparent=True, dpi=400)

## 8) Save object

In [None]:
# log.info("Save scenicplus object...")

# with open(scenicplus_path / "scplus_obj.pkl", "wb") as f:
#     dill.dump(scplus_obj, f, protocol=-1)

## 9) Export network for Cytoscape

In [None]:
log.info("...compute HVGs and HVRs")
hvr = find_highly_variable_features(
    scplus_obj.to_df('ACC').loc[
        list(set(scplus_obj.uns['eRegulon_metadata_filtered']['Region']))
    ],
    n_top_features=1000, 
    plot = False
)
hvg = find_highly_variable_features(
    scplus_obj.to_df('EXP')[
        list(set(scplus_obj.uns['eRegulon_metadata_filtered']['Gene']))
    ].T,
    n_top_features=1000,
    plot = False
)

In [None]:
try:
    log.info("...create nx_tables")
    nx_tables = create_nx_tables(
        scplus_obj = scplus_obj,
        eRegulon_metadata_key ='eRegulon_metadata_filtered',
        subset_regions = hvr,
        subset_genes = hvg,
        subset_eRegulons = ['NEUROD1', 'NEUROG3', 'FOXA2'],
        add_differential_gene_expression = True,
        add_differential_region_accessibility = True,
        differential_variable = [cell_type_col]
    )
except Exception:
    log.exception("error when creating networkx tables")

In [None]:
try:
    log.info("...create nx_graph")
    G, pos, edge_tables, node_tables = create_nx_graph(
        nx_tables,
        use_edge_tables = ['TF2R','R2G'],
        shape_node_by = {
            'TF': {'variable': 'fixed_shape', 'fixed_shape': 'ellipse'},
            'Gene': {'variable': 'fixed_shape', 'fixed_shape': 'ellipse'},
            'Region': {'variable': 'fixed_shape', 'fixed_shape': 'diamond'}
        },
        size_node_by = {
            'TF': {'variable': 'fixed_size', 'fixed_size': 30},
            'Gene': {'variable': 'fixed_size', 'fixed_size': 15},
            'Region': {'variable': 'fixed_size', 'fixed_size': 10}
        },
        label_size_by = {
            'TF': {'variable': 'fixed_label_size', 'fixed_label_size': 18.0},
            'Gene': {'variable': 'fixed_label_size', 'fixed_label_size': 12.0},
            'Region': {'variable': 'fixed_label_size', 'fixed_label_size': 6.0}
        },
        width_edge_by = {'R2G': {'variable' : 'R2G_importance', 'max_size' :  1.5, 'min_size' : 1}},
        transparency_edge_by =  {'R2G': {'variable' : 'R2G_importance', 'min_alpha': 0.1, 'v_min': 0}},
        color_edge_by = {
            # 'R2G': {'variable' : 'TF', 'category_color' : {}},
            'TF2R': {'variable' : 'TF', 'category_color' : {"DLX3": "Red", "DLX5": "Green", "RUNX2": "Blue"}},
            'R2G': {'variable' : 'R2G_rho', 'continuous_color' : 'viridis', 'v_min': -1, 'v_max': 1}
        },
        layout = 'kamada_kawai_layout',
        scale_position_by = 250,
    )
except Exception:
    log.exception("error when creating networkx graph")

In [None]:
# plt.figure(figsize=(10,10))
# plot_networkx(
#     G, 
#     pos,
# )

In [None]:
try:
    log.info("...export")

    nx.to_pandas_edgelist(G, source="TF", target="gene").to_csv(scenicplus_path / "network.csv")
except Exception:
    log.exception("error when exporting pandas dataframe to csv")