# Load all libraries:

In [1]:
import csv
import torch
import random
import warnings
import subprocess
import scipy as sp
import numpy as np
import pandas as pd
import concurrent.futures
from functools import partial
from datetime import datetime
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")
from statsmodels.formula.api import ols
from pandas_plink import read_plink1_bin, read_plink

# Helper Functions:

In [2]:
def doMPP(X_tar, Y_train_true, tar_snps, aux_betas, L1_penalty):
    
    #Initialize Beta_hats
    Beta_hat_init = np.random.normal(0, 0.00000000001, len(tar_snps))
    
    X_tensor = torch.from_numpy(X_tar)
    XtX = torch.matmul(X_tensor.T, X_tensor)
    XtX_np = XtX.numpy()

    y_tensor = torch.from_numpy(Y_train_true)
    Xty = torch.matmul(X_tensor.T, y_tensor)
    Xty_np = Xty.numpy()

    mu = 0.1
    L1_penalty = L1_penalty

    def func(Bs, L1_penalty):
        t1 = torch.matmul(X_tensor, torch.from_numpy(Bs)).numpy()
        temp = Bs - aux_betas
        nesterov = mu*np.log((0.5*np.exp(-temp/mu))+(0.5*np.exp(temp/mu)))
        term2 = L1_penalty*sum(nesterov)
        val = sum((Y_train_true - t1)**2) + term2
        return val

    def jacfunc(Bs, L1_penalty):
        term1 = 2*torch.matmul(XtX, torch.from_numpy(Bs)).numpy()
        term2 = 2*Xty_np
        temp = Bs - aux_betas
        nesterov = np.divide((-np.exp(-temp/mu) + np.exp(temp/mu)),(np.exp(-temp/mu) + np.exp(temp/mu)))
        term3 = L1_penalty*nesterov
        return term1-term2+term3

    ans = sp.optimize.minimize(func, jac=jacfunc, x0=Beta_hat_init, args=(L1_penalty),\
                               method='L-BFGS-B', options={'maxfun':10})
    
    final_snps = tar_snps
    final_betas = ans.x
    final_a0 = tar_a0
    final_a1 = tar_a1
    final_results_df = pd.DataFrame({'SNP':final_snps, 'A0':final_a0,\
                                     'A1':final_a1, 'BETA':final_betas,})
    
    return final_results_df

# Change here to reproduce results on other simulation configurations:

In [3]:
inPATH = './Simul_Analysis_1/'

aux_pops = 'EurEasAmrAfr'

exp_list = ['exp1.4.1','exp1.4.2','exp1.4.3']

In [4]:
for exp_num in exp_list:
    #AUX Summary Statistics
    ss_df_aux1 = pd.read_csv(inPATH+'Additional/Results_SOTAmethods/lassosumExtLD_pred_betas_eur_'+str(exp_num)+'.txt', delimiter='\t', header=None,\
                             names = ['rsid','a1','a0','weight'])
    ss_df_aux2 = pd.read_csv(inPATH+'Additional/Results_SOTAmethods/lassosumExtLD_pred_betas_eas_'+str(exp_num)+'.txt', delimiter='\t', header=None,\
                             names = ['rsid','a1','a0','weight'])
    ss_df_aux3 = pd.read_csv(inPATH+'Additional/Results_SOTAmethods/lassosumExtLD_pred_betas_amr_'+str(exp_num)+'.txt', delimiter='\t', header=None,\
                             names = ['rsid','a1','a0','weight'])
    ss_df_aux4 = pd.read_csv(inPATH+'Additional/Results_SOTAmethods/lassosumExtLD_pred_betas_afr_'+str(exp_num)+'.txt', delimiter='\t', header=None,\
                             names = ['rsid','a1','a0','weight'])

    #TAR Summary Statistics
    ss_df_tar = pd.read_csv(inPATH+'Additional/Results_SOTAmethods/lassosumExtLD_pred_betas_sas_'+str(exp_num)+'.txt', delimiter='\t', header=None,\
                             names = ['rsid','a1','a0','weight'])

    #True X
    G_sas  = read_plink1_bin(inPATH+'Genotypes/sas_train_'+str(exp_num)+'_CHR22.bed',\
                             inPATH+'Genotypes/sas_train_'+str(exp_num)+'_CHR22.bim',\
                             inPATH+'Genotypes/sas_train_'+str(exp_num)+'_CHR22.fam',\
                             verbose = False)
    X_sas = G_sas.values
    X_sas = np.where(X_sas == 2, 0, np.where(X_sas == 0, 2, X_sas))
    X_sas = np.array(X_sas, dtype=float)
    all_snps_sas_ld_full = G_sas.snp.values

    common_snps = list(set(ss_df_aux1['rsid']) & set(ss_df_aux2['rsid']) &\
                       set(ss_df_aux3['rsid']) & set(ss_df_aux4['rsid']) &\
                       set(ss_df_tar['rsid']) & set(all_snps_sas_ld_full))

    #Preprocess Target and Auxiliary Summ. Stats. File
    ss_df_tar = ss_df_tar[ss_df_tar['rsid'].isin(common_snps)]
    ss_df_aux1 = ss_df_aux1[ss_df_aux1['rsid'].isin(common_snps)]
    ss_df_aux2 = ss_df_aux2[ss_df_aux2['rsid'].isin(common_snps)]
    ss_df_aux3 = ss_df_aux3[ss_df_aux3['rsid'].isin(common_snps)]
    ss_df_aux4 = ss_df_aux4[ss_df_aux4['rsid'].isin(common_snps)]

    ss_df_tar = ss_df_tar.reset_index(drop=True)
    ss_df_aux1 = ss_df_aux1.reset_index(drop=True)
    ss_df_aux2 = ss_df_aux2.reset_index(drop=True)
    ss_df_aux3 = ss_df_aux3.reset_index(drop=True)
    ss_df_aux4 = ss_df_aux4.reset_index(drop=True)

    tar_snps = list(ss_df_tar['rsid'])
    tar_a1 = list(ss_df_tar['a1'])
    tar_a0 = list(ss_df_tar['a0'])
    tar_betas = np.asarray(list(ss_df_tar['weight']), dtype='float')

    aux1_snps = list(ss_df_aux1['rsid'])
    aux1_betas = np.asarray(list(ss_df_aux1['weight']), dtype='float')

    aux2_snps = list(ss_df_aux2['rsid'])
    aux2_betas = np.asarray(list(ss_df_aux2['weight']), dtype='float')

    aux3_snps = list(ss_df_aux3['rsid'])
    aux3_betas = np.asarray(list(ss_df_aux3['weight']), dtype='float')

    aux4_snps = list(ss_df_aux4['rsid'])
    aux4_betas = np.asarray(list(ss_df_aux4['weight']), dtype='float')

    #Preprocess Target Ref. file
    temp_set = set(tar_snps)
    temp_indices_tar = [i for i, e in enumerate(all_snps_sas_ld_full) if e in temp_set]
    all_snps_sas_ld_full = [i for i in all_snps_sas_ld_full if i in common_snps]
    X_tar = X_sas[:, temp_indices_tar]

    #True Y
    Y_train_true = pd.read_csv(inPATH+'Phenotypes/pheno_sas_train_'+str(exp_num)+'.truepheno',\
                                   delimiter=' ', header=None, names=['Pheno'])
    Y_train_true = np.array(Y_train_true['Pheno'], dtype=float)

    #Assign Weightage to each auxiliary population:
    pops_tuple = aux_pops

    if pops_tuple == 'EurOnly':
        temp = (aux1_betas*1.0) + (aux2_betas*0.0) + (aux3_betas*0.0) + (aux4_betas*0.0)
    elif pops_tuple == 'EasOnly':
        temp = (aux1_betas*0.0) + (aux2_betas*1.0) + (aux3_betas*0.0) + (aux4_betas*0.0)
    elif pops_tuple == 'AmrOnly':
        temp = (aux1_betas*0.0) + (aux2_betas*0.0) + (aux3_betas*1.0) + (aux4_betas*0.0)
    elif pops_tuple == 'AfrOnly':
        temp = (aux1_betas*0.0) + (aux2_betas*0.0) + (aux3_betas*0.0) + (aux4_betas*1.0)

    elif pops_tuple == 'EurAmr':
        per_pop_weight = [0.5,0.5]
        temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*0.0) +\
        (aux3_betas*per_pop_weight[1]) + (aux4_betas*0.0)
    elif pops_tuple == 'EurEas':
        per_pop_weight = [0.5,0.5]
        temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*per_pop_weight[1]) +\
        (aux3_betas*0.0) + (aux4_betas*0.0)
    elif pops_tuple == 'EurAfr':
        per_pop_weight = [0.5,0.5]
        temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*0.0) +\
        (aux3_betas*0.0) + (aux4_betas*per_pop_weight[1])
    elif pops_tuple == 'EasAmr':
        per_pop_weight = [0.5,0.5]
        temp = (aux1_betas*0.0) + (aux2_betas*per_pop_weight[0]) +\
        (aux3_betas*per_pop_weight[1]) + (aux4_betas*0.0)
    elif pops_tuple == 'EasAfr':
        per_pop_weight = [0.5,0.5]
        temp = (aux1_betas*0.0) + (aux2_betas*per_pop_weight[0]) +\
        (aux3_betas*0.0) + (aux4_betas*per_pop_weight[1])
    elif pops_tuple == 'AmrAfr':
        per_pop_weight = [0.5,0.5]
        temp = (aux1_betas*0.0) + (aux2_betas*0.0) +\
        (aux3_betas*per_pop_weight[0]) + (aux4_betas*per_pop_weight[1])

    elif pops_tuple == 'EurEasAmr':
        per_pop_weight = [0.33,0.33,0.33]
        temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*per_pop_weight[1]) +\
        (aux3_betas*per_pop_weight[2]) + (aux4_betas*0.0)
    elif pops_tuple == 'EurEasAfr':
        per_pop_weight = [0.33,0.33,0.33]
        temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*per_pop_weight[1]) +\
        (aux3_betas*0.0) + (aux4_betas*per_pop_weight[2])
    elif pops_tuple == 'EurAmrAfr':
        per_pop_weight = [0.33,0.33,0.33]
        temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*0.0) +\
        (aux3_betas*per_pop_weight[1]) + (aux4_betas*per_pop_weight[2])
    elif pops_tuple == 'EasAmrAfr':
        per_pop_weight = [0.33,0.33,0.33]
        temp = (aux1_betas*0.0) + (aux2_betas*per_pop_weight[0]) +\
        (aux3_betas*per_pop_weight[1]) + (aux4_betas*per_pop_weight[2])

    elif pops_tuple == 'EurEasAmrAfr':
        per_pop_weight = [0.25,0.25,0.25,0.25]
        temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*per_pop_weight[1]) + \
        (aux3_betas*per_pop_weight[2]) + (aux4_betas*per_pop_weight[3])
    else:
        print("Invalid Aux Pops!")

    aux_betas = temp

    print('Start time: '+str(datetime.now()))
    result_df = doMPP(X_tar, Y_train_true, tar_snps, aux_betas, 7.5)
    print('End time: '+str(datetime.now()))
    print(str(exp_num)+' ... Done!')

Start time: 2025-08-09 11:22:52.031523
End time: 2025-08-09 11:22:52.237106
exp1.4.1 ... Done!
Start time: 2025-08-09 11:22:52.577076
End time: 2025-08-09 11:22:52.648502
exp1.4.2 ... Done!
Start time: 2025-08-09 11:22:52.856572
End time: 2025-08-09 11:22:52.873382
exp1.4.3 ... Done!
