# Load all libraries:

In [1]:
import csv
import torch
import random
import warnings
import subprocess
import scipy as sp
import numpy as np
import pandas as pd
import concurrent.futures
from functools import partial
from datetime import datetime
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")
from statsmodels.formula.api import ols
from pandas_plink import read_plink1_bin, read_plink

# Helper Functions:

In [2]:
def doMPP(X_tar, Y_train_true, tar_snps, L1_penalty):
    
    #Initialize Beta_hats
    Beta_hat_init = np.random.normal(0, 0.00000000001, len(tar_snps))
    
    X_tensor = torch.from_numpy(X_tar)
    XtX = torch.matmul(X_tensor.T, X_tensor)
    XtX_np = XtX.numpy()

    y_tensor = torch.from_numpy(Y_train_true)
    Xty = torch.matmul(X_tensor.T, y_tensor)
    Xty_np = Xty.numpy()

    mu = 0.1
    L1_penalty = L1_penalty

    def func(Bs, L1_penalty):
        t1 = torch.matmul(X_tensor, torch.from_numpy(Bs)).numpy()
        temp = Bs
        nesterov = mu*np.log((0.5*np.exp(-temp/mu))+(0.5*np.exp(temp/mu)))
        term2 = L1_penalty*sum(nesterov)
        val = sum((Y_train_true - t1)**2) + term2
        return val

    def jacfunc(Bs, L1_penalty):
        term1 = 2*torch.matmul(XtX, torch.from_numpy(Bs)).numpy()
        term2 = 2*Xty_np
        temp = Bs
        nesterov = np.divide((-np.exp(-temp/mu) + np.exp(temp/mu)),(np.exp(-temp/mu) + np.exp(temp/mu)))
        term3 = L1_penalty*nesterov
        return term1-term2+term3

    ans = sp.optimize.minimize(func, jac=jacfunc, x0=Beta_hat_init, args=(L1_penalty),\
                               method='L-BFGS-B', options={'maxfun':10})
    
    final_snps = tar_snps
    final_betas = ans.x
    final_chr = tar_chroms
    final_pos = tar_pos
    final_a1 = tar_a1
    final_results_df = pd.DataFrame({'CHR':final_chr, 'SNP':final_snps, 'POS':final_pos,\
                                     'A1':final_a1, 'BETA':final_betas,})
    
    return final_results_df

# Change here to reproduce results on other simulation configurations:

In [3]:
inPATH = './Simul_Analysis_1/'

pop = 'sas'

exp_list = ['exp1.4.1','exp1.4.2','exp1.4.3']

In [4]:
for exp_num in exp_list:
    
    #TAR Summary Statistics
    ss_df_tar = pd.read_csv(inPATH+'GWAS_SummaryStats/'+str(pop)+'_'+str(exp_num)+'.sumstat',\
                            delimiter=' ', header=None,\
                            names = ['CHR','SNP','GENETIC.DIST','BP','A1','A2','BETA','SE','T','P','N'])

    #True X
    if pop == 'sas':
        G_sas  = read_plink1_bin(inPATH+'Genotypes/'+str(pop)+'_train_'+str(exp_num)+'_CHR22.bed',\
                                 inPATH+'Genotypes/'+str(pop)+'_train_'+str(exp_num)+'_CHR22.bim',\
                                 inPATH+'Genotypes/'+str(pop)+'_train_'+str(exp_num)+'_CHR22.fam',\
                                 verbose = False)
    else:
        G_sas  = read_plink1_bin(inPATH+'Genotypes/'+str(pop)+'_'+str(exp_num)+'_CHR22.bed',\
                                 inPATH+'Genotypes/'+str(pop)+'_'+str(exp_num)+'_CHR22.bim',\
                                 inPATH+'Genotypes/'+str(pop)+'_'+str(exp_num)+'_CHR22.fam',\
                                 verbose = False)
        
    X_sas = G_sas.values
    X_sas = np.where(X_sas == 2, 0, np.where(X_sas == 0, 2, X_sas))
    X_sas = np.array(X_sas, dtype=float)
    all_snps_sas_ld_full = G_sas.snp.values

    common_snps = list(set(ss_df_tar['SNP']) & set(all_snps_sas_ld_full))

    #Preprocess Target and Auxiliary Summ. Stats. File
    ss_df_tar = ss_df_tar[ss_df_tar['SNP'].isin(common_snps)]

    ss_df_tar = ss_df_tar.reset_index(drop=True)

    N_tar = int(ss_df_tar['N'][0])

    tar_snps = list(ss_df_tar['SNP'])
    tar_a1 = list(ss_df_tar['A1'])
    tar_pvals = np.asarray(list(ss_df_tar['P']), dtype='float')
    tar_betas = np.asarray(list(ss_df_tar['BETA']), dtype='float')
    tar_chroms = list(ss_df_tar['CHR'])
    tar_pos = np.asarray(list(ss_df_tar['BP']), dtype='float')
    tar_corr = np.asarray(list(ss_df_tar['BETA']), dtype='float')*N_tar

    #Preprocess Target Ref. file
    temp_set = set(tar_snps)
    temp_indices_tar = [i for i, e in enumerate(all_snps_sas_ld_full) if e in temp_set]
    all_snps_sas_ld_full = [i for i in all_snps_sas_ld_full if i in common_snps]
    X_tar = X_sas[:, temp_indices_tar]

    #True Y
    if pop == 'sas':
        Y_train_true = pd.read_csv(inPATH+'Phenotypes/pheno_'+str(pop)+'_train_'+str(exp_num)+'.truepheno',\
                               delimiter=' ', header=None, names=['Pheno'])
        Y_train_true = np.array(Y_train_true['Pheno'], dtype=float)
    else:
        Y_train_true = pd.read_csv(inPATH+'Phenotypes/pheno_'+str(pop)+'_'+str(exp_num)+'.truepheno',\
                               delimiter=' ', header=None, names=['Pheno'])
        Y_train_true = np.array(Y_train_true['Pheno'], dtype=float)

    print('Start time: '+str(datetime.now()))
    result_df = doMPP(X_tar, Y_train_true, tar_snps, 7.5)
    print('End time: '+str(datetime.now()))
    print(str(exp_num)+' ... Done!')

Start time: 2025-08-09 09:51:49.545563
End time: 2025-08-09 09:51:50.200758
exp1.4.1 ... Done!
Start time: 2025-08-09 09:51:50.774235
End time: 2025-08-09 09:51:51.456082
exp1.4.2 ... Done!
Start time: 2025-08-09 09:51:52.022833
End time: 2025-08-09 09:51:52.789961
exp1.4.3 ... Done!
