# Questions
- which TFs make into a regulon and which do not? 

- TFBS database
- level of expression 
- additional criteria? 

In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata 
import seaborn as sns
from scipy.stats import zscore
import matplotlib.pyplot as plt
import collections
from natsort import natsorted

from scipy import stats
from scipy import sparse
from sklearn.decomposition import PCA
from umap import UMAP
from statsmodels.stats.multitest import multipletests

from matplotlib.colors import LinearSegmentedColormap

from scroutines.config_plots import *
from scroutines import powerplots # .config_plots import *
from scroutines import pnmf
from scroutines import basicu
from scroutines.gene_modules import GeneModules  

import atac_utils

In [None]:
f = '/u/home/f/f7xiesnm/v1_multiome/res/regulon_clusters_250515.csv'
df_reg_clst = pd.read_csv(f, index_col=0)
df_reg_clst['tf'] = df_reg_clst['regulon'].apply(lambda x: x.split('_')[0])
df_reg_clst = df_reg_clst.set_index('tf')
df_reg_clst

In [None]:
ddir = '/u/home/f/f7xiesnm/v1_multiome/juyoun/' 
f = ddir+'L23alltime_eReg_metadata_filtered.csv'
# scenic metadata
df_scenic = pd.read_csv(f, index_col=0)
df_scenic


In [None]:
df_reg = df_scenic.groupby(['TF', 'is_extended', 'Consensus_name']).first()[['Region_signature_name', 'Gene_signature_name']].sort_values('TF')

scenic_regions = np.sort(df_scenic['Region'].unique()) # .shape
scenic_genes = np.sort(df_scenic['Gene'].unique()) # .shape
scenic_tfs = np.sort(df_scenic['TF'].unique())

num_reg    = len(df_reg)
num_tf     = len(scenic_tfs)
num_gene = len(scenic_genes)
num_region = len(scenic_regions)
print(num_reg, num_tf, num_gene, num_region)
df_reg

# Put all regulons together - network plot

In [None]:
df_scenic_tfp = df_scenic[((df_scenic['Gene'].isin(scenic_tfs)) & (df_scenic['TF2G_regulation']==1))]
df_tfpi = df_scenic_tfp.groupby(['TF', 'Gene']).first().reset_index()[['TF', 'Gene']] #.unstack()
df_tfpi

In [None]:
import plotly.graph_objects as go
import networkx as nx
from node2vec import Node2Vec
from sklearn.manifold import TSNE

connections = df_tfpi.values
G = nx.DiGraph()
G.add_edges_from(connections)
len(G.nodes), len(G.edges)

In [None]:
pos_v1 = nx.spring_layout(G)

node2vec = Node2Vec(G, dimensions=64, walk_length=30, num_walks=200) #, workers=2)
model = node2vec.fit(window=10, min_count=1)
X = model.wv[G.nodes]  # node IDs must be strings
X_2d = TSNE(n_components=2, random_state=0).fit_transform(X)
pos_v2 = {node: X_2d[i] for i, node in enumerate(G.nodes)}

In [None]:
def rgb_to_hex(r, g, b):
    return "#{:02X}{:02X}{:02X}".format(r, g, b)

palette = sns.color_palette('tab20', n_colors=9)
palette

In [None]:
hex_palette = np.array([rgb_to_hex(*(np.array(color)*255).astype(int)) for color in palette])
hex_palette

In [None]:
pos = pos_v2
colors = hex_palette[df_reg_clst.reindex(G.nodes)['clst'].fillna(-1).astype(int)]

# 3. Create edge traces as arrows using shapes
edge_shapes = []
for u, v in G.edges():
    x0, y0 = pos[u]
    x1, y1 = pos[v]
    edge_shapes.append(
        dict(
            type="line",
            x0=x0, y0=y0,
            x1=x1, y1=y1,
            line=dict(width=1, color="black"),
            axref='x', ayref='y',
            xref='x', yref='y'
        )
    )
    # Add arrowhead using annotation
    edge_shapes.append(
        dict(
            type="path",
            path=f"M {x0},{y0} L{x1},{y1}",
            line=dict(color="black"),
        )
    )

# 4. Create node scatter trace
node_trace = go.Scatter(
    x=[pos[n][0] for n in G.nodes()],
    y=[pos[n][1] for n in G.nodes()],
    mode='markers+text',
    text=[n for n in G.nodes()],
    textposition='top center',
    marker=dict(size=10, color=colors),#, line=dict(width=1, color='black')),
    hoverinfo='text'
)


# 5. Add arrows using annotations
annotations = []
for u, v in G.edges():
    x0, y0 = pos[u]
    x1, y1 = pos[v]
    annotations.append(
        dict(
            ax=x0,
            ay=y0,
            x=x1,
            y=y1,
            xref='x', yref='y',
            axref='x', ayref='y',
            showarrow=True,
            arrowhead=1,
            arrowsize=1,
            arrowwidth=1,
            arrowcolor="gray"
        )
    )

# 6. Create figure
fig = go.Figure(
    data=[node_trace],
    layout=go.Layout(
        title="TF-TF network",
        showlegend=False,
        hovermode='closest',
        margin=dict(b=20, l=5, r=5, t=40),
        annotations=annotations,
        xaxis=dict(showgrid=False, zeroline=False),
        yaxis=dict(showgrid=False, zeroline=False),
        width=1000,
        height=1000,
    )
)

fig.show()