- Author: Peter Riesebos
- Purpose: (outdated code) Script used to prepare the eQTL sum stat and GWAS sum stat data for colocalization
- Input: eQTL and gwas summary statistics
- Output: table to be used for colocalization

## Library and file imports

In [None]:
import pandas as pd
import dask.dataframe as dd
import json
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed

In [None]:
gwas = pd.read_csv("/groups/umcg-fg/tmp04/projects/gut-bulk/ongoing/2024-02-07-GutPublicRNASeq/datasets/GWAS/GCST90292538.h.tsv.gz", sep='\t')
sum = pd.read_csv("/groups/umcg-fg/tmp04/projects/gut-bulk/ongoing/2024-02-07-GutPublicRNASeq/datasets/combined/mbqtl_output_combined_exp_fixed/merged_topeffects_final.txt", sep='\t')

In [None]:
gwas.columns

# Approach

1. Maak een lijst met genen met een significant eQTL
2. Maak een lijst met varianten die in de eQTL studie getest zijn en die binnen 200kb van de tss liggen van genen uit stap 1
3. filter de eQTL all effects files op de genen en varianten uit stap 1 en 2.
4. Filter de GWAS sumstats op basis van de lijst varianten uit stap 2
5. Voor elk significant gen: bepaal of binnen 200kb van de tss (TSS positie staat in de eqtl output file als GenePos oid) een gwas snp is met p<5e-8, die ook getest is in de eQTL analyse (e.g. moet voorkomen in de lijst van stap 2).
6. Voor elk gen waar dit het geval is:
- Pak alle varianten binnen 200kb van de tss (upstream en downstream)  uit de eQTL all effects sumstats
- pak de overlappende varianten uit de GWAS sumstats
- run coloc

Sla de volgende waardes op per gen waarvoor je coloc hebt gerunned, maak een tabel:
gennaam, gensymbol, PP3, PP4, top GWAS SNP + pvalue,  top eQTL SNP + p-value, afstand tussen beide snps, afstand van beide top snps tot de TSS van het gen 

In [None]:
#step 1: create a list of genes with a significant eQTL

list1 = sum[["Gene", "GenePos", "SNP"]]

signif_eqtl_genes = sum['Gene'].tolist()

In [None]:
len(signif_eqtl_genes)

In [None]:
# # step 2: make a list with variants that are both tested in the eQTL study and are within 200kb from the tss from step 1

# # Calculate the absolute distance between GenePos (TSS) and SNPPos
# sum["distance_to_tss"] = (sum["GenePos"] - sum["SNPPos"]).abs()

# # Filter for SNPs within 200 kb (200,000 base pairs)
# within_200kb = sum[sum["distance_to_tss"] < 200_000]

# # Extract the SNPs that meet the criteria into a list
# snp_within_200kb = within_200kb["SNP"].tolist()
# print(len(snp_within_200kb))

In [None]:
# step 2 revised: filter the 22 eQTL AllEffects

# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

# Define list1 of genes to filter on
gene_set = set(list1['Gene'])  # Convert to a set for faster lookup

# Paths for the 22 files
file_paths = [f"/groups/umcg-fg/tmp04/projects/gut-bulk/ongoing/2024-02-07-GutPublicRNASeq/datasets/combined/mbqtl_output_combined_exp_fixed/combined_chr{num}-AllEffects.txt.gz" for num in range(1, 23)]

# Function to filter a single file
def filter_file(file_path):
    file_name = file_path.split("/")[-1]
    logging.info(f"Starting to process {file_name}")
    df = pd.read_csv(file_path, sep="\t", compression="gzip")
    filtered_df = df[df['Gene'].isin(gene_set)]
    logging.info(f"Finished processing {file_name} with {len(filtered_df)} matching rows")
    return filtered_df

# Process files in parallel
filtered_results = []
with concurrent.futures.ProcessPoolExecutor() as executor:
    # Map the filter function to each file path
    results = executor.map(filt er_file, file_paths)
    # Collect results
    filtered_results = list(results)

# Optional: Combine results into a single DataFrame
combined_filtered_df = pd.concat(filtered_results, ignore_index=True)
combined_filtered_df

In [None]:
# # save / load in output_file from step 2:
# output_file = "AllEffectsFiltered.csv"
# combined_filtered_df.to_csv(output_file, index=False)
combined_filtered_df = pd.read_csv("AllEffectsFiltered.csv", sep=',')

In [None]:
combined_filtered_df

In [None]:
# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

# Create a dictionary to store results
rsid_dict = {}

# Define the window size in base pairs
window_size = 200_000

# Function to process each gene in parallel
def process_gene(row):
    gene = row['Gene']
    gene_pos = row['GenePos']

    # Log progress for each Gene, GenePos
    logging.info(f"Processing Gene: {gene}, GenePos: {gene_pos}")

    # Filter combined_filtered_df for the current Gene and 200kb window around GenePos
    filtered_rows = combined_filtered_df[
        (combined_filtered_df['Gene'] == gene) &
        (combined_filtered_df['SNPPos'] >= gene_pos - window_size) &
        (combined_filtered_df['SNPPos'] <= gene_pos + window_size)
    ]

    # Extract the rsids (SNP column) within this window and convert to a list
    rsid_list = filtered_rows['SNP'].tolist()
    
    return (gene, gene_pos), rsid_list

# Using ThreadPoolExecutor to run the tasks in parallel
with ThreadPoolExecutor() as executor:
    futures = {executor.submit(process_gene, row): row for index, row in list1.iterrows()}
    
    # Collect results as they are completed
    for future in as_completed(futures):
        (gene, gene_pos), rsid_list = future.result()
        rsid_dict[(gene, gene_pos)] = rsid_list

# Log final summary
logging.info(f"Completed processing {len(list1)} genes.")

In [None]:
# # Convert tuple keys to comma-separated strings
# json_compatible_dict = {f"{k[0]},{k[1]}": v for k, v in rsid_dict.items()}

# # Save to a JSON file
# with open('gene_snp_rsid_mapping.json', 'w') as file:
#     json.dump(json_compatible_dict, file)

# Load from the JSON file and convert keys back to tuples
with open('gene_snp_rsid_mapping.json', 'r') as file:
    loaded_dict = json.load(file)
    loaded_dict = {tuple(k.split(',')): v for k, v in loaded_dict.items()}
    print(loaded_dict)

In [None]:
len(loaded_dict)

In [None]:
# Combine all values into one list and remove duplicates
all_values = list(set(value for values in loaded_dict.values() for value in values))

In [None]:
all_values

In [None]:
# step 3: filter the 22 eQTL AllEffects files on the genes and variants from step 1 and 2

# Paths for the 22 files
file_paths = [f"/groups/umcg-fg/tmp04/projects/gut-bulk/ongoing/2024-02-07-GutPublicRNASeq/datasets/combined/mbqtl_output_combined_exp_fixed/combined_chr{num}-AllEffects.txt.gz" for num in range(1, 23)]

# Create a Dask DataFrame by loading all files at once
df = dd.read_csv(file_paths, sep='\t', compression='gzip')

# Filter rows where the Gene is in signif_eqtl_genes and SNP is in snp_within_200kb
filtered_df = df[(df['Gene'].isin(signif_eqtl_genes)) & (df['SNP'].isin(signif_eqtl_genes))]

# Compute the filtered result to get a pandas DataFrame
final_output = filtered_df.compute()

# Display the result
final_output

In [None]:
# uncomment when needed:
# save step 3 data
# final_output.to_csv("filtered_eqtl_results.tsv.gz", sep='\t', index=False, compression='gzip')

# load in data to skip step 3
final_output = pd.read_csv("filtered_eqtl_results.tsv.gz", sep='\t', compression='gzip')

In [None]:
# step 4: filter the GWAS sum stats on the list of step 2
filtered_gwas = gwas[gwas['rsid'].isin(signif_eqtl_genes)]

In [None]:
filtered_gwas.base_pair_location.size

In [None]:
# Step 5: 
significant_gwas_snps = filtered_gwas[filtered_gwas['p_value'] < 5e-8]

# Initialize a list to store results
results = []

# Define the window size in base pairs
window_size = 200000  # 200 kb

# Iterate over each gene in `sum`
for _, gene_row in sum.iterrows():
    gene = gene_row['Gene']
    gene_pos = gene_row['GenePos']
    
    # Find SNPs within 200kb of the gene's position
    nearby_snps = significant_gwas_snps[
        (significant_gwas_snps['base_pair_location'] >= gene_pos - window_size) &
        (significant_gwas_snps['base_pair_location'] <= gene_pos + window_size)
    ]
    
    # If we find any nearby SNPs for this gene, store the results
    if not nearby_snps.empty:
        for _, snp_row in nearby_snps.iterrows():
            # Combine gene information with the full SNP row
            result_row = gene_row.to_dict()  # Convert gene row to a dictionary
            result_row.update(snp_row.to_dict())  # Update with SNP row
            results.append(result_row)

# Convert results to a DataFrame with all columns from `sum` and `significant_gwas_snps`
merged_results = pd.DataFrame(results)
merged_results

In [None]:
# step 6a: for each gene from step 5, select all variants within 200kb from the TSS (upstream and downstream) from the eQTL all effect sum stats (step 3)

filtered_results = merged_results.drop_duplicates(subset=["Gene", "GenePos", "SNP"])

# Dictionary to store output with (Gene, GenePos) as keys
output_dict = {}

# Loop over each row in filtered_results
for _, row in filtered_results.iterrows():
    gene = row["Gene"]
    gene_pos = row["GenePos"]

    # Filter final_output based on matching Gene and GenePos within 200kb range
    matches = final_output[
        (final_output["Gene"] == gene) &
        (final_output["GenePos"].between(gene_pos - 200000, gene_pos + 200000))
    ]
    
    # Store the result in the dictionary with (Gene, GenePos) as the key
    output_dict[(gene, gene_pos)] = matches

In [None]:
# Extract unique Gene and GenePos combinations
gene_pos_combinations = filtered_results[['Gene', 'GenePos']].drop_duplicates()

# Convert to a list of dictionaries
gene_pos_list = gene_pos_combinations.to_dict(orient='records')

# Save the list as a JSON file
with open('gene_pos_combinations.json', 'w') as f:
    json.dump(gene_pos_list, f, indent=4)

# Optionally, print the first few combinations to check
print(gene_pos_list[:5])

In [None]:
# step 6b: select the overlapping variants from the GWAS sum stats
gwas_overlap_dict = {}

# Loop over each (Gene, GenePos) key in output_dict
for (gene, gene_pos), matches_df in output_dict.items():
    
    # Get the unique SNPs from the matches DataFrame
    snps = matches_df["SNP"].unique()
    
    # Filter full_gwas based on matching rsid
    gwas_matches = gwas[gwas["rsid"].isin(snps)]
    
    # Store the result in gwas_overlap_dict with (Gene, GenePos) as the key
    gwas_overlap_dict[(gene, gene_pos)] = gwas_matches

In [None]:
# # Save dictionaries to a file
# with open('output_dict.pkl', 'wb') as f:
#     pickle.dump(output_dict, f)

# with open('gwas_overlap_dict.pkl', 'wb') as f:
#     pickle.dump(gwas_overlap_dict, f)

In [None]:
# # load in dictonaries
# with open('output_dict.pkl', 'rb') as f:
#     output_dict = pickle.load(f)

# with open('gwas_overlap_dict.pkl', 'rb') as f:
#     gwas_overlap_dict = pickle.load(f)

In [None]:
# quick data inspection test
gene = "ENSG00000165171"
gene_pos = 73834590

# Access the DataFrame from output_dict
result_df = output_dict.get((gene, gene_pos), None)
result_df

In [None]:
gwas_result_df = gwas_overlap_dict.get((gene, gene_pos), None)
gwas_result_df

# Old approach

## Data manipulation and inspection

In [None]:
gwas_df = gwas.copy()
sum_df = sum.copy()

In [None]:
gwas_df = gwas_df[gwas_df["p_value"] < 5e-8]
gwas_df.chromosome.size

In [None]:
gwas_df = gwas_df.drop_duplicates(subset='rsid', keep='first')
gwas_df

In [None]:
sum_df.Gene.size

In [None]:
sum_df = sum_df.drop_duplicates(subset='SNP', keep='first')
sum_df.Gene.size

In [None]:
gwas_df["varbeta"] = gwas_df["standard_error"] ** 2

In [None]:
sum_df["varbeta"] = sum_df["MetaSE"] ** 2
sum_df.head(3)

In [None]:
gwas_df.columns

## Export adjusted sum stats

In [None]:
gwas_df.to_csv("/groups/umcg-fg/tmp04/projects/gut-bulk/ongoing/2024-02-07-GutPublicRNASeq/extra_scripts/coloc/gwas.tsv", sep="\t", header=True, index=False)
sum_df.to_csv("/groups/umcg-fg/tmp04/projects/gut-bulk/ongoing/2024-02-07-GutPublicRNASeq/extra_scripts/coloc/sum.tsv", sep="\t", header=True, index=False)