In [None]:
# get the UKB EUR ids list and split it into 10 iterations
import pandas as pd
import os
import numpy as np
from sklearn.model_selection import KFold

def check_distribution(subset_df, set_name, pheno_list):
    print(f"Checking distribution for {set_name} set...")
    for pheno in pheno_list:
        pheno_series = subset_df[pheno].dropna()
        class_counts = pheno_series.value_counts()
        if class_counts.shape[0] < 2:
            print(f"Warning: {set_name} set for {pheno} has only one class: {class_counts.index.tolist()}")
            return False
        else:
            distribution = class_counts / class_counts.sum()
            print(f"{set_name} set for {pheno} class distribution:\n{distribution}\n")
    return True

ukb_eur_pheno_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/UKB/pheno/trait/trait_ukb_white_british.txt"
ukb_eur_covar_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/UKB/pheno/covar/covars_white_british_final.tsv"
ukb_eur_pheno_output_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/UKB_EUR/pheno/"
binary_pheno_cols = ['p20116_int', 'p20117_int']

pheno = pd.read_csv(ukb_eur_pheno_path, sep="\t", header=0)
covar = pd.read_csv(ukb_eur_covar_path, sep="\t", header=0)
pheno_cols = pheno.columns.tolist()
pheno_raw_cols = ["FID", "IID", "eid"] + [col for col in pheno_cols if col.endswith("_raw")]
pheno_int_cols = ["FID", "IID", "eid"] + [col for col in pheno_cols if col.endswith("_int")]
covar_cols = covar.columns.tolist()
# "age,sex,PC1,PC2,PC3,PC4,PC5,PC6,PC7,PC8,PC9,PC10,PC11,PC12,PC13,PC14,PC15,PC16,PC17,PC18,PC19,PC20"
covar_for_gwas_cols = ["FID", "IID", "age", "sex", "PC1", "PC2", "PC3", "PC4", "PC5", "PC6", "PC7", "PC8", "PC9", "PC10", "PC11", "PC12", "PC13", "PC14", "PC15", "PC16", "PC17", "PC18", "PC19", "PC20"]

pheno_covar = pd.merge(pheno, covar, on=["FID", "IID"], how="inner")
if pheno_covar.shape[0] != pheno.shape[0]:
    print("pheno_covar and pheno do not match, please check!")
print(f'Processing 10-fold cross-validation for {pheno_covar.shape[0]} samples.')

### core function: 10 fold cross-validation ###
n_splits = 10
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
all_fold_indices = [test_idx for _, test_idx in kf.split(pheno_covar)]
print(f"Starting 10-fold cross-validation splitting...")
for i in range(n_splits):
    test_indices = all_fold_indices[i]
    test_df = pheno_covar.iloc[test_indices]

    # Split the rest into train and validation sets (80% train, 10% val)
    tune_fold_index = (i + 1) % n_splits
    tune_indices = all_fold_indices[tune_fold_index]
    tune_df = pheno_covar.iloc[tune_indices]

    train_fold_selection = [j for j in range(n_splits) if j != i and j != tune_fold_index]
    train_indices = [idx for j in train_fold_selection for idx in all_fold_indices[j]]
    train_df = pheno_covar.iloc[train_indices]

    # Check distribution of binary phenotypes in each set
    if not (check_distribution(train_df, "Train", binary_pheno_cols) and
            check_distribution(tune_df, "Validation", binary_pheno_cols) and
            check_distribution(test_df, "Test", binary_pheno_cols)):
        print(f"Distribution check failed for fold {i+1}. Please investigate.")
        continue
    # Save the splits
    print(f"Saving fold {i+1} splits...")
    output_path = os.path.join(ukb_eur_pheno_output_path, f"fold_{i+1}")
    os.makedirs(output_path, exist_ok=True)
    train_df.to_csv(os.path.join(output_path, "train_pheno_covar.txt"), sep="\t", index=False)
    tune_df.to_csv(os.path.join(output_path, "tune_pheno_covar.txt"), sep="\t", index=False)
    test_df.to_csv(os.path.join(output_path, "test_pheno_covar.txt"), sep="\t", index=False)
    train_df[pheno_cols].to_csv(os.path.join(output_path, "train_pheno.txt"), sep="\t", index=False)
    tune_df[pheno_cols].to_csv(os.path.join(output_path, "tune_pheno.txt"), sep="\t", index=False)
    test_df[pheno_cols].to_csv(os.path.join(output_path, "test_pheno.txt"), sep="\t", index=False)
    train_df[pheno_raw_cols].to_csv(os.path.join(output_path, "train_pheno_raw.txt"), sep="\t", index=False)
    train_df[pheno_int_cols].to_csv(os.path.join(output_path, "train_pheno_int.txt"), sep="\t", index=False)
    tune_df[pheno_raw_cols].to_csv(os.path.join(output_path, "tune_pheno_raw.txt"), sep="\t", index=False)
    tune_df[pheno_int_cols].to_csv(os.path.join(output_path, "tune_pheno_int.txt"), sep="\t", index=False)
    test_df[pheno_raw_cols].to_csv(os.path.join(output_path, "test_pheno_raw.txt"), sep="\t", index=False)
    test_df[pheno_int_cols].to_csv(os.path.join(output_path, "test_pheno_int.txt"), sep="\t", index=False)
    train_df[covar_cols].to_csv(os.path.join(output_path, "train_covar.txt"), sep="\t", index=False)
    tune_df[covar_cols].to_csv(os.path.join(output_path, "tune_covar.txt"), sep="\t", index=False)
    test_df[covar_cols].to_csv(os.path.join(output_path, "test_covar.txt"), sep="\t", index=False)
    train_df[covar_for_gwas_cols].to_csv(os.path.join(output_path, "train_covar_for_gwas.txt"), sep="\t", index=False)

In [None]:
# check the phenotype missing rate in each fold in the training set
import pandas as pd
import os
ukb_eur_pheno_output_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/UKB_EUR/pheno/"
n_splits = 10

def check_missing_rate(df, set_name):
    print(f"Checking missing rate for {set_name} set...")
    # check missing rate for each phenotype column
    missing_rates = df.isnull().mean()
    print(f"{set_name} set missing rates:\n{missing_rates}\n")
    return missing_rates

for i in range(n_splits):
    output_path = os.path.join(ukb_eur_pheno_output_path, f"fold_{i+1}")
    train_df = pd.read_csv(os.path.join(output_path, "train_pheno_covar.txt"), sep="\t")
    tune_df = pd.read_csv(os.path.join(output_path, "tune_pheno_covar.txt"), sep="\t")
    test_df = pd.read_csv(os.path.join(output_path, "test_pheno_covar.txt"), sep="\t")

    train_missing_rate = check_missing_rate(train_df, "Train")

    # open a log file and save the printed output
    log_file_path = os.path.join(output_path, "missing_rate.log")
    with open(log_file_path, "w") as log_file:
        log_file.write(f"Fold {i+1} missing rates:\n")
        log_file.write(train_missing_rate.to_string())

In [None]:
# check the phenotype missing rate in each fold in the training set
import pandas as pd
import os
ukb_eur_pheno_output_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/UKB_EUR/pheno/"
n_splits = 10

def get_df_ids_list(df, output_path, set_name):
    print(f"Getting IDs list for {set_name} set...")
    ids_list = df[["FID", "IID"]]
    ids_list.to_csv(os.path.join(output_path, f"{set_name.lower()}_ids.txt"), sep="\t", index=False, header=False)

for i in range(n_splits):
    output_path = os.path.join(ukb_eur_pheno_output_path, f"fold_{i+1}")
    train_df = pd.read_csv(os.path.join(output_path, "train_pheno_covar.txt"), sep="\t")
    tune_df = pd.read_csv(os.path.join(output_path, "tune_pheno_covar.txt"), sep="\t")
    test_df = pd.read_csv(os.path.join(output_path, "test_pheno_covar.txt"), sep="\t")

    get_df_ids_list(train_df, output_path, "train")
    get_df_ids_list(tune_df, output_path, "tune")
    get_df_ids_list(test_df, output_path, "test")


In [None]:
import pandas as pd
import os
import sys

# eas: /data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/CAS/alt/group_1/ids

def process_single_file(bfile_path, ids_path, output_path, plink_exec):
    cmd = f"{plink_exec} --bfile {bfile_path} --keep {ids_path} --make-bed --out {output_path}"
    os.system(cmd)

if __name__ == "__main__":
    eur_base_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/UKB_EUR/pheno/"
    eur_geno_base_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/UKB/geno/White_British/0_sample_qc/"
    valid_output_base_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/UKB_EUR/test/"
    tuning_output_base_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/UKB_EUR/tune/"
    plink_exec = "/data1/jiapl_group/lishuhua/software/general/plink"

    for i in range(1, 11):
        print(f'Processing fold: {i}')
        valid_ids_path = os.path.join(eur_base_dir, f"fold_{i}", "test_ids.txt")
        tune_ids_path = os.path.join(eur_base_dir, f"fold_{i}", "tune_ids.txt")
        valid_output_dir = os.path.join(valid_output_base_dir, f"fold_{i}")
        if not os.path.exists(valid_output_dir):
            os.makedirs(valid_output_dir)
        tune_output_dir = os.path.join(tuning_output_base_dir, f"fold_{i}")
        if not os.path.exists(tune_output_dir):
            os.makedirs(tune_output_dir)
        for chrom in range(1, 23):
            print(f'Processing chromosome: {chrom}')
            bfile_path = os.path.join(eur_geno_base_path, f"chr{chrom}")
            valid_output_prefix = os.path.join(valid_output_dir, f"chr{chrom}")
            tune_output_prefix = os.path.join(tune_output_dir, f"chr{chrom}")
            process_single_file(bfile_path, valid_ids_path, valid_output_prefix, plink_exec)
            process_single_file(bfile_path, tune_ids_path, tune_output_prefix, plink_exec)
        print(f'Finished processing fold: {i}')

In [None]:
# run GWAS for each fold (Binary traits)
import pandas as pd
import os
import subprocess
import glob
import numpy as np

def run_gwas_for_fold(fold_index, pheno_covar_path, bfile_path, pheno_list, covar_list, is_continuous, threads_num, chrom, output_path, plink_exec):
    print(f"Running GWAS for fold {fold_index} in chromosome {chrom}...")
    train_pheno_path = os.path.join(pheno_covar_path, "train_pheno_int_dropna.txt")
    train_covar_path = os.path.join(pheno_covar_path, "train_covar_for_gwas.txt")
    gwas_output_path = os.path.join(output_path)
    os.makedirs(gwas_output_path, exist_ok=True)

    if is_continuous:
        print(f"Running continuous GWAS for fold {fold_index}...")
        command = [
            plink_exec,
            "--bfile", bfile_path,
            "--pheno", train_pheno_path,
            "--pheno-name", ",".join(pheno_list),
            "--covar", train_covar_path,
            "--covar-name", ",".join(covar_list),
            "--glm", "hide-covar", "cols=+a1freq",
            "--no-input-missing-phenotype",
            "--threads", str(threads_num),
            "--covar-variance-standardize",
            "--out", f'{output_path}/gwas_chr{str(chrom)}'
        ]
    else:
        print(f"Running binary GWAS for fold {fold_index}...")
        command = [
            plink_exec,
            "--bfile", bfile_path,
            "--pheno", train_pheno_path,
            "--pheno-name", ",".join(pheno_list),
            "--covar", train_covar_path,
            "--covar-name", ",".join(covar_list),
            "--glm", "hide-covar", "no-firth", "cols=+a1freq",
            "--threads", str(threads_num),
            "--no-input-missing-phenotype",
            "--covar-variance-standardize",
            "--out", f'{output_path}/gwas_chr{str(chrom)}_binary'
        ]
    subprocess.run(command, check=True)

if __name__ == "__main__":
    plink_exec = "/data1/jiapl_group/lishuhua/software/general/plink2"
    eur_bfile_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/UKB/geno/White_British/0_sample_qc/"
    output_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/UKB_EUR/train/gwas/"
    pheno_covar_file_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/UKB_EUR/pheno/"
    trait_dict = {
        'p48': 'waist',
        'p50': 'height',
        'p102': 'pulse',
        'p4079': 'dbp',
        'p4080': 'sbp',
        'p20116': 'smoke',
        'p20117': 'drink',
        'p21001': 'bmi',
        'p30000': 'wbc',
        'p30010': 'rbc',
        'p30020':'hb',
        'p30080': 'plt',
        'p30120': 'lymph',
        'p30130': 'mono',
        'p30140': 'neut',
        'p30150': 'eos',
        'p30620': 'alt',
        'p30650': 'ast',
        'p30670': 'bun',
        'p30690': 'cholesterol',
        'p30700': 'creatinine',
        'p30730': 'ggt',
        'p30740': 'glucose',
        'p30760': 'hdl',
        'p30780': 'ldl',
        'p30870': 'triglycerides',
        'p30880': 'ua'
    }
    pheno_cols = [f"{key}_int" for key in trait_dict.keys()]
    # pheno_cols = ['p48_int']
    # remove 'p20116_int' and 'p20117_int' from pheno_cols for continuous traits
    binary_pheno_cols = ['p20116_int', 'p20117_int']
    continuous_pheno_cols = [col for col in pheno_cols if col not in binary_pheno_cols]
    covar_list = ['age', 'sex', 'PC1', 'PC2', 'PC3', 'PC4', 'PC5', 'PC6', 'PC7', 'PC8', 'PC9', 'PC10', 'PC11', 'PC12', 'PC13', 'PC14', 'PC15', 'PC16', 'PC17', 'PC18', 'PC19', 'PC20']

    
    # Run GWAS for each fold
    for fold_index in range(1, 11):
        pheno_covar_path = os.path.join(pheno_covar_file_dir, f"fold_{fold_index}")
        if not os.path.exists(os.path.join(output_dir, f"fold_{fold_index}")):
            os.makedirs(os.path.join(output_dir, f"fold_{fold_index}"), exist_ok=True)
        # # handle the pheno_file and remove all NA values in the pheno_list columns
        train_pheno_path = os.path.join(pheno_covar_path, "train_pheno_int.txt")
        pheno_df = pd.read_csv(train_pheno_path, sep="\t")
        pheno_df = pheno_df.fillna('NA')
        pheno_df.to_csv(train_pheno_path.replace("_int.txt", "_int_dropna.txt"), sep="\t", index=False)
        new_train_pheno_path = train_pheno_path.replace("_int.txt", "_int_dropna.txt")
        for chrom in range(1, 23):
            bfile_path = os.path.join(eur_bfile_dir, f"chr{chrom}")
            output_path = os.path.join(output_dir, f"fold_{fold_index}")
            # run_gwas_for_fold(fold_index, pheno_covar_path, bfile_path, binary_pheno_cols, covar_list, False, 8, chrom, output_path, plink_exec)
            run_gwas_for_fold(fold_index, pheno_covar_path, bfile_path, continuous_pheno_cols, covar_list, True, 16, chrom, output_path, plink_exec)

    print("GWAS runs completed for all folds.")

In [None]:
# run GWAS for each fold (Binary traits)
import pandas as pd
import os
import subprocess
import glob
import numpy as np

def run_gwas_for_fold(fold_index, pheno_covar_path, bfile_path, pheno_list, covar_list, is_continuous, threads_num, chrom, output_path, plink_exec):
    print(f"Running GWAS for fold {fold_index} in chromosome {chrom}...")
    train_pheno_path = os.path.join(pheno_covar_path, "train_pheno_int_dropna.txt")
    train_covar_path = os.path.join(pheno_covar_path, "train_covar_for_gwas.txt")
    gwas_output_path = os.path.join(output_path)
    os.makedirs(gwas_output_path, exist_ok=True)

    if is_continuous:
        print(f"Running continuous GWAS for fold {fold_index}...")
        command = [
            plink_exec,
            "--bfile", bfile_path,
            "--pheno", train_pheno_path,
            "--pheno-name", ",".join(pheno_list),
            "--covar", train_covar_path,
            "--covar-name", ",".join(covar_list),
            "--glm", "hide-covar", "cols=+a1freq",
            "--no-input-missing-phenotype",
            "--threads", str(threads_num),
            "--covar-variance-standardize",
            "--out", f'{output_path}/gwas_chr{str(chrom)}'
        ]
    else:
        print(f"Running binary GWAS for fold {fold_index}...")
        command = [
            plink_exec,
            "--bfile", bfile_path,
            "--pheno", train_pheno_path,
            "--pheno-name", ",".join(pheno_list),
            "--covar", train_covar_path,
            "--covar-name", ",".join(covar_list),
            "--glm", "hide-covar", "no-firth", "cols=+a1freq",
            "--threads", str(threads_num),
            "--no-input-missing-phenotype",
            "--covar-variance-standardize",
            "--out", f'{output_path}/gwas_chr{str(chrom)}_binary'
        ]
    subprocess.run(command, check=True)

if __name__ == "__main__":
    plink_exec = "/data1/jiapl_group/lishuhua/software/general/plink2"
    eur_bfile_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/UKB/geno/White_British/0_sample_qc/"
    output_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/UKB_EUR/train/gwas/"
    pheno_covar_file_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/UKB_EUR/pheno/"
    trait_dict = {
        'p48': 'waist',
        'p50': 'height',
        'p102': 'pulse',
        'p4079': 'dbp',
        'p4080': 'sbp',
        'p20116': 'smoke',
        'p20117': 'drink',
        'p21001': 'bmi',
        'p30000': 'wbc',
        'p30010': 'rbc',
        'p30020':'hb',
        'p30080': 'plt',
        'p30120': 'lymph',
        'p30130': 'mono',
        'p30140': 'neut',
        'p30150': 'eos',
        'p30620': 'alt',
        'p30650': 'ast',
        'p30670': 'bun',
        'p30690': 'cholesterol',
        'p30700': 'creatinine',
        'p30730': 'ggt',
        'p30740': 'glucose',
        'p30760': 'hdl',
        'p30780': 'ldl',
        'p30870': 'triglycerides',
        'p30880': 'ua'
    }
    pheno_cols = [f"{key}_int" for key in trait_dict.keys()]
    # pheno_cols = ['p48_int']
    # remove 'p20116_int' and 'p20117_int' from pheno_cols for continuous traits
    binary_pheno_cols = ['p20116_int', 'p20117_int']
    continuous_pheno_cols = [col for col in pheno_cols if col not in binary_pheno_cols]
    covar_list = ['age', 'sex', 'PC1', 'PC2', 'PC3', 'PC4', 'PC5', 'PC6', 'PC7', 'PC8', 'PC9', 'PC10', 'PC11', 'PC12', 'PC13', 'PC14', 'PC15', 'PC16', 'PC17', 'PC18', 'PC19', 'PC20']

    
    # Run GWAS for each fold
    for fold_index in range(1, 11):
        pheno_covar_path = os.path.join(pheno_covar_file_dir, f"fold_{fold_index}")
        if not os.path.exists(os.path.join(output_dir, f"fold_{fold_index}")):
            os.makedirs(os.path.join(output_dir, f"fold_{fold_index}"), exist_ok=True)
        # # handle the pheno_file and remove all NA values in the pheno_list columns
        train_pheno_path = os.path.join(pheno_covar_path, "train_pheno_int.txt")
        pheno_df = pd.read_csv(train_pheno_path, sep="\t")
        pheno_df = pheno_df.fillna('NA')
        pheno_df.to_csv(train_pheno_path.replace("_int.txt", "_int_dropna.txt"), sep="\t", index=False)
        new_train_pheno_path = train_pheno_path.replace("_int.txt", "_int_dropna.txt")
        for chrom in range(1, 23):
            bfile_path = os.path.join(eur_bfile_dir, f"chr{chrom}")
            output_path = os.path.join(output_dir, f"fold_{fold_index}")
            run_gwas_for_fold(fold_index, pheno_covar_path, bfile_path, binary_pheno_cols, covar_list, False, 16, chrom, output_path, plink_exec)
            # run_gwas_for_fold(fold_index, pheno_covar_path, bfile_path, continuous_pheno_cols, covar_list, True, 16, chrom, output_path, plink_exec)

    print("GWAS runs completed for all folds.")

In [None]:
# run GWAS for each fold (More faster version)
import pandas as pd
import os
import subprocess
import glob
import numpy as np

def run_gwas_for_fold(fold_index, pheno_covar_path, bfile_path, pheno_list, covar_list, is_continuous, threads_num, chrom, output_path, plink_exec):
    print(f"Running GWAS for fold {fold_index} in chromosome {chrom}...")
    train_pheno_path = os.path.join(pheno_covar_path, "train_pheno_int_dropna_2.txt")
    train_covar_path = os.path.join(pheno_covar_path, "train_covar_for_gwas.txt")
    gwas_output_path = os.path.join(output_path)
    os.makedirs(gwas_output_path, exist_ok=True)

    if is_continuous:
        print(f"Running continuous GWAS for fold {fold_index}...")
        command = [
            plink_exec,
            "--bfile", bfile_path,
            "--pheno", train_pheno_path,
            "--pheno-name", ",".join(pheno_list),
            "--covar", train_covar_path,
            "--covar-name", ",".join(covar_list),
            "--glm", "hide-covar", "cols=+a1freq",
            "--no-input-missing-phenotype",
            "--threads", str(threads_num),
            "--covar-variance-standardize",
            "--out", f'{output_path}/gwas_chr{str(chrom)}'
        ]
    else:
        print(f"Running binary GWAS for fold {fold_index}...")
        command = [
            plink_exec,
            "--bfile", bfile_path,
            "--pheno", train_pheno_path,
            "--pheno-name", ",".join(pheno_list),
            "--covar", train_covar_path,
            "--covar-name", ",".join(covar_list),
            "--glm", "hide-covar", "no-firth", "cols=+a1freq",
            "--threads", str(threads_num),
            "--no-input-missing-phenotype",
            "--covar-variance-standardize",
            "--out", f'{output_path}/gwas_chr{str(chrom)}_binary'
        ]
    subprocess.run(command, check=True)

if __name__ == "__main__":
    plink_exec = "/data1/jiapl_group/lishuhua/software/general/plink2"
    eur_bfile_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/UKB/geno/White_British/0_sample_qc/"
    output_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/UKB_EUR/train/gwas/"
    pheno_covar_file_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/UKB_EUR/pheno/"
    trait_dict = {
        'p48': 'waist',
        'p50': 'height',
        'p102': 'pulse',
        'p4079': 'dbp',
        'p4080': 'sbp',
        'p20116': 'smoke',
        'p20117': 'drink',
        'p21001': 'bmi',
        'p30000': 'wbc',
        'p30010': 'rbc',
        'p30020':'hb',
        'p30080': 'plt',
        'p30120': 'lymph',
        'p30130': 'mono',
        'p30140': 'neut',
        'p30150': 'eos',
        'p30620': 'alt',
        'p30650': 'ast',
        'p30670': 'bun',
        'p30690': 'cholesterol',
        'p30700': 'creatinine',
        'p30730': 'ggt',
        'p30740': 'glucose',
        'p30760': 'hdl',
        'p30780': 'ldl',
        'p30870': 'triglycerides',
        'p30880': 'ua'
    }
    pheno_cols = [f"{key}_int" for key in trait_dict.keys()]
    # pheno_cols = ['p48_int']
    # remove 'p20116_int' and 'p20117_int' from pheno_cols for continuous traits
    binary_pheno_cols = ['p20116_int', 'p20117_int']
    continuous_pheno_cols = [col for col in pheno_cols if col not in binary_pheno_cols]
    covar_list = ['age', 'sex', 'PC1', 'PC2', 'PC3', 'PC4', 'PC5', 'PC6', 'PC7', 'PC8', 'PC9', 'PC10', 'PC11', 'PC12', 'PC13', 'PC14', 'PC15', 'PC16', 'PC17', 'PC18', 'PC19', 'PC20']

    
    # Run GWAS for each fold
    for fold_index in range(4, 11):
        pheno_covar_path = os.path.join(pheno_covar_file_dir, f"fold_{fold_index}")
        if not os.path.exists(os.path.join(output_dir, f"fold_{fold_index}")):
            os.makedirs(os.path.join(output_dir, f"fold_{fold_index}"), exist_ok=True)
        # # handle the pheno_file and remove all NA values in the pheno_list columns
        train_pheno_path = os.path.join(pheno_covar_path, "train_pheno_int.txt")
        pheno_df = pd.read_csv(train_pheno_path, sep="\t")
        # pheno_df = pheno_df.fillna('NA')
        pheno_df = pheno_df.dropna(subset=pheno_cols)
        pheno_df.to_csv(train_pheno_path.replace("_int.txt", "_int_dropna_2.txt"), sep="\t", index=False)
        for chrom in range(1, 23):
            bfile_path = os.path.join(eur_bfile_dir, f"chr{chrom}")
            output_path = os.path.join(output_dir, f"fold_{fold_index}")
            # run_gwas_for_fold(fold_index, pheno_covar_path, bfile_path, binary_pheno_cols, covar_list, False, 16, chrom, output_path, plink_exec)
            run_gwas_for_fold(fold_index, pheno_covar_path, bfile_path, continuous_pheno_cols, covar_list, True, 16, chrom, output_path, plink_exec)

    print("GWAS runs completed for all folds.")

In [None]:
# run GWAS for each fold in binary trait (More faster version)
import pandas as pd
import os
import subprocess
import glob
import numpy as np

def run_gwas_for_fold(fold_index, pheno_covar_path, bfile_path, pheno_list, covar_list, is_continuous, threads_num, chrom, output_path, plink_exec):
    print(f"Running GWAS for fold {fold_index} in chromosome {chrom}...")
    train_pheno_path = os.path.join(pheno_covar_path, "train_pheno_int_dropna_2.txt")
    train_covar_path = os.path.join(pheno_covar_path, "train_covar_for_gwas.txt")
    gwas_output_path = os.path.join(output_path)
    os.makedirs(gwas_output_path, exist_ok=True)

    if is_continuous:
        print(f"Running continuous GWAS for fold {fold_index}...")
        command = [
            plink_exec,
            "--bfile", bfile_path,
            "--pheno", train_pheno_path,
            "--pheno-name", ",".join(pheno_list),
            "--covar", train_covar_path,
            "--covar-name", ",".join(covar_list),
            "--glm", "hide-covar", "cols=+a1freq",
            "--no-input-missing-phenotype",
            "--threads", str(threads_num),
            "--covar-variance-standardize",
            "--out", f'{output_path}/gwas_chr{str(chrom)}'
        ]
    else:
        print(f"Running binary GWAS for fold {fold_index}...")
        command = [
            plink_exec,
            "--bfile", bfile_path,
            "--pheno", train_pheno_path,
            "--pheno-name", ",".join(pheno_list),
            "--covar", train_covar_path,
            "--covar-name", ",".join(covar_list),
            "--glm", "hide-covar", "no-firth", "cols=+a1freq",
            "--threads", str(threads_num),
            "--no-input-missing-phenotype",
            "--covar-variance-standardize",
            "--out", f'{output_path}/gwas_chr{str(chrom)}_binary'
        ]
    subprocess.run(command, check=True)

if __name__ == "__main__":
    plink_exec = "/data1/jiapl_group/lishuhua/software/general/plink2"
    eur_bfile_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/UKB/geno/White_British/0_sample_qc/"
    output_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/UKB_EUR/train/gwas/"
    pheno_covar_file_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/UKB_EUR/pheno/"
    trait_dict = {
        'p48': 'waist',
        'p50': 'height',
        'p102': 'pulse',
        'p4079': 'dbp',
        'p4080': 'sbp',
        'p20116': 'smoke',
        'p20117': 'drink',
        'p21001': 'bmi',
        'p30000': 'wbc',
        'p30010': 'rbc',
        'p30020':'hb',
        'p30080': 'plt',
        'p30120': 'lymph',
        'p30130': 'mono',
        'p30140': 'neut',
        'p30150': 'eos',
        'p30620': 'alt',
        'p30650': 'ast',
        'p30670': 'bun',
        'p30690': 'cholesterol',
        'p30700': 'creatinine',
        'p30730': 'ggt',
        'p30740': 'glucose',
        'p30760': 'hdl',
        'p30780': 'ldl',
        'p30870': 'triglycerides',
        'p30880': 'ua'
    }
    pheno_cols = [f"{key}_int" for key in trait_dict.keys()]
    # pheno_cols = ['p48_int']
    # remove 'p20116_int' and 'p20117_int' from pheno_cols for continuous traits
    binary_pheno_cols = ['p20116_int', 'p20117_int']
    continuous_pheno_cols = [col for col in pheno_cols if col not in binary_pheno_cols]
    covar_list = ['age', 'sex', 'PC1', 'PC2', 'PC3', 'PC4', 'PC5', 'PC6', 'PC7', 'PC8', 'PC9', 'PC10', 'PC11', 'PC12', 'PC13', 'PC14', 'PC15', 'PC16', 'PC17', 'PC18', 'PC19', 'PC20']

    
    # Run GWAS for each fold
    for fold_index in range(2, 11):
        pheno_covar_path = os.path.join(pheno_covar_file_dir, f"fold_{fold_index}")
        if not os.path.exists(os.path.join(output_dir, f"fold_{fold_index}")):
            os.makedirs(os.path.join(output_dir, f"fold_{fold_index}"), exist_ok=True)
        # # handle the pheno_file and remove all NA values in the pheno_list columns
        train_pheno_path = os.path.join(pheno_covar_path, "train_pheno_int.txt")
        pheno_df = pd.read_csv(train_pheno_path, sep="\t")
        # pheno_df = pheno_df.fillna('NA')
        pheno_df = pheno_df.dropna(subset=pheno_cols)
        pheno_df.to_csv(train_pheno_path.replace("_int.txt", "_int_dropna_2.txt"), sep="\t", index=False)
        for chrom in range(1, 23):
            bfile_path = os.path.join(eur_bfile_dir, f"chr{chrom}")
            output_path = os.path.join(output_dir, f"fold_{fold_index}")
            run_gwas_for_fold(fold_index, pheno_covar_path, bfile_path, binary_pheno_cols, covar_list, False, 24, chrom, output_path, plink_exec)
            # run_gwas_for_fold(fold_index, pheno_covar_path, bfile_path, continuous_pheno_cols, covar_list, True, 16, chrom, output_path, plink_exec)

    print("GWAS runs completed for all folds.")