In [None]:
import numpy as np
import pandas as pd
import os, sys
import anndata
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve
from plotnine import *
import scipy.stats as stats
from scipy.stats import mannwhitneyu, fisher_exact

from grelu.resources import load_model

import warnings
warnings.filterwarnings('ignore')

## Paths

In [None]:
save_dir='/gstore/data/resbioai/grelu/decima/20240823'
eqtl_file  = '/gstore/data/omni/regulatory_elements/decima_files/decima_pos_eQTL_overlaps.csv'

out_dir = os.path.join(save_dir, 'gwas_44traits')
pos_dir = os.path.join(out_dir, 'positive_variants')
neg_dir = os.path.join(out_dir, 'negative_variants')

pos_file = os.path.join(pos_dir, 'positive_variants_and_traits.csv')
matched_neg_file = os.path.join(neg_dir, 'negatives_matched.csv')

In [None]:
decima_pos_file = os.path.join(pos_dir, 'decima_preds_agg.h5ad')
decima_neg_file = os.path.join(neg_dir, 'decima_preds_agg.h5ad')
gene_pos_file = os.path.join(pos_dir, 'gene_preds_agg.h5ad')
gene_neg_file = os.path.join(neg_dir, 'gene_preds_agg.h5ad')
tss_pos_file = os.path.join(pos_dir, 'tss_preds_agg.h5ad')
tss_neg_file = os.path.join(neg_dir, 'tss_preds_agg.h5ad')

## Load predictions

In [None]:
decima_pos_preds = anndata.read_h5ad(decima_pos_file)
decima_neg_preds =anndata.read_h5ad(decima_neg_file)
print(decima_pos_preds.shape, decima_neg_preds.shape)

gene_pos_preds = anndata.read_h5ad(gene_pos_file)
gene_neg_preds =anndata.read_h5ad(gene_neg_file)
print(gene_pos_preds.shape, gene_neg_preds.shape)

tss_pos_preds = anndata.read_h5ad(tss_pos_file)
tss_neg_preds =anndata.read_h5ad(tss_neg_file)
print(tss_pos_preds.shape, tss_neg_preds.shape)

## Load variants

In [None]:
match = pd.read_csv(matched_neg_file)
match = match.groupby(['pos_variant', 'gene']).variant.apply(list).reset_index()

In [None]:
traits = pd.read_csv(pos_file)

## Calculate aggregate score for each variant-gene pair

In [None]:
for ad in [decima_pos_preds, decima_neg_preds, gene_pos_preds,
          gene_neg_preds, tss_pos_preds, tss_neg_preds]:
    ad.obs['score'] = np.mean(np.abs(ad.X), 1)
    ad.obs['abs_tss_dist'] = np.abs(ad.obs.tss_dist)

## Compare overall scores

In [None]:
for p, n, m in [[decima_pos_preds, decima_neg_preds, 'Decima'],
             [gene_pos_preds, gene_neg_preds, 'RNA'],
                [tss_pos_preds, tss_neg_preds, 'CAGE']]:
    p.obs['label'] = 1
    p.obs['label_str'] = 'GWAS'
    n.obs['label'] = 0
    n.obs['label_str'] = 'Control'

    labels = p.obs.label_str.tolist() + n.obs.label_str.tolist()
    scores = p.obs.score.tolist() + n.obs.score.tolist()

    print(m, p.obs.score.mean(), n.obs.score.mean(), mannwhitneyu(p.obs.score, n.obs.score, alternative='greater').pvalue)
    display((
        ggplot(pd.DataFrame({
            'label':labels, 'score':scores}), aes(x='label', y='scores')) 
        + geom_boxplot(outlier_size=.1) + theme_classic() + theme(figure_size=(3,2.6))
        + scale_y_log10() + ggtitle(m) + xlab('Variant Type') + ylab('VEP score')
    ))

## Overall classification

In [None]:
for p, n, m in [[decima_pos_preds, decima_neg_preds, 'Decima'],
             [gene_pos_preds, gene_neg_preds, 'RNA'],
                [tss_pos_preds, tss_neg_preds, 'CAGE']]:
    
    labels = p.obs.label.tolist() + n.obs.label.tolist()
    scores = p.obs.score.tolist() + n.obs.score.tolist()
    dists = np.array(
        p.obs['abs_tss_dist'].tolist() + n.obs['abs_tss_dist'].tolist())

    ap = np.round(average_precision_score(labels, scores), 2)
    auroc = np.round(roc_auc_score(labels, scores), 2)
    ap_dist = np.round(average_precision_score(labels, -dists), 2)
    auroc_dist = np.round(roc_auc_score(labels, -dists), 2)

    print(m, len(p), len(n), ap, ap_dist, auroc, auroc_dist)

    pr, re, _ = precision_recall_curve(labels, scores)
    df = pd.DataFrame({'Precision':pr, 'Recall':re})
    df['Method'] = m

    pr, re, _ = precision_recall_curve(labels, -dists)
    df = pd.concat([df, pd.DataFrame({'Precision':pr, 'Recall':re, 'Method':'Distance'})])
    df.Method= pd.Categorical(df.Method, categories=[m, 'Distance'])

    display((
        ggplot(df, aes(x='Recall', y='Precision', color='Method'))
            + geom_point() + theme_classic() + theme(figure_size=(4,2))
    ))

## Per variant classification

In [None]:
for p, n, m in [[decima_pos_preds, decima_neg_preds, 'Decima'],
             [gene_pos_preds, gene_neg_preds, 'RNA'],
                [tss_pos_preds, tss_neg_preds, 'CAGE']]:

    top_10 = match.merge(p.obs[['variant', 'gene']].rename(columns={'variant': 'pos_variant'}))
    top_10 = top_10.apply(
        lambda row: n[(n.obs.gene==row.gene) & (n.obs.variant.isin(row.variant))].obs.sort_values('abs_tss_dist').head(12).index.tolist(),
    axis=1)
    n_ = n[np.concatenate(top_10).tolist()].copy()
    labels = p.obs.label.tolist() + n_.obs.label.tolist()
    scores = p.obs.score.tolist() + n_.obs.score.tolist()

    is_best = []
    pvals = []
    for row in p.obs.iterrows():
        row=row[1]
        pos_score = row['score']
        variant = row['variant']
        gene = row['gene']
        matched_variants = match.loc[(match.pos_variant==variant) & (match.gene==gene), 'variant'].tolist()[0]
        neg_scores = n_.obs.loc[(n_.obs.variant.isin(matched_variants)) & (n_.obs.gene==gene), 'score']
        
        is_best.append(pos_score > neg_scores.max())
        z = (pos_score - neg_scores.mean())/neg_scores.std()
        pvals.append(stats.norm.sf(z))

    p.obs['pval'] = pvals
    p.obs['sig_05'] = [x < .05 for x in pvals]
    p.obs['is_best'] = is_best
    
    print(m, len(p), len(n_), p.obs.is_best.value_counts(normalize=True)[True], p.obs.sig_05.value_counts(normalize=True)[True])

## eQTLs

In [None]:
pos_variant_file = '/gstore/data/resbioai/grelu/decima/20240823/gwas_44traits/positive_variants/positive_variants.csv'
eqtl_file = '/gstore/data/omni/regulatory_elements/decima_files/fine_mapped_OT_eqtl.txt'

In [None]:
pos_variants = pd.read_csv(pos_variant_file)
eqtl = pd.read_table(eqtl_file)

In [None]:
gwas_variants = pos_variants.variant.unique()
eqtl_variants = 'chr' + eqtl['tag_variant_id']

In [None]:
p.obs['has_eqtl'] = p.obs.variant.isin(eqtl_variants)

In [None]:
p.obs[['variant', 'has_eqtl']].drop_duplicates().has_eqtl.value_counts(normalize=True)

In [None]:
df = p.obs[['has_eqtl', 'sig_05']].value_counts().unstack()
df

In [None]:
58/(38+37)

In [None]:
fisher_exact(p.obs[['has_eqtl', 'sig_05']].value_counts().unstack())

## Save Decima results

In [None]:
ad = anndata.concat([decima_pos_preds, decima_neg_preds], join='outer')

In [None]:
ad.write_h5ad(os.path.join(out_dir, 'gwas_variant_predictions_matched.h5ad'))