# Visualizing the Enriched Transcription Factors from Time Series RNA-seq Experiments
### The Appyter below generates a regulatory subnetwork and a UMAP visualization of the enriched TFs at different time points from time series RNA-seq data. 

## Loading the necessary packages

In [1]:
import pandas as pd
import numpy as np
import json
import urllib.parse
from urllib.parse import quote
import requests
from IPython.display import display, HTML, Markdown
from IPython.display import Image as IPyImage, Javascript, display, FileLink
import os
import time

# time series TF subnetwork visualization
import ipycytoscape as ipc
from ipycytoscape import CytoscapeWidget, Node, Edge
import py4cytoscape as p4c
import networkx as nx
from dash import Dash, html

# enriched TF UMAP visualization
from collections import defaultdict
import tqdm
import random
from sklearn.feature_extraction.text import TfidfVectorizer
import scanpy as sc
import anndata
from collections import OrderedDict
from bokeh.io import output_notebook, export_png, export_svg
from bokeh.plotting import figure, show
from bokeh.models import HoverTool, ColumnDataSource
from bokeh.palettes import Category20
import glasbey
output_notebook()

from maayanlab_bioinformatics.dge import deseq2_differential_expression
from maayanlab_bioinformatics.dge import characteristic_direction
from maayanlab_bioinformatics.dge import up_down_from_characteristic_direction
from maayanlab_bioinformatics.dge import limma_voom_differential_expression
from maayanlab_bioinformatics.dge import up_down_from_limma_voom

## 1. Finding DEGs given a raw gene counts matrix

### Upload fields for raw gene counts matrix and experiment metadata.

In [2]:
ds1 = ''
ds1
ds2 = ''
ds2
compute_degs = False
if ds1 != '' and ds2 != '':
    compute_degs = True

### Method 1: PyDESeq2

In [3]:
if compute_degs:
    raw_counts = pd.read_csv(ds1)
    sample_metadata = pd.read_csv(ds2)

    time_pt_list = []
    for time_pt in sample_metadata["time_pt_annotation"].tolist():
        if time_pt not in time_pt_list:
            time_pt_list.append(time_pt)

    time_pt_dict = {}
    for i, time_pt in enumerate(time_pt_list):
        samples_at_time_pt = sample_metadata.loc[sample_metadata["time_pt_annotation"] == time_pt, "sample_name"].tolist()
        subset_counts = raw_counts[["gene_id"] + samples_at_time_pt]
        rev_subset_counts = subset_counts.set_index("gene_id")
        time_pt_dict[i] = (time_pt, rev_subset_counts)
    # print(time_pt_dict)

In [4]:
### [DESeq2] finding the DEGs from adjacent time pt comparisons
if compute_degs:
    adj_time_pt_comparisons = []
    adj_time_pt_degs = []
    for i in range(len(time_pt_list) - 1):
        controls, cases = time_pt_dict[i][1], time_pt_dict[i+1][1]
        results_df = deseq2_differential_expression(controls, cases)

        p_vals = [0.05, 0.01, 0.001, 0.0001, 0.00001]
        significant_genes = results_df[results_df["padj"] < p_vals[0]]
        up_count = (significant_genes["log2FoldChange"] > 0).sum()
        down_count = (significant_genes["log2FoldChange"] < 0).sum()

        idx = 1
        while (up_count + down_count) > 2000:
            significant_genes = results_df[results_df["padj"] < p_vals[idx]]
            up_count = (significant_genes["log2FoldChange"] > 0).sum()
            down_count = (significant_genes["log2FoldChange"] < 0).sum()
            idx += 1
        print("total DEGs:", (up_count + down_count), "up:", up_count, "down:", down_count, "padj:", p_vals[idx-1])
        print(significant_genes.head())

        ctrl_time_pt, case_time_pt = time_pt_dict[i][0], time_pt_dict[i+1][0]
        adj_time_pt_comparisons.append(f"{case_time_pt} v {ctrl_time_pt}")

        file = f"deseq2_{case_time_pt}_v_{ctrl_time_pt}.csv"
        significant_genes.to_csv(file)
        adj_time_pt_degs.append(file)

In [5]:
### [DESeq2] finding the DEGs from time pt 0 comparisons
if compute_degs:
    time_pt_0_comparisons = []
    time_pt_0_degs = []
    for i in range(1, len(time_pt_list)):
        controls, cases = time_pt_dict[0][1], time_pt_dict[i][1]
        results_df = deseq2_differential_expression(controls, cases)

        p_vals = [0.05, 0.01, 0.001, 0.0001, 0.00001]
        significant_genes = results_df[results_df["padj"] < p_vals[0]]
        up_count = (significant_genes["log2FoldChange"] > 0).sum()
        down_count = (significant_genes["log2FoldChange"] < 0).sum()

        idx = 1
        while (up_count + down_count) > 2000:
            significant_genes = results_df[results_df["padj"] < p_vals[idx]]
            up_count = (significant_genes["log2FoldChange"] > 0).sum()
            down_count = (significant_genes["log2FoldChange"] < 0).sum()
            idx += 1
        print("total DEGs:", (up_count + down_count), "up:", up_count, "down:", down_count, "padj:", p_vals[idx-1])
        print(significant_genes.head())

        ctrl_time_pt, case_time_pt = time_pt_dict[0][0], time_pt_dict[i][0]
        time_pt_0_comparisons.append(f"{case_time_pt} v {ctrl_time_pt}")

        file = f"deseq2_{case_time_pt}_v_{ctrl_time_pt}.csv"
        significant_genes.to_csv(file)
        time_pt_0_degs.append(file)

In [6]:
def up_gene_list(input_csv, filename=None):
    """
    Outputs list of upregulated DEGs from DESeq2 results CSV.
    """
    df = pd.read_csv(input_csv)
    row_filter = df["log2FoldChange"] > 0
    filtered = df.loc[row_filter, df.columns[0]]
    gene_ids = list(filtered)

    up_list = []
    for gene in gene_ids:
        if "_" in gene:
            up_list.append(gene.split("_", 1)[1])
        else:
            up_list.append(gene)

    print(len(up_list))
    return up_list


def down_gene_list(input_csv, filename=None):
    """
    Outputs list of downregulated DEGs from DESeq2 results CSV.
    """
    df = pd.read_csv(input_csv)
    row_filter = df["log2FoldChange"] < 0
    filtered = df.loc[row_filter, df.columns[0]]
    gene_ids = list(filtered)

    down_list = []
    for gene in gene_ids:
        if "_" in gene:
            down_list.append(gene.split("_", 1)[1])
        else:
            down_list.append(gene)

    print(len(down_list))
    return down_list


def csv_to_gmt(input_csv_list, comparisons, filename):
    gmt_dict = {}
    for i, file in enumerate(input_csv_list):
        up_genes = up_gene_list(file)
        print(len(up_genes))
        down_genes = down_gene_list(file)
        print(len(down_genes))
        gmt_dict[f"{comparisons[i]} up genes"] = up_genes
        gmt_dict[f"{comparisons[i]} down genes"] = down_genes

    with open(filename, "w") as file:
        for s,t in gmt_dict.items():
            file.write(str(s) + "\t\t" + "\t".join(t) + "\n")
    print("FINISHED CONVERTING TO GMT")
    return filename

if compute_degs:
    deseq2_degs_gmt_1 = csv_to_gmt(adj_time_pt_degs, adj_time_pt_comparisons, "appyter_deseq2_adj_time_pt_degs.gmt")
    deseq2_degs_gmt_2 = csv_to_gmt(time_pt_0_degs, time_pt_0_comparisons, "appyter_deseq2_compare_w_time_pt_0_degs.gmt")

### Method 2: Characteristic Direction

In [7]:
### [CD] finding the DEGs from adjacent time pt comparisons
def run_chea_kg(gene_list, num_tfs):
    """
    Outputs JSON of TF subnetwork best corresponding to input gene list.
    """
    CHEA_KG = 'https://chea-kg.maayanlab.cloud/api/enrichment'

    description = "insert description here"
    payload = {
        'list': (None, "\n".join(gene_list)),
        'description': (None, description)
    }
    response=requests.post(f"{CHEA_KG}/addList", files=payload)
    time.sleep(0.2)
    data = json.loads(response.text)

    q = {
        'min_lib': 3, # minimum number of libraries that a TF must be ranked in
        'libraries': [
            {'library': "Integrated--meanRank", 'term_limit': num_tfs} # edit term_limit to change number of top-ranked TFs
        ],
        'limit':50, # controls number of edges returned - may cause issues with visualization if too large
        'userListId': data['userListId']
    }
    query_json=json.dumps(q)
    res = requests.post(CHEA_KG, data=query_json)
    if res.ok:
        data = json.loads(res.text)
        return data
    else:
        data = None
        return res.text


def top_tfs(gene_list, num_tfs=5):
    """
    Returns a list of the top N most enriched TFs corresponding to an input gene list.
    """
    enriched_tfs = run_chea_kg(gene_list, num_tfs)
    tfs_list = []
    for node in enriched_tfs["nodes"]:
        tfs_list.append(node["data"]["label"])
    return tfs_list

if compute_degs:
    cd_adj_time_pt_comparisons = []
    cd_tf_time_dict_1 = {}
    for i in range(len(time_pt_list) - 1):
        ctrl_time_pt, case_time_pt = time_pt_dict[i][0], time_pt_dict[i+1][0]
        cd_adj_time_pt_comparisons.append(f"{case_time_pt} v {ctrl_time_pt}")

        controls, cases = time_pt_dict[i][1], time_pt_dict[i+1][1]
        results_df = characteristic_direction(controls, cases)
        # print(results_df.head())

        up_genes = up_down_from_characteristic_direction(results_df).up
        up_list = []
        for gene in up_genes:
            if "_" in gene:
                up_list.append(gene.split("_", 1)[1])
            else:
                up_list.append(gene)

        down_genes = up_down_from_characteristic_direction(results_df).down
        down_list = []
        for gene in down_genes:
            if "_" in gene:
                down_list.append(gene.split("_", 1)[1])
            else:
                down_list.append(gene)

        # print(up_list)
        # print(down_list)
        print(len(up_list), len(down_list))

        cd_tf_time_dict_1[i] = (top_tfs(up_list), top_tfs(down_list))
    print(cd_tf_time_dict_1)

In [8]:
### [CD] finding the DEGs from time pt 0 comparisons
if compute_degs:
    cd_time_pt_0_comparisons = []
    cd_tf_time_dict_2 = {}
    for i in range(1, len(time_pt_list)):
        ctrl_time_pt, case_time_pt = time_pt_dict[0][0], time_pt_dict[i][0]
        cd_time_pt_0_comparisons.append(f"{case_time_pt} v {ctrl_time_pt}")

        controls, cases = time_pt_dict[0][1], time_pt_dict[i][1]
        results_df = characteristic_direction(controls, cases)

        up_genes = up_down_from_characteristic_direction(results_df).up
        up_list = []
        for gene in up_genes:
            if "_" in gene:
                up_list.append(gene.split("_", 1)[1])
            else:
                up_list.append(gene)

        down_genes = up_down_from_characteristic_direction(results_df).down
        down_list = []
        for gene in down_genes:
            if "_" in gene:
                down_list.append(gene.split("_", 1)[1])
            else:
                down_list.append(gene)

        # print(up_list)
        # print(down_list)
        print(len(up_list), len(down_list))

        cd_tf_time_dict_2[i-1] = (top_tfs(up_list), top_tfs(down_list))
    print(cd_tf_time_dict_2)

## 2. Defining functions that output the enriched TFs given an input gene set

In [9]:
def run_chea_kg(gene_list, num_tfs):
    """
    Outputs JSON of TF subnetwork best corresponding to input gene list.
    """
    CHEA_KG = 'https://chea-kg.maayanlab.cloud/api/enrichment'

    description = "insert description here"
    payload = {
        'list': (None, "\n".join(gene_list)),
        'description': (None, description)
    }
    response=requests.post(f"{CHEA_KG}/addList", files=payload)
    time.sleep(0.2)
    data = json.loads(response.text)

    q = {
        'min_lib': 3, # minimum number of libraries that a TF must be ranked in
        'libraries': [
            {'library': "Integrated--meanRank", 'term_limit': num_tfs} # edit term_limit to change number of top-ranked TFs
        ],
        'limit':50, # controls number of edges returned - may cause issues with visualization if too large
        'userListId': data['userListId']
    }
    query_json=json.dumps(q)

    res = requests.post(CHEA_KG, data=query_json)
    if res.ok:
        data = json.loads(res.text)
        return data
    else:
        data = None
        return res.text


def top_tfs(gene_list, num_tfs=5):
    """
    Returns a list of the top N most enriched TFs corresponding to an input gene list.
    """
    enriched_tfs = run_chea_kg(gene_list, num_tfs)
    tfs_list = []
    for node in enriched_tfs["nodes"]:
        tfs_list.append(node["data"]["label"])
    return tfs_list

## 3. Defining subnetwork visualization functions

In [10]:
def fetch_chea_kg_data(start_tf, end_tf):
    """
    Outputs JSON data for shortest path connecting two TFs.
    """
    base_url = "https://chea-kg.maayanlab.cloud/api/knowledge_graph"

    query_filter = {
        "start": "Transcription Factor",
        "start_field": "label",
        "start_term": start_tf,
        "end": "Transcription Factor",
        "end_field": "label",
        "end_term": end_tf
    }

    encoded_filter = urllib.parse.quote(str(query_filter).replace("'", '"'))
    full_url = f"{base_url}?filter={encoded_filter}"

    response = requests.get(full_url)
    response.raise_for_status()
    return response.json()


def get_tf_node_info(tf_label):
    """
    Gets node information associated with single TF.
    """
    base_url = "https://chea-kg.maayanlab.cloud/api/knowledge_graph"

    query_filter = {
        "start": "Transcription Factor",
        "start_field": "label",
        "start_term": tf_label
    }
    encoded_filter = urllib.parse.quote(str(query_filter).replace("'", '"'))
    full_url = f"{base_url}?filter={encoded_filter}"

    response = requests.get(full_url)
    response.raise_for_status()
    data = response.json()

    for node in data["nodes"]:
        if node["data"]["label"] == tf_label:
            return node
    raise ValueError(f"Node for TF '{tf_label}' not found in response.")


def get_tf_edge_info(tf_label):
    """
    Returns edge info if TF is autoregulatory.
    """
    base_url = "https://chea-kg.maayanlab.cloud/api/knowledge_graph"

    query_filter = {
        "start": "Transcription Factor",
        "start_field": "label",
        "start_term": tf_label
    }
    encoded_filter = urllib.parse.quote(str(query_filter).replace("'", '"'))
    full_url = f"{base_url}?filter={encoded_filter}"

    response = requests.get(full_url)
    response.raise_for_status()
    data = response.json()

    for edge in data["edges"]:
        if edge["data"]["source_label"] == tf_label and edge["data"]["target_label"] == tf_label:
            return edge
    raise ValueError(f"Self-edge for TF '{tf_label}' not found in response.")


def create_tf_time_series_graph(tf_time_dict, num_comparisons, filename):
    """
    Creates network of TFs given differentially expressed genes at each time point.
    """
    subnetwork = {"nodes": [], "edges": []}
    for time in tf_time_dict.keys():
        up_tfs = tf_time_dict[time][0]
        down_tfs = tf_time_dict[time][1]

        for tf in up_tfs:
            node_info = get_tf_node_info(tf)
            node_info["data"]["color"] = "#80eaff"
            node_info["data"]["id"] = f"{node_info['data']['label']}_up_{time}"
            subnetwork["nodes"].append(node_info)

        for tf in down_tfs:
            node_info = get_tf_node_info(tf)
            node_info["data"]["color"] = "#ff8a80"
            node_info["data"]["id"] = f"{node_info['data']['label']}_down_{time}"
            subnetwork["nodes"].append(node_info)

    for i in range(num_comparisons-1):
        for j, source_tf_list in enumerate(tf_time_dict[i]):
            for k, target_tf_list in enumerate(tf_time_dict[i+1]):
                for source_tf in source_tf_list:
                    for target_tf in target_tf_list:
                        if source_tf != target_tf:
                            data = fetch_chea_kg_data(source_tf, target_tf)
                            if len(data["nodes"]) == 2 and len(data["edges"]) == 1:
                                if data["edges"][0]["data"]["source_label"] == source_tf and \
                                    data["edges"][0]["data"]["target_label"] == target_tf:
                                    if j == 0 and k == 0:
                                        data["edges"][0]["data"]["source"] = f"{source_tf}_up_{i}"
                                        data["edges"][0]["data"]["target"] = f"{target_tf}_up_{i+1}"
                                    elif j == 0 and k == 1:
                                        data["edges"][0]["data"]["source"] = f"{source_tf}_up_{i}"
                                        data["edges"][0]["data"]["target"] = f"{target_tf}_down_{i+1}"
                                    elif j == 1 and k == 0:
                                        data["edges"][0]["data"]["source"] = f"{source_tf}_down_{i}"
                                        data["edges"][0]["data"]["target"] = f"{target_tf}_up_{i+1}"
                                    elif j == 1 and k == 1:
                                        data["edges"][0]["data"]["source"] = f"{source_tf}_down_{i}"
                                        data["edges"][0]["data"]["target"] = f"{target_tf}_down_{i+1}"
                                    subnetwork["edges"].append(data["edges"][0])
                                    print("edge added, case 1")
                        else:
                            if len(source_tf_list) != 0:
                                try:
                                    edge = get_tf_edge_info(source_tf)
                                    if j == 0 and k == 0:
                                        edge["data"]["source"] = f"{source_tf}_up_{i}"
                                        edge["data"]["target"] = f"{target_tf}_up_{i+1}"
                                    elif j == 0 and k == 1:
                                        edge["data"]["source"] = f"{source_tf}_up_{i}"
                                        edge["data"]["target"] = f"{target_tf}_down_{i+1}"
                                    elif j == 1 and k == 0:
                                        edge["data"]["source"] = f"{source_tf}_down_{i}"
                                        edge["data"]["target"] = f"{target_tf}_up_{i+1}"
                                    elif j == 1 and k == 1:
                                        edge["data"]["source"] = f"{source_tf}_down_{i}"
                                        edge["data"]["target"] = f"{target_tf}_down_{i+1}"
                                    subnetwork["edges"].append(edge)
                                    print("edge added, case 2")
                                except ValueError:
                                    continue
    output_path = f"{filename}.json"
    with open(output_path, "w") as outfile:
        json.dump(subnetwork, outfile, indent=4)
    print("FINISHED")
    return output_path


def gmt_to_tf_time_dict(gmt_file):
    """
    Converts GMT file containing DEGs to tf_time_dict.
    """
    with open(gmt_file, 'r') as f:
        lines = f.readlines()

    temp_dict = {}
    for line in lines:
        tokens = line.split("\t\t")
        term = tokens[0]
        genes = [x.split(',')[0].strip() for x in tokens[1].split('\t')]
        temp_dict[term] = top_tfs(genes)
        print("enriched TFs found")

    comparisons = list(temp_dict.keys())
    new_comparisons = []
    for item in comparisons:
        comp = item.rsplit(' ', 2)[0]
        if comp not in new_comparisons:
            new_comparisons.append(comp)
    print("Time point comparisons:", new_comparisons)

    j = 0
    tf_time_dict = {}
    for i in range(len(comparisons) // 2):
        tf_time_dict[i] = (temp_dict[comparisons[j]], temp_dict[comparisons[j+1]])
        j += 2
    print(tf_time_dict)
    return tf_time_dict, new_comparisons

In [11]:
ds3 = 'First set of DEGs from TRAIL-treated TNBC cell line'
ds3
ds4 = 'Second set of DEGS from TRAIL-treated TNBC cell line'
ds4
umap_desc = '''TRAIL-treated TNBC cell line'''
umap_desc

'TRAIL-treated TNBC cell line'

In [12]:
if compute_degs:
    deseq2_tf_time_dict_1, deseq2_adj_time_pt_comparisons = gmt_to_tf_time_dict(deseq2_degs_gmt_1)
    deseq2_subnetwork_data_1 = create_tf_time_series_graph(deseq2_tf_time_dict_1, len(deseq2_adj_time_pt_comparisons), "appyter_deseq2_adj_time_pts_top_5_tfs")

    deseq2_tf_time_dict_2, deseq2_time_pt_0_comparisons = gmt_to_tf_time_dict(deseq2_degs_gmt_2)
    deseq2_subnetwork_data_2 = create_tf_time_series_graph(deseq2_tf_time_dict_2, len(deseq2_time_pt_0_comparisons), "appyter_deseq2_compare_w_time_pt_0_top_5_tfs")

    cd_subnetwork_data_1 = create_tf_time_series_graph(cd_tf_time_dict_1, len(cd_adj_time_pt_comparisons), "appyter_cd_adj_time_pts_top_5_tfs")
    cd_subnetwork_data_2 = create_tf_time_series_graph(cd_tf_time_dict_2, len(cd_time_pt_0_comparisons), "appyter_cd_compare_w_time_pt_0_top_5_tfs")

else:
    tf_time_dict_1, comparisons_1 = gmt_to_tf_time_dict(ds3)
    subnetwork_data_1 = create_tf_time_series_graph(tf_time_dict_1, len(comparisons_1), "appyter_deseq2_adj_time_pts_top_5_tfs")

    tf_time_dict_2, comparisons_2 = gmt_to_tf_time_dict(ds4)
    subnetwork_data_2 = create_tf_time_series_graph(tf_time_dict_2, len(comparisons_2), "appyter_deseq2_compare_w_time_pt_0_top_5_tfs")

enriched TFs found


enriched TFs found


enriched TFs found


enriched TFs found


enriched TFs found


enriched TFs found


enriched TFs found


enriched TFs found


enriched TFs found


enriched TFs found
Time point comparisons: ['Hour 1 vs Hour 0', 'Hour 3 vs Hour 1', 'Hour 6 vs Hour 3', 'Hour 12 vs Hour 6', 'Hour 24 vs Hour 12']
{0: (['KLF6', 'FOSB', 'JUN', 'FOS', 'NR4A3'], ['NR4A1', 'BHLHE40', 'ZNF395', 'FOSB', 'ZNF324']), 1: (['BHLHE40', 'ATF3', 'FOSB', 'SNAI1', 'RELB'], ['BHLHE40', 'ZBED3', 'JUN', 'PPARG', 'NR2F2']), 2: (['TEAD1', 'SP3', 'ZBTB38', 'FOXM1', 'UBP1'], ['MAFF', 'HMGN3', 'ATF3', 'GTF3A', 'ZNF511']), 3: (['STAT2', 'SP100', 'BATF2', 'TRAFD1', 'IRF9'], ['GATAD2A', 'E2F1', 'ZBED4', 'ZNF598', 'SRCAP']), 4: (['TEAD1', 'STAT2', 'NFKB2', 'CREB3L2', 'STAT1'], ['ZNF239', 'ZNF146', 'PRMT3', 'FOSL1', 'CEBPZ'])}


edge added, case 1


edge added, case 1
edge added, case 1


edge added, case 2


edge added, case 1
edge added, case 1


edge added, case 1
edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1
edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1
edge added, case 1


edge added, case 1
edge added, case 1


edge added, case 1


edge added, case 1
edge added, case 1


edge added, case 2


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1
edge added, case 1


edge added, case 1
edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 2


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 2


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1
edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1
edge added, case 1


FINISHED


enriched TFs found


enriched TFs found


enriched TFs found


enriched TFs found


enriched TFs found


enriched TFs found


enriched TFs found


enriched TFs found


enriched TFs found


enriched TFs found
Time point comparisons: ['Hour 1 vs Hour 0', 'Hour 3 vs Hour 0', 'Hour 6 vs Hour 0', 'Hour 12 vs Hour 0', 'Hour 24 vs Hour 0']
{0: (['KLF6', 'FOSB', 'JUN', 'FOS', 'NR4A3'], ['NR4A1', 'BHLHE40', 'ZNF395', 'FOSB', 'ZNF324']), 1: (['ATF3', 'FOSB', 'JUN', 'SNAI1', 'NR4A3'], ['BHLHE40', 'CREBL2', 'ZNF436', 'CREB3L2', 'JUN']), 2: (['ADNP2', 'PRDM4', 'ZBED4', 'TCF20', 'SRCAP'], ['ELF3', 'ZNF580', 'PPARG', 'IRF9', 'ZNF524']), 3: (['ZNF267', 'ZBTB11', 'NFKB2', 'MGA', 'RELB'], ['E2F1', 'THAP7', 'ZNF837', 'ZNF580', 'TFDP1']), 4: (['TEAD1', 'HIVEP1', 'MGA', 'ASH1L', 'NFAT5'], ['E2F1', 'DRAP1', 'THAP4', 'MYC', 'HMGA1'])}


edge added, case 1
edge added, case 2


edge added, case 1


edge added, case 1


edge added, case 2
edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1
edge added, case 1


edge added, case 1
edge added, case 2


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1
edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1
edge added, case 1


edge added, case 1
edge added, case 1


edge added, case 1
edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1
edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 1


edge added, case 2


edge added, case 1


edge added, case 1


edge added, case 1
edge added, case 1


edge added, case 2


edge added, case 1


edge added, case 1


edge added, case 2


edge added, case 1


edge added, case 1


edge added, case 1
edge added, case 1


edge added, case 1
edge added, case 1


edge added, case 1


edge added, case 1
FINISHED


## 4. Visualizing the enriched TFs within a time series regulatory subnetwork
### Blue and red nodes correspond to TFs enriched for upregulated and downregulated gene sets, respectively, at the given time point. 

In [13]:
def export_cytoscape_json(data, filename):
    with open(filename, "w") as f:
        json.dump(data, f, indent=2)
    display(FileLink(filename))


def visualize_network_2(data, filename):
    row_col_dict = {}
    for node in data["nodes"]:
        if "position" in node:
            continue
        node_id = node["data"]["id"]
        row_index = int(node_id.split("_")[-1])
        col_index = row_col_dict.get(row_index, 1)
        row_col_dict[row_index] = col_index + 1
        node["position"] = {"x": col_index * 150, "y": row_index * 150}

    cyto_widget = CytoscapeWidget()
    cyto_widget.graph.add_graph_from_json(data)
    cyto_widget.set_layout(name='preset')

    cyto_widget.set_style([{
        'selector': 'node',
        'style': {
            'label': 'data(label)',
            'background-color': 'data(color)',
            'border-width': 'data(borderWidth)',
            'border-color': 'data(borderColor)',
            'width': 40,
            'height': 40
        }
    }, {
        'selector': '.row-label',
        'style': {
            'label': 'data(label)',
            'background-color': '#ffffff',
            'color': '#000000',
            'font-size': '20px',
            'width': 5,
            'height': 5,
            'text-valign': 'center',
            'text-halign': 'center'
        }
    }, {
        'selector': 'edge',
        'style': {
            'width': 3,
            'line-color': 'data(lineColor)',
            'target-arrow-color': 'data(lineColor)',
            'target-arrow-shape': 'data(directed)',
            'curve-style': 'bezier',
            'arrow-scale': 1.5
        }
    }])

    display(cyto_widget)
    json_filename = f"{filename}.json"
    export_cytoscape_json(data, json_filename)


def run_visualization(input_json, comparisons, filename):
    if isinstance(input_json, dict):
        data = input_json
    else:
        with open(input_json, "r") as input_file:
            data = json.load(input_file)

    nodes = data["nodes"]
    for i, comparison in enumerate(comparisons):
        label_node = {
            'data': {
                'id': f'row_label_{i}',
                'label': comparison,
                'color': '#ffffff',
                'borderColor': '#ffffff',
                'borderWidth': 0
            },
            'classes': 'row-label',
            'position': {"x": 0, "y": i * 150}
        }
        nodes.append(label_node)

    visualize_network_2(data, filename)
    return "VISUALIZATION HAS RUN"

In [14]:
# if compute_degs:
#     print("deseq2_adj_time_pt_visualization:")
#     run_visualization(deseq2_subnetwork_data_1, deseq2_adj_time_pt_comparisons, "deseq2_adj_time_pt_visualization")
#     print("deseq2_compare_w_time_pt_0_visualization:")
#     run_visualization(deseq2_subnetwork_data_2, deseq2_time_pt_0_comparisons, "deseq2_compare_w_time_pt_0_visualization")

#     print("cd_adj_time_pt_visualization:")
#     run_visualization(cd_subnetwork_data_1, cd_adj_time_pt_comparisons, "cd_adj_time_pt_visualization")
#     print("cd_compare_w_time_pt_0_visualization:")
#     run_visualization(cd_subnetwork_data_2, cd_time_pt_0_comparisons, "cd_compare_w_time_pt_0_visualization")
# else:
#     input_json_1 = subnetwork_data_1
#     run_visualization(input_json_1, comparisons_1, "deseq2_adj_time_pt_visualization")

#     input_json_2 = subnetwork_data_2
#     run_visualization(input_json_2, comparisons_2, "deseq2_compare_w_time_pt_0_visualization")

In [15]:
if compute_degs:
    print("deseq2_adj_time_pt_visualization:")
    run_visualization(deseq2_subnetwork_data_1, deseq2_adj_time_pt_comparisons, "deseq2_adj_time_pt_visualization")

In [16]:
if compute_degs:
    print("deseq2_compare_w_time_pt_0_visualization:")
    run_visualization(deseq2_subnetwork_data_2, deseq2_time_pt_0_comparisons, "deseq2_compare_w_time_pt_0_visualization")

In [17]:
if compute_degs:
    print("cd_adj_time_pt_visualization:")
    run_visualization(cd_subnetwork_data_1, cd_adj_time_pt_comparisons, "cd_adj_time_pt_visualization")

In [18]:
if compute_degs:
    print("cd_compare_w_time_pt_0_visualization:")
    run_visualization(cd_subnetwork_data_2, cd_time_pt_0_comparisons, "cd_compare_w_time_pt_0_visualization")

In [19]:
if not compute_degs:
    input_json_1 = subnetwork_data_1
    run_visualization(input_json_1, comparisons_1, "deseq2_adj_time_pt_visualization")

    input_json_2 = subnetwork_data_2
    run_visualization(input_json_2, comparisons_2, "deseq2_compare_w_time_pt_0_visualization")

CytoscapeWidget(cytoscape_layout={'name': 'preset'}, cytoscape_style=[{'selector': 'node', 'style': {'label': …

## 5. Generating a UMAP visualization for enriched TFs

### The enriched transcription factors from each time point are colored on a UMAP plot of 700 TFs identified by ChEA to be "source TFs" (i.e. they exert regulatory effects on other TFs). The UMAP algorithm was performed using the TF-IDF scores of the TFs' target genes, meaning that TFs placed in the same cluster generally regulate similar genes. 

In [20]:
def process_scatterplot(libdict, nneighbors=30, mindist=0.1, spread=1.0, maxdf=1.0, mindf=1):
    print("\tTF-IDF vectorizing gene set data...")
    vec = TfidfVectorizer(max_df=maxdf, min_df=mindf)
    X = vec.fit_transform(libdict.values())
    print(X.shape)
    adata = anndata.AnnData(X)
    adata.obs.index = libdict.keys()

    print("\tPerforming Leiden clustering...")
    ### the nneighbors and min_dist parameters can be altered
    sc.pp.neighbors(adata, n_neighbors=nneighbors)
    sc.tl.leiden(adata, resolution=1.0)
    sc.tl.umap(adata, min_dist=mindist, spread=spread, random_state=42)

    new_order = adata.obs.sort_values(by='leiden').index.tolist()
    adata = adata[new_order, :]
    adata.obs['leiden'] = 'Cluster ' + adata.obs['leiden'].astype('object')

    df = pd.DataFrame(adata.obsm['X_umap'])
    df.columns = ['x', 'y']

    df['cluster'] = adata.obs['leiden'].values
    df['term'] = adata.obs.index
    df['genes'] = [libdict[l] for l in df['term']]

    return df


def get_scatter_colors(df):
    clusters = pd.unique(df['cluster']).tolist()
    n_clusters = len(clusters)
    gray_shades = [f'#{int(v):02x}{int(v):02x}{int(v):02x}' for v in np.linspace(50, 230, n_clusters)]
    color_mapper = {clusters[i]: gray_shades[i] for i in range(n_clusters)}
    return color_mapper


def generate_df_for_comparison(base_df, tf_pair, comparison_label, comparison_idx):
    """
    Generates a new df for each time point.
    """
    up_tfs, down_tfs = tf_pair[0], tf_pair[1]
    df = base_df.copy()
    color_mapper = get_scatter_colors(df)
    df['color'] = df['cluster'].apply(lambda x: color_mapper[x])
    df['size'] = 6
    df['time_point'] = "Not enriched"

    for idx, term in df['term'].items():
        if (term in up_tfs) and (term not in down_tfs):
            df.at[idx, 'color'] = "#1595f0"
            df.at[idx, 'size'] = 12
            df.at[idx, 'time_point'] = comparison_label
        if (term in down_tfs) and (term not in up_tfs):
            df.at[idx, 'color'] = "#f30a1a"
            df.at[idx, 'size'] = 12
            df.at[idx, 'time_point'] = comparison_label
        if (term in up_tfs) and (term in down_tfs):
            df.at[idx, 'color'] = "#26e411"
            df.at[idx, 'size'] = 12
            df.at[idx, 'time_point'] = comparison_label
    return df


from bokeh.plotting import figure
from bokeh.io.export import export_png
from bokeh.models import ColumnDataSource, HoverTool, Slider, CustomJS, Title, Label
from bokeh.layouts import column
from bokeh.palettes import Greys
from bokeh.io import show
from bokeh.plotting import output_file, save
import os
from PIL import Image


def get_scatterplot(scatterdf, tf_time_dict=None, comparisons=None, legend_description=None, image_dir=None, gif_filename=None):
    """
    Generates images navigable via a slider, as well as all the images separately.
    """
    df = scatterdf.copy()
    df['cluster_number'] = df['cluster'].apply(lambda x: int(x.split(" ")[-1]))
    print(df['cluster_number'])
    df.sort_values(by=['cluster_number'], inplace=True)
    df.drop(columns = ['cluster_number'], inplace=True)

    sources = []
    for i, label in enumerate(comparisons):
        df_comp = generate_df_for_comparison(df, tf_time_dict[i], label, i)
        source = ColumnDataSource(data=dict(x = df_comp['x'], y = df_comp['y'],
                                            gene_set = df_comp['term'], colors = df_comp['color'],
                                            label = df_comp['cluster'], size = df_comp['size'],
                                            time_point = df_comp['time_point']))
        sources.append(source)

    source = sources[0]
    tooltips = [
        ("Gene Set", "@gene_set"),
        ("Cluster", "@label"),
        ("Time point", "@time_point")
    ]

    hover_emb = HoverTool(tooltips=tooltips)
    tools_emb = [hover_emb, 'pan', 'wheel_zoom', 'reset', 'save']

    plot_emb = figure(
        width=500*2,
        height=400*2,
        tools=tools_emb,
        output_backend='canvas'
    )

    plot_emb.scatter(
        'x',
        'y',
        size = 'size',
        source = source,
        marker='circle',
        fill_color = 'colors',
        color='colors',
        legend_group = 'label',
    )

    # hide axis labels and grid lines
    plot_title = Title(text=comparisons[0], align='center')
    plot_title.text_font_size = '20pt'
    plot_title.text_font_style = 'bold'
    plot_emb.add_layout(plot_title, 'above')

    plot_emb.xaxis.major_tick_line_color = None
    plot_emb.xaxis.minor_tick_line_color = None
    plot_emb.yaxis.major_tick_line_color = None
    plot_emb.yaxis.minor_tick_line_color = None
    plot_emb.grid.grid_line_color = None
    plot_emb.xaxis.major_label_text_font_size = '0pt'
    plot_emb.yaxis.major_label_text_font_size = '0pt'

    plot_emb.xaxis.axis_label = "UMAP-1"
    plot_emb.yaxis.axis_label = "UMAP-2"
    plot_emb.xaxis.axis_label_text_font_size = '20pt'
    plot_emb.yaxis.axis_label_text_font_size = '20pt'
    plot_emb.xaxis.axis_label_text_font_style = "normal"
    plot_emb.yaxis.axis_label_text_font_style = "normal"

    plot_emb.legend.label_text_font_size = '18pt'
    plot_emb.legend.glyph_height = 20
    plot_emb.legend.glyph_width = 20

    print("legend", plot_emb.legend[0])
    plot_emb.add_layout(plot_emb.legend[0], 'right')

    plot_emb.min_border_bottom = 168

    description_label = Label(x=0, y=-7, x_units='screen', y_units='screen',
                          text=legend_description,
                          text_font_size='12pt', text_align='left')

    plot_emb.add_layout(description_label, 'below')

    ### adding a slider ###
    slider = Slider(start=0, end=len(sources) - 1, value=0, step=1, title="Comparison")
    comparison_source = ColumnDataSource(data=dict(comparisons=[str(c) for c in comparisons]))
    callback = CustomJS(args=dict(source=source, slider=slider, sources=sources, plot=plot_emb,
                                  comparison_source=comparison_source, title_obj=plot_title), code="""
        const i = slider.value;
        const new_data = sources[i].data;
        const copied_data = {};
        for (const key in new_data) {
            copied_data[key] = [...new_data[key]];  // deep copy each column
        }
        source.data = copied_data;

        const comp_labels = comparison_source.data['comparisons'];
        title_obj.text = comp_labels[i];

        source.change.emit();
    """)
    slider.js_on_change('value', callback)
    show(column(slider, plot_emb))

    ### can either show or save the plot (cannot do both)
    # output_file("top_10_tfs_deseq2_adjacent_time_pts_umap_plot.html")
    # save(column(slider, plot_emb))

    ### for isolated individual time point images ###
    frame_dir = os.path.join(os.getcwd(), image_dir)
    os.makedirs(frame_dir, exist_ok=True)
    for i, label in enumerate(comparisons):
        source.data = dict(sources[i].data)
        plot_title.text = label
        export_png(plot_emb, filename=os.path.join(frame_dir, f"frame_{i:02d}_{label}.png"))

    frame_paths = sorted([os.path.join(frame_dir, f) for f in os.listdir(frame_dir) if f.endswith(".png")])
    images = [Image.open(frame) for frame in frame_paths]
    images[0].save(gif_filename, save_all=True, append_images=images[1:], duration=1500, loop=0)
    # display(IPyImage(filename=gif_filename))
    display(FileLink(gif_filename, result_html_prefix="Click here to download: "))
    print("CREATED GIF")

    return plot_emb, source

In [21]:
r = requests.get("https://minio.dev.maayanlab.cloud/hgrn-chear/network_target_sets.gmt")
file = r.text.split("\n")

lib_dict = OrderedDict()
for line in file[:-1]:
    tokens = line.split("\t\t")
    term = tokens[0]
    genes = [x.split(',')[0].strip() for x in tokens[1].split('\t')]
    lib_dict[term] = ' '.join(genes)

## defaults: nneighbors=30, mindist=0.1, spread=1.0, maxdf=1.0, mindf=1
scatter_df = process_scatterplot(
    lib_dict,
    nneighbors=20,
    mindist=0.15,
)

legend_description = ("Blue dots are TFs enriched for upregulated DEGs.\n"
    "Red dots are TFs enriched for downregulated DEGs.\n"
    "Green dots are TFs enriched for both up- and downregulated DEGs.\n")
legend_description += umap_desc

if compute_degs:
    deseq2_plot_emb_1, deseq2_source_1 = get_scatterplot(scatter_df, deseq2_tf_time_dict_1, deseq2_adj_time_pt_comparisons, legend_description, "umap_png_frames_deseq2_adjacent_time_pts_top_5_tfs", "top_5_tfs_deseq2_adjacent_time_pts_umap.gif")
    deseq2_plot_emb_2, deseq2_source_2 = get_scatterplot(scatter_df, deseq2_tf_time_dict_2, deseq2_time_pt_0_comparisons, legend_description, "umap_png_frames_deseq2_compare_w_time_pt_0_top_5_tfs", "top_5_tfs_deseq2_compare_w_time_pt_0_umap.gif")

    cd_plot_emb_1, cd_source_1 = get_scatterplot(scatter_df, cd_tf_time_dict_1, cd_adj_time_pt_comparisons, legend_description, "umap_png_frames_cd_adjacent_time_pts_top_5_tfs", "top_5_tfs_cd_adjacent_time_pts_umap.gif")
    cd_plot_emb_2, cd_source_2 = get_scatterplot(scatter_df, cd_tf_time_dict_2, cd_time_pt_0_comparisons, legend_description, "umap_png_frames_cd_compare_w_time_pt_0_top_5_tfs", "top_5_tfs_cd_compare_w_time_pt_0_umap.gif")
else:
    plot_emb_1, source_1 = get_scatterplot(scatter_df, tf_time_dict_1, comparisons_1, legend_description, "umap_png_frames_deseq2_adjacent_time_pts_top_5_tfs", "top_5_tfs_deseq2_adjacent_time_pts_umap.gif")
    plot_emb_2, source_2 = get_scatterplot(scatter_df, tf_time_dict_2, comparisons_2, legend_description, "umap_png_frames_deseq2_compare_w_time_pt_0_top_5_tfs", "top_5_tfs_deseq2_compare_w_time_pt_0_umap.gif")

	TF-IDF vectorizing gene set data...
(699, 1550)
	Performing Leiden clustering...


         Falling back to preprocessing with `sc.pp.pca` and default params.
  X = _choose_representation(self._adata, use_rep=use_rep, n_pcs=n_pcs)


0       0
1       0
2       0
3       0
4       0
       ..
694    15
695    15
696    15
697    15
698    15
Name: cluster_number, Length: 699, dtype: int64
legend Legend(id='p1053', ...)


  adata.obs['leiden'] = 'Cluster ' + adata.obs['leiden'].astype('object')


CREATED GIF
0       0
1       0
2       0
3       0
4       0
       ..
694    15
695    15
696    15
697    15
698    15
Name: cluster_number, Length: 699, dtype: int64
legend Legend(id='p1158', ...)


CREATED GIF
