# Building the TF-TF network edge list
This is the first of three notebooks required to build a TF-TF interaction network compatible with the KG UI. Here, we build an initial edge list of TF-TF interactions by performing ChEA3 transcription factor enrichment analysis (TFEA) on RummaGEO gene sets. 

After finishing this notebook, users should filter the network using `filter_assertions.ipynb`. 
## Set-Up 
First, import required packages. 

In [4]:
import pyenrichr as py
import os
import numpy as np
import pandas as pd
import tqdm
import warnings
from matplotlib import pyplot as plt
from collections import defaultdict

Now define our raw data and output directories, and create them if they don't already exist: 
1. `raw_data` holds all raw data files. We'll also add one subdirectory for the TFEA libraries, `chea3libs`
2. `edge_constructing_files` will store our output files. 

In [5]:
raw_data = './raw_data'
output = './edge_constructing_files'

if not os.path.exists(raw_data):
    os.mkdir(raw_data)
    os.mkdir(f'{raw_data}/chea3libs')
if not os.path.exists(output):
    os.mkdir(output)

## TFEA using RummaGEO sets
### Downloading required input data
TF enrichment analysis for each of the 171k RummaGEO gene sets is performed using a local version of ChEA3.
This requires files from two locations:
1. RummaGEO:
- Navigate to https://rummageo.com/, then click on downloads. Click on `human-geo-auto.gmt.gz` to download all human gene sets.
- Unzip the file and move it to `raw_data`
2. ChEA3: 
- Navigate to https://maayanlab.cloud/chea3/. Click on 'downloads' in the top right
- Download the following six files (all of the files with file type `primary`) by clicking on their names:
    1. ARCHS4_Coexpression.gmt
    2. ENCODE_ChIP-seq.gmt
    3. Enrichr_Queries.gmt
    4. GTEx_Coexpression.gmt
    5. Literature_ChIP-seq.gmt
    6. ReMap_ChIP-seq.gmt
- Move all six files to `raw_data/chea3libs`
 

### Building the mean ranks matrix
Define functions for uploading a gmt file as a dictionary, and to divide a large dictionary into N smaller subdictionaries (this will help us process the RummaGEO file). 

In [11]:
# read a GMT file into a dictionary: this is formatted depending on the gmt being uploaded, specified by 'mode' 
def read_gmt_file(path, mode):
  gmt = {}
  print("Reading {}".format(path))
  with open(path, "r") as file:
    # Read a tab-delimited GMT formatted as "term gene1 gene2 ... geneN" and check for set size
    if mode == 'raw':
      for line in tqdm.tqdm(file):
        parts = line.strip().split('\t')
        if len(parts) > 5:
          gmt[parts[0].strip()] = set(parts[1:])

    # Read a GMT file formatted as "signature id tf1 tf2 tf3 ... tf10"
    elif mode == 'signed':
      for line in tqdm.tqdm(file):
        signature, id, *tf = line.strip().split()
        gmt[" ".join([signature, id])] = set(tf)

    # similar to mode = 'raw'
    elif mode == 'transpose':
      for line in tqdm.tqdm(file):
        tf, *signature = line.strip().split()
        gmt[tf] = set(signature)
           
  return gmt

# divides a dictionary into N subdictionaries
def split_dict(original_dict, N):
    iter_dict = iter(original_dict.items())
    list_of_dicts = []
    
    while True:
        small_dict = {}
        
        try:
            for _ in range(N):
                key, value = next(iter_dict)
                small_dict[key] = value
        except StopIteration:
            if small_dict:
                list_of_dicts.append(small_dict)
            break

        list_of_dicts.append(small_dict)
    
    return list_of_dicts

Import RummaGEO data and split into subdictionaries, to help with processing.

In [4]:
rummageo = f"{raw_data}/human-geo-auto.gmt"
n_subdicts = 1000
geo = read_gmt_file(rummageo, mode='raw')
geo_split = split_dict(geo, n_subdicts)

Reading ./raw_data/human-geo-auto.gmt


171441it [00:29, 5889.63it/s] 


For each gene set, rank TFs based on enrichment and record the mean rank across ChEA3 libraries.\
**This step can take several hours.**

In [5]:
# determine the mean rank of a TF across all ChEA3 libraries
def mean_rank(results):
    if results:
        # Extract gene set names from the first result
        sigs = list(results[next(iter(results))].columns)

    # Extract unique transcription factor (tf) names from all indices
    tfs = list(set(tf for result in results.values() for tf in result.index))

    # Initialize arrays for scores and counts
    tf_scores = np.zeros((len(tfs), len(sigs)))
    tf_counts = np.zeros((len(tfs), len(sigs)))

    # Create a mapping from transcription factor names to indices
    tf_index_map = {t: idx for idx, t in enumerate(tfs)}

    # Aggregate scores and counts
    for l, result in results.items():
        for t in result.index:
            temp = result.loc[t]
            idx = tf_index_map[t]
            if len(temp.shape) == 1:
                tf_scores[idx, :] += temp.values
                tf_counts[idx, :] += 1
            else:
                for ii in range(temp.shape[0]):
                    tf_scores[idx, :] += temp.iloc[ii, :].values
                    tf_counts[idx, :] += 1

    # Calculate the mean scores by dividing tf_scores by tf_counts element-wise
    mean_scores = pd.DataFrame(np.round(np.divide(tf_scores, tf_counts, out=np.zeros_like(tf_scores), where=tf_counts != 0)), index=tfs, columns=sigs)

    return mean_scores

In [6]:
warnings.simplefilter("ignore")

fisher = py.enrichment.FastFisher(34000)

# Pre-read the GMT files and store the results in a dictionary
libraries = {}
for lib in os.listdir(f"{raw_data}/chea3libs"):
    libraries[lib] = read_gmt_file(f"{raw_data}/chea3libs/{lib}", mode='raw')

# Process each item in geo_split
mranks = []

for i in tqdm.tqdm(range(len(geo_split))):
    results = {}
    for lib_name, lib_gmt in libraries.items():
        # Use the pre-read library data
        res = py.enrichment.fisher(geo_split[i], lib_gmt, min_set_size=10, verbose=False, fisher=fisher)
        temp = py.enrichment.consolidate(res).rank(axis=0)
        temp.index = [x.split("_")[0] for x in temp.index]
        results[lib_name] = temp
    
    mr = mean_rank(results)
    mranks.append(mr)

Reading ./raw_data/chea3libs/ARCHS4_Coexpression.gmt


1628it [00:00, 35693.78it/s]


Reading ./raw_data/chea3libs/ReMap_ChIP-seq.gmt


297it [00:00, 4728.17it/s]


Reading ./raw_data/chea3libs/Enrichr_Queries.gmt


1404it [00:00, 26203.79it/s]


Reading ./raw_data/chea3libs/GTEx_Coexpression.gmt


1607it [00:00, 21202.54it/s]


Reading ./raw_data/chea3libs/ENCODE_ChIP-seq.gmt


552it [00:00, 7549.76it/s]


Reading ./raw_data/chea3libs/Literature_ChIP-seq.gmt


307it [00:00, 5293.66it/s]
100%|██████████| 172/172 [6:23:24<00:00, 133.75s/it]  


Save the mean ranks matrix to a file. 

In [7]:
mranks = pd.concat(mranks, axis=1)
mranks.to_csv(f"{output}/mean_ranks_matrix.csv")

### Preliminary pruning
Working from the mean ranks matrix, we perform two preliminary pruning steps:

1. Remove all but the **top 10** most highly ranked TFs for each signature, which are most likely to be regulators of the genes in their respective signatures


In [1]:
# upload TF rank matrix
def preprocess(path):
  matrix = pd.read_csv(path)
  matrix.index = matrix.iloc[:, 0]
  matrix = matrix.drop(columns = matrix.columns[0])
  matrix = matrix.astype(np.int64)
  return matrix

# produces a library of signatures and the most highly enriched TFs (rank <= 10)
def filter_by_rank(matrix, rank_method = "min", threshold = 15):

  max_rank = 10 

  results = {Signature : None for Signature in matrix.columns}

  for Signature in matrix.columns:
    # this is ranking the index based on the TF rankings, so then the index is used to access the TF name
    rank_index = matrix[Signature].rank(method = rank_method).astype(int)
    rank_top_tfs = [rank <= max_rank for rank in rank_index]

    # Remove outliers above specified threshold
    if len(matrix.index[rank_top_tfs]) <= threshold:
      results[Signature] = matrix.index[rank_top_tfs].to_list()
    else:
      del results[Signature]
  
  # returns a matrix of gene set names x top 30 transcription factors (basically transposed GMT)
  return results

In [6]:
mean_ranks_matrix = preprocess(f"{output}/mean_ranks_matrix.csv")
high_rank_matrix = filter_by_rank(mean_ranks_matrix)

2. Remove any GSE studies without clear **control and perturbation** groups, which ambiguates the sign of the gene set


In [7]:
gses_to_keep_f = f"{raw_data}/single_perturbation_gses.txt"
gses_to_keep = []
filtered_matrix = {} # make a copy to be safe 

with open(gses_to_keep_f, 'r') as file:
    gses_to_keep = [line.strip() for line in file]

for sig in high_rank_matrix.keys():
    gse_tag = sig.split("-")[0].strip()
    if gse_tag in gses_to_keep:
        filtered_matrix[sig] = high_rank_matrix[sig]

print(len(filtered_matrix))

29328


Save output as `.gmt` file

In [8]:
def write_gmt(path, library, transpose=False):
  with open(path, "w") as file:
    for term, items in tqdm.tqdm(library.items()):
      file.write(f"{term}\t\t")
      for item in items:
        if transpose:
          signature_join = "-".join(item.split())
          file.write(f"{signature_join}\t")
        else:
          file.write(f"{item}\t")
      file.write("\n")

write_gmt(f"{output}/filtered_tfsets.gmt", filtered_matrix)

100%|██████████| 29328/29328 [00:00<00:00, 687225.12it/s]


## Creating the TF-TF edge list 
### Count the number of times each TF-TF interaction occurs.
* Create the transpose of the GMT  where each line is a TF followed by the gene sets in which it is ranked (`transpose_human.gmt`)
* For a given highly ranked TF, find all sets where it is present. Mark a source-target edge with any TFs present in the same set.\
* **Directionality** of the edge is determined by the sign (up/down) of the gene set where the relationship occurs. 

Create the transpose matrix of a GMT.

In [9]:
def gmt_transpose(library):
  tfs = list(set([tf for sublist in library.values() for tf in sublist]))
  transpose = {}
  for tf in tqdm.tqdm(tfs):
    collect = []
    for signature in library:
      if tf in library[signature]:
        collect.append(signature)
    transpose[tf] = set(collect)
  return dict(transpose)

In [12]:
tf_sets = read_gmt_file(f"{output}/filtered_tfsets.gmt", mode='signed')
tf_transpose = gmt_transpose(tf_sets)

Reading ./edge_constructing_files/filtered_tfsets.gmt


29328it [00:00, 242532.51it/s]
100%|██████████| 1588/1588 [00:02<00:00, 656.06it/s]


Save file

In [13]:
write_gmt(f"{output}/tf_transpose.gmt", tf_transpose, transpose=True)

100%|██████████| 1588/1588 [00:00<00:00, 11181.06it/s]


Store counts in a nested dictionary\
Dimensions: source_tf --> target_tf --> direction (up/down)

In [14]:
geo_gmt = read_gmt_file("raw_data/human-geo-auto.gmt", mode='raw')

Reading raw_data/human-geo-auto.gmt


171441it [00:19, 8748.74it/s] 


In [15]:
high_rank_tfs = list(tf_transpose.keys())

# Create dict
edge_matrix = {source : {target : {
  "+": 0,
  "-": 0
} for target in high_rank_tfs} for source in high_rank_tfs}

# Calculate counts
for source in tqdm.tqdm(high_rank_tfs):
  for signature in tf_transpose[source]:

    spl = signature.rsplit(" ", 1)
    dir = "+" if spl[1] == "up" else "-" # "dn"
    joined_sig = " ".join(spl)

    # if the gene set name is in the keys for the rummaGEO GMT
    if joined_sig in geo_gmt.keys():
      
      # for each TF in the list of all TFs
      for target in high_rank_tfs:
        
        if target in geo_gmt[joined_sig]: # TF is in list of DEGs
          edge_matrix[source][target][dir] += 1
   
    else:
      raise Exception("Signature {} not found".format(joined_sig))

100%|██████████| 1588/1588 [00:36<00:00, 44.04it/s]


Flatten results into pandas DataFrame object. Only save edges with count > 10.

In [16]:
s = edge_matrix.keys()
t = {target for targets in edge_matrix.values() for target in targets.keys()}
targets = list(t)

hindex = pd.MultiIndex.from_product([s, t, ["+", "-"]],
  names = ["source", "target", "direction"])

filtered_mat = []

for source, targets in tqdm.tqdm(edge_matrix.items()):
  for target, data in targets.items():
    for direction, count in data.items():
      if count >= 10:
        filtered_mat.append((source, target, direction, count))

df = pd.DataFrame(filtered_mat, columns =['source', 'target', 'direction', 'count'])
df.set_index(['source', 'target', 'direction'], inplace=True)

100%|██████████| 1588/1588 [00:00<00:00, 2444.82it/s]


Save results as a CSV file

In [17]:
df.to_csv(f"{output}/edge_list_unfiltered.csv")

## Optional - visualize count distributions
### Generates three log-scale histograms of TF-TF interaction counts: all counts, negative counts, and positive counts

In [None]:
# Plot log-scaled histogram of counts

def tf_histogram(name, counts, num_bins=300, fig_size=(10,6)):
    plt.figure(figsize=fig_size)
    plt.hist(counts, bins = num_bins, edgecolor ='none')
    plt.yscale('log')
    plt.xlabel('Interaction count', fontsize=14)
    plt.ylabel('Frequency', fontsize=14)
    if name == "all":
        plt.title('All TF-TF interaction counts', fontsize=16)
    elif name == "pos":
        plt.title('Positive interaction counts', fontsize=16)
    elif name == "neg": 
        plt.title('Negative interaction counts', fontsize=16)


# get count values
all_counts = df['count'].dropna()
# Filter the DataFrame for rows where the direction is "+"
positive_counts = df.loc[(slice(None), slice(None), "+"), :]['count']
# Filter the DataFrame for rows where the direction is "+"
negative_counts = (df.loc[(slice(None), slice(None), "-"), :])['count']

tf_histogram("all", all_counts)
plt.savefig(f"{output}/img/all_counts_histo.png")
plt.show()

tf_histogram("pos", positive_counts)
plt.savefig(f"{output}/img/pos_counts_histo.png")
plt.show()

tf_histogram("neg", negative_counts)
plt.savefig(f"{output}/img/neg_counts_histo.png")
plt.show()