In [None]:
# EAS clump res demo: /data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/clumped/waist/group_1/r2_0.1_w_500/res.clumps

import pandas as pd
import os
import numpy as np

trait_list = ['waist', 'height', 'pulse', 'dbp', 'sbp', 'smoke', 'drink', 'bmi', 'wbc', 'rbc', 'hb', 'plt', 'lymph', 'mono', 'neut', 'eos', 'alt', 'ast', 'bun', 'cholesterol', 'creatinine', 'glucose', 'ggt', 'hdl', 'ldl', 'triglycerides', 'ua']
eas_base_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/CAS/"
clump_res_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/clumped/"
output_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/filtered/"

p_val_list = [5e-8, 5e-7, 5e-6, 5e-5, 5e-4, 5e-3, 5e-2, 0.5, 1]

for trait in trait_list:
    for group in range(1, 11):
        gwas_file = os.path.join(eas_base_dir, f"{trait}/group_{group}/gwas/train.Pheno.glm.linear")
        if not os.path.exists(gwas_file):
            print(f"File {gwas_file} does not exist. Skipping.")
            continue
        gwas_df = pd.read_csv(gwas_file, sep="\t")
        for p_val in p_val_list:
            print(f"Processing trait {trait}, group {group}, p_val {p_val}")
            clump_file = os.path.join(clump_res_dir, f"{trait}/group_{group}/r2_0.1_w_500/res.clumps")
            if not os.path.exists(clump_file):
                print(f"File {clump_file} does not exist. Skipping.")
                continue
            df = pd.read_csv(clump_file, sep="\t")
            df_filtered = df[df['P'] <= p_val]
            df_filtered = df_filtered[["#CHROM", "POS", "ID"]]
            res = pd.merge(df_filtered, gwas_df, on=["#CHROM", "POS", "ID"], how='inner')
            if res.shape[0] == 0:
                print(f"No SNPs passed the p-value threshold {p_val} for trait {trait}, group {group}. Skipping.")
                continue
            output_subdir = os.path.join(output_dir, f"{trait}/EAS/group_{group}/")
            os.makedirs(output_subdir, exist_ok=True)
            output_file = os.path.join(output_subdir, f"pval_{p_val}.clumped")
            res.to_csv(output_file, sep="\t", index=False, header=True)
            print(f"Saved filtered clump file to {output_file}")

In [None]:
# calculate PRS based on the filtered clump files on EAS test data
# test data demo: /data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/test/alt/group_1/test
import pandas as pd
import os
import numpy as np

trait_list = ['waist', 'height', 'pulse', 'dbp', 'sbp', 'smoke', 'drink', 'bmi', 'wbc', 'rbc', 'hb', 'plt', 'lymph', 'mono', 'neut', 'eos', 'alt', 'ast', 'bun', 'cholesterol', 'creatinine', 'glucose', 'ggt', 'hdl', 'ldl', 'triglycerides', 'ua']
clump_res_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/filtered/"
eas_base_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/CAS/"
ct_base_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/"
output_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/test_in_sample/"

p_val_list = [5e-8, 5e-7, 5e-6, 5e-5, 5e-4, 5e-3, 5e-2, 0.5, 1]

for trait in trait_list:
    for group in range(1, 11):
        test_bfile = os.path.join(ct_base_dir, f"test/{trait}/group_{group}/combine")
        if not os.path.exists(test_bfile + ".bed"):
            print(f"File {test_bfile}.bed does not exist. Skipping.")
            continue
        for p_val in p_val_list:
            print(f"Processing trait {trait}, group {group}, p_val {p_val}")
            clump_file = os.path.join(clump_res_dir, f"{trait}/EAS/group_{group}/pval_{p_val}.clumped")
            if not os.path.exists(clump_file):
                print(f"File {clump_file} does not exist. Skipping.")
                continue
            output_subdir = os.path.join(output_dir, f"EAS/{trait}/group_{group}/")
            os.makedirs(output_subdir, exist_ok=True)
            output_prefix = os.path.join(output_subdir, f"pval_{p_val}")
            prs_command = f"plink2 --bfile {test_bfile} --score {clump_file} 3 5 12 header no-mean-imputation --out {output_prefix}"
            os.system(prs_command)

In [None]:
import pandas as pd
import numpy as np
import statsmodels.api as sm
from sklearn.metrics import roc_auc_score, average_precision_score, mean_squared_error
from scipy.stats import pearsonr
import os
import warnings

# Ignore warnings that may arise from certain fits in statsmodels
from statsmodels.tools.sm_exceptions import ConvergenceWarning
warnings.simplefilter('ignore', ConvergenceWarning)

# --- 1. Metric Calculation Functions ---

def calculate_continuous_metrics(df, base_covars, full_covars):
    """Calculates all performance metrics for a continuous trait for a given dataset (df)."""
    # Incremental R²
    model_base = sm.OLS(df["trait"], sm.add_constant(df[base_covars])).fit()
    model_full = sm.OLS(df["trait"], sm.add_constant(df[full_covars])).fit()
    r2_incremental = model_full.rsquared - model_base.rsquared

    # Pearson correlation coefficient (SCORE vs. phenotype residuals)
    pheno_residuals = model_base.resid
    corr, _ = pearsonr(df["SCORE"], pheno_residuals)

    # RMSE
    prediction_full = model_full.predict(sm.add_constant(df[full_covars]))
    rmse = np.sqrt(mean_squared_error(df["trait"], prediction_full))
    nrmse_mean = rmse / df["trait"].mean() if df["trait"].mean() != 0 else np.nan
    nrmse_range = rmse / (df["trait"].max() - df["trait"].min()) if (df["trait"].max() - df["trait"].min()) != 0 else np.nan
    nrmse_std = rmse / df["trait"].std() if df["trait"].std() != 0 else np.nan

    # Quantile means
    df['quantile'] = pd.qcut(df['SCORE'], 5, labels=False, duplicates='drop')
    quantile_means = df.groupby('quantile')['trait'].mean()
    
    return {
        "r2_incremental": r2_incremental,
        "r2_full": model_full.rsquared,
        "rmse": rmse,
        "nrmse_mean": nrmse_mean,
        "nrmse_range": nrmse_range,
        "nrmse_std": nrmse_std,
        "pearson_r": corr,
        "top_quintile_mean": quantile_means.iloc[-1] if not quantile_means.empty else np.nan,
        "bottom_quintile_mean": quantile_means.iloc[0] if not quantile_means.empty else np.nan
    }

def calculate_binary_metrics(df, base_covars, full_covars):
    """Calculates all performance metrics for a binary trait for a given dataset (df)."""
    # AUC and PR-AUC
    logit_model = sm.Logit(df["trait"], sm.add_constant(df[full_covars])).fit(disp=0)
    pred_prob = logit_model.predict(sm.add_constant(df[full_covars]))
    auc = roc_auc_score(df["trait"], pred_prob)
    pr_auc = average_precision_score(df["trait"], pred_prob)

    # OR per 1-SD
    df["prs_scaled"] = (df["SCORE"] - df["SCORE"].mean()) / df["SCORE"].std()
    logit_model_scaled = sm.Logit(df["trait"], sm.add_constant(df[base_covars + ["prs_scaled"]])).fit(disp=0)
    or_per_sd = np.exp(logit_model_scaled.params["prs_scaled"])

    # Quantile OR
    df['prs_quintile'] = pd.qcut(df['SCORE'], 5, labels=False, duplicates='drop')
    reference_quintile = 2 # Middle quintile
    or_quintiles = {}
    for q in range(5):
        if q == reference_quintile:
            or_quintiles[f'OR_Quintile_{q+1}'] = 1.0
            continue
        
        # Check if both current and reference quintiles exist in the data
        if not df['prs_quintile'].isin([q, reference_quintile]).all():
             or_quintiles[f'OR_Quintile_{q+1}'] = np.nan
             continue
             
        temp_df = df[df['prs_quintile'].isin([q, reference_quintile])].copy()
        
        # Check for sufficient data in both groups for stable model fitting
        if temp_df['trait'].nunique() < 2 or temp_df['prs_quintile'].nunique() < 2:
            or_quintiles[f'OR_Quintile_{q+1}'] = np.nan
            continue
            
        temp_df['is_current_quintile'] = (temp_df['prs_quintile'] == q).astype(int)
        X_quintile = sm.add_constant(temp_df[['is_current_quintile'] + base_covars])
        try:
            model_q = sm.Logit(temp_df["trait"], X_quintile).fit(disp=0)
            or_quintiles[f'OR_Quintile_{q+1}'] = np.exp(model_q.params['is_current_quintile'])
        except Exception:
            or_quintiles[f'OR_Quintile_{q+1}'] = np.nan
            
    results = {
        "auc": auc,
        "pr_auc": pr_auc,
        "or_per_sd": or_per_sd,
    }
    results.update(or_quintiles) # Merge quantile ORs into the results dictionary
    return results

# --- 2. Main Execution Flow ---

def main():
    # --- Parameter Settings ---
    # !! Note: Please modify the variables below according to your actual paths !!
    cleaned_prs_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/test_in_sample/"
    covar_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/CAS/pheno/covariates.txt"
    output_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/test_in_sample/"
    pheno_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/CAS/"
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # --- Initialization ---
    final_results_continuous = []
    covar_cols = ["FID", "IID", "age", "sex"] + [f"PC{i}" for i in range(1, 11)]
    base_covars = ["age", "sex"] + [f"PC{i}" for i in range(1, 11)]
    full_covars = base_covars + ["SCORE"]

    # --- Data Loading and Processing ---
    covars = pd.read_csv(covar_path, sep='\t', usecols=covar_cols)

    trait_dict = {
        'p48': 'waist',
        'p50': 'height',
        'p102': 'pulse',
        'p4079': 'dbp',
        'p4080': 'sbp',
        'p20116': 'smoke',
        'p20117': 'drink',
        'p21001': 'bmi',
        'p30000': 'wbc',
        'p30010': 'rbc',
        'p30020':'hb',
        'p30080': 'plt',
        'p30120': 'lymph',
        'p30130': 'mono',
        'p30140': 'neut',
        'p30150': 'eos',
        'p30620': 'alt',
        'p30650': 'ast',
        'p30670': 'bun',
        'p30690': 'cholesterol',
        'p30700': 'creatinine',
        'p30730': 'ggt',
        'p30740': 'glucose',
        'p30760': 'hdl',
        'p30780': 'ldl',
        'p30870': 'triglycerides',
        'p30880': 'ua'
    }
    p_val_list = [5e-8, 5e-7, 5e-6, 5e-5, 5e-4, 5e-3, 5e-2, 0.5, 1]
    # /data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/CAS/alt/group_1/pheno/test.txt

    for trait, name in trait_dict.items():
        for i in range(1, 11):
            for p_val in p_val_list:
                trait_prs_path = os.path.join(cleaned_prs_path, f"EAS/{name}/group_{i}/pval_{p_val}.sscore")
                if not os.path.exists(trait_prs_path):
                    print(f"Warning: PRS file for trait {name} group {i} not found. Skipping.")
                    continue
                print(f"\nProcessing Trait: {name}, Group: {i}, P-value: {p_val}")
                if trait == 'p20116' or trait == 'p20117':
                    trait_id_from_prs = f"{trait}_int"
                else:
                    trait_id_from_prs = f"{trait}_raw"
                pheno_path = os.path.join(pheno_dir, f"{name}/group_{i}/pheno/test_pheno.txt")
                pheno = pd.read_csv(pheno_path, sep='\t')
                pheno.columns = ["FID", "IID", "trait"]
                print(pheno.head())
                prs = pd.read_csv(trait_prs_path, sep='\t')
                prs.rename(columns={"#FID": "FID"}, inplace=True)
                prs.rename(columns={"SCORE1_AVG": "SCORE"}, inplace=True)

                for df in [pheno, prs, covars]:
                    df["FID"] = df["FID"].astype(str)
                    df["IID"] = df["IID"].astype(str)
                merged_data = pd.merge(pheno, covars, on=["FID", "IID"], how="inner")
                merged_data = pd.merge(merged_data, prs, on=["FID", "IID"], how="inner")
                print(merged_data.head())
                # Defensive data cleaning and type conversion
                numeric_cols = ["trait", "SCORE", "age", "sex"] + [f"PC{i}" for i in range(1, 11)]
                for col in numeric_cols:
                    if col in merged_data.columns:
                        merged_data[col] = pd.to_numeric(merged_data[col], errors='coerce')

                original_rows = len(merged_data)
                merged_data.dropna(subset=numeric_cols, inplace=True)
                new_rows = len(merged_data)
                if original_rows > new_rows:
                    print(f"--> Warning: Dropped {original_rows - new_rows} rows due to non-numeric data or NaNs in trait {trait_id_from_prs}.")
                if new_rows == 0:
                    print(f"--> Error: No valid samples remained for trait {trait_id_from_prs} after cleaning. Skipping.")
                    continue
                print(f"Data merged and cleaned for trait {trait_id_from_prs}. Total samples: {len(merged_data)}")

                if trait_id_from_prs in ["p20116_int", "p20117_int"]:
                    # Binary trait analysis
                    # Ensure binary trait is 0/1 coded
                    unique_vals = sorted(merged_data["trait"].unique())
                    if not set(unique_vals).issubset({0, 1}):
                        if len(unique_vals) == 2:
                            print(f"Converting binary trait from {unique_vals} to 0/1.")
                            merged_data["trait"] = (merged_data["trait"] == unique_vals[1]).astype(int)
                        else:
                            print(f"Error: Binary trait column for {trait_id_from_prs} contains unexpected values: {unique_vals}. Skipping.")
                            continue
                        
                    # Direct calculation of metrics
                    analysis_report = calculate_continuous_metrics(merged_data, base_covars, full_covars)
                    analysis_report['trait'] = trait_id_from_prs
                    analysis_report['p_val_threshold'] = p_val
                    final_results_continuous.append(analysis_report)
                else:
                    # Continuous trait analysis
                    # Direct calculation of metrics
                    analysis_report = calculate_continuous_metrics(merged_data, base_covars, full_covars)
                    analysis_report['trait'] = trait_id_from_prs
                    analysis_report['p_val_threshold'] = p_val
                    final_results_continuous.append(analysis_report)

        # --- 3. Save Final Results ---
        if final_results_continuous:
            continuous_df = pd.DataFrame(final_results_continuous)
            # Reorder columns to have 'trait' first
            cols = ['trait'] + [col for col in continuous_df.columns if col != 'trait']
            continuous_df = continuous_df[cols]
            continuous_df.to_csv(os.path.join(output_dir, "EAS_in_sample_metrics.csv"), index=False)
            print("\nContinuous trait results saved to EAS_in_sample_metrics.csv")
            print(continuous_df)

if __name__ == '__main__':
    main()

In [None]:
import pandas as pd
import os
import numpy as np

# r2_data_path = "../../../../PRS_benchmark/data/result/real_data/ct/EAS_in_sample_metrics.csv"
r2_data_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/test_in_sample/EAS_in_sample_metrics.csv"
# eas_bfile_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/CAS/geno/CAS_final/CAS_merged_qc_final"
eas_bfile_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/tlprs/reference/EAS_1kg/1000G.EAS.QC.hm3.ind"
full_eas_gwas_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/CAS/gwas/gwas/"
res_df = []
output_base_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/full_model/"

trait_dict = {
        'p48': 'waist',
        'p50': 'height',
        'p102': 'pulse',
        'p4079': 'dbp',
        'p4080': 'sbp',
        'p20116': 'smoke',
        'p20117': 'drink',
        'p21001': 'bmi',
        'p30000': 'wbc',
        'p30010': 'rbc',
        'p30020':'hb',
        'p30080': 'plt',
        'p30120': 'lymph',
        'p30130': 'mono',
        'p30140': 'neut',
        'p30150': 'eos',
        'p30620': 'alt',
        'p30650': 'ast',
        'p30670': 'bun',
        'p30690': 'cholesterol',
        'p30700': 'creatinine',
        'p30730': 'ggt',
        'p30740': 'glucose',
        'p30760': 'hdl',
        'p30780': 'ldl',
        'p30870': 'triglycerides',
        'p30880': 'ua'
    }

r2_data = pd.read_csv(r2_data_path)
# print(r2_data.head())
for (trait, p_val_threshold), group in r2_data.groupby(['trait', 'p_val_threshold']):
    avg_incremental_r2 = group['r2_incremental'].mean()
    res_df.append({
        "trait": trait,
        "p_val_threshold": p_val_threshold,
        "average_incremental_r2": avg_incremental_r2
    })
res_df = pd.DataFrame(res_df)
res_df = res_df.sort_values(by=['trait', 'average_incremental_r2'], ascending=[True, False])
# display(res_df[res_df['trait'].str.contains('p30620')])
# for each trait, get the best p_val_threshold
best_res = res_df.loc[res_df.groupby('trait')['average_incremental_r2'].idxmax()]
best_res = best_res.sort_values(by='trait')
# display(best_res)
print(best_res.shape)

for row in best_res.itertuples(index=False):
    trait = row.trait
    trait_name = trait_dict[trait.split("_")[0]]
    # print(f"Processing trait {trait_name} with p-value threshold {row.p_val_threshold}")
    gwas_path = os.path.join(full_eas_gwas_dir, f"{trait_name}_int.{trait_name}_int.glm.linear")
    if trait_name in ['smoke', 'drink']:
        gwas_path = os.path.join(full_eas_gwas_dir, f"{trait_name}_raw.{trait_name}_raw.glm.logistic")
    p_val_threshold = row.p_val_threshold
    output_prefix = os.path.join(output_base_dir, f"EAS/{trait_name}")
    if not os.path.exists(os.path.dirname(output_prefix)):
        os.makedirs(os.path.dirname(output_prefix))
    command = f"plink2 --bfile {eas_bfile_path} --clump {gwas_path} --clump-p1 {p_val_threshold} --clump-r2 0.1 --clump-kb 500 --out {output_prefix}"
    os.system(command)

Unnamed: 0,trait,p_val_threshold,average_incremental_r2
96,p30620_raw,5e-06,0.005615636
100,p30620_raw,0.05,0.005498912
98,p30620_raw,0.0005,0.005405407
97,p30620_raw,5e-05,0.004948829
99,p30620_raw,0.005,0.004800611
102,p30620_raw,1.0,0.003208825
101,p30620_raw,0.5,0.002872917
95,p30620_raw,5e-07,3.557112e-07


(27, 3)


In [None]:
plink2 --bfile /data1/jiapl_group/lishuhua/project/PRS_benchmark/software/tlprs/reference/EAS_1kg/1000G.EAS.QC.hm3.ind --clump /data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/CAS/gwas/gwas/alt_int.alt_int.glm.linear --clump-kb 500 --clump-p1 0.05 --clump-r2 0.1 --out /data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/full_model/EAS/alt

In [None]:
# EUR clump res demo: /data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/clumped/alt/group_1/r2_0.1_w_500/ukb_chr10.clumps
# EUR gwas demo: gwas_chr2.p30010_int.glm.linear

import pandas as pd
import os
import numpy as np

clump_res_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/clumped/"
eur_base_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/UKB_EUR/train/gwas/"
output_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/filtered/"

p_val_list = [5e-8, 5e-7, 5e-6, 5e-5, 5e-4, 5e-3, 5e-2, 0.5, 1]

trait_dict = {
    'p48': 'waist',
    'p50': 'height',
    'p102': 'pulse',
    'p4079': 'dbp',
    'p4080': 'sbp',
    'p20116': 'smoke',
    'p20117': 'drink',
    'p21001': 'bmi',
    'p30000': 'wbc',
    'p30010': 'rbc',
    'p30020':'hb',
    'p30080': 'plt',
    'p30120': 'lymph',
    'p30130': 'mono',
    'p30140': 'neut',
    'p30150': 'eos',
    'p30620': 'alt',
    'p30650': 'ast',
    'p30670': 'bun',
    'p30690': 'cholesterol',
    'p30700': 'creatinine',
    'p30730': 'ggt',
    'p30740': 'glucose',
    'p30760': 'hdl',
    'p30780': 'ldl',
    'p30870': 'triglycerides',
    'p30880': 'ua'
}

for trait_code, trait in trait_dict.items():
    for group in range(10, 11):
        for chrom in range(1, 23):
            print(f"Processing trait {trait}, group {group}, chrom {chrom}")
            gwas_file = os.path.join(eur_base_dir, f"fold_{group}/gwas_chr{chrom}.{trait_code}_int.glm.linear")
            if trait in ['smoke', 'drink']:
                gwas_file = os.path.join(eur_base_dir, f"fold_{group}/gwas_chr{chrom}_binary.{trait_code}_int.glm.linear")
            if not os.path.exists(gwas_file):
                print(f"File {gwas_file} does not exist. Skipping.")
                continue
            gwas_df = pd.read_csv(gwas_file, sep="\t")
            for p_val in p_val_list:
                print()
                clump_file = os.path.join(clump_res_dir, f"{trait}/group_{group}/r2_0.1_w_500/ukb_chr{chrom}.clumps")
                if not os.path.exists(clump_file):
                    print(f"File {clump_file} does not exist. Skipping.")
                    continue
                df = pd.read_csv(clump_file, sep="\t")
                df_filtered = df[df['P'] <= p_val]
                df_filtered = df_filtered[["#CHROM", "POS", "ID"]]
                res = pd.merge(df_filtered, gwas_df, on=["#CHROM", "POS", "ID"], how='inner')
                if res.shape[0] == 0:
                    print(f"No SNPs passed the p-value threshold {p_val} for trait {trait}, group {group}. Skipping.")
                    continue
                output_subdir = os.path.join(output_dir, f"{trait}/EUR/group_{group}/")
                os.makedirs(output_subdir, exist_ok=True)
                output_file = os.path.join(output_subdir, f"chr_{chrom}_pval_{p_val}.clumped")
                res.to_csv(output_file, sep="\t", index=False, header=True)
                print(f"Saved filtered clump file to {output_file}")

In [None]:
import pandas as pd
import os
import numpy as np

trait_dict = {
    'p48': 'waist',
    'p50': 'height',
    'p102': 'pulse',
    'p4079': 'dbp',
    'p4080': 'sbp',
    'p20116': 'smoke',
    'p20117': 'drink',
    'p21001': 'bmi',
    'p30000': 'wbc',
    'p30010': 'rbc',
    'p30020':'hb',
    'p30080': 'plt',
    'p30120': 'lymph',
    'p30130': 'mono',
    'p30140': 'neut',
    'p30150': 'eos',
    'p30620': 'alt',
    'p30650': 'ast',
    'p30670': 'bun',
    'p30690': 'cholesterol',
    'p30700': 'creatinine',
    'p30730': 'ggt',
    'p30740': 'glucose',
    'p30760': 'hdl',
    'p30780': 'ldl',
    'p30870': 'triglycerides',
    'p30880': 'ua'
}

p_val_list = [5e-8, 5e-7, 5e-6, 5e-5, 5e-4, 5e-3, 5e-2, 0.5, 1]

clump_res_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/filtered/"
eur_bfile_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/UKB_EUR/test/"
output_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/test_in_sample/"

for trait_code, trait in trait_dict.items():
    for group in range(1, 11):
        for chrom in range(1, 23):
            for p_val in p_val_list:
                clump_file = os.path.join(clump_res_dir, f"{trait}/EUR/group_{group}/chr_{chrom}_pval_{p_val}.clumped")
                if not os.path.exists(clump_file):
                    print(f"File {clump_file} does not exist. Skipping.")
                    continue
                eur_bfile = os.path.join(eur_bfile_dir, f"fold_{group}/chr{chrom}")
                if not os.path.exists(eur_bfile + ".bed"):
                    print(f"File {eur_bfile}.bed does not exist. Skipping.")
                    continue
                output_subdir = os.path.join(output_dir, f"EUR/{trait}/group_{group}/")
                os.makedirs(output_subdir, exist_ok=True)
                output_prefix = os.path.join(output_subdir, f"chr_{chrom}_pval_{p_val}")
                prs_command = f"plink2 --bfile {eur_bfile} --score {clump_file} 3 5 12 header no-mean-imputation --out {output_prefix}"
                os.system(prs_command)

In [None]:
### COMBINE ALL .sscore FILES INTO ONE .tsv FILE FOR EACH TRAIT
import pandas as pd
import os
import numpy as np

prs_base_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/test_in_sample/EUR/"
output_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/test_in_sample/EUR/"
trait_list = ['waist', 'height', 'pulse', 'dbp', 'sbp', 'smoke', 'drink', 'bmi', 'wbc', 'rbc', 'hb', 'plt', 'lymph', 'mono', 'neut', 'eos', 'alt', 'ast', 'bun', 'cholesterol', 'creatinine', 'glucose', 'ggt', 'hdl', 'ldl', 'triglycerides', 'ua']
p_val_list = [5e-8, 5e-7, 5e-6, 5e-5, 5e-4, 5e-3, 5e-2, 0.5, 1]

for trait in trait_list:
    for group in range(1, 11):
        for p_val in p_val_list:
            print(f"Processing trait {trait}, group {group}, p_val {p_val}")
            output_path = os.path.join(output_dir, f"{trait}/group_{group}/")
            os.makedirs(output_path, exist_ok=True)
            if os.path.exists(os.path.join(output_path, f"pval_{p_val}.tsv")):
                print(f"File {os.path.join(output_path, f'pval_{p_val}.tsv')} already exists. Skipping.")
                continue
            all_chr_dfs = []
            for chrom in range(1, 23):
                prs_file_path = os.path.join(prs_base_dir, f"{trait}/group_{group}/chr_{chrom}_pval_{p_val}.sscore")
                if not os.path.exists(prs_file_path):
                    print(f"File {prs_file_path} does not exist. Skipping.")
                    continue
                prs_data = pd.read_csv(prs_file_path, sep='\t')
                all_chr_dfs.append(prs_data)
            if not all_chr_dfs:
                print(f"No chromosome files found for trait {trait}, group {group}, p_val {p_val}. Skipping.")
                continue
            combined_prs_df = pd.concat(all_chr_dfs, ignore_index=True)
            combined_prs_df.to_csv(os.path.join(output_path, f"pval_{p_val}.tsv"), sep='\t', index=False, header=True)

In [None]:
## TEST FOR EUR IN SAMPLE
import pandas as pd
import numpy as np
import statsmodels.api as sm
from sklearn.metrics import roc_auc_score, average_precision_score, mean_squared_error
from scipy.stats import pearsonr
import os
import warnings

# Ignore warnings that may arise from certain fits in statsmodels
from statsmodels.tools.sm_exceptions import ConvergenceWarning
warnings.simplefilter('ignore', ConvergenceWarning)

# --- 1. Metric Calculation Functions ---

def calculate_continuous_metrics(df, base_covars, full_covars):
    """Calculates all performance metrics for a continuous trait for a given dataset (df)."""
    # Incremental R²
    model_base = sm.OLS(df["trait"], sm.add_constant(df[base_covars])).fit()
    model_full = sm.OLS(df["trait"], sm.add_constant(df[full_covars])).fit()
    r2_incremental = model_full.rsquared - model_base.rsquared

    # Pearson correlation coefficient (SCORE vs. phenotype residuals)
    pheno_residuals = model_base.resid
    corr, _ = pearsonr(df["SCORE"], pheno_residuals)

    # RMSE
    prediction_full = model_full.predict(sm.add_constant(df[full_covars]))
    rmse = np.sqrt(mean_squared_error(df["trait"], prediction_full))
    nrmse_mean = rmse / df["trait"].mean() if df["trait"].mean() != 0 else np.nan
    nrmse_range = rmse / (df["trait"].max() - df["trait"].min()) if (df["trait"].max() - df["trait"].min()) != 0 else np.nan
    nrmse_std = rmse / df["trait"].std() if df["trait"].std() != 0 else np.nan

    # Quantile means
    df['quantile'] = pd.qcut(df['SCORE'], 5, labels=False, duplicates='drop')
    quantile_means = df.groupby('quantile')['trait'].mean()
    
    return {
        "r2_incremental": r2_incremental,
        "r2_full": model_full.rsquared,
        "rmse": rmse,
        "nrmse_mean": nrmse_mean,
        "nrmse_range": nrmse_range,
        "nrmse_std": nrmse_std,
        "pearson_r": corr,
        "top_quintile_mean": quantile_means.iloc[-1] if not quantile_means.empty else np.nan,
        "bottom_quintile_mean": quantile_means.iloc[0] if not quantile_means.empty else np.nan
    }

def calculate_binary_metrics(df, base_covars, full_covars):
    """Calculates all performance metrics for a binary trait for a given dataset (df)."""
    # AUC and PR-AUC
    logit_model = sm.Logit(df["trait"], sm.add_constant(df[full_covars])).fit(disp=0)
    pred_prob = logit_model.predict(sm.add_constant(df[full_covars]))
    auc = roc_auc_score(df["trait"], pred_prob)
    pr_auc = average_precision_score(df["trait"], pred_prob)

    # OR per 1-SD
    df["prs_scaled"] = (df["SCORE"] - df["SCORE"].mean()) / df["SCORE"].std()
    logit_model_scaled = sm.Logit(df["trait"], sm.add_constant(df[base_covars + ["prs_scaled"]])).fit(disp=0)
    or_per_sd = np.exp(logit_model_scaled.params["prs_scaled"])

    # Quantile OR
    df['prs_quintile'] = pd.qcut(df['SCORE'], 5, labels=False, duplicates='drop')
    reference_quintile = 2 # Middle quintile
    or_quintiles = {}
    for q in range(5):
        if q == reference_quintile:
            or_quintiles[f'OR_Quintile_{q+1}'] = 1.0
            continue
        
        # Check if both current and reference quintiles exist in the data
        if not df['prs_quintile'].isin([q, reference_quintile]).all():
             or_quintiles[f'OR_Quintile_{q+1}'] = np.nan
             continue
             
        temp_df = df[df['prs_quintile'].isin([q, reference_quintile])].copy()
        
        # Check for sufficient data in both groups for stable model fitting
        if temp_df['trait'].nunique() < 2 or temp_df['prs_quintile'].nunique() < 2:
            or_quintiles[f'OR_Quintile_{q+1}'] = np.nan
            continue
            
        temp_df['is_current_quintile'] = (temp_df['prs_quintile'] == q).astype(int)
        X_quintile = sm.add_constant(temp_df[['is_current_quintile'] + base_covars])
        try:
            model_q = sm.Logit(temp_df["trait"], X_quintile).fit(disp=0)
            or_quintiles[f'OR_Quintile_{q+1}'] = np.exp(model_q.params['is_current_quintile'])
        except Exception:
            or_quintiles[f'OR_Quintile_{q+1}'] = np.nan
            
    results = {
        "auc": auc,
        "pr_auc": pr_auc,
        "or_per_sd": or_per_sd,
    }
    results.update(or_quintiles) # Merge quantile ORs into the results dictionary
    return results

# --- 2. Main Execution Flow ---

def main():
    # --- Parameter Settings ---
    # !! Note: Please modify the variables below according to your actual paths !!
    cleaned_prs_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/test_in_sample/"
    covar_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/UKB/pheno/covar/covars_white_british_final.tsv"
    output_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/test_in_sample/"
    pheno_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/UKB/pheno/trait/White_British/"
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # --- Initialization ---
    final_results_continuous = []
    covar_cols = ["FID", "IID", "age", "sex"] + [f"PC{i}" for i in range(1, 21)]
    base_covars = ["age", "sex"] + [f"PC{i}" for i in range(1, 21)]
    full_covars = base_covars + ["SCORE"]

    # --- Data Loading and Processing ---
    covars = pd.read_csv(covar_path, sep='\t', usecols=covar_cols)

    trait_dict = {
        'p48': 'waist',
        'p50': 'height',
        'p102': 'pulse',
        'p4079': 'dbp',
        'p4080': 'sbp',
        'p20116': 'smoke',
        'p20117': 'drink',
        'p21001': 'bmi',
        'p30000': 'wbc',
        'p30010': 'rbc',
        'p30020':'hb',
        'p30080': 'plt',
        'p30120': 'lymph',
        'p30130': 'mono',
        'p30140': 'neut',
        'p30150': 'eos',
        'p30620': 'alt',
        'p30650': 'ast',
        'p30670': 'bun',
        'p30690': 'cholesterol',
        'p30700': 'creatinine',
        'p30730': 'ggt',
        'p30740': 'glucose',
        'p30760': 'hdl',
        'p30780': 'ldl',
        'p30870': 'triglycerides',
        'p30880': 'ua'
    }
    p_val_list = [5e-8, 5e-7, 5e-6, 5e-5, 5e-4, 5e-3, 5e-2, 0.5, 1]
    # /data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/CAS/alt/group_1/pheno/test.txt

    for trait, name in trait_dict.items():
        for i in range(1, 11):
            for p_val in p_val_list:
                trait_prs_path = os.path.join(cleaned_prs_path, f"EUR/{name}/group_{i}/pval_{p_val}.tsv")
                if not os.path.exists(trait_prs_path):
                    print(f"Warning: PRS file for trait {name} group {i} not found. Skipping.")
                    continue
                print(f"\nProcessing Trait: {name}, Group: {i}, P-value: {p_val}")
                if trait == 'p20116' or trait == 'p20117':
                    trait_id_from_prs = f"{trait}_int"
                else:
                    trait_id_from_prs = f"{trait}_raw"
                pheno_path = os.path.join(pheno_dir, f"{trait_id_from_prs}.txt")
                pheno = pd.read_csv(pheno_path, sep='\t')
                pheno.columns = ["FID", "IID", "trait"]
                print(pheno.head())
                prs = pd.read_csv(trait_prs_path, sep='\t')
                prs.rename(columns={"#FID": "FID"}, inplace=True)
                prs.rename(columns={"SCORE1_AVG": "SCORE"}, inplace=True)

                for df in [pheno, prs, covars]:
                    df["FID"] = df["FID"].astype(str)
                    df["IID"] = df["IID"].astype(str)
                merged_data = pd.merge(pheno, covars, on=["FID", "IID"], how="inner")
                merged_data = pd.merge(merged_data, prs, on=["FID", "IID"], how="inner")
                print(merged_data.head())
                # Defensive data cleaning and type conversion
                numeric_cols = ["trait", "SCORE", "age", "sex"] + [f"PC{i}" for i in range(1, 11)]
                for col in numeric_cols:
                    if col in merged_data.columns:
                        merged_data[col] = pd.to_numeric(merged_data[col], errors='coerce')

                original_rows = len(merged_data)
                merged_data.dropna(subset=numeric_cols, inplace=True)
                new_rows = len(merged_data)
                if original_rows > new_rows:
                    print(f"--> Warning: Dropped {original_rows - new_rows} rows due to non-numeric data or NaNs in trait {trait_id_from_prs}.")
                if new_rows == 0:
                    print(f"--> Error: No valid samples remained for trait {trait_id_from_prs} after cleaning. Skipping.")
                    continue
                print(f"Data merged and cleaned for trait {trait_id_from_prs}. Total samples: {len(merged_data)}")

                if trait_id_from_prs in ["p20116_int", "p20117_int"]:
                    # Binary trait analysis
                    # Ensure binary trait is 0/1 coded
                    unique_vals = sorted(merged_data["trait"].unique())
                    if not set(unique_vals).issubset({0, 1}):
                        if len(unique_vals) == 2:
                            print(f"Converting binary trait from {unique_vals} to 0/1.")
                            merged_data["trait"] = (merged_data["trait"] == unique_vals[1]).astype(int)
                        else:
                            print(f"Error: Binary trait column for {trait_id_from_prs} contains unexpected values: {unique_vals}. Skipping.")
                            continue
                        
                    # Direct calculation of metrics
                    analysis_report = calculate_continuous_metrics(merged_data, base_covars, full_covars)
                    analysis_report['trait'] = trait_id_from_prs
                    analysis_report['p_val_threshold'] = p_val
                    final_results_continuous.append(analysis_report)
                else:
                    # Continuous trait analysis
                    # Direct calculation of metrics
                    analysis_report = calculate_continuous_metrics(merged_data, base_covars, full_covars)
                    analysis_report['trait'] = trait_id_from_prs
                    analysis_report['p_val_threshold'] = p_val
                    final_results_continuous.append(analysis_report)

        # --- 3. Save Final Results ---
        if final_results_continuous:
            continuous_df = pd.DataFrame(final_results_continuous)
            # Reorder columns to have 'trait' first
            cols = ['trait'] + [col for col in continuous_df.columns if col != 'trait']
            continuous_df = continuous_df[cols]
            continuous_df.to_csv(os.path.join(output_dir, "EUR_in_sample_metrics.csv"), index=False)
            print("\nContinuous trait results saved to EUR_in_sample_metrics.csv")
            print(continuous_df)

if __name__ == '__main__':
    main()

In [3]:
import pandas as pd
import os
import numpy as np

# r2_data_path = "../../../../PRS_benchmark/data/result/real_data/ct/EUR_in_sample_metrics.csv"
r2_data_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/test_in_sample/EUR_in_sample_metrics.csv"
eur_bfile_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/UKB/geno/White_British/0_sample_qc/"
full_eur_gwas_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/UKB/gwas/White_British/"
res_df = []
output_base_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/ct/res/full_model/"

trait_dict = {
        'p48': 'waist',
        'p50': 'height',
        'p102': 'pulse',
        'p4079': 'dbp',
        'p4080': 'sbp',
        'p20116': 'smoke',
        'p20117': 'drink',
        'p21001': 'bmi',
        'p30000': 'wbc',
        'p30010': 'rbc',
        'p30020':'hb',
        'p30080': 'plt',
        'p30120': 'lymph',
        'p30130': 'mono',
        'p30140': 'neut',
        'p30150': 'eos',
        'p30620': 'alt',
        'p30650': 'ast',
        'p30670': 'bun',
        'p30690': 'cholesterol',
        'p30700': 'creatinine',
        'p30730': 'ggt',
        'p30740': 'glucose',
        'p30760': 'hdl',
        'p30780': 'ldl',
        'p30870': 'triglycerides',
        'p30880': 'ua'
    }

r2_data = pd.read_csv(r2_data_path)
for (trait, p_val_threshold), group in r2_data.groupby(['trait', 'p_val_threshold']):
    avg_incremental_r2 = group['r2_incremental'].mean()
    res_df.append({
        "trait": trait,
        "p_val_threshold": p_val_threshold,
        "average_incremental_r2": avg_incremental_r2
    })
res_df = pd.DataFrame(res_df)
res_df = res_df.sort_values(by=['trait', 'average_incremental_r2'], ascending=[True, False])
# for each trait, get the best p_val_threshold
best_res = res_df.loc[res_df.groupby('trait')['average_incremental_r2'].idxmax()]
best_res = best_res.sort_values(by='trait')
# display(best_res)
print(best_res.shape)

for row in best_res.itertuples(index=False):
    # p102_int_chr1.p102_int.glm.linear
    # p20116_int_chr1.p20116_int.glm.logistic
    trait = row.trait
    trait = trait.replace("_raw", "").replace("_int", "")
    trait_name = trait_dict[trait]
    p_val_threshold = row.p_val_threshold
    for chrom in range(1, 23):
        print(f"Processing trait {trait_name} with p-value threshold {p_val_threshold} on chr {chrom}")
        eur_bfile_path = os.path.join(eur_bfile_dir, f"chr{chrom}")
        gwas_path = os.path.join(full_eur_gwas_path, f"{trait}_int_chr{chrom}.{trait}_int.glm.linear")
        if trait_name in ['smoke', 'drink']:
            gwas_path = os.path.join(full_eur_gwas_path, f"{trait}_int_chr{chrom}.{trait}_int.glm.logistic")
        output_prefix = os.path.join(output_base_dir, f"EUR/{trait_name}/chr_{chrom}")
        if not os.path.exists(os.path.dirname(output_prefix)):
            os.makedirs(os.path.dirname(output_prefix))
        command = f"plink2 --bfile {eur_bfile_path} --clump {gwas_path} --rm-dup exclude-mismatch --clump-p1 {p_val_threshold} --clump-r2 0.1 --clump-kb 500 --out {output_prefix}"
        os.system(command)

Unnamed: 0,trait,p_val_threshold,average_incremental_r2
6,p102_raw,0.05,0.000983
16,p4079_raw,0.5,0.000901
24,p4080_raw,0.05,0.000997
35,p48_raw,1.0,0.001428
42,p50_raw,0.05,0.003941


(5, 3)
