In [1]:
import pandas as pd
from collections import defaultdict
import numpy as np
import tqdm
import chromedriver_binary
import random
from sklearn.feature_extraction.text import TfidfVectorizer
import scanpy as sc
import anndata
from collections import OrderedDict
from bokeh.io import output_notebook, export_png, export_svg
from bokeh.plotting import figure, show
from bokeh.models import HoverTool, ColumnDataSource
from bokeh.palettes import Category20
import glasbey
output_notebook()

from IPython.display import display, HTML, Markdown

In [2]:
def load_network(networkdir):
    up = pd.read_csv(f"{networkdir}/Transcription Factor.upregulates.Transcription Factor.edges.csv", usecols=["source_label","target_label"])
    down = pd.read_csv(f"{networkdir}/Transcription Factor.downregulates.Transcription Factor.edges.csv",usecols=["source_label","target_label"])
    nodes = pd.read_csv(f"{networkdir}/Transcription Factor.nodes.csv",usecols=["label"])
    return up,down,nodes

def find_neighbors(edge_list):
    neighbors = defaultdict(list)
    for (source, target) in edge_list:
        neighbors[source].append(target) 
    return neighbors

up,down,nodes = load_network("../kg_assertions_for_neo4j")

# Network UMAP

Step 1. Convert the network to a GMT.

In [3]:
up = pd.read_csv("../kg_assertions_for_neo4j/Transcription Factor.upregulates.Transcription Factor.edges.csv", usecols=["source_label","target_label"])
down = pd.read_csv("../kg_assertions_for_neo4j/Transcription Factor.downregulates.Transcription Factor.edges.csv",usecols=["source_label","target_label"])
edges = pd.concat([up, down], ignore_index=True)

edgelist = list(zip(edges["source_label"], edges["target_label"]))
gmt = {}
for (source, target) in edgelist:
    if source in gmt.keys():
        gmt[source].append(target)
    else:
        gmt[source] = [target]

with open("./network_gmt.gmt", "w") as file:
    for s,t in gmt.items():
        file.write(str(s) + "\t\t" + "\t".join(t) + "\n")

Step 2. Build the UMAP using code from the Enrichr processing libraries

In [4]:
libname = 'network_gmt'
libdir = './fig_data' # directory where library is

def get_scatter_library(libname, local, augmented):
    '''
    Processes the GMT file for the input Enrichr library {lib} and returns a 
    dictionary where the keys correspond to gene set names, and the value for
    each key is a space-delimited string containing all genes belonging to
    the gene set: 
    {
        "gene set name": "gene_1 gene_2 gene_3 ... gene_n", 
        ...
    }
    In addition, this function can augment each gene set library using ARCHS4 
    gene-gene co-expression data. For each gene set, the most co-expressed genes 
    (determined by summing the coexpression coefficients across all genes)
    are added to the gene set before visualization. 
    '''
    ### open local file or from Enrichr
    if local: 
        print(f"\tOpening library locally from '{libdir}'...")
        with open(f"{libdir}/{libname}.gmt", 'r') as f:
            lines = f.readlines()

    ### variables to store gene set data
    lib_dict = OrderedDict()

    if augmented:
        print("\tProcessing gene sets and augmenting with ARCHS4...")
    else:
        print("\tProcessing gene sets without augmentation...")

    for line in lines:
        tokens = line.split("\t\t")
        term = tokens[0]
        genes = [x.split(',')[0].strip() for x in tokens[1].split('\t')]
        lib_dict[term] = ' '.join(genes)
    
    return lib_dict

def process_scatterplot(libdict, nneighbors=30, mindist=0.1, spread=1.0, maxdf=1.0, mindf=1):
    print("\tTF-IDF vectorizing gene set data...")
    vec = TfidfVectorizer(max_df=maxdf, min_df=mindf)
    X = vec.fit_transform(libdict.values())
    print(X.shape)
    adata = anndata.AnnData(X)
    adata.obs.index = libdict.keys()

    print("\tPerforming Leiden clustering...")
    ### the n_neighbors and min_dist parameters can be altered
    sc.pp.neighbors(adata, n_neighbors=nneighbors)
    sc.tl.leiden(adata, resolution=1.0)
    sc.tl.umap(adata, min_dist=mindist, spread=spread, random_state=42)

    new_order = adata.obs.sort_values(by='leiden').index.tolist()
    adata = adata[new_order, :]
    adata.obs['leiden'] = 'Cluster ' + adata.obs['leiden'].astype('object')

    df = pd.DataFrame(adata.obsm['X_umap'])
    df.columns = ['x', 'y']

    df['cluster'] = adata.obs['leiden'].values
    df['term'] = adata.obs.index
    df['genes'] = [libdict[l] for l in df['term']]

    return df

def get_scatter_colors(df):
    clusters = pd.unique(df['cluster']).tolist()
    colors = glasbey.create_palette(palette_size=len(clusters), lightness_bounds=(0,100), chroma_bounds=(50,100), as_hex=True)
    color_mapper = {clusters[i]: colors[i % 20] for i in range(len(clusters))}
    return color_mapper

def get_scatterplot(scatterdf):
    df = scatterdf.copy()
    color_mapper = get_scatter_colors(df)
    df['color'] = df['cluster'].apply(lambda x: color_mapper[x])
    df['cluster_number'] = df['cluster'].apply(lambda x: int(x.split(" ")[-1]))
    print(df['cluster_number'])
    df.sort_values(by=['cluster_number'], inplace=True)
    df.drop(columns = ['cluster_number'])
    tooltips = [
        ("Gene Set", "@gene_set"),
        ("Cluster", "@label")
    ]
        
    hover_emb = HoverTool(tooltips=tooltips)
    tools_emb = [hover_emb, 'pan', 'wheel_zoom', 'reset', 'save']

    plot_emb =  figure(
        width=500*2, 
        height=400*2, 
        tools=tools_emb,
        output_backend='svg'
    )

    source = ColumnDataSource(
        data=dict(
            x = df['x'],
            y = df['y'],
            gene_set = df['term'],
            colors = df['color'],
            label = df['cluster']
        )
    )

    # hide axis labels and grid lines
    plot_emb.xaxis.major_tick_line_color = None
    plot_emb.xaxis.minor_tick_line_color = None
    plot_emb.yaxis.major_tick_line_color = None
    plot_emb.yaxis.minor_tick_line_color = None
    plot_emb.grid.grid_line_color = None
    plot_emb.xaxis.major_label_text_font_size = '0pt'
    plot_emb.yaxis.major_label_text_font_size = '0pt' 
    
    plot_emb.xaxis.axis_label = "UMAP-1"
    plot_emb.yaxis.axis_label = "UMAP-2"
    plot_emb.xaxis.axis_label_text_font_size = '40pt'
    plot_emb.yaxis.axis_label_text_font_size = '40pt'
    plot_emb.xaxis.axis_label_text_font_style = "normal"
    plot_emb.yaxis.axis_label_text_font_style = "normal"


    s = plot_emb.scatter(
        'x', 
        'y', 
        size = 3 *2, 
        source = source, 
        marker='circle',
        fill_color = 'colors', 
        color='colors',
        legend_group = 'label'
    )
    plot_emb.legend.label_text_font_size = '18pt'
    plot_emb.legend.glyph_height = 20
    plot_emb.legend.glyph_width = 20

    print("legend", plot_emb.legend[0])
    plot_emb.add_layout(plot_emb.legend[0], 'right')
    
    return plot_emb, source

In [5]:
l_dict = get_scatter_library(libname, augmented=False, local=True)
print(f"Now processing {libname}")
## defaults: nneighbors=30, mindist=0.1, spread=1.0, maxdf=1.0, mindf=1
scatter_df = process_scatterplot(
    l_dict, 
    nneighbors=5,
    mindist=.05,
)
print(f"\tDone!")

# Display Scatter Plots
caption1 = f"**Figure 1. Scatterplot of all terms in the {libname} gene set library.** Each point represents a term in the library. \
    Term frequency-inverse document frequency (TF-IDF) values were computed for the gene set corresponding to each term, and UMAP was  \
    applied to the resulting values. The terms are plotted based on the first two UMAP dimensions. Generally, terms with more similar \
    gene sets are positioned closer together. Terms are colored by automatically identified clusters computed with the Leiden algorithm \
    applied to the TF-IDF values. Hovering over points will display the term and the automatically assigned cluster."

plot, src = get_scatterplot(scatter_df)
print(plot)
display(HTML(f"<div style='font-size:1.5rem;'>Scatter plot visualization for {libname}.</div>"))
show(plot)

	Opening library locally from './fig_data'...
	Processing gene sets without augmentation...
Now processing network_gmt
	TF-IDF vectorizing gene set data...
(700, 1550)
	Performing Leiden clustering...


         Falling back to preprocessing with `sc.pp.pca` and default params.

 To achieve the future defaults please pass: flavor="igraph" and n_iterations=2.  directed must also be False to work with igraph's implementation.
  sc.tl.leiden(adata, resolution=1.0)


	Done!
0       0
1       0
2       0
3       0
4       0
       ..
695    26
696    26
697    26
698    26
699    26
Name: cluster_number, Length: 700, dtype: int64
legend Legend(id='p1039', ...)
figure(id='p1002', ...)


  adata.obs['leiden'] = 'Cluster ' + adata.obs['leiden'].astype('object')
