In [2]:
import pandas as pd
from scipy import stats
from joblib import Parallel, delayed
from tqdm import tqdm

In [3]:
def get_fisher_stat(row, total_genes=68_258_704, defensive_genes=244_022):
    random_prob = 1 - (1-defensive_genes/total_genes)**20
    random_defense_total = random_prob*total_genes
    fisher = stats.fisher_exact([[row['defense_count'], row['background_count'] - row['defense_count']], 
                                 [random_defense_total, total_genes - random_defense_total]], alternative='greater')
    return fisher[1]

In [4]:
model_seq_info = pd.read_parquet('../data3/interim/model_seq_info.pq')
model_cluster_assemblies = (pd.read_parquet('../data3/interim/model_cluster_assemblies.pq')
                            .reset_index())
model_cluster_neighbors = pd.read_parquet('../data3/interim/model_cluster_defense_neighbors.pq')

In [5]:
cluster_background_stats = (model_cluster_assemblies.groupby('cluster_id')
                            .agg(background_count=('assembly_stub', 'nunique')))

In [6]:
results_list = list()
for fold, fold_df in model_seq_info.groupby('test_fold'):
    print(fold)
    fold_test_defense_genes = fold_df['defense_gene'].unique()
    filtered_cluster_defense_neighbors = (model_cluster_neighbors
                                          .loc[(~model_cluster_neighbors['defense_neighbor'].isin(fold_test_defense_genes) & 
                                                model_cluster_neighbors['cluster_id'].isin(fold_df['seq_id'])), 
                                               ['seq_id', 'cluster_id', 'assembly']]
                                          .drop_duplicates())
    cluster_defense_stats = (filtered_cluster_defense_neighbors.groupby('cluster_id')
                                     .agg(defense_count=('assembly', 'nunique')))
    fold_clusters = (fold_df[['seq_id']])
    fold_cluster_enrichment = (fold_clusters
                           .merge(cluster_background_stats, left_on='seq_id', 
                                  right_index=True, how='inner')
                           .merge(cluster_defense_stats, left_on='seq_id', 
                                  right_index=True, how='left'))
    fold_cluster_enrichment['defense_count'] = fold_cluster_enrichment['defense_count'].fillna(0)
    fold_cluster_enrichment['frac_defensive'] = (fold_cluster_enrichment['defense_count']/
                                                 fold_cluster_enrichment['background_count'])
    fold_cluster_enrichment['fisher_p'] = Parallel(n_jobs=48)(delayed(get_fisher_stat)(row) for _, row in 
                                                         tqdm(fold_cluster_enrichment.iterrows(), 
                                                              total=len(fold_cluster_enrichment), 
                                                              position=0))
    results_list.append(fold_cluster_enrichment)

0


100%|██████████| 44006/44006 [00:10<00:00, 4204.52it/s]


1


100%|██████████| 45951/45951 [00:11<00:00, 4073.71it/s]


2


100%|██████████| 37456/37456 [00:12<00:00, 2991.98it/s]


3


100%|██████████| 38950/38950 [00:14<00:00, 2638.06it/s]


4


100%|██████████| 34095/34095 [00:06<00:00, 5238.97it/s]


In [7]:
results_df = pd.concat(results_list)
results_df['one_minus_fisher_p'] = 1 - results_df['fisher_p']

In [8]:
results_df

Unnamed: 0,seq_id,background_count,defense_count,frac_defensive,fisher_p,one_minus_fisher_p
0,00011a6f43ddef04b38ddb80e079e995128c758b8d0fef...,1,1.0,1.000000,0.069122,0.930878
1,00d5886a101316db1f7fae8c378cbd8c54fdc7e3260834...,7,2.0,0.285714,0.079486,0.920514
2,11773e80a189f17a60c241e286666f00a4cb7e06c7179d...,1,0.0,0.000000,1.000000,0.000000
3,1636cbe8d2ef79e4a6d97f3e2a4c94d33f5de611c7f2ca...,5,3.0,0.600000,0.002970,0.997030
4,3071bdad201ef0e6279ecaca5095bc298e8461345e78ea...,1,0.0,0.000000,1.000000,0.000000
...,...,...,...,...,...,...
200437,fff86b7c99c52d59c18d7217da734f72a75be4311b11b9...,7,0.0,0.000000,1.000000,0.000000
200442,fff92ed969130402b18c2f965ed2c47376001310513555...,1,0.0,0.000000,1.000000,0.000000
200447,fffafb6f04ef826d1fe339e4e4dfbd7555ddc77fe50f19...,1,0.0,0.000000,1.000000,0.000000
200453,fffe22c006b65defec0d43d82f17b7a57e3b3e75e64b5b...,1,0.0,0.000000,1.000000,0.000000


In [9]:
out_df = (results_df.rename(columns={'one_minus_fisher_p': 'Guilt-by-association (p-value)', 
                                     'frac_defensive': 'Guilt-by-association (frequency)'})
          .melt(id_vars='seq_id', value_vars=['Guilt-by-association (p-value)', 'Guilt-by-association (frequency)'], 
                value_name='prediction', var_name='method'))

In [10]:
out_df.to_parquet('../data3/interim/cv_predictions_gba.pq', index=False)