In [182]:
import sys
print(sys.executable)

import selenium
print(selenium.__version__)

import numpy
print(numpy.__version__)
print(numpy.__file__)

/Users/andrewjenkinsvandusen/Downloads/mount_sinai_internship/new_umap_env/bin/python
4.33.0
2.2.0
/Users/andrewjenkinsvandusen/Downloads/mount_sinai_internship/new_umap_env/lib/python3.11/site-packages/numpy/__init__.py


In [183]:
import pandas as pd
from collections import defaultdict
import numpy as np
import tqdm
import chromedriver_binary
import random
from sklearn.feature_extraction.text import TfidfVectorizer
import scanpy as sc
import anndata
from collections import OrderedDict
from bokeh.io import output_notebook, export_png, export_svg
from bokeh.plotting import figure, show
from bokeh.models import HoverTool, ColumnDataSource
from bokeh.palettes import Category20
import glasbey
output_notebook()

from IPython.display import display, HTML, Markdown
import matplotlib.colors as mcolors

In [184]:
def load_network(networkdir):
    up = pd.read_csv(f"{networkdir}/Transcription Factor.upregulates.Transcription Factor.edges.csv", usecols=["source_label","target_label"])
    down = pd.read_csv(f"{networkdir}/Transcription Factor.downregulates.Transcription Factor.edges.csv",usecols=["source_label","target_label"])
    nodes = pd.read_csv(f"{networkdir}/Transcription Factor.nodes.csv",usecols=["label"])
    return up,down,nodes

def find_neighbors(edge_list):
    neighbors = defaultdict(list)
    for (source, target) in edge_list:
        neighbors[source].append(target)
    return neighbors

up,down,nodes = load_network("/Users/andrewjenkinsvandusen/Downloads/mount_sinai_internship/tf_umap_visualization")

# Network UMAP

Step 1. Convert the network to a GMT.

In [185]:
network_dir = "/Users/andrewjenkinsvandusen/Downloads/mount_sinai_internship/tf_umap_visualization"
up = pd.read_csv(f"{network_dir}/Transcription Factor.upregulates.Transcription Factor.edges.csv", usecols=["source_label","target_label"])
down = pd.read_csv(f"{network_dir}/Transcription Factor.downregulates.Transcription Factor.edges.csv",usecols=["source_label","target_label"])
edges = pd.concat([up, down], ignore_index=True)

edgelist = list(zip(edges["source_label"], edges["target_label"])) # generates list of tuples
gmt = {}
for (source, target) in edgelist:
    if source in gmt.keys():
        gmt[source].append(target)
    else:
        gmt[source] = [target]

with open("./network_gmt.gmt", "w") as file:
    for s,t in gmt.items():
        file.write(str(s) + "\t\t" + "\t".join(t) + "\n")

Step 2. Build the UMAP using code from the Enrichr processing libraries

In [186]:
libname = 'network_gmt'
libdir = '/Users/andrewjenkinsvandusen/Downloads/mount_sinai_internship/tf_umap_visualization' # directory where library is

def get_scatter_library(libname, local, augmented):
    '''
    Processes the GMT file for the input Enrichr library {lib} and returns a
    dictionary where the keys correspond to gene set names, and the value for
    each key is a space-delimited string containing all genes belonging to
    the gene set:
    {
        "gene set name": "gene_1 gene_2 gene_3 ... gene_n",
        ...
    }
    In addition, this function can augment each gene set library using ARCHS4
    gene-gene co-expression data. For each gene set, the most co-expressed genes
    (determined by summing the coexpression coefficients across all genes)
    are added to the gene set before visualization.
    '''
    ### open local file or from Enrichr
    if local:
        print(f"\tOpening library locally from '{libdir}'...")
        with open(f"{libdir}/{libname}.gmt", 'r') as f:
            lines = f.readlines()

    ### variables to store gene set data
    lib_dict = OrderedDict()

    if augmented:
        print("\tProcessing gene sets and augmenting with ARCHS4...")
    else:
        print("\tProcessing gene sets without augmentation...")

    for line in lines:
        tokens = line.split("\t\t")
        term = tokens[0]
        genes = [x.split(',')[0].strip() for x in tokens[1].split('\t')]
        lib_dict[term] = ' '.join(genes)

    return lib_dict

def process_scatterplot(libdict, nneighbors=30, mindist=0.1, spread=1.0, maxdf=1.0, mindf=1):
    print("\tTF-IDF vectorizing gene set data...")
    vec = TfidfVectorizer(max_df=maxdf, min_df=mindf)
    X = vec.fit_transform(libdict.values())
    print(X.shape)
    adata = anndata.AnnData(X)
    adata.obs.index = libdict.keys()

    print("\tPerforming Leiden clustering...")
    ### the nneighbors and min_dist parameters can be altered
    sc.pp.neighbors(adata, n_neighbors=nneighbors)
    sc.tl.leiden(adata, resolution=1.0)
    sc.tl.umap(adata, min_dist=mindist, spread=spread, random_state=42)

    new_order = adata.obs.sort_values(by='leiden').index.tolist()
    adata = adata[new_order, :]
    adata.obs['leiden'] = 'Cluster ' + adata.obs['leiden'].astype('object')

    df = pd.DataFrame(adata.obsm['X_umap'])
    df.columns = ['x', 'y']

    df['cluster'] = adata.obs['leiden'].values
    df['term'] = adata.obs.index
    df['genes'] = [libdict[l] for l in df['term']]

    return df

# def get_scatter_colors(df):
#     clusters = pd.unique(df['cluster']).tolist()
#     colors = glasbey.create_palette(palette_size=len(clusters), lightness_bounds=(0,100), chroma_bounds=(50,100), as_hex=True)
#     color_mapper = {clusters[i]: colors[i % 20] for i in range(len(clusters))}
#     return color_mapper

# def get_scatter_colors(df):
#     clusters = pd.unique(df['cluster']).tolist()
#     gray = '#808080'  # standard medium gray
#     color_mapper = {cluster: gray for cluster in clusters}
#     return color_mapper

def get_scatter_colors(df):
    clusters = pd.unique(df['cluster']).tolist()
    n_clusters = len(clusters)
    gray_shades = [f'#{int(v):02x}{int(v):02x}{int(v):02x}' for v in np.linspace(50, 230, n_clusters)]
    color_mapper = {clusters[i]: gray_shades[i] for i in range(n_clusters)}
    return color_mapper

def blend_colors(color_hex, factor):
    """
    Outputs a slightly modified hex color given a hex color input.
    """
    rgb = mcolors.to_rgb(color_hex)
    adjusted = tuple(min(1, c * factor) for c in rgb)
    return mcolors.to_hex(adjusted)

def generate_df_for_comparison(base_df, tf_pair, comparison_label, comparison_idx):
    """
    Generates a new df for each time point.
    """
    up_tfs, down_tfs = tf_pair[0], tf_pair[1]
    df = base_df.copy()
    color_mapper = get_scatter_colors(df)
    df['color'] = df['cluster'].apply(lambda x: color_mapper[x])
    df['size'] = 6
    df['time_point'] = "Not enriched"

    for idx, term in df['term'].items():
        if (term in up_tfs) and (term not in down_tfs):
            # df.at[idx, 'color'] = blend_colors('#1f77b4', 1 + 0.2 * comparison_idx)
            df.at[idx, 'color'] = "#1595f0"
            df.at[idx, 'size'] = 12
            df.at[idx, 'time_point'] = comparison_label
        if (term in down_tfs) and (term not in up_tfs):
            # df.at[idx, 'color'] = blend_colors('#b41f29', 1 + 0.2 * comparison_idx)
            df.at[idx, 'color'] = "#f30a1a"
            df.at[idx, 'size'] = 12
            df.at[idx, 'time_point'] = comparison_label
        if (term in up_tfs) and (term in down_tfs):
            df.at[idx, 'color'] = "#26e411"
            df.at[idx, 'size'] = 12
            df.at[idx, 'time_point'] = comparison_label
    return df

from bokeh.plotting import figure
from bokeh.io.export import export_png
from bokeh.models import ColumnDataSource, HoverTool, Slider, CustomJS, Title, Label
from bokeh.layouts import column
from bokeh.palettes import Greys
from bokeh.io import show
from bokeh.plotting import output_file, save
import os
from PIL import Image

def get_scatterplot(scatterdf, tf_time_dict=None, comparisons=None, legend_description=None):
    """
    Generates images navigable via a slider, as well as all the images separately.
    """
    df = scatterdf.copy()
    df['cluster_number'] = df['cluster'].apply(lambda x: int(x.split(" ")[-1]))
    print(df['cluster_number'])
    df.sort_values(by=['cluster_number'], inplace=True)
    df.drop(columns = ['cluster_number'], inplace=True)

    sources = []
    for i, label in enumerate(comparisons):
        df_comp = generate_df_for_comparison(df, tf_time_dict[i], label, i)
        source = ColumnDataSource(data=dict(x = df_comp['x'], y = df_comp['y'],
                                            gene_set = df_comp['term'], colors = df_comp['color'],
                                            label = df_comp['cluster'], size = df_comp['size'],
                                            time_point = df_comp['time_point']))
        sources.append(source)

    source = sources[0]
    tooltips = [
        ("Gene Set", "@gene_set"),
        ("Cluster", "@label"),
        ("Time point", "@time_point")
    ]

    hover_emb = HoverTool(tooltips=tooltips)
    tools_emb = [hover_emb, 'pan', 'wheel_zoom', 'reset', 'save']

    plot_emb = figure(
        width=500*2,
        height=400*2,
        tools=tools_emb,
        output_backend='canvas'
    )

    plot_emb.scatter(
        'x',
        'y',
        size = 'size',
        source = source,
        marker='circle',
        fill_color = 'colors',
        color='colors',
        legend_group = 'label',
    )

    # hide axis labels and grid lines
    plot_title = Title(text=comparisons[0], align='center')
    plot_title.text_font_size = '20pt'
    plot_title.text_font_style = 'bold'
    plot_emb.add_layout(plot_title, 'above')

    plot_emb.xaxis.major_tick_line_color = None
    plot_emb.xaxis.minor_tick_line_color = None
    plot_emb.yaxis.major_tick_line_color = None
    plot_emb.yaxis.minor_tick_line_color = None
    plot_emb.grid.grid_line_color = None
    plot_emb.xaxis.major_label_text_font_size = '0pt'
    plot_emb.yaxis.major_label_text_font_size = '0pt'

    plot_emb.xaxis.axis_label = "UMAP-1"
    plot_emb.yaxis.axis_label = "UMAP-2"
    plot_emb.xaxis.axis_label_text_font_size = '20pt'
    plot_emb.yaxis.axis_label_text_font_size = '20pt'
    plot_emb.xaxis.axis_label_text_font_style = "normal"
    plot_emb.yaxis.axis_label_text_font_style = "normal"

    plot_emb.legend.label_text_font_size = '18pt'
    plot_emb.legend.glyph_height = 20
    plot_emb.legend.glyph_width = 20

    print("legend", plot_emb.legend[0])
    plot_emb.add_layout(plot_emb.legend[0], 'right')

    plot_emb.min_border_bottom = 168

    description_label = Label(x=0, y=-7, x_units='screen', y_units='screen',
                          text=legend_description,
                          text_font_size='12pt', text_align='left')

    plot_emb.add_layout(description_label, 'below')

    ### adding a slider ###
    slider = Slider(start=0, end=len(sources) - 1, value=0, step=1, title="Comparison")
    comparison_source = ColumnDataSource(data=dict(comparisons=[str(c) for c in comparisons]))
    callback = CustomJS(args=dict(source=source, slider=slider, sources=sources, plot=plot_emb,
                                  comparison_source=comparison_source, title_obj=plot_title), code="""
        const i = slider.value;
        const new_data = sources[i].data;
        const copied_data = {};
        for (const key in new_data) {
            copied_data[key] = [...new_data[key]];  // deep copy each column
        }
        source.data = copied_data;

        const comp_labels = comparison_source.data['comparisons'];
        title_obj.text = comp_labels[i];

        source.change.emit();
    """)
    slider.js_on_change('value', callback)
    # show(column(slider, plot_emb))

    # output_file("top_10_tfs_deseq2_adjacent_time_pts_umap_plot.html")
    output_file("top_10_tfs_deseq2_compare_w_time_pt_0_umap_plot.html")
    save(column(slider, plot_emb))

    ### for isolated individual time point images ###
    # frame_dir = "umap_png_frames_deseq2_adjacent_time_pts_top_10_tfs"
    frame_dir = "umap_png_frames_deseq2_compare_w_time_pt_0_top_10_tfs"
    os.makedirs(frame_dir, exist_ok=True)
    for i, label in enumerate(comparisons):
        source.data = dict(sources[i].data)
        plot_title.text = label
        export_png(plot_emb, filename=os.path.join(frame_dir, f"frame_{i:02d}_{label}.png"))

    return plot_emb, source

def create_umap_gif(frame_dir, gif_filename, duration):
    """
    Creates GIF given individual input images found in frame_dir.
    """
    frame_paths = sorted([os.path.join(frame_dir, f) for f in os.listdir(frame_dir) if f.endswith(".png")])
    images = [Image.open(frame) for frame in frame_paths]
    images[0].save(gif_filename, save_all=True, append_images=images[1:], duration=duration, loop=0)
    return "CREATED GIF"

In [None]:
# for i in range(5, 35, 5):
#     for j in [k * 0.05 for k in range(1, 11)]:

l_dict = get_scatter_library(libname, local=True, augmented=False)
print(f"Now processing {libname}") # print(f"Now processing {libname} with nneighbors = {i} and mindist = {j}")
## defaults: nneighbors=30, mindist=0.1, spread=1.0, maxdf=1.0, mindf=1
scatter_df = process_scatterplot(
    l_dict,
    nneighbors=20,
    mindist=0.15,
)
print(f"\tDone!")

# Display Scatter Plots
caption1 = f"**Figure 1. Scatterplot of all terms in the {libname} gene set library.** Each point represents a term in the library. \
    Term frequency-inverse document frequency (TF-IDF) values were computed for the gene set corresponding to each term, and UMAP was  \
    applied to the resulting values. The terms are plotted based on the first two UMAP dimensions. Generally, terms with more similar \
    gene sets are positioned closer together. Terms are colored by automatically identified clusters computed with the Leiden algorithm \
    applied to the TF-IDF values. Hovering over points will display the term and the automatically assigned cluster."

# deseq2, comparing adjacent time pts
# tf_time_dict_1 = {0: (['ATF3', 'KLF6', 'FOSB', 'JUN', 'SNAI1', 'NFIL3', 'FOS', 'NR4A3', 'FOSL1', 'EGR2'], ['NR4A1', 'BHLHE40', 'ZNF395', 'ATF3', 'FOSB', 'ZNF740', 'JUN', 'ZNF594', 'PRDM8', 'ZNF324']), 1: (['BHLHE40', 'ATF3', 'MYC', 'JUNB', 'FOSB', 'FOSL2', 'NFKB1', 'SNAI1', 'RELB', 'FOSL1'], ['BHLHE40', 'ATF3', 'CREBL2', 'KLF7', 'ZBED3', 'FOXA1', 'JUN', 'ZNF608', 'PPARG', 'NR2F2']), 2: (['HMGA2', 'TEAD1', 'SP3', 'PRDM4', 'ZBED4', 'TCF20', 'ZBTB38', 'FOXM1', 'UBP1', 'GLYR1'], ['MAFF', 'HMGN3', 'ATF3', 'FOSB', 'ZNF581', 'EGR1', 'JUN', 'GTF3A', 'ZNF207', 'ZNF511']), 3: (['STAT2', 'CREB3', 'JUN', 'SP100', 'BATF2', 'PPARG', 'TRAFD1', 'IRF1', 'IRF9', 'STAT3'], ['GATAD2A', 'E2F1', 'MYC', 'FOXK2', 'ZBED4', 'ZNF598', 'TCF3', 'SRCAP', 'FOXM1', 'FOSL1']), 4: (['TEAD1', 'STAT2', 'HIF1A', 'NFKB2', 'CREB3L2', 'NFKB1', 'STAT1', 'MAFK', 'ZNF697', 'STAT3'], ['MYC', 'ZNF239', 'ZNF146', 'TFDP1', 'PRMT3', 'CEBPG', 'ETV4', 'FOSL1', 'HMGA1', 'CEBPZ'])}

# deseq2, comparing to time pt 0
tf_time_dict_2 = {0: (['ATF3', 'KLF6', 'FOSB', 'JUN', 'SNAI1', 'NFIL3', 'FOS', 'NR4A3', 'FOSL1', 'EGR2'], ['NR4A1', 'BHLHE40', 'ZNF395', 'ATF3', 'FOSB', 'ZNF740', 'JUN', 'ZNF594', 'PRDM8', 'ZNF324']), 1: (['ATF3', 'MYC', 'JUNB', 'FOSB', 'FOSL2', 'JUN', 'SNAI1', 'NFIL3', 'NR4A3', 'FOSL1'], ['TEAD1', 'BHLHE40', 'CREBL2', 'ZNF436', 'TFCP2L1', 'TCF7L2', 'CREB3L2', 'JUN', 'KLF9', 'SOX13']), 2: (['ADNP2', 'PRDM4', 'FOXK2', 'ZBED4', 'HIVEP1', 'TCF20', 'BAZ2A', 'ZNF697', 'SRCAP', 'FOSL1'], ['CREB3', 'ELF3', 'CREB3L4', 'JUN', 'ZNF580', 'PPARG', 'NR1H3', 'KLF2', 'IRF9', 'ZNF524']), 3: (['HMGA2', 'ZNF267', 'HIVEP1', 'ZBTB11', 'NFKB2', 'MGA', 'ZNF134', 'NFKB1', 'RLF', 'RELB'], ['E2F1', 'E2F7', 'MBD3', 'THAP4', 'THAP7', 'ZNF837', 'ZNF580', 'TFDP1', 'ZNF511', 'HMGA1']), 4: (['TEAD1', 'ZNF407', 'HIVEP1', 'MGA', 'TCF20', 'SMAD3', 'RFX7', 'ASH1L', 'NFAT5', 'NCOA2'], ['E2F1', 'DRAP1', 'THAP4', 'THAP7', 'MYC', 'ZNF598', 'ZNF692', 'GTF3A', 'TFDP1', 'HMGA1'])}

# comparisons_1 = ["Hour 1 vs Hour 0", "Hour 3 vs Hour 1", "Hour 6 vs Hour 3", "Hour 12 vs Hour 6", "Hour 24 vs Hour 12"]
comparisons_2 = ["Hour 1 vs Hour 0", "Hour 3 vs Hour 0", "Hour 6 vs Hour 0", "Hour 12 vs Hour 0", "Hour 24 vs Hour 0"]

legend_description = ("Blue dots are TFs enriched for upregulated DEGs.\n"
    "Red dots are TFs enriched for downregulated DEGs.\n"
    "Green dots are TFs enriched for both up- and downregulated DEGs.\n"
    "This study looked at the response of a triple-negative breast cancer cell line to TRAIL.\n"
    "RNA-seq samples were taken at six time points (0, 1, 3, 6, 12, & 24 hours).")

# plot, source = get_scatterplot(scatter_df, tf_time_dict_1, comparisons_1, legend_description)
plot, source = get_scatterplot(scatter_df, tf_time_dict_2, comparisons_2, legend_description)
# print(plot)
# display(HTML(f"<div style='font-size:1.5rem;'>Scatter plot visualization for {libname}.</div>"))
# show(plot)

# create_umap_gif("umap_png_frames_deseq2_adjacent_time_pts_top_10_tfs", "top_10_tfs_deseq2_adjacent_time_pts_umap.gif", 1500)
create_umap_gif("umap_png_frames_deseq2_compare_w_time_pt_0_top_10_tfs", "top_10_tfs_deseq2_compare_w_time_pt_0_umap.gif", 1500)

	Opening library locally from '/Users/andrewjenkinsvandusen/Downloads/mount_sinai_internship/tf_umap_visualization'...
	Processing gene sets without augmentation...
Now processing network_gmt
	TF-IDF vectorizing gene set data...
(700, 1550)
	Performing Leiden clustering...


         Falling back to preprocessing with `sc.pp.pca` and default params.
  X = _choose_representation(self._adata, use_rep=use_rep, n_pcs=n_pcs)
  adata.obs['leiden'] = 'Cluster ' + adata.obs['leiden'].astype('object')


	Done!
0       0
1       0
2       0
3       0
4       0
       ..
695    15
696    15
697    15
698    15
699    15
Name: cluster_number, Length: 700, dtype: int64
legend Legend(id='p3945', ...)


'CREATED GIF'