In [None]:
import pandas as pd
import numpy as np
import os
from argparse import ArgumentParser
from Customize import JAVA_PATH, CPRD_CODE_PATH, COHORT_SAVE_PATH,MODEL_SAVE_PATH, get_symptom_medcode, get_icdcode
import itertools
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import ttest_ind
from statannotations.Annotator import Annotator

In [None]:
parser = ArgumentParser()
parser.add_argument("--disease", type=str, default='AD')
parser.add_argument("--ukb_dir", type=str, default='UKB_AD_data')
parser.add_argument("--experient_dir", type=str, default='AD')
parser.add_argument("--model_name", type=str, default='cl_maskage_b32')
parser.add_argument("--stage", type=str, default='before')
parser.add_argument("--seed", type=int, default=2024)
parser.add_argument("--follow_up_year", type=int, default=5)
parser.add_argument("--k", type=int, default=5)

args = parser.parse_args([])
stage = args.stage
ukb_path = os.path.join(COHORT_SAVE_PATH, args.ukb_dir+'_'+args.stage)
experient_dir = os.path.join(MODEL_SAVE_PATH, args.experient_dir)
model_save_dir = os.path.join(experient_dir, args.model_name+'_'+stage)
results_dir = os.path.join(model_save_dir, 'results')
results_save_dir = os.path.join(results_dir, f'{args.disease}-{args.k}')


In [None]:
prs_save_dir = os.path.join(results_save_dir, 'genetic_analyses')
if not os.path.exists(prs_save_dir):
    os.makedirs(prs_save_dir)
result_df = pd.read_pickle(os.path.join(results_save_dir, 'result_df_'+stage+'_EHR_UKB.pkl'))
prs = pd.read_csv('adpd_snps.raw', delim_whitespace=True)
prs.rename(columns={'FID':'patid'}, inplace=True)

In [None]:
ad_snp = ['rs429358_C', 'rs7412_T', 'rs143332484_T', 'rs3764650_G', 'rs3752246_G']
# APEO4, APEO2, TREM2, ABCA7, ABCA7 
prs['patid'] = prs['patid'].astype(str)
prs = prs[[*['patid'], *ad_snp]]
new_df = prs.merge(result_df[['patid', 'label', 'age','Sex']], on='patid', how='inner')
normal_values = {
    'rs429358_C': 0.06817,
    'rs7412_T': 0.082919,
    'rs143332484_T': 0.010499,
    'rs3764650_G': 0.093706,
    'rs3752246_G': 0.17085,
}

In [None]:
grouped = new_df.groupby('label')

# Create an empty dictionary to collect the percentages for each label
results = {}
for label, group in grouped:
    # For each SNP, calculate the percentage of non-zero entries in the group
    percents = {snp: (group[snp].astype(float).ne(0).sum() / len(group)) / normal_values[snp]
                for snp in ad_prs}
    results[label] = percents

# Convert the dictionary to a DataFrame.
# Rows will be SNPs, columns will be label values (1 through 5)
heatmap_data = pd.DataFrame(results)

# Optionally, sort the columns (if needed) to ensure labels appear in order
heatmap_data = heatmap_data.reindex(sorted(heatmap_data.columns), axis=1)

# Plot the heatmap
plt.figure(figsize=(8, 4))
sns.heatmap(heatmap_data,cmap='Reds', annot=True, fmt=".2f", 
            annot_kws={"size": 10},  # Increase font size of the cell values
            cbar=False,  # Remove the color bar
            xticklabels=True, yticklabels=True)

plt.xlabel('Label')
plt.ylabel('SNP')
plt.title('Percentage of Mutation Ratio Compare with Population by Label')
plt.savefig(os.path.join(prs_save_dir, 'snp_Ratio_heatmap.png'))
plt.show()

In [None]:
import pandas as pd
import numpy as np
from scipy.stats import fisher_exact
import os

# Exclude 'Control' and reset index
new_df_non = new_df[new_df['label'] != 'Control'].copy()
new_df_non.reset_index(drop=True, inplace=True)

# Rename numeric labels to string labels
label_mapping = {1: 'Cluster 1', 2: 'Cluster 2', 3: 'Cluster 3', 4: 'Cluster 4', 5: 'Cluster 5'}
new_df_non['label'] = new_df_non['label'].map(label_mapping)

# Define clusters (order)
clusters = ['Cluster 1', 'Cluster 2', 'Cluster 3', 'Cluster 4', 'Cluster 5']

# Create a list to store the results
results = []

# Loop over each SNP and each cluster (compared to all others)
for snp in ad_prs:
    for cluster in clusters:
        # Select data for the current cluster and for the rest ("Other")
        group_df = new_df_non[new_df_non['label'] == cluster]
        other_df = new_df_non[new_df_non['label'] != cluster]
        
        # Calculate counts for nonzero (mutated) vs zero
        nonzero_group = group_df[snp].astype(float).ne(0).sum()
        total_group = len(group_df)
        nonzero_other = other_df[snp].astype(float).ne(0).sum()
        total_other = len(other_df)
        
        # Build the 2x2 contingency table:
        #           Mutated            Not Mutated
        # Group:    nonzero_group      total_group - nonzero_group
        # Other:    nonzero_other      total_other - nonzero_other
        table = [[nonzero_group, total_group - nonzero_group],
                 [nonzero_other, total_other - nonzero_other]]
        
        # Run Fisher's exact test (if counts allow; otherwise p_value = NaN)
        try:
            odds_ratio, p_value = fisher_exact(table)
        except Exception as e:
            p_value = np.nan
        
        # Compute the percentage of mutated samples in each group
        perc_group = (nonzero_group / total_group * 100) if total_group > 0 else np.nan
        perc_other = (nonzero_other / total_other * 100) if total_other > 0 else np.nan
        
        # Append the results as a row in the results list
        results.append({
            'SNP': snp,
            'Cluster': cluster,
            'Group_n': total_group,
            'Other_n': total_other,
            'Group_Nonzero': nonzero_group,
            'Other_Nonzero': nonzero_other,
            'Group_%': perc_group,
            'Other_%': perc_other,
            'Odds_Ratio': odds_ratio,
            'p_value': p_value
        })

# Convert the list to a DataFrame
results_df = pd.DataFrame(results)
results_df.to_csv(os.path.join(prs_save_dir, f'UKB_{args.disease}_SNP_fisher_test.csv'))
