In [None]:
import os
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
os.environ['OMP_NUM_THREADS'] = '1'

In [None]:
import scipy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
from sklearn.metrics import roc_auc_score

In [None]:
# PRS-CSx function
def prs_csx(src='PRScsx/PRScsx.py',
            ref_dir=None,
            bim_prefix=None,
            sst_file=None,
            n_gwas=None, # example: '79550,257730'
            pop='EAS,EUR',
            phi=1e-2,
            chrom=None,
            meta=True,
            seed=68,
            out_dir=None,
            out_name=None):

    os.system(f'python {src} --ref_dir={ref_dir} \
                            --bim_prefix={bim_prefix} \
                            --sst_file={sst_file} \
                            --n_gwas={n_gwas} \
                            --pop={pop} \
                            --phi={phi} \
                            --chrom={chrom} \
                            --meta={meta} \
                            --seed={seed} \
                            --out_dir={out_dir} \
                            --out_name={out_name}')

def concat_prscsx_output(out_dir=None,
                        out_name=None,
                        pop='META'):
    res = pd.DataFrame([], columns=['CHROM', 'SNP', 'POS', 'A1', 'A2', 'pst_eff'])
    for f in os.listdir(out_dir):
        if f.startswith(out_name) and ('pst_eff' in f) and (pop in f):
            file_path = os.path.join(out_dir, f)
            temp = pd.read_table(file_path, header=None)
            temp.columns = ['CHROM', 'SNP', 'POS', 'A1', 'A2', 'pst_eff']
            res = pd.concat([res, temp], axis=0)
    return res.sort_values(by=['CHROM', 'POS'], ignore_index=True)

def prscsx_score(target_prefix=None,
                out_dir=None,
                out_name=None,
                pop='META'):
    prscsx_output = concat_prscsx_output(out_dir=out_dir, out_name=out_name, pop=pop)
    file_path = os.path.join(out_dir, out_name + '_' + pop + '.concat.txt')
    prscsx_output.to_csv(file_path, sep='\t', header=None, index=False)
    os.system(f'plink --bfile {target_prefix} \
                        --allow-no-sex \
                        --score {file_path} 2 4 6 \
                        --out {file_path[:-11]}')
    res = pd.read_table(f'{file_path[:-11]}.profile', sep='\s+')
    m = np.mean(res['SCORE'])
    std = np.std(res['SCORE'])
    res.loc[:, 'SCORE'] = (res['SCORE'] - m) / std
    print('AUC score =', roc_auc_score(res['PHENO'], res['SCORE']))
    sns.histplot(data=res, x='SCORE', hue='PHENO', multiple="stack", kde=True)
    plt.show()