In [None]:
# demo: /data1/jiapl_group/lishuhua/project/PRS_benchmark/software/lassosum/res/beta/EAS/ua/group_10_beta.txt
import pandas as pd
import os
import numpy as np

beta_base_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/lassosum/res/beta/EAS/"
test_bfile_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/UKB/geno/Chinese/1_merged/merged"
output_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/lassosum/res/test_prs/EAS/"

trait_list = ['waist', 'height', 'pulse', 'dbp', 'sbp', 'smoke', 'drink', 'bmi', 'wbc', 'rbc', 'hb', 'plt', 'lymph', 'mono', 'neut', 'eos', 'alt', 'ast', 'bun', 'cholesterol', 'creatinine', 'glucose', 'ggt', 'hdl', 'ldl', 'triglycerides', 'ua']

for trait in trait_list:
    for group in range(1, 11):
        beta_path = os.path.join(beta_base_dir, f"{trait}/group_{group}_beta.txt")
        if not os.path.exists(beta_path):
            print(f"Warning: Beta file for trait {trait} group {group} not found. Skipping.")
            continue
        output_prefix = os.path.join(output_dir, f"{trait}/group_{group}")
        if not os.path.exists(os.path.dirname(output_prefix)):
            os.makedirs(os.path.dirname(output_prefix))
        command = f"plink2 --bfile {test_bfile_path} --score {beta_path} 5 3 8 header no-mean-imputation --out {output_prefix}"
        print(command)
        os.system(command)

In [None]:
import pandas as pd
import numpy as np
import statsmodels.api as sm
from sklearn.metrics import roc_auc_score, average_precision_score, mean_squared_error
from scipy.stats import pearsonr
import os
import warnings

# Ignore warnings that may arise from certain fits in statsmodels
from statsmodels.tools.sm_exceptions import ConvergenceWarning
warnings.simplefilter('ignore', ConvergenceWarning)

# --- 1. Metric Calculation Functions ---

def calculate_continuous_metrics(df, base_covars, full_covars):
    """Calculates all performance metrics for a continuous trait for a given dataset (df)."""
    # Incremental R²
    model_base = sm.OLS(df["trait"], sm.add_constant(df[base_covars])).fit()
    model_full = sm.OLS(df["trait"], sm.add_constant(df[full_covars])).fit()
    r2_incremental = model_full.rsquared - model_base.rsquared

    # Pearson correlation coefficient (SCORE vs. phenotype residuals)
    pheno_residuals = model_base.resid
    corr, _ = pearsonr(df["SCORE"], pheno_residuals)

    # RMSE
    prediction_full = model_full.predict(sm.add_constant(df[full_covars]))
    rmse = np.sqrt(mean_squared_error(df["trait"], prediction_full))
    nrmse_mean = rmse / df["trait"].mean() if df["trait"].mean() != 0 else np.nan
    nrmse_range = rmse / (df["trait"].max() - df["trait"].min()) if (df["trait"].max() - df["trait"].min()) != 0 else np.nan
    nrmse_std = rmse / df["trait"].std() if df["trait"].std() != 0 else np.nan

    # Quantile means
    df['quantile'] = pd.qcut(df['SCORE'], 5, labels=False, duplicates='drop')
    quantile_means = df.groupby('quantile')['trait'].mean()
    
    return {
        "r2_incremental": r2_incremental,
        "r2_full": model_full.rsquared,
        "rmse": rmse,
        "nrmse_mean": nrmse_mean,
        "nrmse_range": nrmse_range,
        "nrmse_std": nrmse_std,
        "pearson_r": corr,
        "top_quintile_mean": quantile_means.iloc[-1] if not quantile_means.empty else np.nan,
        "bottom_quintile_mean": quantile_means.iloc[0] if not quantile_means.empty else np.nan
    }

def calculate_binary_metrics(df, base_covars, full_covars):
    """Calculates all performance metrics for a binary trait for a given dataset (df)."""
    # AUC and PR-AUC
    logit_model = sm.Logit(df["trait"], sm.add_constant(df[full_covars])).fit(disp=0)
    pred_prob = logit_model.predict(sm.add_constant(df[full_covars]))
    auc = roc_auc_score(df["trait"], pred_prob)
    pr_auc = average_precision_score(df["trait"], pred_prob)

    # OR per 1-SD
    df["prs_scaled"] = (df["SCORE"] - df["SCORE"].mean()) / df["SCORE"].std()
    logit_model_scaled = sm.Logit(df["trait"], sm.add_constant(df[base_covars + ["prs_scaled"]])).fit(disp=0)
    or_per_sd = np.exp(logit_model_scaled.params["prs_scaled"])

    # Quantile OR
    df['prs_quintile'] = pd.qcut(df['SCORE'], 5, labels=False, duplicates='drop')
    reference_quintile = 2 # Middle quintile
    or_quintiles = {}
    for q in range(5):
        if q == reference_quintile:
            or_quintiles[f'OR_Quintile_{q+1}'] = 1.0
            continue
        
        # Check if both current and reference quintiles exist in the data
        if not df['prs_quintile'].isin([q, reference_quintile]).all():
             or_quintiles[f'OR_Quintile_{q+1}'] = np.nan
             continue
             
        temp_df = df[df['prs_quintile'].isin([q, reference_quintile])].copy()
        
        # Check for sufficient data in both groups for stable model fitting
        if temp_df['trait'].nunique() < 2 or temp_df['prs_quintile'].nunique() < 2:
            or_quintiles[f'OR_Quintile_{q+1}'] = np.nan
            continue
            
        temp_df['is_current_quintile'] = (temp_df['prs_quintile'] == q).astype(int)
        X_quintile = sm.add_constant(temp_df[['is_current_quintile'] + base_covars])
        try:
            model_q = sm.Logit(temp_df["trait"], X_quintile).fit(disp=0)
            or_quintiles[f'OR_Quintile_{q+1}'] = np.exp(model_q.params['is_current_quintile'])
        except Exception:
            or_quintiles[f'OR_Quintile_{q+1}'] = np.nan
            
    results = {
        "auc": auc,
        "pr_auc": pr_auc,
        "or_per_sd": or_per_sd,
    }
    results.update(or_quintiles) # Merge quantile ORs into the results dictionary
    return results

# --- 2. Main Execution Flow ---

def main():
    # --- Parameter Settings ---
    # !! Note: Please modify the variables below according to your actual paths !!
    cleaned_prs_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/lassosum/res/test_prs/EAS/"
    covar_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/UKB/pheno/covar/covars_chinese_final.tsv"
    output_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/software/lassosum/res/test_res/"
    pheno_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/UKB/pheno/trait/Chinese/"
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # --- Initialization ---
    final_results_continuous = []
    final_results_binary = []
    covar_cols = ["FID", "IID", "age", "sex"] + [f"PC{i}" for i in range(1, 11)]
    base_covars = ["age", "sex"] + [f"PC{i}" for i in range(1, 11)]
    full_covars = base_covars + ["SCORE"]

    # --- Data Loading and Processing ---
    covars = pd.read_csv(covar_path, sep='\t', usecols=covar_cols)

    trait_dict = {
        'p48': 'waist',
        'p50': 'height',
        'p102': 'pulse',
        'p4079': 'dbp',
        'p4080': 'sbp',
        'p20116': 'smoke',
        'p20117': 'drink',
        'p21001': 'bmi',
        'p30000': 'wbc',
        'p30010': 'rbc',
        'p30020':'hb',
        'p30080': 'plt',
        'p30120': 'lymph',
        'p30130': 'mono',
        'p30140': 'neut',
        'p30150': 'eos',
        'p30620': 'alt',
        'p30650': 'ast',
        'p30670': 'bun',
        'p30690': 'cholesterol',
        'p30700': 'creatinine',
        'p30730': 'ggt',
        'p30740': 'glucose',
        'p30760': 'hdl',
        'p30780': 'ldl',
        'p30870': 'triglycerides',
        'p30880': 'ua'
    }
    # /data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/CAS/alt/group_1/pheno/test.txt

    for trait, name in trait_dict.items():
        for group in range(1, 11):
            trait_prs_path = os.path.join(cleaned_prs_path, f"{name}/group_{group}.sscore")
            if not os.path.exists(trait_prs_path):
                print(f"Warning: PRS file for trait {name} group {group} not found. Skipping.")
                continue
            print(f"\nProcessing Trait: {name}")
            if trait == 'p20116' or trait == 'p20117':
                trait_id_from_prs = f"{trait}_int"
            else:
                trait_id_from_prs = f"{trait}_raw"
            pheno_path = os.path.join(pheno_dir, f"{trait_id_from_prs}.txt")
            pheno = pd.read_csv(pheno_path, sep='\t')
            pheno.columns = ["FID", "IID", "trait"]
            print(pheno.head())
            prs = pd.read_csv(trait_prs_path, sep='\t')
            prs.rename(columns={"#FID": "FID"}, inplace=True)
            prs.rename(columns={"SCORE1_AVG": "SCORE"}, inplace=True)
            for df in [pheno, prs, covars]:
                df["FID"] = df["FID"].astype(str)
                df["IID"] = df["IID"].astype(str)
            merged_data = pd.merge(pheno, prs, on=["FID", "IID"], how="inner")
            merged_data = pd.merge(merged_data, covars, on=["FID", "IID"], how="inner")
            print(merged_data.head())
            # Defensive data cleaning and type conversion
            numeric_cols = ["trait", "SCORE", "age", "sex"] + [f"PC{i}" for i in range(1, 11)]
            for col in numeric_cols:
                if col in merged_data.columns:
                    merged_data[col] = pd.to_numeric(merged_data[col], errors='coerce')
            original_rows = len(merged_data)
            merged_data.dropna(subset=numeric_cols, inplace=True)
            new_rows = len(merged_data)
            if original_rows > new_rows:
                print(f"--> Warning: Dropped {original_rows - new_rows} rows due to non-numeric data or NaNs in trait {trait_id_from_prs}.")
            if new_rows == 0:
                print(f"--> Error: No valid samples remained for trait {trait_id_from_prs} after cleaning. Skipping.")
                continue
            print(f"Data merged and cleaned for trait {trait_id_from_prs}. Total samples: {len(merged_data)}")
            # analysis_report = calculate_continuous_metrics(merged_data, base_covars, full_covars)
            analysis_report = calculate_continuous_metrics(merged_data, base_covars, full_covars)
            analysis_report['trait'] = trait_id_from_prs
            final_results_continuous.append(analysis_report)
            # --- 3. Save Final Results ---
            if final_results_continuous:
                continuous_df = pd.DataFrame(final_results_continuous)
                # Reorder columns to have 'trait' first
                cols = ['trait'] + [col for col in continuous_df.columns if col != 'trait']
                continuous_df = continuous_df[cols]
                continuous_df.to_csv(os.path.join(output_dir, "EAS_metrics_with_ci.csv"), index=False)
                print("\nContinuous trait results saved to EAS_metrics_with_ci.csv")
                print(continuous_df)

if __name__ == '__main__':
    main()