In [1]:
import pandas as pd, numpy as np, sqlite3
from scipy.stats import fisher_exact
from statsmodels.stats.multitest import multipletests
import os

db_path = "results.db"
THRESHOLD = 0.35
ADDITIONAL_FILTER = True
FILTER_VAL = 0.25

In [2]:
def get_filtered_data(db_path, table_name, threshold=0.5, apply_additional_filter=False, 
                     low_threshold=0.1, high_threshold=0.9):
    base_query = f"""
    SELECT 
        id,
        wt_prediction,
        mut_prediction,
        pred_difference,
        vcf_id,
        mirna_accession,
        gene_id,
        is_intron,
        mutation_context,
        is_gene_upregulated,
        mutsig,
        gene_name,
        cancer_type
    FROM {table_name}
    WHERE gene_id != 'not_found'
    AND vcf_id != 'PD4120a'
    AND ABS(wt_prediction - mut_prediction) > {threshold}
    """
    
    if apply_additional_filter:
        additional_filter = f"""
        AND (
            (pred_difference < 0 AND mut_prediction < {low_threshold}) OR
            (pred_difference > 0 AND mut_prediction > {high_threshold})
        )
        """
        query = base_query + additional_filter
    else:
        query = base_query

    dtype_dict = {
        'id': 'int32',
        'wt_prediction': 'float32',
        'mut_prediction': 'float32',
        'pred_difference': 'float32',
        'vcf_id': 'category',
        'mirna_accession': 'category',
        'gene_id': 'category',
        'is_intron': 'bool',
        'mutation_context': 'category',
        'is_gene_upregulated': 'bool',
        'mutsig': 'category',
        'gene_name': 'category',
        'cancer_type': 'category'
    }

    with sqlite3.connect(db_path) as conn:
        df = pd.read_sql_query(
            query, 
            conn,
            dtype=dtype_dict
        )

    return df

def get_gene_regulation_counts(db_path, table_name, threshold=0.5, apply_additional_filter=False,
                             low_threshold=0.1, high_threshold=0.9):
    
    base_query = f"""
    SELECT 
        gene_id,
        COUNT(*) FILTER (WHERE is_gene_upregulated = TRUE) as upregulated,
        COUNT(*) FILTER (WHERE is_gene_upregulated = FALSE) as downregulated
    FROM {table_name}
    WHERE vcf_id != 'PD4120a' 
    AND gene_id != 'not_found'
    AND ABS(wt_prediction - mut_prediction) > {threshold}
    """
    
    if apply_additional_filter:
        additional_filter = f"""
        AND (
            (pred_difference < 0 AND mut_prediction < {low_threshold}) OR
            (pred_difference > 0 AND mut_prediction > {high_threshold})
        )
        """
        query = base_query + additional_filter + " GROUP BY gene_id"
    else:
        query = base_query + " GROUP BY gene_id"
    
    with sqlite3.connect(db_path) as conn:
        counts_df = pd.read_sql_query(query, conn)
    
    print(f"Total unique genes: {len(counts_df)}")
    return counts_df

def calculate_log2_odds_ratio(a, b, c, d, k=0.5):

    odds_ratio = ((a + k) * (d + k)) / ((b + k) * (c + k))
    return np.log2(odds_ratio)

def shrink_log2_odds(values, prior_scale=1.0, min_count=10):
    """Implement empirical Bayes shrinkage for log2 odds ratios."""
    total_counts = (values['upregulated_real'] + values['downregulated_real'] +
                   values['upregulated_synth'] + values['downregulated_synth'])
    
    raw_log2_odds = values.apply(lambda row: calculate_log2_odds_ratio(
        row['upregulated_real'], 
        row['downregulated_real'], 
        row['upregulated_synth'], 
        row['downregulated_synth'],
        k=0.5
    ), axis=1)

    weights = 1 - np.exp(-total_counts / min_count)
    prior_mean = np.average(raw_log2_odds, weights=weights)
    prior_var = np.var(raw_log2_odds)
    posterior_scale = prior_scale / (1 + weights * prior_scale)
    
    return weights * raw_log2_odds + (1 - weights) * prior_mean

def fetch_real_data(db_path, THRESHOLD):
    q_real = f"""
    SELECT 
        id,
        wt_prediction,
        mut_prediction,
        pred_difference,
        vcf_id,
        mirna_accession,
        gene_id,
        is_intron,
        mutation_context,
        is_gene_upregulated,
        mutsig,
        gene_name,
        cancer_type
    FROM real
    WHERE gene_id != 'not_found'
    AND vcf_id != 'PD4120a'
    AND ABS(wt_prediction - mut_prediction) > {THRESHOLD}
    """

    with sqlite3.connect(db_path) as conn:
        real = pd.read_sql_query(q_real, conn)
    
    return real

def perform_fisher_test_vectorized(df, pseudocount=0.01, bonf_holm=False):
    # Add pseudocount to the table
    table = np.array([
        [df['upregulated_real'] + pseudocount, df['downregulated_real'] + pseudocount],
        [df['upregulated_synth'] + pseudocount, df['downregulated_synth'] + pseudocount]
    ]).transpose((2, 0, 1))  # reshape for 2x2 tables

    p_values = np.zeros(len(df))

    for i in range(len(df)):
        _, p_values[i] = fisher_exact(table[i])

    df['p_value'] = p_values
    
    if bonf_holm:
        df['p_adj'] = multipletests(p_values, method='holm')[1]
    
    else:
        df['p_adj'] = multipletests(p_values, method='fdr_bh')[1]
    
    return df

def add_z_score(df):
    # Calculate mean and standard deviation of log2 odds ratios
    mean_log2or = df['log2_odds_ratio'].mean()
    std_log2or = df['log2_odds_ratio'].std()
    
    # Calculate Z-score
    df['z_score'] = (df['log2_odds_ratio'] - mean_log2or) / std_log2or
    
    return df



In [3]:
df = get_filtered_data(db_path, "real", THRESHOLD, apply_additional_filter=ADDITIONAL_FILTER, low_threshold=FILTER_VAL, high_threshold=1-FILTER_VAL)

df.head()

Unnamed: 0,id,wt_prediction,mut_prediction,pred_difference,vcf_id,mirna_accession,gene_id,is_intron,mutation_context,is_gene_upregulated,mutsig,gene_name,cancer_type
0,19,0.534925,0.16758,-0.367,PD10010a,MIMAT0019762,ENSG00000172987,True,C[C>T]G,True,SBS1,HPSE2,nnn
1,21,0.559073,0.132445,-0.427,PD10010a,MIMAT0019906,ENSG00000172987,True,C[C>T]G,True,SBS1,HPSE2,nnn
2,37,0.841528,0.234101,-0.607,PD10010a,MIMAT0003257,ENSG00000120029,True,G[C>A]A,True,SBS5,C10orf76,nnn
3,38,0.675805,0.149498,-0.526,PD10010a,MIMAT0004671,ENSG00000120029,True,G[C>A]A,True,SBS5,C10orf76,nnn
4,47,0.542836,0.102328,-0.441,PD10010a,MIMAT0018948,ENSG00000120029,True,G[C>A]A,True,SBS5,C10orf76,nnn


In [3]:
real_counts = get_gene_regulation_counts(db_path, "real", THRESHOLD, apply_additional_filter=ADDITIONAL_FILTER, low_threshold=FILTER_VAL, high_threshold=1-FILTER_VAL)
synth_counts = get_gene_regulation_counts(db_path, "synth", THRESHOLD, apply_additional_filter=ADDITIONAL_FILTER, low_threshold=FILTER_VAL, high_threshold=1-FILTER_VAL)

counts = pd.merge(real_counts, synth_counts, how="inner", on="gene_id", suffixes=["_real", "_synth"])
counts["upregulated_synth"] = counts["upregulated_synth"] / 10
counts["downregulated_synth"] = counts["downregulated_synth"] / 10

counts['log2_odds_ratio'] = counts.apply(lambda row: calculate_log2_odds_ratio(
    row['upregulated_real'], 
    row['downregulated_real'], 
    row['upregulated_synth'], 
    row['downregulated_synth']
), axis=1)

counts['shrunk_log2_odds'] = shrink_log2_odds(counts)
counts = add_z_score(counts)
counts = perform_fisher_test_vectorized(counts, bonf_holm=False)

counts["is_significant"] = counts['p_value'] < 0.05
counts["is_significant_adj"] = counts['p_adj'] < 0.05


Total unique genes: 30384
Total unique genes: 36318


In [4]:
export_path = f"results/last/{THRESHOLD:.2f}".replace("0.", "0")
os.makedirs(export_path, exist_ok=True)
count_sign = len(counts[counts.p_value < 0.05])
count_sign_adj = len(counts[counts.p_adj < 0.05])
filter_string = "FILTER_VAL:.2f".replace(".", "")

if ADDITIONAL_FILTER:
    counts.to_csv(f"{export_path}/counts_sig{count_sign}_adj{count_sign_adj}_filter{FILTER_VAL:.2f}.csv", index=False)

else:
    counts.to_csv(f"{export_path}/counts_sig{count_sign}_adj{count_sign_adj}.csv", index=False)
    
    

In [5]:
counts[counts.p_adj < 0.05]


Unnamed: 0,gene_id,upregulated_real,downregulated_real,upregulated_synth,downregulated_synth,log2_odds_ratio,shrunk_log2_odds,z_score,p_value,p_adj,is_significant,is_significant_adj
4,ENSG00000000460,671,160,510.6,186.6,0.615018,0.615018,0.324390,0.000581,0.032158,True,True
17,ENSG00000001630,112,29,70.2,52.7,1.520854,1.520854,1.024575,0.000160,0.014166,True,True
37,ENSG00000003400,73,9,74.0,37.0,1.961395,1.961395,1.365101,0.000296,0.020947,True,True
51,ENSG00000004487,26,37,114.5,46.1,-1.804130,-1.804130,-1.545541,0.000052,0.006783,True,True
59,ENSG00000004809,50,48,109.4,42.5,-1.295484,-1.295484,-1.152372,0.001128,0.046861,True,True
...,...,...,...,...,...,...,...,...,...,...,...,...
29829,ENSG00000269994,14,22,66.2,19.1,-2.400705,-2.400691,-2.006676,0.000106,0.011064,True,True
30230,ENSG00000272647,36,1,30.0,19.5,3.996053,3.995385,2.937832,0.000057,0.007164,True,True
30318,ENSG00000273081,0,3,25.7,1.9,-6.255815,-5.953892,-4.986565,0.001095,0.046280,True,True
30326,ENSG00000273118,851,447,999.6,379.8,-0.466816,-0.466816,-0.511836,0.000116,0.011573,True,True
