In [1]:
import pandas as pd, numpy as np, sqlite3
from scipy.stats import fisher_exact
from statsmodels.stats.multitest import multipletests


db_path = "results.db"
# Using with statement ensures proper connection handling
with sqlite3.connect(db_path) as conn:
    # Basic head with default 5 rows
    real = pd.read_sql_query("SELECT * FROM real LIMIT 10", conn)
    synth = pd.read_sql_query("SELECT * FROM synth LIMIT 10", conn)

In [None]:
def get_filtered_data(db_path, table_name, threshold=0.5):
    query = f"""
    SELECT 
        id,
        wt_prediction,
        mut_prediction,
        pred_difference,
        vcf_id,
        mirna_accession,
        gene_id,
        is_intron,
        mutation_context,
        is_gene_upregulated,
        mutsig,
        gene_name,
        cancer_type
    FROM {table_name}
    WHERE gene_id != 'not_found'
    AND vcf_id != 'PD4120a'
    AND ABS(wt_prediction - mut_prediction) > {threshold}
    """

    dtype_dict = {
        'id': 'int32',
        'wt_prediction': 'float32',
        'mut_prediction': 'float32',
        'pred_difference': 'float32',
        'vcf_id': 'category',
        'mirna_accession': 'category',
        'gene_id': 'category',
        'is_intron': 'bool',
        'mutation_context': 'category',
        'is_gene_upregulated': 'bool',
        'mutsig': 'category',
        'gene_name': 'category',
        'cancer_type': 'category'
    }

    with sqlite3.connect(db_path) as conn:
        df = pd.read_sql_query(
            query, 
            conn,
            dtype=dtype_dict
        )

    print(f"Total rows: {len(df)}")
    print(f"Memory usage: {df.memory_usage().sum() / 1024**2:.2f} MB")
    
    return df

def get_gene_regulation_counts(db_path, table_name, threshold=0.5):
    
    query = f"""
    SELECT 
        gene_id,
        COUNT(*) FILTER (WHERE is_gene_upregulated = TRUE) as upregulated,
        COUNT(*) FILTER (WHERE is_gene_upregulated = FALSE) as downregulated
    FROM {table_name}
    WHERE vcf_id != 'PD4120a' 
    AND gene_id != 'not_found'
    AND ABS(wt_prediction - mut_prediction) > {threshold}
    GROUP BY gene_id
    """
    
    with sqlite3.connect(db_path) as conn:
        counts_df = pd.read_sql_query(query, conn)
    
    print(f"Total unique genes: {len(counts_df)}")
    return counts_df

def calculate_log2_odds_ratio(a, b, c, d, k=0.5):

    odds_ratio = ((a + k) * (d + k)) / ((b + k) * (c + k))
    return np.log2(odds_ratio)

def shrink_log2_odds(values, prior_scale=1.0, min_count=10):
    """Implement empirical Bayes shrinkage for log2 odds ratios."""
    total_counts = (values['upregulated_real'] + values['downregulated_real'] +
                   values['upregulated_synth'] + values['downregulated_synth'])
    
    raw_log2_odds = values.apply(lambda row: calculate_log2_odds_ratio(
        row['upregulated_real'], 
        row['downregulated_real'], 
        row['upregulated_synth'], 
        row['downregulated_synth'],
        k=0.5
    ), axis=1)

    weights = 1 - np.exp(-total_counts / min_count)
    prior_mean = np.average(raw_log2_odds, weights=weights)
    prior_var = np.var(raw_log2_odds)
    posterior_scale = prior_scale / (1 + weights * prior_scale)
    
    return weights * raw_log2_odds + (1 - weights) * prior_mean

def perform_fisher_test_vectorized(df, pseudocount=0.01):
    # Add pseudocount to the table
    table = np.array([
        [df['upregulated_real'] + pseudocount, df['downregulated_real'] + pseudocount],
        [df['upregulated_synth'] + pseudocount, df['downregulated_synth'] + pseudocount]
    ]).transpose((2, 0, 1))  # reshape for 2x2 tables

    p_values = np.zeros(len(df))

    for i in range(len(df)):
        _, p_values[i] = fisher_exact(table[i])

    df['p_value'] = p_values
    df['p_adj'] = multipletests(p_values, method='fdr_bh')[1]
    
    return df
def add_z_score(df):
    # Calculate mean and standard deviation of log2 odds ratios
    mean_log2or = df['log2_odds_ratio'].mean()
    std_log2or = df['log2_odds_ratio'].std()
    
    # Calculate Z-score
    df['z_score'] = (df['log2_odds_ratio'] - mean_log2or) / std_log2or
    
    return df


In [3]:
real_counts = get_gene_regulation_counts(db_path, "real", 0.25)
synth_counts = get_gene_regulation_counts(db_path, "synth", 0.25)

counts = pd.merge(real_counts, synth_counts, how="inner", on="gene_id", suffixes=["_real", "_synth"])
counts["upregulated_synth"] = counts["upregulated_synth"] / 10
counts["downregulated_synth"] = counts["downregulated_synth"] / 10

# real_counts.to_csv("results/sql/real_postsql.csv", index=False)
# synth_counts.to_csv("results/sql/synth_postsql.csv", index=False)
# counts.to_csv("results/sql/merged_postsql.csv", index=False)

Total unique genes: 31303
Total unique genes: 36982


In [None]:
# counts = pd.read_csv("results/sql/merged_postsql.csv")

counts['log2_odds_ratio'] = counts.apply(lambda row: calculate_log2_odds_ratio(
    row['upregulated_real'], 
    row['downregulated_real'], 
    row['upregulated_synth'], 
    row['downregulated_synth']
), axis=1)

counts['shrunk_log2_odds'] = shrink_log2_odds(counts)

counts = perform_fisher_test_vectorized(counts)
counts = add_z_score(counts)


In [None]:
counts.to_csv("results/sql/results_025.csv", index=False)


In [12]:
counts[counts.p_adj < 0.05]

Unnamed: 0,gene_id,upregulated_real,downregulated_real,upregulated_synth,downregulated_synth,p_value,p_adj,log2_odds_ratio,shrunk_log2_odds,z_score
2,ENSG00000000419,282,47,331.1,105.4,9.441058e-04,5.024370e-03,0.925510,0.925510,0.528698
4,ENSG00000000460,3323,883,2555.7,1046.2,2.427139e-16,4.106850e-14,0.623249,0.623249,0.312830
7,ENSG00000001036,42,69,156.4,91.1,1.112474e-05,1.165065e-04,-1.485976,-1.485976,-1.193534
8,ENSG00000001084,1684,456,1551.9,558.0,8.677525e-05,6.805409e-04,0.408763,0.408763,0.159648
9,ENSG00000001167,224,112,348.7,107.9,2.892646e-03,1.279838e-02,-0.690898,-0.690898,-0.625706
...,...,...,...,...,...,...,...,...,...,...
31290,ENSG00000273398,32,84,83.3,60.4,8.923308e-07,1.306484e-05,-1.839020,-1.839020,-1.445670
31292,ENSG00000273408,0,25,10.9,29.6,4.744668e-03,1.933130e-02,-4.271696,-4.265325,-3.183036
31297,ENSG00000273471,87,8,93.8,96.5,1.503592e-13,1.409190e-11,3.404475,3.404475,2.299123
31300,ENSG00000273481,3,23,34.8,14.6,1.510276e-06,2.085406e-05,-3.972354,-3.970145,-2.969252
