# Load all libraries:

In [1]:
import csv
import torch
import random
import warnings
import subprocess
import scipy as sp
import numpy as np
import pandas as pd
import concurrent.futures
from functools import partial
from datetime import datetime
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")
from statsmodels.formula.api import ols
from pandas_plink import read_plink1_bin, read_plink

# Helper Functions:

In [2]:
def doMPP(X_tar, tar_snps, tar_corr, aux_betas, L1_penalty, L2_penalty):
    
    #Initialize Beta_hats
    Beta_hat_init = np.random.normal(0, 0.00000000001, len(tar_snps))
    
    X_tensor = torch.from_numpy(X_tar)
    XtX = torch.matmul(X_tensor.T, X_tensor)
    XtX_np = XtX.numpy()

    mu = 0.1
    L1_penalty = L1_penalty
    L2_penalty = L2_penalty

    XtX_np = abs(1-L2_penalty)*XtX_np
    r = tar_corr

    def func(Bs, L1_penalty, L2_penalty):
        t1 = torch.matmul(torch.from_numpy(Bs).T, XtX)
        term1 = torch.matmul(t1, torch.from_numpy(Bs)).numpy()
        term2 = 2*torch.matmul(torch.from_numpy(Bs).T, torch.from_numpy(r)).numpy()
        term3 = L2_penalty*np.matmul(Bs.T,Bs)
        temp = Bs - aux_betas
        nesterov = mu*np.log((0.5*np.exp(-temp/mu))+(0.5*np.exp(temp/mu)))
        term4 = L1_penalty*sum(nesterov)
        return term1-term2+term3+term4

    def jacfunc(Bs, L1_penalty, L2_penalty):
        term1 = 2*torch.matmul(XtX,torch.from_numpy(Bs)).numpy()
        term2 = 2*r
        term3 = L2_penalty*2*Bs
        temp = Bs - aux_betas
        nesterov = np.divide((-np.exp(-temp/mu) + np.exp(temp/mu)),(np.exp(-temp/mu) + np.exp(temp/mu)))
        term4 = L1_penalty*nesterov
        return term1-term2+term3+term4

    ans = sp.optimize.minimize(func, jac=jacfunc, x0=Beta_hat_init, args=(L1_penalty, L2_penalty),\
                               method='L-BFGS-B', options={'maxfun':10})
    
    final_snps = tar_snps
    final_betas = ans.x
    final_chr = tar_chroms
    final_pos = tar_pos
    final_a1 = tar_a1
    final_results_df = pd.DataFrame({'CHR':final_chr, 'SNP':final_snps, 'POS':final_pos,\
                                     'A1':final_a1, 'BETA':final_betas,})
    
    return final_results_df

# Change here to reproduce results on other simulation configurations:

In [3]:
inPATH = './Simul_Analysis_1/'

aux_pops = 'EurEasAmrAfr'

exp_list = ['exp1.4.1','exp1.4.2','exp1.4.3']

In [4]:
for exp_num in exp_list:
    #AUX Summary Statistics
    ss_df_aux1 = pd.read_csv(inPATH+'GWAS_SummaryStats/eur_'+str(exp_num)+'.sumstat', delimiter=' ', header=None,\
                             names = ['CHR','SNP','GENETIC.DIST','BP','A1','A2','BETA','SE','T','P','N'])
    ss_df_aux2 = pd.read_csv(inPATH+'GWAS_SummaryStats/eas_'+str(exp_num)+'.sumstat', delimiter=' ', header=None,\
                             names = ['CHR','SNP','GENETIC.DIST','BP','A1','A2','BETA','SE','T','P','N'])
    ss_df_aux3 = pd.read_csv(inPATH+'GWAS_SummaryStats/amr_'+str(exp_num)+'.sumstat', delimiter=' ', header=None,\
                             names = ['CHR','SNP','GENETIC.DIST','BP','A1','A2','BETA','SE','T','P','N'])
    ss_df_aux4 = pd.read_csv(inPATH+'GWAS_SummaryStats/afr_'+str(exp_num)+'.sumstat', delimiter=' ', header=None,\
                             names = ['CHR','SNP','GENETIC.DIST','BP','A1','A2','BETA','SE','T','P','N'])

    #TAR Summary Statistics
    ss_df_tar = pd.read_csv(inPATH+'GWAS_SummaryStats/sas_'+str(exp_num)+'.sumstat', delimiter=' ', header=None,\
                            names = ['CHR','SNP','GENETIC.DIST','BP','A1','A2','BETA','SE','T','P','N'])

    #External LD
    G_sas  = read_plink1_bin(inPATH+'sas_chr22_geno_1KG.bed',\
                             inPATH+'sas_chr22_geno_1KG.bim',\
                             inPATH+'sas_chr22_geno_1KG.fam',\
                             verbose = False)
    X_sas = G_sas.values
    X_sas = np.where(X_sas == 2, 0, np.where(X_sas == 0, 2, X_sas))
    X_sas = np.array(X_sas, dtype=float)
    all_snps_sas_ld_full = G_sas.snp.values

    common_snps = list(set(ss_df_aux1['SNP']) & set(ss_df_aux2['SNP']) &\
                       set(ss_df_aux3['SNP']) & set(ss_df_aux4['SNP']) &\
                       set(ss_df_tar['SNP']) & set(all_snps_sas_ld_full))

    #Preprocess Target and Auxiliary Summ. Stats. File
    ss_df_tar = ss_df_tar[ss_df_tar['SNP'].isin(common_snps)]
    ss_df_aux1 = ss_df_aux1[ss_df_aux1['SNP'].isin(common_snps)]
    ss_df_aux2 = ss_df_aux2[ss_df_aux2['SNP'].isin(common_snps)]
    ss_df_aux3 = ss_df_aux3[ss_df_aux3['SNP'].isin(common_snps)]
    ss_df_aux4 = ss_df_aux4[ss_df_aux4['SNP'].isin(common_snps)]

    ss_df_tar = ss_df_tar.reset_index(drop=True)
    ss_df_aux1 = ss_df_aux1.reset_index(drop=True)
    ss_df_aux2 = ss_df_aux2.reset_index(drop=True)
    ss_df_aux3 = ss_df_aux3.reset_index(drop=True)
    ss_df_aux4 = ss_df_aux4.reset_index(drop=True)

    N_tar = int(ss_df_tar['N'][0])
    N_aux1 = int(ss_df_aux1['N'][0])
    N_aux2 = int(ss_df_aux2['N'][0])
    N_aux3 = int(ss_df_aux3['N'][0])
    N_aux4 = int(ss_df_aux4['N'][0])

    tar_snps = list(ss_df_tar['SNP'])
    tar_a1 = list(ss_df_tar['A1'])
    tar_pvals = np.asarray(list(ss_df_tar['P']), dtype='float')
    tar_betas = np.asarray(list(ss_df_tar['BETA']), dtype='float')
    tar_chroms = list(ss_df_tar['CHR'])
    tar_pos = np.asarray(list(ss_df_tar['BP']), dtype='float')
    tar_corr = np.asarray(list(ss_df_tar['BETA']), dtype='float')*N_tar

    aux1_snps = list(ss_df_aux1['SNP'])
    aux1_betas = np.asarray(list(ss_df_aux1['BETA']), dtype='float')

    aux2_snps = list(ss_df_aux2['SNP'])
    aux2_betas = np.asarray(list(ss_df_aux2['BETA']), dtype='float')

    aux3_snps = list(ss_df_aux3['SNP'])
    aux3_betas = np.asarray(list(ss_df_aux3['BETA']), dtype='float')

    aux4_snps = list(ss_df_aux4['SNP'])
    aux4_betas = np.asarray(list(ss_df_aux4['BETA']), dtype='float')

    #Preprocess Target Ref. file
    temp_set = set(tar_snps)
    temp_indices_tar = [i for i, e in enumerate(all_snps_sas_ld_full) if e in temp_set]
    all_snps_sas_ld_full = [i for i in all_snps_sas_ld_full if i in common_snps]
    X_tar = X_sas[:, temp_indices_tar]

    # Assign Weightage to each auxiliary population:
    pops_tuple = aux_pops

    if pops_tuple == 'EurOnly':
        temp = (aux1_betas*1.0) + (aux2_betas*0.0) + (aux3_betas*0.0) + (aux4_betas*0.0)
    elif pops_tuple == 'EasOnly':
        temp = (aux1_betas*0.0) + (aux2_betas*1.0) + (aux3_betas*0.0) + (aux4_betas*0.0)
    elif pops_tuple == 'AmrOnly':
        temp = (aux1_betas*0.0) + (aux2_betas*0.0) + (aux3_betas*1.0) + (aux4_betas*0.0)
    elif pops_tuple == 'AfrOnly':
        temp = (aux1_betas*0.0) + (aux2_betas*0.0) + (aux3_betas*0.0) + (aux4_betas*1.0)

    elif pops_tuple == 'EurAmr':
        per_pop_weight = [0.5,0.5]
        temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*0.0) +\
        (aux3_betas*per_pop_weight[1]) + (aux4_betas*0.0)
    elif pops_tuple == 'EurEas':
        per_pop_weight = [0.5,0.5]
        temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*per_pop_weight[1]) +\
        (aux3_betas*0.0) + (aux4_betas*0.0)
    elif pops_tuple == 'EurAfr':
        per_pop_weight = [0.5,0.5]
        temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*0.0) +\
        (aux3_betas*0.0) + (aux4_betas*per_pop_weight[1])
    elif pops_tuple == 'EasAmr':
        per_pop_weight = [0.5,0.5]
        temp = (aux1_betas*0.0) + (aux2_betas*per_pop_weight[0]) +\
        (aux3_betas*per_pop_weight[1]) + (aux4_betas*0.0)
    elif pops_tuple == 'EasAfr':
        per_pop_weight = [0.5,0.5]
        temp = (aux1_betas*0.0) + (aux2_betas*per_pop_weight[0]) +\
        (aux3_betas*0.0) + (aux4_betas*per_pop_weight[1])
    elif pops_tuple == 'AmrAfr':
        per_pop_weight = [0.5,0.5]
        temp = (aux1_betas*0.0) + (aux2_betas*0.0) +\
        (aux3_betas*per_pop_weight[0]) + (aux4_betas*per_pop_weight[1])

    elif pops_tuple == 'EurEasAmr':
        per_pop_weight = [0.33,0.33,0.33]
        temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*per_pop_weight[1]) +\
        (aux3_betas*per_pop_weight[2]) + (aux4_betas*0.0)
    elif pops_tuple == 'EurEasAfr':
        per_pop_weight = [0.33,0.33,0.33]
        temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*per_pop_weight[1]) +\
        (aux3_betas*0.0) + (aux4_betas*per_pop_weight[2])
    elif pops_tuple == 'EurAmrAfr':
        per_pop_weight = [0.33,0.33,0.33]
        temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*0.0) +\
        (aux3_betas*per_pop_weight[1]) + (aux4_betas*per_pop_weight[2])
    elif pops_tuple == 'EasAmrAfr':
        per_pop_weight = [0.33,0.33,0.33]
        temp = (aux1_betas*0.0) + (aux2_betas*per_pop_weight[0]) +\
        (aux3_betas*per_pop_weight[1]) + (aux4_betas*per_pop_weight[2])

    elif pops_tuple == 'EurEasAmrAfr':
        per_pop_weight = [0.25,0.25,0.25,0.25]
        temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*per_pop_weight[1]) + \
        (aux3_betas*per_pop_weight[2]) + (aux4_betas*per_pop_weight[3])
    else:
        print("Invalid Aux Pops!")

    aux_betas = temp

    print('Start time: '+str(datetime.now()))
    result_df = doMPP(X_tar, tar_snps, tar_corr, aux_betas, 100, 0.05)
    print('End time: '+str(datetime.now()))
    print(str(exp_num)+' ... Done!')

Start time: 2025-08-09 11:48:22.485163
End time: 2025-08-09 11:48:23.590493
exp1.4.1 ... Done!
Start time: 2025-08-09 11:48:46.874741
End time: 2025-08-09 11:48:48.242642
exp1.4.2 ... Done!
Start time: 2025-08-09 11:49:11.987758
End time: 2025-08-09 11:49:13.297667
exp1.4.3 ... Done!
