In [None]:
#%%appyter init
from appyter import magic
magic.init(lambda _=globals: _())

In [None]:
%%appyter hide_code

{% do SectionField(
    name= 'appyter_intro',
    title= 'ChEA-KG-TS Appyter',
    img= 'time series.png'
) %}

{% do DescriptionField(
    name= 'appyter_description',
    text= 'The ChEA-KG Time Series Appyter visualizes the enriched transcription factor landscape from time series RNA-seq data, helping to inform hypothesis about the regulatory mechanisms governing biological processes.',
    section= 'appyter_intro'
) %}

{% do SectionField(
    name= 'title_section',
    title= 'Provide a title for your study.'
) %}

In [None]:
%%appyter code_eval

{% set study_title = TextField(
        name= 'user_inputted_study_title',
        label= 'Title',
        default= '',
        section = 'title_section',
) %}

study_title = {{study_title}}

# <u>ChEA-KG-TS Appyter Report:</u> Temporal Transitions of Transcription Factor Regulatory Modules

Given time series RNA-seq data from either humans or mice, this Appyter first determines which transcription factors are enriched for the differentially expressed genes (DEGs) found at each time point. DEGs are computed using [DESeq2](https://genomebiology.biomedcentral.com/articles/10.1186/s13059-014-0550-8) (Love et al., 2014) or [Characteristic Direction](https://bmcbioinformatics.biomedcentral.com/articles/10.1186/1471-2105-15-79) (Clark et al., 2014) by comparing gene expression at adjacent time points (*t0 vs t1,  t1 vs t2,  t2 vs t3,  ... , tn-1 vs tn*) or comparing to the initial time point (*t0 vs t1,  t0 vs t2,  t0 vs t3,  ... , t0 vs tn*). Thus, there are 4 different options for computing DEGs. Next, the enriched transcription factors for each set of DEGs are determined using [ChEA3](https://maayanlab.cloud/chea3/) (Keenan et al., 2019). 

Afterwards, the Appyter plots a regulatory subnetwork using [ChEA-KG](https://chea-kg.maayanlab.cloud/) (project led by Anna Byrd) of the enriched TFs at each time point, enabling users to determine how the enriched TFs at one time point regulate the TFs at the subsequent time point. A UMAP plot of the enriched TFs is also generated, providing users with visualization of enriched TFs that act in modules based on their shared target genes. 

If users have pre-computed the up- and downregulated DEGs at each time point, they can opt to upload a GMT file of those DEGs to *directly* visualize the enriched TFs. 

In [None]:
from IPython.display import display, Markdown
display(Markdown(f"# __Study title:__ {study_title}"))

## Loading the necessary packages

In [None]:
# general
import pandas as pd
import numpy as np
import json
import requests
import urllib
from IPython.display import display, FileLink, HTML, Markdown
import time
import os

# differential gene expression
from maayanlab_bioinformatics.dge import deseq2_differential_expression
from maayanlab_bioinformatics.dge import characteristic_direction
from maayanlab_bioinformatics.dge import up_down_from_characteristic_direction

# time series enriched TF subnetwork visualization
from ipycytoscape import CytoscapeWidget
from dash import html

# bar chart
import plotly.graph_objects as go
import plotly.io as pio
import uuid
import html
from pathlib import Path

# enriched TF UMAP visualization
from sklearn.feature_extraction.text import TfidfVectorizer
import scanpy as sc
import anndata
from copy import deepcopy
from collections import OrderedDict
from bokeh.io import output_notebook, show
from bokeh.io.export import export_png
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, HoverTool, Slider, CustomJS, Title, Label
from bokeh.layouts import column
from PIL import Image
output_notebook()

## Step 1. Computing DEGs for the raw gene counts matrix

DEGs are computed using [DESeq2](https://genomebiology.biomedcentral.com/articles/10.1186/s13059-014-0550-8) (Love et al., 2014) or [Characteristic Direction](https://bmcbioinformatics.biomedcentral.com/articles/10.1186/1471-2105-15-79) (Clark et al., 2014) by comparing gene expression at adjacent time points (*t0 vs t1,  t1 vs t2,  t2 vs t3,  ... , tn-1 vs tn*) or comparing to the initial time point (*t0 vs t1,  t0 vs t2,  t0 vs t3, ... , t0 vs tn*). All in all, there are 4 different methods of computing DEGs.

Upload fields are for (1) a raw gene counts matrix and (2) experiment metadata from a time series RNA-seq experiment. 

In [None]:
%%appyter hide_code

{% do SectionField(
    name= 'section0',
    title= 'Upload files to this section ONLY if you are uploading raw time series RNA-seq data.',
    img= 'matrix.png'
) %}

{% do SectionField(
    name= 'section1',
    title= '1. Upload raw gene counts matrix as a CSV file.',
    img= 'mRNA.png'
) %}

{% do SectionField(
    name= 'section2',
    title= '2. Upload experiment metadata as a CSV file.',
    img= 'experiment.png'
) %}

{% do SectionField(
    name= 'section3',
    title= '3. Choose whether you would like to calculate differentially expressed genes using DESeq2, CD, or both.'
) %}

{% do SectionField(
    name= 'section4',
    title= '4. Choose the number of top enriched TFs to include in the visualizations.'
) %}

{% do SectionField(
    name= 'section5',
    title= '5. Provide a short description of your study (optional).'
) %}

In [None]:
%%appyter code_eval

{% do DescriptionField(
    name= 'matrix_description',
    text= 'Ensure that the first column is titled "gene_name" or "gene_id" and contains the gene names. The example file has {id}_{name}, which is also an acceptable notation.',
    section= 'section1'
) %}

{% set dataset1 = FileField(
    name= 'dataset1',
    label= 'Raw gene counts matrix',
    default= '',
    examples= {'Raw gene counts from TRAIL-treated TNBC cell line': 'https://minio.dev.maayanlab.cloud/chea-kg-timeseries/GSE271120_RawCountFile_rsemgenes.CCBR1062.csv'},
    section= 'section1'
) %}

ds1 = {{dataset1}}

{% do DescriptionField(
    name= 'metadata_description',
    text= 'Ensure that the columns are labeled exactly as "sample_name" and "time_pt_annotation". See example for how to format the experiment metadata file.',
    section= 'section2'
) %}

{% set dataset2 = FileField(
    name= 'dataset2',
    label= 'Experiment metadata',
    default= '',
    examples= {'Samples taken from TRAIL-treated TNBC cell line': 'https://minio.dev.maayanlab.cloud/chea-kg-timeseries/sample_metadata.csv'},
    section= 'section2'
) %}

ds2 = {{dataset2}}

{% set chosen_method = ChoiceField(
    name= 'chosen_dge_method',
    label= 'Differential gene expression method(s)',
    default= 'DESeq2',
    choices= {'DESeq2': '1', 'Characteristic Direction': '2', 'Both': '3'},
    section= 'section3'
) %}

dge_method = {{chosen_method}}
dge_method = str(dge_method)

{% set chosen_num = ChoiceField(
    name= 'chosen_num_tfs',
    label= 'Number of top enriched TFs',
    default= '5',
    choices= {'5': '1', '10': '2'},
    section= 'section4'
) %}

num_tfs = {{chosen_num}}
num_tfs = str(num_tfs)

{% set user_description = TextField(
        name= 'user_description',
        label= 'Description',
        default= '',
        section = 'section5',
) %}

umap_desc = {{user_description}}

compute_degs = False
if ds1 != '' and ds2 != '':
    compute_degs = True

if num_tfs == '1':
    num_tfs = 5
if num_tfs == '2':
    num_tfs = 10

### Differential Expression Method Option 1: DESeq2
[PyDESeq2](https://pydeseq2.readthedocs.io/en/stable/), the Python implementation of DESeq2 (Love et al., 2014), is used to compute the DEGs of time series RNA-seq data given a user-inputted raw gene counts matrix and sample metadata file describing the time points. 

DEGs are computed by comparing gene expression at adjacent time points (*t0 vs t1,  t1 vs t2,  t2 vs t3,  ... , tn-1 vs tn*) or with the initial time point (*t0 vs t1,  t0 vs t2,  t0 vs t3,  ... , t0 vs tn*).

In [None]:
if compute_degs:
    raw_counts = pd.read_csv(ds1)
    sample_metadata = pd.read_csv(ds2)

    time_pt_list = []
    for time_pt in sample_metadata["time_pt_annotation"].tolist():
        if time_pt not in time_pt_list:
            time_pt_list.append(time_pt)

    time_pt_dict = {}
    for i, time_pt in enumerate(time_pt_list):
        samples_at_time_pt = sample_metadata.loc[sample_metadata["time_pt_annotation"] == time_pt, "sample_name"].astype(str).tolist()
        if raw_counts.columns[0] == "gene_id":
            subset_counts = raw_counts[["gene_id"] + samples_at_time_pt]
            rev_subset_counts = subset_counts.set_index("gene_id")
            time_pt_dict[i] = (time_pt, rev_subset_counts)
        elif raw_counts.columns[0] == "gene_name":
            subset_counts = raw_counts[["gene_name"] + samples_at_time_pt]
            rev_subset_counts = subset_counts.set_index("gene_name")
            time_pt_dict[i] = (time_pt, rev_subset_counts)

In [None]:
### [DESeq2] finding the DEGs from adjacent time pt comparisons
if compute_degs and (dge_method == '1' or dge_method == '3'):
    adj_time_pt_comparisons = []
    adj_time_pt_degs = []
    for i in range(len(time_pt_list) - 1):
        controls, cases = time_pt_dict[i][1], time_pt_dict[i+1][1]
        results_df = deseq2_differential_expression(controls, cases)

        p_vals = [0.05] + [10**(-i) for i in range(2, 21, 1)]
        significant_genes = results_df[results_df["padj"] < p_vals[0]]
        up_count = (significant_genes["log2FoldChange"] > 0).sum()
        down_count = (significant_genes["log2FoldChange"] < 0).sum()

        idx = 1
        while (up_count + down_count) > 2000:
            significant_genes = results_df[results_df["padj"] < p_vals[idx]]
            up_count = (significant_genes["log2FoldChange"] > 0).sum()
            down_count = (significant_genes["log2FoldChange"] < 0).sum()
            idx += 1

        ctrl_time_pt, case_time_pt = time_pt_dict[i][0], time_pt_dict[i+1][0]
        adj_time_pt_comparisons.append(f"{ctrl_time_pt} v {case_time_pt}")

        file = f"deseq2_{case_time_pt}_v_{ctrl_time_pt}.csv"
        significant_genes.to_csv(file)
        adj_time_pt_degs.append(file)

        print(f"time point comparison: {ctrl_time_pt}_v_{case_time_pt}")
        print(f"total DEGs: {up_count + down_count}, up genes: {up_count}, down genes: {down_count}, padj: {p_vals[idx-1]}")
        print()

In [None]:
### [DESeq2] finding the DEGs from time pt 0 comparisons
if compute_degs and (dge_method == '1' or dge_method == '3'):
    time_pt_0_comparisons = []
    time_pt_0_degs = []
    for i in range(1, len(time_pt_list)):
        controls, cases = time_pt_dict[0][1], time_pt_dict[i][1]
        results_df = deseq2_differential_expression(controls, cases)

        p_vals = [0.05] + [10**(-i) for i in range(2, 21, 1)]
        significant_genes = results_df[results_df["padj"] < p_vals[0]]
        up_count = (significant_genes["log2FoldChange"] > 0).sum()
        down_count = (significant_genes["log2FoldChange"] < 0).sum()

        idx = 1
        while (up_count + down_count) > 2000:
            significant_genes = results_df[results_df["padj"] < p_vals[idx]]
            up_count = (significant_genes["log2FoldChange"] > 0).sum()
            down_count = (significant_genes["log2FoldChange"] < 0).sum()
            idx += 1

        ctrl_time_pt, case_time_pt = time_pt_dict[0][0], time_pt_dict[i][0]
        time_pt_0_comparisons.append(f"{ctrl_time_pt} v {case_time_pt}")

        file = f"deseq2_{case_time_pt}_v_{ctrl_time_pt}.csv"
        significant_genes.to_csv(file)
        time_pt_0_degs.append(file)

        print(f"time point comparison: {ctrl_time_pt}_v_{case_time_pt}")
        print(f"total DEGs: {up_count + down_count}, up genes: {up_count}, down genes: {down_count}, padj: {p_vals[idx-1]}")
        print()

In [None]:
def up_gene_list(input_csv, filename=None):
    """
    Outputs list of upregulated DEGs from DESeq2 results CSV.
    """
    df = pd.read_csv(input_csv)
    row_filter = df["log2FoldChange"] > 0
    filtered = df.loc[row_filter, df.columns[0]]
    gene_ids = list(filtered)

    up_list = []
    for gene in gene_ids:
        if "_" in gene:
            up_list.append(gene.split("_", 1)[1])
        else:
            up_list.append(gene)

    return up_list


def down_gene_list(input_csv, filename=None):
    """
    Outputs list of downregulated DEGs from DESeq2 results CSV.
    """
    df = pd.read_csv(input_csv)
    row_filter = df["log2FoldChange"] < 0
    filtered = df.loc[row_filter, df.columns[0]]
    gene_ids = list(filtered)

    down_list = []
    for gene in gene_ids:
        if "_" in gene:
            down_list.append(gene.split("_", 1)[1])
        else:
            down_list.append(gene)

    return down_list


def csv_to_gmt(input_csv_list, comparisons, filename):
    gmt_dict = {}
    for i, file in enumerate(input_csv_list):
        up_genes = up_gene_list(file)
        down_genes = down_gene_list(file)
        print(f"up genes: {len(up_genes)}, down genes: {len(down_genes)}")
        gmt_dict[f"{comparisons[i]} up genes"] = up_genes
        gmt_dict[f"{comparisons[i]} down genes"] = down_genes

    with open(filename, "w") as file:
        for s,t in gmt_dict.items():
            file.write(str(s) + "\t\t" + "\t".join(t) + "\n")
    print("Finished creating the GMT file containing the PyDESeq2 DEGs at each time point.")
    return filename

if compute_degs and (dge_method == '1' or dge_method == '3'):
    deseq2_degs_gmt_1 = csv_to_gmt(adj_time_pt_degs, adj_time_pt_comparisons, "appyter_deseq2_adj_time_pt_degs.gmt")
    deseq2_degs_gmt_2 = csv_to_gmt(time_pt_0_degs, time_pt_0_comparisons, "appyter_deseq2_compare_w_time_pt_0_degs.gmt")

### Differential Expression Method Option 2: Characteristic Direction

The [Characteristic Direction](https://bmcbioinformatics.biomedcentral.com/articles/10.1186/1471-2105-15-79) method (Clark et al., 2014) is used to compute the DEGs of time series RNA-seq data given a user-inputted raw gene counts matrix and sample metadata file describing the time points. 

DEGs are computed by comparing gene expression at adjacent time points (*t0 vs t1,  t1 vs t2,  t2 vs t3,  ... , tn-1 vs tn*) or with the initial time point (*t0 vs t1,  t0 vs t2,  t0 vs t3,  ... , t0 vs tn*).

In [None]:
### [CD] finding the DEGs from adjacent time pt comparisons
def run_chea_kg(gene_list, num_tfs):
    """
    Outputs JSON of TF subnetwork best corresponding to input gene list.
    """
    CHEA_KG = 'https://chea-kg.maayanlab.cloud/api/enrichment'

    description = "insert description here"
    payload = {
        'list': (None, "\n".join(gene_list)),
        'description': (None, description)
    }
    response=requests.post(f"{CHEA_KG}/addList", files=payload)
    time.sleep(0.2)
    data = json.loads(response.text)

    q = {
        'min_lib': 3, # minimum number of libraries that a TF must be ranked in
        'libraries': [
            {'library': "Integrated--meanRank", 'term_limit': num_tfs} # edit term_limit to change number of top-ranked TFs
        ],
        'limit':50, # controls number of edges returned - may cause issues with visualization if too large
        'userListId': data['userListId']
    }
    query_json=json.dumps(q)
    res = requests.post(CHEA_KG, data=query_json)
    if res.ok:
        data = json.loads(res.text)
        return data
    else:
        data = None
        return res.text


def top_tfs(gene_list, num_tfs=5):
    """
    Returns a list of the top N most enriched TFs corresponding to an input gene list.
    """
    enriched_tfs = run_chea_kg(gene_list, num_tfs)
    tfs_list = []
    for node in enriched_tfs["nodes"]:
        tfs_list.append(node["data"]["label"])
    return tfs_list


if compute_degs and (dge_method == '2' or dge_method == '3'):
    cd_adj_time_pt_comparisons = []
    cd_degs_1 = []
    cd_tf_time_dict_1 = {}
    for i in range(len(time_pt_list) - 1):
        ctrl_time_pt, case_time_pt = time_pt_dict[i][0], time_pt_dict[i+1][0]
        cd_adj_time_pt_comparisons.append(f"{ctrl_time_pt} v {case_time_pt}")

        controls, cases = time_pt_dict[i][1], time_pt_dict[i+1][1]
        results_df = characteristic_direction(controls, cases)

        up_genes = up_down_from_characteristic_direction(results_df).up
        up_list = []
        for gene in up_genes:
            if "_" in gene:
                up_list.append(gene.split("_", 1)[1])
            else:
                up_list.append(gene)

        down_genes = up_down_from_characteristic_direction(results_df).down
        down_list = []
        for gene in down_genes:
            if "_" in gene:
                down_list.append(gene.split("_", 1)[1])
            else:
                down_list.append(gene)

        cd_degs_1.append((up_list, down_list))
        cd_tf_time_dict_1[i] = (top_tfs(up_list, num_tfs), top_tfs(down_list, num_tfs))

        print(f"time point comparison: {ctrl_time_pt}_v_{case_time_pt}")
        print(f"total DEGs: {len(up_list) + len(down_list)}, up genes: {len(up_list)}, down genes: {len(down_list)}")
        print()

In [None]:
### [CD] finding the DEGs from time pt 0 comparisons
if compute_degs and (dge_method == '2' or dge_method == '3'):
    cd_time_pt_0_comparisons = []
    cd_degs_2 = []
    cd_tf_time_dict_2 = {}
    for i in range(1, len(time_pt_list)):
        ctrl_time_pt, case_time_pt = time_pt_dict[0][0], time_pt_dict[i][0]
        cd_time_pt_0_comparisons.append(f"{ctrl_time_pt} v {case_time_pt}")

        controls, cases = time_pt_dict[0][1], time_pt_dict[i][1]
        results_df = characteristic_direction(controls, cases)

        up_genes = up_down_from_characteristic_direction(results_df).up
        up_list = []
        for gene in up_genes:
            if "_" in gene:
                up_list.append(gene.split("_", 1)[1])
            else:
                up_list.append(gene)

        down_genes = up_down_from_characteristic_direction(results_df).down
        down_list = []
        for gene in down_genes:
            if "_" in gene:
                down_list.append(gene.split("_", 1)[1])
            else:
                down_list.append(gene)

        cd_degs_2.append((up_list, down_list))
        cd_tf_time_dict_2[i-1] = (top_tfs(up_list, num_tfs), top_tfs(down_list, num_tfs))

        print(f"time point comparison: {ctrl_time_pt}_v_{case_time_pt}")
        print(f"total DEGs: {len(up_list) + len(down_list)}, up genes: {len(up_list)}, down genes: {len(down_list)}")
        print()

## Step 2. Performing TF enrichment analysis

The DEGs computed in Step 1 are submitted to [ChEA3](https://maayanlab.cloud/chea3/) (Keenan et al., 2019) for TF enrichment analysis. The output from this step is a set of enriched TFs for each up and down gene set at each time point. 

In [None]:
def run_chea_kg(gene_list, num_tfs):
    """
    Outputs JSON of TF subnetwork best corresponding to input gene list.
    """
    CHEA_KG = 'https://chea-kg.maayanlab.cloud/api/enrichment'

    description = "searching for enriched TFs"
    payload = {
        'list': (None, "\n".join(gene_list)),
        'description': (None, description)
    }
    response=requests.post(f"{CHEA_KG}/addList", files=payload)
    time.sleep(0.2)
    data = json.loads(response.text)

    q = {
        'min_lib': 3, # minimum number of libraries that a TF must be ranked in
        'libraries': [
            {'library': "Integrated--meanRank", 'term_limit': num_tfs} # edit term_limit to change number of top-ranked TFs
        ],
        'limit':50, # controls number of edges returned - may cause issues with visualization if too large
        'userListId': data['userListId']
    }
    query_json=json.dumps(q)

    res = requests.post(CHEA_KG, data=query_json)
    if res.ok:
        data = json.loads(res.text)
        return data
    else:
        data = None
        return res.text


def top_tfs(gene_list, num_tfs=5):
    """
    Returns a list of the top N most enriched TFs corresponding to an input gene list.
    """
    enriched_tfs = run_chea_kg(gene_list, num_tfs)
    tfs_list = []
    for node in enriched_tfs["nodes"]:
        tfs_list.append(node["data"]["label"])
    return tfs_list

## Step 3. Constructing the regulatory subnetwork

In step 3, a network that connects the enriched TFs at each time point is constructed. A JSON file describing the regulatory subnetwork of enriched TFs is created. Edges between enriched TF nodes are determined using [ChEA-KG](https://chea-kg.maayanlab.cloud/) (project led by Anna Byrd). 

In [None]:
def fetch_chea_kg_data(start_tf, end_tf):
    """
    Outputs JSON data for shortest path connecting two TFs.
    """
    base_url = "https://chea-kg.maayanlab.cloud/api/knowledge_graph"

    query_filter = {
        "start": "Transcription Factor",
        "start_field": "label",
        "start_term": start_tf,
        "end": "Transcription Factor",
        "end_field": "label",
        "end_term": end_tf
    }

    encoded_filter = urllib.parse.quote(str(query_filter).replace("'", '"'))
    full_url = f"{base_url}?filter={encoded_filter}"

    response = requests.get(full_url)
    response.raise_for_status()
    return response.json()


def get_tf_node_info(tf_label):
    """
    Gets node information associated with single TF.
    """
    base_url = "https://chea-kg.maayanlab.cloud/api/knowledge_graph"

    query_filter = {
        "start": "Transcription Factor",
        "start_field": "label",
        "start_term": tf_label
    }
    encoded_filter = urllib.parse.quote(str(query_filter).replace("'", '"'))
    full_url = f"{base_url}?filter={encoded_filter}"

    response = requests.get(full_url)
    response.raise_for_status()
    data = response.json()

    for node in data["nodes"]:
        if node["data"]["label"] == tf_label:
            return node
    raise ValueError(f"Node for TF '{tf_label}' not found in response.")


def get_tf_edge_info(tf_label):
    """
    Returns edge info if TF is autoregulatory.
    """
    base_url = "https://chea-kg.maayanlab.cloud/api/knowledge_graph"

    query_filter = {
        "start": "Transcription Factor",
        "start_field": "label",
        "start_term": tf_label
    }
    encoded_filter = urllib.parse.quote(str(query_filter).replace("'", '"'))
    full_url = f"{base_url}?filter={encoded_filter}"

    response = requests.get(full_url)
    response.raise_for_status()
    data = response.json()

    for edge in data["edges"]:
        if edge["data"]["source_label"] == tf_label and edge["data"]["target_label"] == tf_label:
            return edge
    raise ValueError(f"Self-edge for TF '{tf_label}' not found in response.")


def create_tf_time_series_graph(tf_time_dict, num_comparisons, filename):
    """
    Creates network of TFs given differentially expressed genes at each time point.
    """
    subnetwork = {"nodes": [], "edges": []}
    for time in tf_time_dict.keys():
        up_tfs = tf_time_dict[time][0]
        down_tfs = tf_time_dict[time][1]

        for tf in up_tfs:
            node_info = get_tf_node_info(tf)
            node_info["data"]["color"] = "#80eaff"
            node_info["data"]["id"] = f"{node_info['data']['label']}_up_{time}"
            subnetwork["nodes"].append(node_info)

        for tf in down_tfs:
            node_info = get_tf_node_info(tf)
            node_info["data"]["color"] = "#ff8a80"
            node_info["data"]["id"] = f"{node_info['data']['label']}_down_{time}"
            subnetwork["nodes"].append(node_info)

    print("Creating time series TF subnetwork data structure...")
    print("Case 1 = connecting two different TFs at adjacent time points")
    print("Case 2 = connecting the same TF at adjacent time points")
    for i in range(num_comparisons-1):
        for j, source_tf_list in enumerate(tf_time_dict[i]):
            for k, target_tf_list in enumerate(tf_time_dict[i+1]):
                for source_tf in source_tf_list:
                    for target_tf in target_tf_list:
                        if source_tf != target_tf:
                            data = fetch_chea_kg_data(source_tf, target_tf)
                            if len(data["nodes"]) == 2 and len(data["edges"]) == 1:
                                if data["edges"][0]["data"]["source_label"] == source_tf and \
                                    data["edges"][0]["data"]["target_label"] == target_tf:
                                    if j == 0 and k == 0:
                                        data["edges"][0]["data"]["source"] = f"{source_tf}_up_{i}"
                                        data["edges"][0]["data"]["target"] = f"{target_tf}_up_{i+1}"
                                    elif j == 0 and k == 1:
                                        data["edges"][0]["data"]["source"] = f"{source_tf}_up_{i}"
                                        data["edges"][0]["data"]["target"] = f"{target_tf}_down_{i+1}"
                                    elif j == 1 and k == 0:
                                        data["edges"][0]["data"]["source"] = f"{source_tf}_down_{i}"
                                        data["edges"][0]["data"]["target"] = f"{target_tf}_up_{i+1}"
                                    elif j == 1 and k == 1:
                                        data["edges"][0]["data"]["source"] = f"{source_tf}_down_{i}"
                                        data["edges"][0]["data"]["target"] = f"{target_tf}_down_{i+1}"
                                    subnetwork["edges"].append(data["edges"][0])
                                    print("Edge added to subnetwork (case 1)")
                        else:
                            if len(source_tf_list) != 0:
                                try:
                                    edge = get_tf_edge_info(source_tf)
                                    if j == 0 and k == 0:
                                        edge["data"]["source"] = f"{source_tf}_up_{i}"
                                        edge["data"]["target"] = f"{target_tf}_up_{i+1}"
                                    elif j == 0 and k == 1:
                                        edge["data"]["source"] = f"{source_tf}_up_{i}"
                                        edge["data"]["target"] = f"{target_tf}_down_{i+1}"
                                    elif j == 1 and k == 0:
                                        edge["data"]["source"] = f"{source_tf}_down_{i}"
                                        edge["data"]["target"] = f"{target_tf}_up_{i+1}"
                                    elif j == 1 and k == 1:
                                        edge["data"]["source"] = f"{source_tf}_down_{i}"
                                        edge["data"]["target"] = f"{target_tf}_down_{i+1}"
                                    subnetwork["edges"].append(edge)
                                    print("Edge added to subnetwork (case 2)")
                                except ValueError:
                                    continue
    output_path = f"{filename}.json"
    with open(output_path, "w") as outfile:
        json.dump(subnetwork, outfile, indent=4)
    print("Finished creating the time series TF subnetwork graph.")
    print()
    return output_path


def gmt_to_tf_time_dict(gmt_file, num_tfs):
    """
    Converts GMT file containing DEGs to tf_time_dict.
    """
    print("Processing GMT file of DEGs...")
    with open(gmt_file, 'r') as f:
        lines = f.readlines()

    temp_dict = {}
    degs_list = []
    for line in lines:
        tokens = line.split("\t\t")
        term = tokens[0]
        genes = [x.split(',')[0].strip() for x in tokens[1].split('\t')]
        degs_list.append(genes)
        temp_dict[term] = top_tfs(genes, num_tfs) # edit this step to also find the top 10 TFs

    new_degs_list = []
    for i in range(0, len(degs_list), 2):
        new_degs_list.append((degs_list[i], degs_list[i+1]))

    comparisons = list(temp_dict.keys())
    new_comparisons = []
    for item in comparisons:
        comp = item.rsplit(' ', 2)[0] # extracts "Hour 1 vs Hour 0" from "Hour 1 vs Hour 0 up genes" in GMT file
        if comp not in new_comparisons:
            new_comparisons.append(comp)

    j = 0
    tf_time_dict = {}
    for i in range(len(comparisons) // 2):
        tf_time_dict[i] = (temp_dict[comparisons[j]], temp_dict[comparisons[j+1]])
        j += 2
    return tf_time_dict, new_comparisons, new_degs_list

In [None]:
%%appyter hide_code

{% do SectionField(
    name= 'gmt_section_header',
    title= 'Upload files to this section ONLY if you pre-computed up and down DEGs and saved them in a GMT file.',
    img= 'gene.png'
)%}

{% do SectionField(
    name= 'gmt_data',
    title= '1. Upload differentially expressed genes as a GMT file.'
)%}

{% do SectionField(
    name= 'tf_section',
    title= '2. Choose the number of top enriched TFs to include in the visualizations.'
) %}

{% do SectionField(
    name= 'final_section',
    title= '3. Provide a short description of your study (optional).'
) %}

In [None]:
%%appyter code_eval

{% do DescriptionField(
    name= 'description_2',
    text= 'Each row should consist of either the "up" or "down" differentially expressed genes at each time point comparison. The first element of each row should be formatted exactly as "{time point comparison} up genes" or "{time point comparison} down genes". See example files for how to format the input GMT file.',
    section= 'gmt_data'
) %}

{% set dataset3 = FileField(
    name= 'dataset3',
    label= 'GMT file containing DEGs from ADJACENT TIME POINT COMPARISONS',
    default= '',
    examples= {'First set of DEGs from TRAIL-treated TNBC cell line': 'https://minio.dev.maayanlab.cloud/chea-kg-timeseries/deseq2_adj_time_pts_degs.gmt'},
    section= 'gmt_data'
) %}

ds3 = {{dataset3}}

{% set dataset4 = FileField(
    name= 'dataset4',
    label= 'GMT file containing DEGs from COMPARISONS WITH TIME POINT 0',
    default= '',
    examples= {'Second set of DEGS from TRAIL-treated TNBC cell line': 'https://minio.dev.maayanlab.cloud/chea-kg-timeseries/deseq2_compare_w_time_pt_0_degs.gmt'},
    section= 'gmt_data'
) %}

ds4 = {{dataset4}}

{% set chosen_num = ChoiceField(
    name= 'chosen_num_tfs_2',
    label= 'Number of top enriched TFs',
    default= '5',
    choices= {'5': '1', '10': '2'},
    section= 'tf_section'
) %}

if not compute_degs:
    num_tfs = {{chosen_num}}
    num_tfs = str(num_tfs)

    if num_tfs == '1':
        num_tfs = 5
    if num_tfs == '2':
        num_tfs = 10

{% set user_description_2 = TextField(
        name= 'user_description_2',
        label= 'Description',
        default= '',
        section = 'final_section',
) %}

if not compute_degs:
    umap_desc = {{user_description_2}}

In [None]:
if compute_degs and (dge_method == '1' or dge_method == '3'):
    deseq2_tf_time_dict_1, deseq2_adj_time_pt_comparisons, deseq2_degs_1 = gmt_to_tf_time_dict(deseq2_degs_gmt_1, num_tfs)
    deseq2_subnetwork_data_1 = create_tf_time_series_graph(deseq2_tf_time_dict_1, len(deseq2_adj_time_pt_comparisons), "appyter_deseq2_adj_time_pts_top_5_tfs")

    deseq2_tf_time_dict_2, deseq2_time_pt_0_comparisons, deseq2_degs_2 = gmt_to_tf_time_dict(deseq2_degs_gmt_2, num_tfs)
    deseq2_subnetwork_data_2 = create_tf_time_series_graph(deseq2_tf_time_dict_2, len(deseq2_time_pt_0_comparisons), "appyter_deseq2_compare_w_time_pt_0_top_5_tfs")

if compute_degs and (dge_method == '2' or dge_method == '3'):
    cd_subnetwork_data_1 = create_tf_time_series_graph(cd_tf_time_dict_1, len(cd_adj_time_pt_comparisons), "appyter_cd_adj_time_pts_top_5_tfs")
    cd_subnetwork_data_2 = create_tf_time_series_graph(cd_tf_time_dict_2, len(cd_time_pt_0_comparisons), "appyter_cd_compare_w_time_pt_0_top_5_tfs")

if not compute_degs:
    if ds3:
        tf_time_dict_1, comparisons_1, deseq2_degs_1 = gmt_to_tf_time_dict(ds3, num_tfs)
        subnetwork_data_1 = create_tf_time_series_graph(tf_time_dict_1, len(comparisons_1), "appyter_deseq2_adj_time_pts_top_5_tfs")

    if ds4:
        tf_time_dict_2, comparisons_2, deseq2_degs_2 = gmt_to_tf_time_dict(ds4, num_tfs)
        subnetwork_data_2 = create_tf_time_series_graph(tf_time_dict_2, len(comparisons_2), "appyter_deseq2_compare_w_time_pt_0_top_5_tfs")

## Step 4. Visualizing the enriched TFs within a time series regulatory subnetwork

The JSON file created in step 3 is used to visualize the regulatory network using ball-and-stick diagrams. Blue and red nodes correspond to TFs enriched for upregulated and downregulated gene sets at the given time point, respectively. Activation (green) or inhibition (red) arrows are drawn between TFs at adjacent time points (as determined by [ChEA-KG](https://chea-kg.maayanlab.cloud/)), therefore forming a regulatory subnetwork of all the enriched TFs. 

In [None]:
def export_cytoscape_json(data, filename):
    output_dir = os.environ.get("APP_STATIC", "/app/static")
    os.makedirs(output_dir, exist_ok=True)
    json_output_path = os.path.join(output_dir, filename)

    with open(json_output_path, "w") as f:
        json.dump(data, f, indent=2)
    display(FileLink(json_output_path))


def visualize_network_2(data, filename):
    row_col_dict = {}
    for node in data["nodes"]:
        if "position" in node:
            continue
        node_id = node["data"]["id"]
        row_index = int(node_id.split("_")[-1])
        col_index = row_col_dict.get(row_index, 1)
        row_col_dict[row_index] = col_index + 1 # sets the col_index for the next node in that row
        node["position"] = {"x": col_index * 150, "y": row_index * 150}

    cyto_widget = CytoscapeWidget()
    cyto_widget.graph.add_graph_from_json(data)
    cyto_widget.set_layout(name='preset')

    cyto_widget.set_style([{
        'selector': 'node',
        'style': {
            'label': 'data(label)',
            'background-color': 'data(color)',
            'border-width': 'data(borderWidth)',
            'border-color': 'data(borderColor)',
            'width': 50,
            'height': 50,
            'font-size': '16px',
            'font-weight': 'bold',
            'text-valign': 'center',
            'text-halign': 'center'
        }
    }, {
        'selector': '.row-label',
        'style': {
            'label': 'data(label)',
            'background-color': '#ffffff',
            'color': '#000000',
            'font-size': '20px',
            'width': 5,
            'height': 5,
            'text-valign': 'center',
            'text-halign': 'center'
        }
    }, {
        'selector': 'edge',
        'style': {
            'width': 3,
            'line-color': 'data(lineColor)',
            'target-arrow-color': 'data(lineColor)',
            'target-arrow-shape': 'data(directed)',
            'curve-style': 'bezier',
            'arrow-scale': 1.5
        }
    }])

    display(cyto_widget)
    json_filename = f"{filename}.json"
    export_cytoscape_json(data, json_filename)


def run_visualization(input_json, comparisons, filename):
    if isinstance(input_json, dict):
        data = input_json
    else:
        with open(input_json, "r") as input_file:
            data = json.load(input_file)

    nodes = data["nodes"]
    for i, comparison in enumerate(comparisons):
        label_node = {
            'data': {
                'id': f'row_label_{i}',
                'label': comparison,
                'color': '#ffffff',
                'borderColor': '#ffffff',
                'borderWidth': 0
            },
            'classes': 'row-label',
            'position': {"x": 0, "y": i * 150}
        }
        nodes.append(label_node)

    visualize_network_2(data, filename)
    return "VISUALIZATION HAS RUN"

In [None]:
if compute_degs and (dge_method == '1' or dge_method == '3'):
    for i, comparison in enumerate(deseq2_adj_time_pt_comparisons):
        all_enriched_tfs = list(deseq2_tf_time_dict_1[i][0]) + list(deseq2_tf_time_dict_1[i][1])
        # print(all_enriched_tfs)

        up_enriched_tfs = []
        for tf in all_enriched_tfs:
            if tf in deseq2_degs_1[i][0]:
                up_enriched_tfs.append(tf)
        print(f"{comparison} enriched TFs found in {comparison} up genes: {up_enriched_tfs}")

        if i < (len(deseq2_adj_time_pt_comparisons) - 1):
            up_enriched_tfs = []
            for tf in all_enriched_tfs:
                if tf in deseq2_degs_1[i+1][0]:
                    up_enriched_tfs.append(tf)
            print(f"{comparison} enriched TFs found in {deseq2_adj_time_pt_comparisons[i+1]} up genes: {up_enriched_tfs}")

        down_enriched_tfs = []
        for tf in all_enriched_tfs:
            if tf in deseq2_degs_1[i][1]:
                down_enriched_tfs.append(tf)
        print(f"{comparison} enriched TFs found in {comparison} down genes: {down_enriched_tfs}")

        if i < (len(deseq2_adj_time_pt_comparisons) - 1):
            down_enriched_tfs = []
            for tf in all_enriched_tfs:
                if tf in deseq2_degs_1[i+1][1]:
                    down_enriched_tfs.append(tf)
            print(f"{comparison} enriched TFs found in {deseq2_adj_time_pt_comparisons[i+1]} down genes: {down_enriched_tfs}")

    run_visualization(deseq2_subnetwork_data_1, deseq2_adj_time_pt_comparisons, "deseq2_adj_time_pt_visualization")
    display(Markdown(f"##### *__Figure 1.1__: Subnetwork of enriched TFs determined from comparing gene expression at adjacent time points (DEGs computed using PyDESeq2). Blue TFs are enriched for up genes while red TFs are enriched for down genes. Activation and inhibition arrows depict the potential regulatory effects of a TF in an earlier time point on a TF in the subsequent time point.*"))

In [None]:
if compute_degs and (dge_method == '1' or dge_method == '3'):
    for i, comparison in enumerate(deseq2_time_pt_0_comparisons):
        all_enriched_tfs = list(deseq2_tf_time_dict_2[i][0]) + list(deseq2_tf_time_dict_2[i][1])
        # print(all_enriched_tfs)

        up_enriched_tfs = []
        for tf in all_enriched_tfs:
            if tf in deseq2_degs_2[i][0]:
                up_enriched_tfs.append(tf)
        print(f"{comparison} enriched TFs found in {comparison} up genes: {up_enriched_tfs}")

        if i < (len(deseq2_time_pt_0_comparisons) - 1):
            up_enriched_tfs = []
            for tf in all_enriched_tfs:
                if tf in deseq2_degs_2[i+1][0]:
                    up_enriched_tfs.append(tf)
            print(f"{comparison} enriched TFs found in {deseq2_time_pt_0_comparisons[i+1]} up genes: {up_enriched_tfs}")

        down_enriched_tfs = []
        for tf in all_enriched_tfs:
            if tf in deseq2_degs_2[i][1]:
                down_enriched_tfs.append(tf)
        print(f"{comparison} enriched TFs found in {comparison} down genes: {down_enriched_tfs}")

        if i < (len(deseq2_time_pt_0_comparisons) - 1):
            down_enriched_tfs = []
            for tf in all_enriched_tfs:
                if tf in deseq2_degs_2[i+1][1]:
                    down_enriched_tfs.append(tf)
            print(f"{comparison} enriched TFs found in {deseq2_time_pt_0_comparisons[i+1]} down genes: {down_enriched_tfs}")

    run_visualization(deseq2_subnetwork_data_2, deseq2_time_pt_0_comparisons, "deseq2_compare_w_time_pt_0_visualization")
    display(Markdown(f"##### *__Figure 1.2__: Subnetwork of enriched TFs determined from comparing gene expression at each time point to time point 0 (DEGs computed using PyDESeq2). Blue TFs are enriched for up genes while red TFs are enriched for down genes. Activation and inhibition arrows depict the potential regulatory effects of a TF in an earlier time point on a TF in the subsequent time point.*"))

In [None]:
if compute_degs and dge_method == '2':
    fig_nums = ('1.1', '1.2')

if compute_degs and dge_method == '3':
    fig_nums = ('1.3', '1.4')

In [None]:
if compute_degs and (dge_method == '2' or dge_method == '3'):
    for i, comparison in enumerate(cd_adj_time_pt_comparisons):
        all_enriched_tfs = list(cd_tf_time_dict_1[i][0]) + list(cd_tf_time_dict_1[i][1])
        # print(all_enriched_tfs)

        up_enriched_tfs = []
        for tf in all_enriched_tfs:
            if tf in cd_degs_1[i][0]:
                up_enriched_tfs.append(tf)
        print(f"{comparison} enriched TFs found in {comparison} up genes: {up_enriched_tfs}")

        if i < (len(cd_adj_time_pt_comparisons) - 1):
            up_enriched_tfs = []
            for tf in all_enriched_tfs:
                if tf in cd_degs_1[i+1][0]:
                    up_enriched_tfs.append(tf)
            print(f"{comparison} enriched TFs found in {cd_adj_time_pt_comparisons[i+1]} up genes: {up_enriched_tfs}")

        down_enriched_tfs = []
        for tf in all_enriched_tfs:
            if tf in cd_degs_1[i][1]:
                down_enriched_tfs.append(tf)
        print(f"{comparison} enriched TFs found in {comparison} down genes: {down_enriched_tfs}")

        if i < (len(cd_adj_time_pt_comparisons) - 1):
            down_enriched_tfs = []
            for tf in all_enriched_tfs:
                if tf in cd_degs_1[i+1][1]:
                    down_enriched_tfs.append(tf)
            print(f"{comparison} enriched TFs found in {cd_adj_time_pt_comparisons[i+1]} down genes: {down_enriched_tfs}")

    run_visualization(cd_subnetwork_data_1, cd_adj_time_pt_comparisons, "cd_adj_time_pt_visualization")
    display(Markdown(f"##### *__Figure {fig_nums[0]}__: Subnetwork of enriched TFs determined from comparing gene expression at adjacent time points (DEGs computed using CD). Blue TFs are enriched for up genes while red TFs are enriched for down genes. Activation and inhibition arrows depict the potential regulatory effects of a TF in an earlier time point on a TF in the subsequent time point.*"))

In [None]:
if compute_degs and (dge_method == '2' or dge_method == '3'):
    for i, comparison in enumerate(cd_time_pt_0_comparisons):
        all_enriched_tfs = list(cd_tf_time_dict_2[i][0]) + list(cd_tf_time_dict_2[i][1])
        # print(all_enriched_tfs)

        up_enriched_tfs = []
        for tf in all_enriched_tfs:
            if tf in cd_degs_2[i][0]:
                up_enriched_tfs.append(tf)
        print(f"{comparison} enriched TFs found in {comparison} up genes: {up_enriched_tfs}")

        if i < (len(cd_time_pt_0_comparisons) - 1):
            up_enriched_tfs = []
            for tf in all_enriched_tfs:
                if tf in cd_degs_2[i+1][0]:
                    up_enriched_tfs.append(tf)
            print(f"{comparison} enriched TFs found in {cd_time_pt_0_comparisons[i+1]} up genes: {up_enriched_tfs}")

        down_enriched_tfs = []
        for tf in all_enriched_tfs:
            if tf in cd_degs_2[i][1]:
                down_enriched_tfs.append(tf)
        print(f"{comparison} enriched TFs found in {comparison} down genes: {down_enriched_tfs}")

        if i < (len(cd_time_pt_0_comparisons) - 1):
            down_enriched_tfs = []
            for tf in all_enriched_tfs:
                if tf in cd_degs_2[i+1][1]:
                    down_enriched_tfs.append(tf)
            print(f"{comparison} enriched TFs found in {cd_time_pt_0_comparisons[i+1]} down genes: {down_enriched_tfs}")

    run_visualization(cd_subnetwork_data_2, cd_time_pt_0_comparisons, "cd_compare_w_time_pt_0_visualization")
    display(Markdown(f"##### *__Figure {fig_nums[1]}__: Subnetwork of enriched TFs determined from comparing gene expression at each time point to time point 0 (DEGs computed using CD). Blue TFs are enriched for up genes while red TFs are enriched for down genes. Activation and inhibition arrows depict the potential regulatory effects of a TF in an earlier time point on a TF in the subsequent time point.*"))

In [None]:
if not compute_degs:
    if ds3:
        for i, comparison in enumerate(comparisons_1):
            all_enriched_tfs = list(tf_time_dict_1[i][0]) + list(tf_time_dict_1[i][1])
            # print(all_enriched_tfs)

            up_enriched_tfs = []
            for tf in all_enriched_tfs:
                if tf in deseq2_degs_1[i][0]:
                    up_enriched_tfs.append(tf)
            print(f"{comparison} enriched TFs found in {comparison} up genes: {up_enriched_tfs}")

            if i < (len(comparisons_1) - 1):
                up_enriched_tfs = []
                for tf in all_enriched_tfs:
                    if tf in deseq2_degs_1[i+1][0]:
                        up_enriched_tfs.append(tf)
                print(f"{comparison} enriched TFs found in {comparisons_1[i+1]} up genes: {up_enriched_tfs}")

            down_enriched_tfs = []
            for tf in all_enriched_tfs:
                if tf in deseq2_degs_1[i][1]:
                    down_enriched_tfs.append(tf)
            print(f"{comparison} enriched TFs found in {comparison} down genes: {down_enriched_tfs}")

            if i < (len(comparisons_1) - 1):
                down_enriched_tfs = []
                for tf in all_enriched_tfs:
                    if tf in deseq2_degs_1[i+1][1]:
                        down_enriched_tfs.append(tf)
                print(f"{comparison} enriched TFs found in {comparisons_1[i+1]} down genes: {down_enriched_tfs}")

        run_visualization(subnetwork_data_1, comparisons_1, "adj_time_pt_visualization")
        display(Markdown(f"##### *__Figure 1.1__: Subnetwork of enriched TFs determined from comparing gene expression at adjacent time points. Blue TFs are enriched for up genes while red TFs are enriched for down genes. Activation and inhibition arrows depict the potential regulatory effects of a TF in an earlier time point on a TF in the subsequent time point.*"))

    if ds4:
        for i, comparison in enumerate(comparisons_2):
            all_enriched_tfs = list(tf_time_dict_2[i][0]) + list(tf_time_dict_2[i][1])
            # print(all_enriched_tfs)

            up_enriched_tfs = []
            for tf in all_enriched_tfs:
                if tf in deseq2_degs_2[i][0]:
                    up_enriched_tfs.append(tf)
            print(f"{comparison} enriched TFs found in {comparison} up genes: {up_enriched_tfs}")

            if i < (len(comparisons_2) - 1):
                up_enriched_tfs = []
                for tf in all_enriched_tfs:
                    if tf in deseq2_degs_2[i+1][0]:
                        up_enriched_tfs.append(tf)
                print(f"{comparison} enriched TFs found in {comparisons_2[i+1]} up genes: {up_enriched_tfs}")

            down_enriched_tfs = []
            for tf in all_enriched_tfs:
                if tf in deseq2_degs_2[i][1]:
                    down_enriched_tfs.append(tf)
            print(f"{comparison} enriched TFs found in {comparison} down genes: {down_enriched_tfs}")

            if i < (len(comparisons_2) - 1):
                down_enriched_tfs = []
                for tf in all_enriched_tfs:
                    if tf in deseq2_degs_2[i+1][1]:
                        down_enriched_tfs.append(tf)
                print(f"{comparison} enriched TFs found in {comparisons_2[i+1]} down genes: {down_enriched_tfs}")

        run_visualization(subnetwork_data_2, comparisons_2, "compare_w_time_pt_0_visualization")
        display(Markdown(f"##### *__Figure 1.2__: Subnetwork of enriched TFs determined from comparing gene expression at each time point to time point 0. Blue TFs are enriched for up genes while red TFs are enriched for down genes. Activation and inhibition arrows depict the potential regulatory effects of a TF in an earlier time point on a TF in the subsequent time point.*"))

## Step 5. Bar graph representation of enriched TF ranks

The bar graphs below depict the ranks of the top enriched TFs across the TF-target gene set libraries in ChEA3 (Keenan et al., 2019). Bar graphs are shown for the enriched TFs for the "up" and "down" gene sets at each time point. 

Here, a TF's rank within a given library refers to how well the TF's target genes within that library overlap with the input gene set (e.g. up genes at the "hour 3 vs hour 1" comparison) compared to other TFs within that library. The top enriched TFs are computed by finding the TFs with the lowest average rank across all six ChEA3 libraries (referred to as the MeanRank method).  

In [None]:
def get_chea3_results(gene_set, query_name):
    ADDLIST_URL = 'https://maayanlab.cloud/chea3/api/enrich/'
    payload = {
        'gene_set': gene_set,
        'query_name': query_name
    }
    response = requests.post(ADDLIST_URL, data=json.dumps(payload))
    if not response.ok:
        # r.ok (where r is the object) returns whether the call to the url was successful
        raise Exception('Error analyzing gene list')
    time.sleep(1)
    return json.loads(response.text) # .text returns the content of response in


def indexfinder(lib_score_list, value):
    index = 1
    for num in lib_score_list:
        if num == value:
            return index
        elif num != 0:
            index += 1

In [None]:
def get_bar_graph(gene_set, query_name, num_tfs, title):
    term_limit = num_tfs

    c_lib_palette = {'ARCHS4 Coexpression':'rgb(196, 8, 8)',
                    'ENCODE ChIP-seq':'rgb(244, 109, 67)',
                    'Enrichr Queries':'rgb(242, 172, 68)',
                    'GTEx Coexpression':'rgb(236, 252, 68)',
                    'Literature ChIP-seq':'rgb(165, 242, 162)',
                    'ReMap ChIP-seq':'rgb(92, 217, 78)'}

    c_lib_means = {'ARCHS4 Coexpression': [0] * term_limit, 'ENCODE ChIP-seq': [0] * term_limit,
                'Enrichr Queries': [0] * term_limit, 'GTEx Coexpression': [0] * term_limit,
                'Literature ChIP-seq': [0] * term_limit, 'ReMap ChIP-seq': [0] * term_limit}
    # creates a dictionary where each library is a key, and the values are empty lists with as
    # many indices/spaces as the user has requested transcription factors (ex: if the user requests
    # 15 TFs to be returned, the lists will have 15 spaces)

    libs_sorted = ['ARCHS4 Coexpression','ENCODE ChIP-seq','Enrichr Queries',
                'GTEx Coexpression','Literature ChIP-seq','ReMap ChIP-seq']

    results = get_chea3_results(gene_set, query_name)
    mr_results = results['Integrated--meanRank']
    # for MeanRank, the TFs are already ranked by 'Score' within mr_results

    for i in range(len(mr_results)):
        for lib in libs_sorted:
            mr_results[i].update({lib:0})

    for i in range(len(mr_results)):
        thing = mr_results[i]['Library'].split(';')
        for a in range(len(thing)):
            library, value = thing[a].split(',')
            mr_results[i].update({library:int(value)})

    sortedARCHS4 = sorted(mr_results, key = lambda k: k['ARCHS4 Coexpression'])
    sortedGTEx = sorted(mr_results, key = lambda k: k['GTEx Coexpression'])
    sortedEnrichr = sorted(mr_results, key = lambda k: k['Enrichr Queries'])
    sortedENCODE = sorted(mr_results, key = lambda k: k['ENCODE ChIP-seq'])
    sortedReMap = sorted(mr_results, key = lambda k: k['ReMap ChIP-seq'])
    sortedLit = sorted(mr_results, key = lambda k: k['Literature ChIP-seq'])

    rankedARCHS4 = [entry['ARCHS4 Coexpression'] for entry in sortedARCHS4]
    rankedENCODE = [entry['ENCODE ChIP-seq'] for entry in sortedENCODE]
    rankedEnrichr = [entry['Enrichr Queries'] for entry in sortedEnrichr]
    rankedGTEx = [entry['GTEx Coexpression'] for entry in sortedGTEx]
    rankedLit = [entry['Literature ChIP-seq'] for entry in sortedLit]
    rankedReMap = [entry['ReMap ChIP-seq'] for entry in sortedReMap]

    ranking_dict = {'ARCHS4 Coexpression':rankedARCHS4,
                    'ENCODE ChIP-seq':rankedENCODE,
                    'Enrichr Queries':rankedEnrichr,
                    'GTEx Coexpression':rankedGTEx,
                    'Literature ChIP-seq':rankedLit,
                    'ReMap ChIP-seq':rankedReMap}

    # Computing MeanRank
    for tfentry in mr_results:
        tfentry.update( [('SumRank', 0), ('AvgRank', 0) ])
        library_scores = tfentry['Library'].split(';')
        lib_counter = 0
        for a in library_scores:
            l, v = a.split(',')
            v = int(v)
            #scorerank = ranking_dict[l].index(v) + 1
            scorerank = indexfinder(ranking_dict[l], int(v))
            tfentry['SumRank'] += int(scorerank)
            lib_counter += 1
        tfentry['AvgRank'] = (tfentry['SumRank'] / lib_counter)
    sorted_results = sorted(mr_results, key = lambda k: k['AvgRank']) # rank by AvgRank (aka MeanRank method)

    sorted_top_results = [] # only adds num_tfs # of TFs
    threshold = 3
    index = 0
    while (len(sorted_top_results) < term_limit) and (index < len(sorted_results)):
        if len(sorted_results[index]['Library'].split(';')) >= threshold: # makes sure there are enough libraries
            sorted_top_results.append(sorted_results[index])
        index += 1

    sorted_top_results = sorted_top_results[::-1]

    sorted_tfs = []
    for i in range(0, len(sorted_top_results)):
        sorted_tfs.append(sorted_top_results[i].get('TF'))

    # Defining bar length
    for i, tfentry in enumerate(sorted_top_results):
        libscores = tfentry['Library'].split(';')
        for a in libscores:
            lib, value = a.split(',')
            rank = indexfinder(ranking_dict[lib], int(value))
            avg = tfentry['AvgRank']
            tot = tfentry['SumRank']
            bar_length = (rank*avg)/tot
            c_lib_means[lib][i] = float(bar_length)

    # Plotting the actual bar chart
    fig = go.Figure(data = [go.Bar(name = c_lib,
                                x = c_lib_means[c_lib],
                                y = sorted_tfs,
                                marker = go.bar.Marker(color = c_lib_palette[c_lib]),
                                orientation = 'h')
                            for c_lib in libs_sorted])
    h = 400 if term_limit <=10 else 400+10*term_limit
    fig.update_layout(barmode = 'stack')
    '''fig.update_layout(
        title = {
            'text': 'Stacked Bar Chart of Average Ranks in Different Libraries',
            'y': 0.67,
            'x': 0.5,
            'xanchor': 'center',
            'yanchor': 'top',
        }
    )'''
    fig.update_layout(

        xaxis_title = 'Average of Ranks Across All Libraries',
        yaxis_title = title,
        font = dict(
            size = 12,
            color = 'black'
        ),
        width=900,
        height=h
    )
    # uid = uuid.uuid4().hex[:8]
    # folder = Path("chea3_graph_exports")
    # folder.mkdir(exist_ok=True)

    # svg_path = folder / f"bar_plot_{uid}.svg"
    # png_path = folder / f"bar_plot_{uid}.png"
    # jpg_path = folder / f"bar_plot_{uid}.jpg"

    # pio.write_image(fig, str(svg_path), format="svg")
    # pio.write_image(fig, str(png_path), format="png")
    # pio.write_image(fig, str(jpg_path), format="jpg")

    # html_str = fig.to_html(include_plotlyjs='cdn')
    # escaped_html = html.escape(html_str)
    # iframe = f"""
    # <iframe srcdoc="{escaped_html}" width="100%" height="{h}" frameborder="0"></iframe>
    # """

    # download_html = f"""
    # <div style="margin-top: 10px;">
    #     <strong>Download plot:</strong><br>
    #     <a download href="{png_path}" target="_blank">Download as PNG</a><br>
    # </div>
    # """

    # display(HTML(iframe + download_html))

    html_str = fig.to_html(include_plotlyjs='cdn')
    escaped_html = html.escape(html_str)

    iframe = f"""
    <iframe srcdoc="{escaped_html}" width="100%" height="{h}" frameborder="0"></iframe>
    """

    display(HTML(iframe))

In [None]:
if compute_degs and (dge_method == '1' or dge_method == '3'):
    j = 1
    for i, pair in enumerate(deseq2_degs_1):
        up, down = pair
        get_bar_graph(up, deseq2_adj_time_pt_comparisons[i], num_tfs, f"Enriched TFs at {deseq2_adj_time_pt_comparisons[i]}")
        display(Markdown(f"##### *__Figure 2.1.{j}__: Ranks of enriched TFs for UP genes at {deseq2_adj_time_pt_comparisons[i]} (computed using DESeq2) in their respective ChEA3 libraries. Bar length is proportional to rank (a narrower bar means a better rank).*"))
        get_bar_graph(down, deseq2_adj_time_pt_comparisons[i], num_tfs, f"Enriched TFs at {deseq2_adj_time_pt_comparisons[i]}")
        display(Markdown(f"##### *__Figure 2.1.{j+1}__: Ranks of enriched TFs for DOWN genes at {deseq2_adj_time_pt_comparisons[i]} (computed using DESeq2) in their respective ChEA3 libraries. Bar length is proportional to rank (a narrower bar means a better rank).*"))
        j += 2

In [None]:
if compute_degs and (dge_method == '1' or dge_method == '3'):
    j = 1
    for i, pair in enumerate(deseq2_degs_2):
        up, down = pair
        get_bar_graph(up, deseq2_time_pt_0_comparisons[i], num_tfs, f"Enriched TFs at {deseq2_time_pt_0_comparisons[i]}")
        display(Markdown(f"##### *__Figure 2.2.{j}__: Ranks of enriched TFs for UP genes at {deseq2_time_pt_0_comparisons[i]} (computed using DESeq2) in their respective ChEA3 libraries. Bar length is proportional to rank (a narrower bar means a better rank).*"))
        get_bar_graph(down, deseq2_time_pt_0_comparisons[i], num_tfs, f"Enriched TFs at {deseq2_time_pt_0_comparisons[i]}")
        display(Markdown(f"##### *__Figure 2.2.{j+1}__: Ranks of enriched TFs for DOWN genes at {deseq2_time_pt_0_comparisons[i]} (computed using DESeq2) in their respective ChEA3 libraries. Bar length is proportional to rank (a narrower bar means a better rank).*"))
        j += 2

In [None]:
if compute_degs and dge_method == '2':
    fig_nums = ('2.1', '2.2')

if compute_degs and dge_method == '3':
    fig_nums = ('2.3', '2.4')

In [None]:
if compute_degs and (dge_method == '2' or dge_method == '3'):
    j = 1
    for i, pair in enumerate(cd_degs_1):
        up, down = pair
        get_bar_graph(up, cd_adj_time_pt_comparisons[i], num_tfs, f"Enriched TFs at {cd_adj_time_pt_comparisons[i]}")
        display(Markdown(f"##### *__Figure {fig_nums[0]}.{j}__: Ranks of enriched TFs for UP genes at {cd_adj_time_pt_comparisons[i]} (computed using CD) in their respective ChEA3 libraries. Bar length is proportional to rank (a narrower bar means a better rank).*"))
        get_bar_graph(down, cd_adj_time_pt_comparisons[i], num_tfs, f"Enriched TFs at {cd_adj_time_pt_comparisons[i]}")
        display(Markdown(f"##### *__Figure {fig_nums[0]}.{j+1}__: Ranks of enriched TFs for DOWN genes at {cd_adj_time_pt_comparisons[i]} (computed using CD) in their respective ChEA3 libraries. Bar length is proportional to rank (a narrower bar means a better rank).*"))
        j += 2

In [None]:
if compute_degs and (dge_method == '2' or dge_method == '3'):
    j = 1
    for i, pair in enumerate(cd_degs_2):
        up, down = pair
        get_bar_graph(up, cd_time_pt_0_comparisons[i], num_tfs, f"Enriched TFs at {cd_time_pt_0_comparisons[i]}")
        display(Markdown(f"##### *__Figure {fig_nums[1]}.{j}__: Ranks of enriched TFs for UP genes at {cd_time_pt_0_comparisons[i]} (computed using CD) in their respective ChEA3 libraries. Bar length is proportional to rank (a narrower bar means a better rank).*"))
        get_bar_graph(down, cd_time_pt_0_comparisons[i], num_tfs, f"Enriched TFs at {cd_time_pt_0_comparisons[i]}")
        display(Markdown(f"##### *__Figure {fig_nums[1]}.{j+1}__: Ranks of enriched TFs for DOWN genes at {cd_time_pt_0_comparisons[i]} (computed using CD) in their respective ChEA3 libraries. Bar length is proportional to rank (a narrower bar means a better rank).*"))
        j += 2

In [None]:
if not compute_degs:
    if ds3:
        j = 1
        for i, pair in enumerate(deseq2_degs_1):
            up, down = pair
            get_bar_graph(up, comparisons_1[i], num_tfs, f"Enriched TFs at {comparisons_1[i]}")
            display(Markdown(f"##### *__Figure 2.1.{j}__: Ranks of enriched TFs for UP genes at {comparisons_1[i]} in their respective ChEA3 libraries. Bar length is proportional to rank (a narrower bar means a better rank).*"))
            get_bar_graph(down, comparisons_1[i], num_tfs, f"Enriched TFs at {comparisons_1[i]}")
            display(Markdown(f"##### *__Figure 2.1.{j+1}__: Ranks of enriched TFs for DOWN genes at {comparisons_1[i]} in their respective ChEA3 libraries. Bar length is proportional to rank (a narrower bar means a better rank).*"))
            j += 2

    if ds4:
        j = 1
        for i, pair in enumerate(deseq2_degs_2):
            up, down = pair
            get_bar_graph(up, comparisons_2[i], num_tfs, f"Enriched TFs at {comparisons_2[i]}")
            display(Markdown(f"##### *__Figure 2.2.{j}__: Ranks of enriched TFs for UP genes at {comparisons_2[i]} in their respective ChEA3 libraries. Bar length is proportional to rank (a narrower bar means a better rank).*"))
            get_bar_graph(down, comparisons_2[i], num_tfs, f"Enriched TFs at {comparisons_2[i]}")
            display(Markdown(f"##### *__Figure 2.2.{j+1}__: Ranks of enriched TFs for DOWN genes at {comparisons_2[i]} in their respective ChEA3 libraries. Bar length is proportional to rank (a narrower bar means a better rank).*"))
            j += 2

## Step 6. Generating a UMAP visualization of the enriched TFs

The enriched transcription factors from each time point are colored on a UMAP plot of 700 TFs identified by ChEA-KG to be "source TFs" (i.e. they exert regulatory effects on other TFs). The UMAP algorithm was performed using the TF-IDF scores of the TFs' target genes, meaning that TFs are placed in the same cluster if they generally regulate similar genes. 

Hovering over each data point reveals the gene identity of the TF, the cluster the TF belongs to, and if the TF is enriched at that time point. 

Use the slider to navigate between time points or click on the corresponding GIF link.

In [None]:
def process_tf_data(libdict, nneighbors=30, mindist=0.1, spread=1.0, maxdf=1.0, mindf=1):
    # print("\tTF-IDF vectorizing gene set data...")
    vec = TfidfVectorizer(max_df=maxdf, min_df=mindf)
    X = vec.fit_transform(libdict.values())
    adata = anndata.AnnData(X)
    adata.obs.index = libdict.keys()

    # print("\tPerforming Leiden clustering...")
    sc.pp.neighbors(adata, n_neighbors=nneighbors) # added use_rep='X'
    sc.tl.leiden(adata, resolution=1.0)
    sc.tl.umap(adata, min_dist=mindist, spread=spread, random_state=42)

    new_order = adata.obs.sort_values(by='leiden').index.tolist()
    adata = adata[new_order, :] # added .copy()
    adata.obs['leiden'] = 'Cluster ' + adata.obs['leiden'].astype('object')

    df = pd.DataFrame(adata.obsm['X_umap'])
    df.columns = ['x', 'y']

    df['cluster'] = adata.obs['leiden'].values
    df['term'] = adata.obs.index
    df['genes'] = [libdict[l] for l in df['term']]

    return df


def get_scatter_colors(df):
    clusters = pd.unique(df['cluster']).tolist()
    n_clusters = len(clusters)
    gray_shades = [f'#{int(v):02x}{int(v):02x}{int(v):02x}' for v in np.linspace(50, 230, n_clusters)]
    color_mapper = {clusters[i]: gray_shades[i] for i in range(n_clusters)}
    return color_mapper


def generate_df_for_comparison(base_df, tf_pair, comparison_label):
    """
    Generates a new df for each time point.
    """
    up_tfs, down_tfs = tf_pair[0], tf_pair[1]
    df = base_df.copy()
    color_mapper = get_scatter_colors(df)
    df['color'] = df['cluster'].apply(lambda x: color_mapper[x])
    df['size'] = 6
    df['time_point'] = "Not enriched"

    for idx, term in df['term'].items():
        if (term in up_tfs) and (term not in down_tfs):
            df.at[idx, 'color'] = "#1595f0"
            df.at[idx, 'size'] = 12
            df.at[idx, 'time_point'] = comparison_label
        if (term in down_tfs) and (term not in up_tfs):
            df.at[idx, 'color'] = "#f30a1a"
            df.at[idx, 'size'] = 12
            df.at[idx, 'time_point'] = comparison_label
        if (term in up_tfs) and (term in down_tfs):
            df.at[idx, 'color'] = "#26e411"
            df.at[idx, 'size'] = 12
            df.at[idx, 'time_point'] = comparison_label
    return df


def generate_legend_descriptions(df, tf_time_dict, comparisons, user_desc):
    descriptions = []
    for i in range(len(comparisons)):
        up_tfs, down_tfs = tf_time_dict[i]
        up_only = sorted([tf for tf in df['term'].tolist() if (tf in up_tfs and tf not in down_tfs)])
        down_only = sorted([tf for tf in df['term'].tolist() if (tf not in up_tfs and tf in down_tfs)])
        both = sorted([tf for tf in df['term'].tolist() if (tf in up_tfs and tf in down_tfs)])

        parts = []
        if up_only:
            parts.append("Enriched transcription factors for up genes: " + ", ".join(up_only))
        if down_only:
            parts.append("Enriched transcription factors for down genes: " + ", ".join(down_only))
        if both:
            parts.append("Enriched transcription factors for up AND down genes: " + ", ".join(both))

        description = str(comparisons[i] + "\n")
        description += "\n".join(parts)
        description += "\n\n" + user_desc.strip()
        descriptions.append(description)
    return descriptions


def get_scatterplot(scatterdf, tf_time_dict=None, comparisons=None, image_dir=None, gif_filename=None):
    """
    Generates images navigable via a slider, as well as all the images separately.
    """
    df = scatterdf.copy()
    df['cluster_number'] = df['cluster'].apply(lambda x: int(x.split(" ")[-1]))
    df.sort_values(by=['cluster_number'], inplace=True)
    df.drop(columns = ['cluster_number'], inplace=True)

    sources = []
    for i, label in enumerate(comparisons):
        df_comp = generate_df_for_comparison(df, tf_time_dict[i], label)
        source = ColumnDataSource(data=dict(x = df_comp['x'], y = df_comp['y'],
                                            gene_set = df_comp['term'], colors = df_comp['color'],
                                            label = df_comp['cluster'], size = df_comp['size'],
                                            time_point = df_comp['time_point']))
        sources.append(source)

    # source = sources[0]
    source = ColumnDataSource(data=deepcopy(sources[0].data))
    tooltips = [
        ("Gene Set", "@gene_set"),
        ("Cluster", "@label"),
        ("Time point", "@time_point")
    ]

    hover_emb = HoverTool(tooltips=tooltips)
    tools_emb = [hover_emb, 'pan', 'wheel_zoom', 'reset', 'save']

    plot_emb = figure(
        width=500*2,
        height=400*2,
        tools=tools_emb,
        output_backend='canvas'
    )

    plot_emb.scatter(
        'x',
        'y',
        size = 'size',
        source = source,
        marker = 'circle',
        fill_color = 'colors',
        color = 'colors'
    )

    color_mapper = get_scatter_colors(df)
    for cluster, gray_color in color_mapper.items():
        dummy_source = ColumnDataSource(data=dict(x=[None], y=[None]))
        plot_emb.scatter(
            'x',
            'y',
            source=dummy_source,
            fill_color=gray_color,
            color=gray_color,
            marker='circle',
            size=10,
            legend_label=cluster
    )

    # hide axis labels and grid lines
    # plot_title = Title(text=comparisons[0], align='center')
    # plot_title.text_font_size = '20pt'
    # plot_title.text_font_style = 'bold'
    # plot_emb.add_layout(plot_title, 'above')

    plot_emb.xaxis.major_tick_line_color = None
    plot_emb.xaxis.minor_tick_line_color = None
    plot_emb.yaxis.major_tick_line_color = None
    plot_emb.yaxis.minor_tick_line_color = None
    plot_emb.grid.grid_line_color = None
    plot_emb.xaxis.major_label_text_font_size = '0pt'
    plot_emb.yaxis.major_label_text_font_size = '0pt'

    plot_emb.xaxis.axis_label = "UMAP-1"
    plot_emb.yaxis.axis_label = "UMAP-2"
    plot_emb.xaxis.axis_label_text_font_size = '20pt'
    plot_emb.yaxis.axis_label_text_font_size = '20pt'
    plot_emb.xaxis.axis_label_text_font_style = "normal"
    plot_emb.yaxis.axis_label_text_font_style = "normal"

    plot_emb.legend.label_text_font_size = '18pt'
    plot_emb.legend.glyph_height = 20
    plot_emb.legend.glyph_width = 20

    plot_emb.add_layout(plot_emb.legend[0], 'right')

    plot_emb.min_border_bottom = 168

    legend_descriptions = generate_legend_descriptions(df, tf_time_dict, comparisons, umap_desc)
    description_source = ColumnDataSource(data=dict(descriptions=legend_descriptions))
    description_label = Label(x=0, y=-7, x_units='screen', y_units='screen',
                          text=legend_descriptions[0],
                          text_font_size='10pt', text_align='left', name='dynamic_description')
    plot_emb.add_layout(description_label, 'below')

    ### adding a slider ###
    slider = Slider(start=0, end=len(sources) - 1, value=0, step=1, title="Comparison")
    comparison_source = ColumnDataSource(data=dict(comparisons=[str(c) for c in comparisons]))
    callback = CustomJS(args=dict(source=source, slider=slider, sources=sources, plot=plot_emb,
                                  comparison_source=comparison_source, description_source=description_source,
                                  description_label=description_label), code="""
        const i = slider.value;
        const new_data = sources[i].data;
        const copied_data = {};
        for (const key in new_data) {
            copied_data[key] = [...new_data[key]];
        }
        source.data = copied_data;

        // Update legend description
        const descriptions = description_source.data['descriptions'];
        description_label.text = descriptions[i];

        source.change.emit();
    """)
    slider.js_on_change('value', callback)
    show(column(slider, plot_emb))

    ### can either show or save the plot (cannot do both)
    # output_file("top_10_tfs_deseq2_adjacent_time_pts_umap_plot.html")
    # save(column(slider, plot_emb))

    ### for isolated individual time point images ###
    frame_dir = os.path.join(os.getcwd(), image_dir)
    os.makedirs(frame_dir, exist_ok=True)
    for i, label in enumerate(comparisons):
        source.data = dict(sources[i].data)
        description_label.text = legend_descriptions[i]
        # plot_title.text = label
        export_png(plot_emb, filename=os.path.join(frame_dir, f"frame_{i:02d}_{label}.png"))

    frame_paths = sorted([os.path.join(frame_dir, f) for f in os.listdir(frame_dir) if f.endswith(".png")])
    images = [Image.open(frame) for frame in frame_paths]

    output_dir = os.environ.get("APP_STATIC", "/app/static")
    os.makedirs(output_dir, exist_ok=True)
    new_gif_filename = os.path.join(output_dir, gif_filename)

    images[0].save(new_gif_filename, save_all=True, append_images=images[1:], duration=1500, loop=0)
    display(FileLink(new_gif_filename, result_html_prefix="Click here to download: "))
    print("GIF of UMAP plot was successfully created.")

    return plot_emb, source

In [None]:
r = requests.get("https://minio.dev.maayanlab.cloud/hgrn-chear/network_target_sets.gmt")
file = r.text.split("\n")

lib_dict = OrderedDict()
for line in file[:-1]:
    tokens = line.split("\t\t")
    term = tokens[0]
    genes = [x.split(',')[0].strip() for x in tokens[1].split('\t')]
    lib_dict[term] = ' '.join(genes)

## defaults: nneighbors=30, mindist=0.1, spread=1.0, maxdf=1.0, mindf=1
scatter_df = process_tf_data(
    lib_dict,
    nneighbors=20,
    mindist=0.15,
)

legend_description = ("Blue dots are TFs enriched for upregulated DEGs.\n"
    "Red dots are TFs enriched for downregulated DEGs.\n"
    "Green dots are TFs enriched for both up- and downregulated DEGs.\n")
legend_description += umap_desc

if compute_degs and (dge_method == '1' or dge_method == '3'):
    deseq2_plot_emb_1, deseq2_source_1 = get_scatterplot(scatter_df, deseq2_tf_time_dict_1, deseq2_adj_time_pt_comparisons, "umap_png_frames_deseq2_adjacent_time_pts", f"top_{num_tfs}_tfs_deseq2_adjacent_time_pts_umap.gif")
    display(Markdown(f"##### *__Figure 3.1__: UMAP of the enriched TFs determined from comparing gene expression at adjacent time points (DEGs computed using PyDESeq2). The top {num_tfs} TFs for the up- and downregulated DEGs are shown in blue and red, respectively, on the UMAP plot. Green TFs (if any) are enriched in both the up and down gene sets.*"))
    deseq2_plot_emb_2, deseq2_source_2 = get_scatterplot(scatter_df, deseq2_tf_time_dict_2, deseq2_time_pt_0_comparisons, "umap_png_frames_deseq2_compare_w_time_pt_0", f"top_{num_tfs}_tfs_deseq2_compare_w_time_pt_0_umap.gif")
    display(Markdown(f"##### *__Figure 3.2__: UMAP of the enriched TFs determined from comparing gene expression at each time point to time point 0 (DEGs computed using PyDESeq2). The top {num_tfs} TFs for the up- and downregulated DEGs are shown in blue and red, respectively, on the UMAP plot. Green TFs (if any) are enriched in both the up and down gene sets.*"))

if compute_degs and dge_method == '2':
    fig_nums = ('3.1', '3.2')

if compute_degs and dge_method == '3':
    fig_nums = ('3.3', '3.4')

if compute_degs and (dge_method == '2' or dge_method == '3'):
    cd_plot_emb_1, cd_source_1 = get_scatterplot(scatter_df, cd_tf_time_dict_1, cd_adj_time_pt_comparisons, "umap_png_frames_cd_adjacent_time_pts", f"top_{num_tfs}_tfs_cd_adjacent_time_pts_umap.gif")
    display(Markdown(f"##### *__Figure {fig_nums[0]}__: UMAP of the enriched TFs determined from comparing gene expression at adjacent time points (DEGs computed using CD). The top {num_tfs} TFs for the up- and downregulated DEGs are shown in blue and red, respectively, on the UMAP plot. Green TFs (if any) are enriched in both the up and down gene sets.*"))
    cd_plot_emb_2, cd_source_2 = get_scatterplot(scatter_df, cd_tf_time_dict_2, cd_time_pt_0_comparisons, "umap_png_frames_cd_compare_w_time_pt_0", f"top_{num_tfs}_tfs_cd_compare_w_time_pt_0_umap.gif")
    display(Markdown(f"##### *__Figure {fig_nums[1]}__: UMAP of the enriched TFs determined from comparing gene expression at each time point to time point 0 (DEGs computed using CD). The top {num_tfs} TFs for the up- and downregulated DEGs are shown in blue and red, respectively, on the UMAP plot. Green TFs (if any) are enriched in both the up and down gene sets.*"))

if not compute_degs:
    if ds3:
        plot_emb_1, source_1 = get_scatterplot(scatter_df, tf_time_dict_1, comparisons_1, "umap_png_frames_deseq2_adjacent_time_pts", f"top_{num_tfs}_tfs_deseq2_adjacent_time_pts_umap.gif")
        display(Markdown(f"##### *__Figure 3.1__: UMAP of the enriched TFs determined from comparing gene expression at adjacent time points. The top {num_tfs} TFs for the up- and downregulated DEGs are shown in blue and red, respectively, on the UMAP plot. Green TFs (if any) are enriched in both the up and down gene sets.*"))

    if ds4:
        plot_emb_2, source_2 = get_scatterplot(scatter_df, tf_time_dict_2, comparisons_2, "umap_png_frames_deseq2_compare_w_time_pt_0", f"top_{num_tfs}_tfs_deseq2_compare_w_time_pt_0_umap.gif")
        display(Markdown(f"##### *__Figure 3.2__: UMAP ofthe enriched TFs determined from comparing gene expression at each time point to time point 0. The top {num_tfs} TFs for the up- and downregulated DEGs are shown in blue and red, respectively, on the UMAP plot. Green TFs (if any) are enriched in both the up and down gene sets.*"))