In [None]:
# Cross Validation

import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold, KFold
import os

def create_plink_split_files(
    fam_filepath,
    pheno_filepath,
    output_prefix: str,
    is_continuous: bool = False,
    trait: str = None,
    # trait is used to name the output files, it can be None if not needed.
    # If is_continuous is True, the phenotype is treated as continuous; otherwise, is False for categorical traits.
    n_splits: int = 10,
    random_state: int = 42
):
    print(f"\n{'='*25}\nHandling Phenotype File: {os.path.basename(str(pheno_filepath))}\n{'='*25}")
    meta_data = pd.DataFrame()
    # Step 1: Load the phenotype data
    master_samples = pd.read_csv(fam_filepath, sep=r'\s+', header=None, usecols=[0, 1], names=['FID', 'IID'], dtype=str)
    # remove the header row if it exists
    if master_samples.iloc[0, 0] == 'FID':
        master_samples = master_samples.iloc[1:].reset_index(drop=True)
    # Ensure FID and IID are strings
    master_samples['FID'] = master_samples['FID'].astype(str)
    master_samples['IID'] = master_samples['IID'].astype(str)
    pheno_data = pd.read_csv(pheno_filepath, sep=r'\s+', header=0)
    pheno_data.columns = ['FID', 'IID', 'Pheno']
    # Ensure FID and IID are strings
    pheno_data['FID'] = pheno_data['FID'].astype(str)
    pheno_data['IID'] = pheno_data['IID'].astype(str)

    # Step 2: Merge the master samples with the phenotype data
    merged_data = pd.merge(master_samples, pheno_data, on=['FID', 'IID'], how='left')

    # Step 3: Remove samples with missing phenotypes
    original_count = len(merged_data)
    merged_data.dropna(subset=['Pheno'], inplace=True)
    final_count = len(merged_data)

    if original_count != final_count:
        print(f"Removed {original_count - final_count} samples with missing phenotypes.")
    
    print(f"Final sample count after removing missing phenotypes: {final_count}")

    y_pheno = merged_data['Pheno'].values
    sample_ids_for_split = merged_data[['FID', 'IID']]
    dummy_x = np.zeros((len(y_pheno), 1))  # Dummy feature matrix
    if is_continuous:
        print("Using continuous phenotype for stratification.")
        splitter = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
        all_fold_indices = [test_idx for _, test_idx in splitter.split(dummy_x)]
    else:
        print("Using categorical phenotype for stratification.")
        # Ensure the phenotype is treated as categorical
        if pd.api.types.is_float_dtype(y_pheno):
            y_pheno = y_pheno.astype(int)
        splitter = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
        all_fold_indices = [test_idx for _, test_idx in splitter.split(dummy_x, y_pheno)]

    output_dir = os.path.dirname(output_prefix)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    for i in range(n_splits):
        test_fold_num = i
        tune_fold_num = (i + 1) % n_splits
        train_fold_num = [j for j in range(n_splits) if j != test_fold_num and j != tune_fold_num]

        test_indices = all_fold_indices[test_fold_num]
        tune_indices = all_fold_indices[tune_fold_num]
        train_indices = np.concatenate([all_fold_indices[j] for j in train_fold_num])
        if not os.path.exists(f"{output_prefix}/group_{i+1}/ids"):
            os.makedirs(f"{output_prefix}/group_{i+1}/ids")
        if not os.path.exists(f"{output_prefix}/group_{i+1}/pheno"):
            os.makedirs(f"{output_prefix}/group_{i+1}/pheno")

        sample_ids_for_split.iloc[train_indices].to_csv(f"{output_prefix}/group_{i+1}/ids/train_ids.txt", sep='\t', index=False, header=False)
        sample_ids_for_split.iloc[tune_indices].to_csv(f"{output_prefix}/group_{i+1}/ids/tune_ids.txt", sep='\t', index=False, header=False)
        sample_ids_for_split.iloc[test_indices].to_csv(f"{output_prefix}/group_{i+1}/ids/test_ids.txt", sep='\t', index=False, header=False)
        # Save phenotype values for each group
        merged_data.iloc[train_indices].to_csv(f"{output_prefix}/group_{i+1}/pheno/train_pheno.txt", sep='\t', index=False, header=True)
        merged_data.iloc[tune_indices].to_csv(f"{output_prefix}/group_{i+1}/pheno/tune_pheno.txt", sep='\t', index=False, header=True)
        merged_data.iloc[test_indices].to_csv(f"{output_prefix}/group_{i+1}/pheno/test_pheno.txt", sep='\t', index=False, header=True)

        # save the number of three groups in each fold into a Dataframe
        fold_counts = pd.DataFrame({
            'Trait': trait,
            'Total_Samples': len(sample_ids_for_split),
            'Fold': [i + 1],
            'Train_Count': len(train_indices),
            'Tune_Count': len(tune_indices),
            'Test_Count': len(test_indices)
        })
        meta_data = pd.concat([meta_data, fold_counts], ignore_index=True)
    # Save the metadata DataFrame to a TSV file
    meta_data.to_csv(f"{output_prefix}/fold_counts.tsv", sep='\t', index=False, header=True)
    print(f"Fold counts saved to {output_prefix}/fold_counts.tsv")



    print(f'Successfully created all {n_splits} ID files for each fold in {os.path.basename(str(pheno_filepath))}.')

if __name__ == "__main__":

    eas_fam_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/CAS/geno/CAS_final/CAS_merged_qc_final.fam"
    eas_trait_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/CAS/pheno/trait/data/"
    eur_fam_path = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/UKB/pheno/trait/trait_ukb_white_british.txt"
    eur_trait_dir = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/UKB/pheno/trait/White_British/"

    trait_dict = {
        'p48': 'waist',
        'p50': 'height',
        'p102': 'pulse',
        'p4079': 'dbp',
        'p4080': 'sbp',
        'p20116': 'smoke',
        'p20117': 'drink',
        'p21001': 'bmi',
        'p21002': 'weight',
        'p30000': 'wbc',
        'p30010': 'rbc',
        'p30020':'hb',
        'p30080': 'plt',
        'p30120': 'lymph',
        'p30130': 'mono',
        'p30140': 'neut',
        'p30150': 'eos',
        'p30620': 'alt',
        'p30650': 'ast',
        'p30670': 'bun',
        'p30690': 'cholesterol',
        'p30700': 'creatinine',
        'p30730': 'ggt',
        'p30740': 'glucose',
        'p30760': 'hdl',
        'p30780': 'ldl',
        'p30870': 'triglycerides',
        'p30880': 'ua'
    }

    for trait, name in trait_dict.items():
        if name == "smoke" or name == "drink":
            pheno_file = f"{eas_trait_dir}{name}_raw.txt"
            is_continuous = False
        else:
            pheno_file = f"{eas_trait_dir}{name}_int.txt"
            is_continuous = True
        output_prefix = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/CAS/"

        create_plink_split_files(
            fam_filepath=eas_fam_path,
            pheno_filepath=pheno_file,
            output_prefix=f"{output_prefix}{name}",
            is_continuous=is_continuous,
            n_splits=10,
            trait=name,
            random_state=42
        )
    
    for trait, name in trait_dict.items():
        if trait == "p20116" or trait == "p20117":
            is_continuous = False
        else:
            is_continuous = True
        pheno_file = f"{eur_trait_dir}{trait}_int.txt"
        output_prefix = "/data1/jiapl_group/lishuhua/project/PRS_benchmark/real_data/Cross_Validation/UKB_EUR/"

        create_plink_split_files(
            fam_filepath=eur_fam_path,
            pheno_filepath=pheno_file,
            output_prefix=f"{output_prefix}{name}",
            is_continuous=is_continuous,
            n_splits=10,
            trait=name,
            random_state=42
        )