In [None]:
import pandas as pd
import numpy as np
import glob

beta_file_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/prscsx/res/data/"
output_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/prscsx/res/merged_data/"

def merge_beta_files(trait):
    file_pattern = f"{trait}_EAS_*_chr*.txt"
    output_file_path = f"{output_path}{trait}_merged.txt"
    search_path = f"{beta_file_path}{file_pattern}"
    file_list = sorted(glob.glob(search_path))
    if not file_list:
        print(f"No files found for trait: {trait}")
        return
    else:
        print(f"Found {len(file_list)} files for trait: {trait}")
        df_list = []
        for f in file_list:
            print(f"Processing file: {f}")
            df = pd.read_csv(f, sep=r'\s+', header=None)
            df_list.append(df)
        merged_df = pd.concat(df_list, ignore_index=True)
        merged_df.to_csv(output_file_path, sep='\t', header=False, index=False)

trait_dict = {
    'p48': 'waist',
    'p50': 'height',
    'p102': 'pulse',
    'p4079': 'dbp',
    'p4080': 'sbp',
    'p20116': 'smoke',
    'p20117': 'drink',
    'p21001': 'bmi',
    'p21002': 'weight',
    'p30000': 'wbc',
    'p30010': 'rbc',
    'p30020':'hb',
    'p30080': 'plt',
    'p30120': 'lymph',
    'p30130': 'mono',
    'p30140': 'neut',
    'p30150': 'eos',
    'p30620': 'alt',
    'p30650': 'ast',
    'p30670': 'bun',
    'p30690': 'cholesterol',
    'p30700': 'creatinine',
    'p30730': 'ggt',
    'p30740': 'glucose',
    'p30760': 'hdl',
    'p30780': 'ldl',
    'p30870': 'triglycerides',
    'p30880': 'ua'
}

for trait, name in trait_dict.items():
    trait_name = f'{name}_{trait}'
    print(f"Merging files for trait: {trait_name}")
    merge_beta_files(trait_name)

In [None]:
import pandas as pd
import os
from multiprocessing import Pool, cpu_count

beta_file_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/prscsx/res/merged_data/"
test_bfile_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/UKB/geno/Chinese/1_merged/merged"
plink_path = "/data1/jiapl_group/lishuhua/software/general/plink"
output_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/prscsx/res/test_prs"

tasks = []
# alt_p30620_EAS_pst_eff_a1_b0.5_phiauto_chr15.txt
for beta_file in os.listdir(beta_file_path):
    if beta_file.endswith(".txt"):
        beta = os.path.join(beta_file_path, beta_file)
        trait = beta_file.replace("_merged.txt", "")
        output = os.path.join(output_path, f"{trait}_PRS")
        cmd = f"{plink_path} --bfile {test_bfile_path} --score {beta} 2 4 6 no-mean-imputation --out {output}"
        tasks.append(cmd)

pool = Pool(cpu_count()-1)
pool.map(os.system, tasks)
pool.close()
pool.join()

In [None]:
import pandas as pd
import os
import sys

prs_file_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/prscsx/res/test_prs"
output_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/prscsx/res/test_prs_cleaned"

for prs_file in os.listdir(prs_file_path):
    if prs_file.endswith("PRS.profile"):
        prs_data = open(os.path.join(prs_file_path, prs_file), "r")
        prs_list = []
        for line in prs_data:
            line_arr = line.strip().split(" ")
            line_arr = [x for x in line_arr if x != ""]
            prs_list.append(line_arr)
        prs_data.close()
        prs_df = pd.DataFrame(prs_list[1:], columns=prs_list[0])
        trait = prs_file.replace("_PRS.profile", "")
        prs_df.to_csv(os.path.join(output_path, f"{trait}_PRS.tsv"), sep='\t', index=False, header=True)

In [None]:
import pandas as pd
import numpy as np
import statsmodels.api as sm
from sklearn.metrics import (r2_score, mean_squared_error, roc_auc_score,
average_precision_score, roc_curve)
from sklearn.calibration import calibration_curve
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import seaborn as sns
import os

cleaned_prs_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/prscsx/res/test_prs_cleaned"
covar_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/UKB/pheno/covar/covars_chinese_final.tsv"
output_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/prscsx/res/full_res"
figure_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/prscsx/res/full_res/figures"
pheno_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/UKB/pheno/trait/Chinese"

n_bootstrap = 1000

result_continus = []
result_binary = []
covar_cols = ["FID", "IID", "age", "sex"] + [f"PC{i}" for i in range(1, 11)]
covars = pd.read_csv(covar_path, sep='\t', usecols=covar_cols)
base_covars = ["age", "sex"] + [f"PC{i}" for i in range(1, 11)]
full_covars = base_covars + ["SCORE"]
for prs_file in os.listdir(cleaned_prs_path):
    trait = prs_file.split("_")[1]
    trait = trait + '_int'
    for pheno_file in os.listdir(pheno_dir):
        if pheno_file.startswith(trait) and pheno_file.endswith(".txt"):
            pheno = pd.read_csv(os.path.join(pheno_dir, pheno_file), sep='\t')
            pheno.columns = ["FID", "IID", "trait"]
            prs = pd.read_csv(os.path.join(cleaned_prs_path, prs_file), sep='\t')
            merged_data = pd.merge(pheno, prs, on=["FID", "IID"], how="inner")
            merged_data = pd.merge(merged_data, covars, on=["FID", "IID"], how="inner")
            # check the rows in merged_data and pheno
            if merged_data.shape[0] == pheno.shape[0]:
                print(f"All rows matched for {pheno_file} and {prs_file}.")
            if trait == "p20116_int" or trait == "p20117_int":
                # do the binary test
                # Ensure binary trait is coded as 0/1
                if not set(merged_data["trait"].unique()).issubset({0, 1}):
                    # Try to convert to 0/1 if possible
                    unique_vals = sorted(merged_data["trait"].unique())
                    if len(unique_vals) == 2:
                        merged_data["trait"] = (merged_data["trait"] == unique_vals[1]).astype(int)
                    else:
                        raise ValueError(f"Trait column contains unexpected values: {unique_vals}")

                logit_model = sm.Logit(merged_data["trait"], sm.add_constant(merged_data[full_covars])).fit(disp=0)
                pred_prob = logit_model.predict(sm.add_constant(merged_data[full_covars]))

                # Step1: calculate AUC
                auc = roc_auc_score(merged_data["trait"], pred_prob)
                pr_auc = average_precision_score(merged_data["trait"], pred_prob)

                # Step2: calculate OR per 1-SD
                merged_data["prs_scaled"] = (merged_data["SCORE"] - merged_data["SCORE"].mean()) / merged_data["SCORE"].std()
                X_full_scaled = sm.add_constant(merged_data[full_covars + ["prs_scaled"]])
                logit_model_scaled = sm.Logit(merged_data["trait"], X_full_scaled).fit(disp=0)
                or_per_sd = np.exp(logit_model_scaled.params["prs_scaled"])

                # Step3: calculate quantile stratification
                merged_data['prs_quintile'] = pd.qcut(merged_data['SCORE'], 5, labels=False, duplicates='drop')
                # 使用中间组(第2组, index=2)作为参照
                reference_quintile = 2
                or_quintiles = {}

                for q in range(5):
                    if q == reference_quintile:
                        or_quintiles[f'Quintile {q+1}'] = 1.0
                        continue

                    # 仅比较当前分位数和参照分位数
                    temp_df = merged_data[merged_data['prs_quintile'].isin([q, reference_quintile])].copy()
                    temp_df['is_current_quintile'] = (temp_df['prs_quintile'] == q).astype(int)
    
                    X_quintile = sm.add_constant(temp_df[['is_current_quintile'] + base_covars])
                    model_q = sm.Logit(temp_df["trait"], X_quintile).fit(disp=0)
                    or_quintiles[f'Quintile {q+1}'] = np.exp(model_q.params['is_current_quintile'])

                print("2.3.1 Quintile Odds Ratios (OR): (Reference: Quintile 3)")
                for q_name, or_val in or_quintiles.items():
                    print(f" {q_name}: {or_val:.2f}")

                # 可视化分位数OR
                plt.figure(figsize=(10, 6))
                sns.barplot(x=list(or_quintiles.keys()), y=list(or_quintiles.values()), color='coral')
                plt.axhline(1.0, color='black', linestyle='--')
                plt.title('Disease Risk Odds Ratio (OR) vs. PRS Quintiles', fontsize=16)
                plt.xlabel('PRS Quintile', fontsize=12)
                plt.ylabel('Odds Ratio (vs. Quintile 3)', fontsize=12)
                plt.savefig(os.path.join(figure_path, f"{trait}_prs_quintile_or.png"))

                # --- 2.4 校准度 (Calibration) ---
                prob_true, prob_pred = calibration_curve(merged_data["trait"], pred_prob, n_bins=10, strategy='uniform')

                plt.figure(figsize=(8, 8))
                plt.plot(prob_pred, prob_true, marker='o', linewidth=1, label='Calibration Curve')
                plt.plot([0, 1], [0, 1], linestyle='--', color='black', label='Perfect Calibration')
                plt.xlabel('Mean Predicted Probability', fontsize=12)
                plt.ylabel('Fraction of Positives (True Probability)', fontsize=12)
                plt.title('Calibration Plot', fontsize=16)
                plt.legend()
                plt.savefig(os.path.join(figure_path, f"{trait}_calibration_curve.png"))
                plt.close()

                result_binary.append({
                    "trait": trait,
                    "auc": auc,
                    "pr_auc": pr_auc,
                    "or_per_sd": or_per_sd,
                    "quantile_or": or_quintiles
                })

            else:
                # do the continuous test
                # Step1: caculate incremental R2
                model_base = sm.OLS(merged_data["trait"], sm.add_constant(merged_data[base_covars])).fit()
                model_full = sm.OLS(merged_data["trait"], sm.add_constant(merged_data[full_covars])).fit()
                r2_incremental = model_full.rsquared - model_base.rsquared

                # Step2: calculate R2, RMSE, Pearson's r
                pheno_residuals = model_base.resid
                corr, p_val = pearsonr(merged_data["SCORE"], pheno_residuals)

                # Step3: calculate RMSE
                prediction_full = model_full.predict(sm.add_constant(merged_data[full_covars]))
                rmse = np.sqrt(mean_squared_error(merged_data["trait"], prediction_full))

                # Step 4: calculate quantile stratification
                merged_data['quantile'] = pd.qcut(merged_data['SCORE'], 5, labels=False, duplicates='drop')
                quantile_means = merged_data.groupby('quantile')['trait'].mean()
                top_decile_mean = quantile_means.iloc[-1]
                bottom_decile_mean = quantile_means.iloc[0]

                plt.figure(figsize=(10, 6))
                sns.barplot(x=quantile_means.index + 1, y=quantile_means.values, color='skyblue')
                plt.title('Quantitative Trait Mean vs. PRS Deciles', fontsize=16)
                plt.xlabel('PRS Decile (1=Lowest, 5=Highest)', fontsize=12)
                plt.ylabel(f'Mean {trait}', fontsize=12)
                plt.savefig(os.path.join(figure_path, f"{trait}_prs_decile_mean.png"))
                plt.close()

                result_continus.append({
                    "trait": trait,
                    "r2_incremental": r2_incremental,
                    "r2": model_full.rsquared,
                    "rmse": rmse,
                    "pearson_r": corr,
                    "top_decile_mean": top_decile_mean,
                    "bottom_decile_mean": bottom_decile_mean
                })
result_continus_df = pd.DataFrame(result_continus)
result_continus_df.to_csv(os.path.join(output_path, "results_continuous.tsv"), sep='\t', index=False)
result_binary_df = pd.DataFrame(result_binary)
result_binary_df.to_csv(os.path.join(output_path, "results_binary.tsv"), sep='\t', index=False)