# Building an edge list for a TF-TF network using RummaGEO gene sets 
### This notebook explains how to produce an edge list of interactions between human TFs for ingestion into the Knowledge Graph UI (https://github.com/MaayanLab/Gene-Knowledge-Graph)

In [1]:
import pyenrichr as py
import os
import numpy as np
import pandas as pd
import tqdm
import warnings
import json
import csv
import statistics as stats
import plotly.graph_objects as go
from matplotlib import pyplot as plt
from collections import defaultdict

In [2]:
raw_data = './raw_data'
output = './build_edge_list_out'

## TF enrichment analysis of RummaGEO sets
### For each gene set, produce a ranked list of all of 1,632 ChEA3 transcription factors
TF enrichment analysis for each of the 171k RummaGEO gene sets is performed using a local version of ChEA3.

*For more information:* 
> RummaGEO: https://rummageo.com/  \
> ChEA3: https://maayanlab.cloud/chea3/

 

In [None]:

# read a GMT file and save to a dictionary

def read_gmt1(gmt_file_path):
    gene_sets = {}
    
    with open(gmt_file_path, 'r') as file:
        for line in file:
            parts = line.strip().split('\t')
            if len(parts) > 5:
                gene_sets[parts[0].strip()] = set(parts[1:])
    return gene_sets


# divide a dictionary into N subdictionaries

def split_dict(original_dict, N):
    iter_dict = iter(original_dict.items())
    list_of_dicts = []
    
    while True:
        small_dict = {}
        
        try:
            for _ in range(N):
                key, value = next(iter_dict)
                small_dict[key] = value
        except StopIteration:
            if small_dict:
                list_of_dicts.append(small_dict)
            break

        list_of_dicts.append(small_dict)
    
    return list_of_dicts


# determine the mean rank of a TF across all ChEA3 libraries

def mean_rank(results):
    if results:
        # Extract gene set names from the first result
        sigs = list(results[next(iter(results))].columns)

    # Extract unique transcription factor (tf) names from all indices
    tfs = list(set(tf for result in results.values() for tf in result.index))

    # Initialize arrays for scores and counts
    tf_scores = np.zeros((len(tfs), len(sigs)))
    tf_counts = np.zeros((len(tfs), len(sigs)))

    # Create a mapping from transcription factor names to indices
    tf_index_map = {t: idx for idx, t in enumerate(tfs)}

    # Aggregate scores and counts
    for l, result in results.items():
        for t in result.index:
            temp = result.loc[t]
            idx = tf_index_map[t]
            if len(temp.shape) == 1:
                tf_scores[idx, :] += temp.values
                tf_counts[idx, :] += 1
            else:
                for ii in range(temp.shape[0]):
                    tf_scores[idx, :] += temp.iloc[ii, :].values
                    tf_counts[idx, :] += 1

    # Calculate the mean scores by dividing tf_scores by tf_counts element-wise
    mean_scores = pd.DataFrame(np.round(np.divide(tf_scores, tf_counts, out=np.zeros_like(tf_scores), where=tf_counts != 0)), index=tfs, columns=sigs)

    return mean_scores

Import RummaGEO data and split into 1k subsets

In [None]:
rummageo = f"{raw_data}/human-geo-auto.gmt"
geo = read_gmt1(rummageo)
geo_split = split_dict(geo, 1000)

For each gene set, rank TFs based on enrichment and record the mean rank across ChEA3 libraries.\
**This step can take several hours.**

In [None]:
warnings.simplefilter("ignore")

fisher = py.enrichment.FastFisher(34000)

# Pre-read the GMT files and store the results in a dictionary
libraries = {}
for lib in os.listdir("raw_data/chea3libs"):
    libraries[lib] = read_gmt1(f"raw_data/chea3libs/{lib}")

# Process each item in geo_split
mranks = []

for i in tqdm.tqdm(range(len(geo_split))):
    results = {}
    for lib_name, lib_gmt in libraries.items():
        # Use the pre-read library data
        res = py.enrichment.fisher(geo_split[i], lib_gmt, min_set_size=10, verbose=False, fisher=fisher)
        temp = py.enrichment.consolidate(res).rank(axis=0)
        temp.index = [x.split("_")[0] for x in temp.index]
        results[lib_name] = temp
    
    mr = mean_rank(results)
    mranks.append(mr)

Save mean ranks matrix

In [None]:
mranks = pd.concat(mranks, axis=1)
mranks.to_csv(f"{output}/mean_ranks_matrix.csv")

## Preliminary filtering
* Retain only the **top 10** highly ranked TFs for each set
* Retain only GSE studies with clear **control and perturbation** groups


In [8]:
# upload TF rank matrix

def preprocess(path):
  matrix = pd.read_csv(path)
  matrix.index = matrix.iloc[:, 0]
  matrix = matrix.drop(columns = matrix.columns[0])
  matrix = matrix.astype(np.int64)
  return matrix


# produce a library of signatures and the most highly enriched TFs (rank <= 10)

def filter_by_rank(matrix, rank_method = "min", threshold = 15):

  max_rank = 10 

  results = {Signature : None for Signature in matrix.columns}

  for Signature in matrix.columns:
    # this is ranking the index based on the TF rankings, so then the index is used to access the TF name
    rank_index = matrix[Signature].rank(method = rank_method).astype(int)
    rank_top_tfs = [rank <= max_rank for rank in rank_index]

    # Remove outliers above specified threshold
    if len(matrix.index[rank_top_tfs]) <= threshold:
      results[Signature] = matrix.index[rank_top_tfs].to_list()
    else:
      del results[Signature]
  
  # returns a matrix of gene set names x top 30 transcription factors (basically transposed GMT)
  return results


In [9]:
mean_ranks_matrix = preprocess(f"{output}/mean_ranks_matrix.csv")
high_rank_matrix = filter_by_rank(mean_ranks_matrix)

Now filter out GSE studies that do not clearly include control and perturbation groups

In [7]:
file_to_keep = f"{raw_data}/single_perturbation_gses.txt"
gses_to_keep = []
filtered_matrix = {} # make a copy to be safe 

with open(file_to_keep, 'r') as file:
    gses_to_keep = [line.strip() for line in file]

for sig in high_rank_matrix.keys():
    gse_tag = sig.split("-")[0]
    if gse_tag in gses_to_keep:
        filtered_matrix[sig] = high_rank_matrix[sig]

Save output as `.gmt` file

In [14]:
def write_gmt(path, tf_library):
  with open(path, "w") as file:
    for signature, geneset in tqdm.tqdm(tf_library.items()):
      file.write(f"{signature}\t\t")
      for gene in geneset:
        file.write(f"{gene}\t")
      file.write("\n")

write_gmt(f"{output}/filtered_tfsets.gmt", filtered_matrix)

100%|██████████| 29330/29330 [00:00<00:00, 427097.18it/s]


## Create TF-TF edge list 
### Count the number of times each TF-TF interaction occurs.
* Create the transpose of the GMT  where each line is a TF followed by the gene sets in which it is ranked (`transpose_human.gmt`)
* For a given highly ranked TF, find all sets where it is present. Mark a source-target edge with any TFs present in the same set.\
* **Directionality** of the edge is determined by the sign (up/down) of the gene set where the relationship occurs. 

Upload high-enrichment TF GMT and create transpose

In [8]:
# Read a GMT file formatted as "signature id tf1 tf2 tf3 ... tf30"

def read_gmt2(path):
  gmt = {}
  print("Reading {}".format(path))
  with open(path, "r") as file:
    for line in tqdm.tqdm(file):
      # 'id' refers to up or down tag
      signature, id, *tf = line.strip().split()
      gmt[" ".join([signature, id])] = set(tf)

  return gmt

def gmt_transpose(library):
  tfs = list(set([tf for sublist in library.values() for tf in sublist]))
  transpose = {}
  for tf in tqdm.tqdm(tfs):
    collect = list()
    for signature in library:
      if tf in library[signature]:
        collect.append(signature)
    transpose[tf] = set(collect)
  return dict(transpose)


In [9]:
tf_sets = read_gmt2(f"{output}/filtered_tfsets.gmt")
tf_transpose = gmt_transpose(tf_sets)

Reading ./build_network_out/filtered_tfsets.gmt


29330it [00:00, 522036.46it/s]
100%|██████████| 1588/1588 [00:02<00:00, 529.83it/s]


Save file

In [10]:
def write_transpose_gmt(library, path):
  with open(path, "w") as file:
    for tf, signatures in tqdm.tqdm(library.items()):
      file.write(f"{tf}\t\t")
      for signature in signatures:
        signature_join = "-".join(signature.split())
        file.write(f"{signature_join}\t")
      file.write("\n")

write_transpose_gmt(tf_transpose, f"{output}/tf_transpose.gmt")

100%|██████████| 1588/1588 [00:00<00:00, 8334.78it/s]


In [6]:
def read_transpose(path):
  gmt = {}
  print("Reading {}".format(path))
  with open(path, "r") as file:
    for line in tqdm.tqdm(file):
        tf, *signature = line.strip().split()
        gmt[tf] = set(signature)
  return gmt

In [7]:
geo_gmt = read_gmt2("raw_data/human-geo-auto.gmt")

Reading raw_data/human-geo-auto.gmt


171441it [00:36, 4667.17it/s] 


Reading ./build_network_out/tf_transpose.gmt


1588it [00:00, 23971.25it/s]


Store counts in a nested dictionary\
Dimensions: source_tf --> target_tf --> direction (up/down)

In [None]:
high_tfs = list(tf_transpose.keys())

# Create dict
edge_matrix = {source : {target : {
  "+": 0,
  "-": 0
} for target in high_tfs} for source in high_tfs}

# Calculate counts
for source in tqdm.tqdm(high_tfs):
  for signature in tf_transpose[source]:

    spl = signature.rsplit("-", 1)
    dir = "+" if spl[1] == "up" else "-" # "dn"
    joined_sig = " ".join(spl)

    # if the gene set name is in the keys for the rummaGEO GMT
    if joined_sig in geo_gmt.keys():
      
      # for each TF in the list of all TFs
      for target in high_tfs:
        
        if target in geo_gmt[joined_sig]: # TF is in list of DEGs
          edge_matrix[source][target][dir] += 1
   
    else:
      raise Exception("Signature {} not found".format(joined_sig))
    


100%|██████████| 1588/1588 [00:32<00:00, 48.22it/s]


Number below ten: 
	up: 2333187
	dn: 2337917
	with both below ten: 2298226


100%|██████████| 1588/1588 [00:36<00:00, 43.55it/s]


In [32]:
# Count the number of True values in the dictionary
below_ten_up_counts = 0
below_ten_dn_counts = 0

filtered_edge_matrix = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))

# iterate through each source, filtering edges and counting how many are removed
for source, target_list in edge_matrix.items():

    # for each target find the counts
    for target, count_dict in target_list.items():
        for direction,count in count_dict.items():
            # determine which counts are below 10
            if direction == '+':
                if count < 10:
                    below_ten_up_counts += 1
                else:
                    filtered_edge_matrix[source][target][direction] = count
            if direction == '-':
                if count < 10:
                    below_ten_dn_counts += 1
                else:
                    filtered_edge_matrix[source][target][direction] = count


# sum the length of each dictionary for each source in edge-mat, multiply by 2
total_edges = 2 * sum(
    len(targets) for targets in edge_matrix.values() 
)

print(f"Total number of edges: {total_edges}")
print("Number of edges below ten: ")
print(f"\tup: {below_ten_up_counts}")
print(f"\tdn: {below_ten_dn_counts}")
print(f"Sources lost: {len(edge_matrix.keys()) - (len(filtered_edge_matrix.keys()))}")
print(f"Sources remaining: {(len(filtered_edge_matrix.keys()))}")

Total number of edges: 5043488
Number of edges below ten: 
	up: 2333187
	dn: 2337917
Sources lost: 860
Sources remaining: 728


Flatten results into pandas DataFrame object

In [44]:
s = filtered_edge_matrix.keys()
t = set()
[t.update(targs.keys()) for targs in filtered_edge_matrix.values()]

targets = list(t)

hindex = pd.MultiIndex.from_product([s, t, ["+", "-"]],
  names = ["source", "target", "direction"])
df = pd.DataFrame(index = hindex, columns = ["count"])

for source, targets in tqdm.tqdm(filtered_edge_matrix.items()):
  for target, data in targets.items():
    for dir, count in data.items():
      df.loc[(source, target, dir)] = count

  0%|          | 0/728 [00:00<?, ?it/s]

100%|██████████| 728/728 [00:06<00:00, 115.19it/s]


Remove edges with empty counts and save results as a CSV file

In [45]:
df_cleaned = df.dropna(subset=['count'])

df_cleaned.to_csv(f"{output}/edge_list_unfiltered.csv")

## Optional - visualize count distributions
### Generates three log-scale histograms of TF-TF interaction counts: all counts, negative counts, and positive counts

In [1]:
# Plot log-scaled histogram of counts

def tf_histogram(name, counts, num_bins=300, fig_size=(10,6)):
    plt.figure(figsize=fig_size)
    plt.hist(counts, bins = num_bins, edgecolor ='none')
    plt.yscale('log')
    plt.xlabel('Interaction count', fontsize=14)
    plt.ylabel('Frequency', fontsize=14)
    if name == "all":
        plt.title('All TF-TF interaction counts', fontsize=16)
    elif name == "pos":
        plt.title('Positive interaction counts', fontsize=16)
    elif name == "neg": 
        plt.title('Negative interaction counts', fontsize=16)


# get count values
all_counts = df_cleaned['count'].dropna()
# Filter the DataFrame for rows where the direction is "+"
positive_counts = df_cleaned.loc[(slice(None), slice(None), "+"), :]['count']
# Filter the DataFrame for rows where the direction is "+"
negative_counts = (df_cleaned.loc[(slice(None), slice(None), "-"), :])['count']

tf_histogram("all", all_counts)
plt.savefig(f"{output}/img/all_counts_histo.png")
plt.show()

tf_histogram("pos", positive_counts)
plt.savefig(f"{output}/img/pos_counts_histo.png")
plt.show()

tf_histogram("neg", negative_counts)
plt.savefig(f"{output}/img/neg_counts_histo.png")
plt.show()

NameError: name 'df_cleaned' is not defined