# Filter TF network edges by expected counts
### Describes a regimen to filter human TF interactions in the edge list produced by `build_edge_list.ipynb`
### The second part of this notebook formats the edge list into neo4j assertions

In [1]:
import csv
from collections import defaultdict
import numpy as np
import json
import pandas as pd
import random
import statistics 
import scipy.stats as stats
import tqdm
from itertools import islice
from matplotlib import pyplot as plt

In [2]:
input = './build_network_out'
output = './filter_assertions_out'
raw_data = './raw_data'
assertions_dir = './kg_assertions_for_neo4j'
benchmark_folder = output + '/benchmarking'

Upload edge matrix to a dataframe and format it as a multiindex

In [3]:
initial_file = f"{output}/unfiltered_edge_list.csv"

initial_counts = pd.read_csv(initial_file)

initial_sources = initial_counts['source'].unique()
initial_targets = initial_counts['target'].unique()

initial_counts.set_index(["source", "target", "direction"], inplace=True)

## Test expected counts methods
#### Choose an expected counts method, then proceed to filter results

### Method 1: Shuffle signature sets between source TFs and recalculate counts 

Upload signature sets which are enriched by each TF. Produces a dictionary formatted as `tf:[signature1 ... signatureN]`

In [None]:
benchmark_method = 'signature_shuffling'
N_TRIALS = 50

# read in tf_transpose.gmt and upload it into an edge list
# upload humanedge.csv to a dictionary

signature_file = f"{input}/tf_transpose.gmt"
signature_sets = defaultdict(set)


with open(signature_file, "r") as file:
    for line in tqdm.tqdm(file):
        tf, *signature = line.strip().split()
        signature_sets[tf] = set(signature)
    

# now shuffle sets between terms
def shuffle_signatures(signature_sets):
    tfs = list(signature_sets.keys())
    sigs = list(signature_sets.values())

    random.shuffle(sigs)

    shuffled_sigsets = dict(zip(tfs, sigs))
    
    return shuffled_sigsets

Upload RummaGEO gene sets to produce counts

In [None]:
def read_gmt2(path):
  gmt = {}
  print("Reading {}".format(path))
  with open(path, "r") as file:
    for line in tqdm.tqdm(file):
      # 'id' refers to up or down tag
      signature, id, *tf = line.strip().split()
      gmt[" ".join([signature, id])] = set(tf)

  return gmt

geo_gmt = read_gmt2(f"{raw_data}/human-geo-auto.gmt")

##### Recalculate counts
Produce an edge list using the randomly shuffled signature sets. Repeat several times and store counts. 

Produces a dictionary formatted 
`
{ source: { 
    target: { 
            '+': np.array(up_counts_all_trials)
            '-': np.array(dn_counts_all_trials)
        }
    }
}
`

In [None]:
# Create dict
all_high_tfs = list(signature_sets.keys())

expected_counts = {source : {target : {
    "+": np.zeros(N_TRIALS),
    "-": np.zeros(N_TRIALS),
  } for target in all_high_tfs} for source in all_high_tfs}

for i in tqdm.tqdm(range(N_TRIALS)):
  shuffled_sets = shuffle_signatures(signature_sets)

  # Calculate counts
  for source in all_high_tfs:
    for signature in shuffled_sets[source]:

      spl = signature.rsplit("-", 1)
      dir = "+" if spl[1] == "up" else "-" # "dn"
      joined_sig = " ".join(spl)

      # if the gene set name is in the keys for the rummaGEO GMT
      if joined_sig in geo_gmt.keys():
        
        # for each TF in the list of all TFs
        for target in all_high_tfs:
          
          if target in geo_gmt[joined_sig]: # TF is in list of DEGs
            expected_counts[source][target][dir][i] += 1
    
      else:
        raise Exception("Signature {} not found".format(joined_sig))


### Method 2: Randomly generate new edges by randomly selecting (weighted) sources and targets independently   


Generate a randomly generated edge list by drawing source-target pairs weighted by their occurance in the initial dataset. 

In [None]:
benchmark_method = 'node_weighted'
n_iterations = 50

source_counts = initial_counts.groupby(['source']).sum()
target_counts = initial_counts.groupby(['target']).sum()

source_weights = source_counts['count'].tolist()
target_weights = target_counts['count'].tolist()

num_s_edges = sum(source_weights)
num_t_edges = sum(target_weights)

source_labels = source_counts.index.tolist()
target_labels = target_counts.index.tolist()

# initialize expected counts
expected_counts = {source : {target : {
    "+": np.zeros(n_iterations),
    "-": np.zeros(n_iterations),
  } for target in initial_targets} for source in initial_sources}

for i in tqdm.tqdm(range(n_iterations)):
        # Perform weighted random choice
        random_source = random.choices(source_labels, weights=source_weights, k=num_s_edges)
        random_target = random.choices(target_labels, weights=target_weights, k=num_t_edges)
          
        if len(random_source) != len(random_target):
                print('Source length != target length')
        else:
                for source, target in zip(random_source, random_target):
                        # randomly choose an edge direction
                        direction = random.randint(0,1)
                        if direction == 0:
                                expected_counts[source][target]["+"][i] += 1
                        else:
                                expected_counts[source][target]["-"][i] += 1
                

In [None]:
hindex = pd.MultiIndex.from_product([initial_sources, initial_targets, ["+", "-"]],
  names = ["source", "target", "direction"])
df = pd.DataFrame(index = hindex, columns = ["count"])

iteration_columns = [f'count_{i}' for i in range(n_iterations)]
df = pd.DataFrame(index=hindex, columns=iteration_columns)
# Ensure the DataFrame index is sorted
df.sort_index(inplace=True)

# populate df
for source in tqdm.tqdm(initial_sources):
  for target in initial_targets:
    for dir in ['+','-']:
      df.loc[(source, target, dir),:] = expected_counts[source][target][dir]

# optional - save random counts
# df.to_csv(f'randomcounts_{benchmark_method}.csv', index = False)

### Method 3: Randomly generate an edge list using weighted pairs 

In [4]:
benchmark_method = 'edge_weighted'
n_iterations = 50

def find_random_edges():
        # get the weights
        edge_counts = initial_counts.groupby(['source', 'target']).sum()
        edge_weights = edge_counts['count'].tolist()

        # get the total number of edges counted
        num_edges = sum(edge_weights)

        # get the labels in the same order as their weights
        edge_labels = edge_counts.index.tolist()

        # randomly select the same number of sources as in the original edge list
        random_edges = random.choices(edge_labels, weights=edge_weights, k=num_edges)

        return random_edges

# initialize expected counts
expected_counts = {source : {target : {
    "+": np.zeros(n_iterations),
    "-": np.zeros(n_iterations),
  } for target in initial_targets} for source in initial_sources}

for i in tqdm.tqdm(range(n_iterations)):
        # Perform weighted random choice
        random_edges = find_random_edges()     

        for index in random_edges:
                # extract the source, target and randomly assign a direction
                source, target = index

                # randomly choose an edge direction
                direction = random.randint(0,1)

                if direction == 0:
                        expected_counts[source][target]["+"][i] += 1
                else:
                        expected_counts[source][target]["-"][i] += 1
                

100%|██████████| 50/50 [14:21<00:00, 17.22s/it]


## Filter results using expected counts

Calculate edge statistics - observed counts, expected counts, mean, stdev, z-score, and p-value

In [7]:
# make a similar array to store stats
hindex = pd.MultiIndex.from_product([initial_sources, initial_targets,['+', '-']],
  names = ["source", "target", "relation"])
edge_statistics = pd.DataFrame(index = hindex, columns = ["observed", "expected", "expected stdev", "z-score", "p-value"])


for source in tqdm.tqdm(initial_sources):
  for target in initial_targets:
    for dir in ['+','-']:
      trial_counts = expected_counts[source][target][dir]

      # find expected and observed counts -- if the key doesn't exist, then the observed counts are zero
      try:
        obsv_counts = initial_counts.loc[(source, target, dir), 'count']
      except KeyError:
        obsv_counts = 0

      # calculate statistics 
      mean = statistics.mean(trial_counts)
      stdev = statistics.stdev(trial_counts)

      # ignore values with no stdev and obsv counts == 0
      if stdev != 0:
        z_score = (obsv_counts - mean) / (stdev)
        p_value = 1 - stats.norm.cdf(z_score)

        # store in dataframe, removing non-existent edges
        if obsv_counts > 0:
          edge_statistics.loc[(source, target, dir)] = [obsv_counts, mean, stdev, z_score, p_value]


100%|██████████| 728/728 [02:12<00:00,  5.51it/s]


Remove insignificant edges.\
*Z-score is the chosen method because it allows for finer filtering. Uncomment p-vaue as desired.*

In [77]:
# drop NAs
edge_statistics.dropna(inplace=True)

z_sorted = edge_statistics.sort_values(by='z-score', ascending = False)
# p_sorted = edge_statistics.sort_values(by='p-value')

Z_MIN = 2.325
# P_MIN = 1e-12

# OR - optional - read in from file: 
# benchmark_method = 'signature_shuffling'
# edge_statistics = pd.read_csv(f'/Users/anna/Projects/KG_UI/build_TF_network/filter_assertions_out/benchmarking/{benchmark_method}/{benchmark_method}_z_sorted_edge_stats.csv', delimiter = '\t')

edge_statistics.set_index(['source', 'target','relation'], inplace=True)

significant_edges = edge_statistics.loc[((edge_statistics['z-score'] != float('Inf')) & (edge_statistics['z-score'] > Z_MIN))]


For pairs that have significant edges in both directions, keep only the edge with the most significance.

In [78]:
# work from a copy since we're removing entries directly
significant_edges_copy = significant_edges.copy()
direction_edges_to_drop = []

# search source-target pairs and retain only the edge with the highest significance 
for (source, target), group in significant_edges_copy.groupby(level=['source', 'target']):

    directions = group.index.get_level_values('relation')
    up_exists = '+' in directions
    dn_exists = '-' in directions

    # remove either up or down if both are significant
    if up_exists and dn_exists:
        # Filter to get '+' and '-' entries
        up_data = group.loc[(slice(None), slice(None), '+'), :]
        dn_data = group.loc[(slice(None), slice(None), '-'), :]   

        # remove lower z-score
        if up_data['z-score'].values[0] < dn_data['z-score'].values[0]:
          direction_edges_to_drop.append((source, target, '+')) 
        else:
          direction_edges_to_drop.append((source, target, '-'))

significant_edges_copy.drop(direction_edges_to_drop, inplace=True)


In [79]:
# keep only the top 3 most significant edges from each source 
max_src_edges = 3
significant_edges_copy.sort_values(by=['source','z-score'], inplace=True)
final_edge_list = significant_edges_copy.groupby(level='source').head(max_src_edges)

#### OPTIONAL: saved output and histograms

Save collected statistics

In [None]:
z_sorted.to_csv(f"{benchmark_folder}/{benchmark_method}_z_sorted_edge_stats.csv", sep='\t')
# p_sorted.to_csv(f"{output}/p_sorted_edge_statistics", sep='\t')

Save edge list

In [80]:
final_edge_list.to_csv(f"{benchmark_folder}/{benchmark_method}/edge_list_filtered.csv")

Plot histogram of edge counts

In [None]:
# Plot log-scaled histogram of counts

def tf_histogram(name, counts, num_bins=300, fig_size=(10,6)):
    plt.figure(figsize=fig_size)
    plt.hist(counts, bins = num_bins, edgecolor ='none')
    plt.yscale('log')
    plt.xlabel('Interaction count', fontsize=14)
    plt.ylabel('Frequency', fontsize=14)
    if name == "all":
        plt.title('All TF-TF interaction counts', fontsize=16)
    elif name == "pos":
        plt.title('Positive interaction counts', fontsize=16)
    elif name == "neg": 
        plt.title('Negative interaction counts', fontsize=16)


# get count values
all_counts = final_edge_list['observed']
# Filter the DataFrame for rows where the direction is "+"
positive_counts = final_edge_list.loc[(slice(None), slice(None), "+"), :]['observed']
# Filter the DataFrame for rows where the direction is "+"
negative_counts = final_edge_list.loc[(slice(None), slice(None), "-"), :]['observed']

tf_histogram("all", all_counts)
plt.savefig(f"{output}/img/all_histo_filtered_weighted_shuffling.png")
plt.show()

tf_histogram("pos", positive_counts)
plt.savefig(f"{output}/img/pos_histo_filtered_weighted_shuffling.png")
plt.show()

tf_histogram("neg", negative_counts)
plt.savefig(f"{output}/img/neg_histo_filtered_weighted_shuffling.png")
plt.show()

## Format data for UI ingestion
### Format nodes: [id,label]


In [81]:
# Set the label used to describe the node type
node_name = "Transcription Factor"

# collect all unique source and target nodes
nodes = set() 
for (source, target), group in final_edge_list.groupby(level=['source', 'target']):
    nodes.add((source, source))
    nodes.add((target, target))


# convert to df for easy csv formatting
node_df = pd.DataFrame(list(nodes), columns=['id', 'label'])
node_df.to_csv(f'{benchmark_folder}/{benchmark_method}/{node_name}.nodes.csv', index = False)

### Format edges: [source,relation,target]

In [82]:
# reorder the index to match ingestion format
new_index = ['source','relation','target']
index_frame = final_edge_list.index.to_frame()
index_frame = index_frame[new_index]

# rename relations to be more descriptive
relation_rename = {
    '+': 'upregulates',
    '-': 'downregulates'
}
index_frame['relation'] = index_frame['relation'].replace(relation_rename)

# split the edge list based on relation type and save to two files
relation_types = index_frame['relation'].unique()

for relation in relation_types:
    filtered_df = index_frame[index_frame['relation'] == relation]
    file_name = f"{benchmark_folder}/{benchmark_method}/{node_name}.{relation}.{node_name}.edges.csv"
    filtered_df.to_csv(file_name, index=False)