# Load all libraries:

In [1]:
import csv
import torch
import random
import warnings
import subprocess
import scipy as sp
import numpy as np
import pandas as pd
import concurrent.futures
from functools import partial
from datetime import datetime
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")
from statsmodels.formula.api import ols
from pandas_plink import read_plink1_bin, read_plink

# Helper Functions:

In [2]:
def evaluate(metric, pred_betas_df):
    
    #Pred Betas
    pred_betas = np.array(pred_betas_df['BETA'], dtype=float)
    
    inPATH = '/Users/ritwizkamal/BIRDS_server/AA_FinalGithub/example_data/'
    Beta_true = pd.read_csv(inPATH+'beta_tar_sas_truebetas.txt', delimiter=' ',\
                            header=None, names=['BETA'])
    Beta_true = np.array(Beta_true['BETA'], dtype=float)
    
    #Train R2
    G_sas_train = read_plink1_bin(inPATH+'target_training_geno.bed',\
                                  inPATH+'target_training_geno.bim',\
                                  inPATH+'target_training_geno.fam',\
                                  verbose = False)
    X_sas_train = G_sas_train.values
    X_sas_train = np.where(X_sas_train == 2, 0, np.where(X_sas_train == 0, 2, X_sas_train))
    X_sas_train = np.array(X_sas_train, dtype=float)

    Y_train_true = pd.read_csv(inPATH+'pheno_tar_sas_train_truepheno.txt', delimiter=' ',\
                               header=None, names = ['Pheno'])
    Y_train_true = np.array(Y_train_true['Pheno'], dtype=float)
    
    pred_y_train = np.matmul(X_sas_train, pred_betas)
    r2_fit_data_train = pd.DataFrame({'Y_true':Y_train_true, 'Y_pred':pred_y_train})
    r2_fit_model_train = ols('Y_true ~ Y_pred', data = r2_fit_data_train).fit()
    r2_train = r2_fit_model_train.rsquared
    
    Y_train_bestpossible = np.matmul(X_sas_train, Beta_true)
    numerator = sp.stats.pearsonr(Y_train_true, pred_y_train)[0]
    denominator = sp.stats.pearsonr(Y_train_true, Y_train_bestpossible)[0]
    cr_train = numerator/denominator
    
    #Val R2
    G_sas_val = read_plink1_bin(inPATH+'target_validation_geno.bed',\
                                inPATH+'target_validation_geno.bim',\
                                inPATH+'target_validation_geno.fam',\
                                verbose = False)
    X_sas_val = G_sas_val.values
    X_sas_val = np.where(X_sas_val == 2, 0, np.where(X_sas_val == 0, 2, X_sas_val))
    X_sas_val = np.array(X_sas_val, dtype=float)

    Y_val_true = pd.read_csv(inPATH+'pheno_tar_sas_val_truepheno.txt', delimiter=' ',\
                             header=None, names=['Pheno'])
    Y_val_true = np.array(Y_val_true['Pheno'], dtype=float)
    
    pred_y_val = np.matmul(X_sas_val, pred_betas)
    r2_fit_data_val = pd.DataFrame({'Y_true':Y_val_true, 'Y_pred':pred_y_val})
    r2_fit_model_val = ols('Y_true ~ Y_pred', data = r2_fit_data_val).fit()
    r2_val = r2_fit_model_val.rsquared
    
    Y_val_bestpossible = np.matmul(X_sas_val, Beta_true)
    numerator = sp.stats.pearsonr(Y_val_true, pred_y_val)[0]
    denominator = sp.stats.pearsonr(Y_val_true, Y_val_bestpossible)[0]
    cr_val = numerator/denominator
    
    #Test R2
    G_sas_test = read_plink1_bin(inPATH+'target_testing_geno.bed',\
                                 inPATH+'target_testing_geno.bim',\
                                 inPATH+'target_testing_geno.fam',\
                                 verbose = False)
    X_sas_test = G_sas_test.values
    X_sas_test = np.where(X_sas_test == 2, 0, np.where(X_sas_test == 0, 2, X_sas_test))
    X_sas_test = np.array(X_sas_test, dtype=float)

    Y_test_true = pd.read_csv(inPATH+'pheno_tar_sas_test_truepheno.txt', delimiter=' ',\
                              header=None, names=['Pheno'])
    Y_test_true = np.asarray(Y_test_true['Pheno'], dtype=float)
    
    pred_y_test = np.matmul(X_sas_test, pred_betas)
    r2_fit_data_test = pd.DataFrame({'Y_true':Y_test_true, 'Y_pred':pred_y_test})
    r2_fit_model_test = ols('Y_true ~ Y_pred', data = r2_fit_data_test).fit()
    r2_test = r2_fit_model_test.rsquared
    
    Y_test_bestpossible = np.matmul(X_sas_test, Beta_true)
    numerator = sp.stats.pearsonr(Y_test_true, pred_y_test)[0]
    denominator = sp.stats.pearsonr(Y_test_true, Y_test_bestpossible)[0]
    cr_test = numerator/denominator
    
    if metric == 'R2':
        out = (r2_train, r2_val, r2_test)
    elif metric == 'CR':
        out = (cr_train, cr_val, cr_test)
    
    return out

In [3]:
def doMPP(X_tar, tar_snps, tar_corr, aux_betas, L1_penalty, L2_penalty):
    
    #Initialize Beta_hats
    Beta_hat_init = np.random.normal(0, 0.00000000001, len(tar_snps))
    
    X_tensor = torch.from_numpy(X_tar)
    XtX = torch.matmul(X_tensor.T, X_tensor)
    XtX_np = XtX.numpy()

    mu = 0.1
    L1_penalty = L1_penalty
    L2_penalty = L2_penalty

    XtX_np = abs(1-L2_penalty)*XtX_np
    r = tar_corr

    def func(Bs, L1_penalty, L2_penalty):
        t1 = torch.matmul(torch.from_numpy(Bs).T, XtX)
        term1 = torch.matmul(t1, torch.from_numpy(Bs)).numpy()
        term2 = 2*torch.matmul(torch.from_numpy(Bs).T, torch.from_numpy(r)).numpy()
        term3 = L2_penalty*np.matmul(Bs.T,Bs)
        temp = Bs - aux_betas
        nesterov = mu*np.log((0.5*np.exp(-temp/mu))+(0.5*np.exp(temp/mu)))
        term4 = L1_penalty*sum(nesterov)
        return term1-term2+term3+term4

    def jacfunc(Bs, L1_penalty, L2_penalty):
        term1 = 2*torch.matmul(XtX,torch.from_numpy(Bs)).numpy()
        term2 = 2*r
        term3 = L2_penalty*2*Bs
        temp = Bs - aux_betas
        nesterov = np.divide((-np.exp(-temp/mu) + np.exp(temp/mu)),(np.exp(-temp/mu) + np.exp(temp/mu)))
        term4 = L1_penalty*nesterov
        return term1-term2+term3+term4

    ans = sp.optimize.minimize(func, jac=jacfunc, x0=Beta_hat_init, args=(L1_penalty, L2_penalty),\
                               method='L-BFGS-B', options={'maxfun':10})
    
    final_snps = tar_snps
    final_betas = ans.x
    final_chr = tar_chroms
    final_pos = tar_pos
    final_a1 = tar_a1
    final_results_df = pd.DataFrame({'CHR':final_chr, 'SNP':final_snps, 'POS':final_pos,\
                                     'A1':final_a1, 'BETA':final_betas,})
    
    return final_results_df

# Load and Pre-Process Data:

In [4]:
inPATH = '/Users/ritwizkamal/BIRDS_server/AA_FinalGithub/example_data/'

#AUX Summary Statistics
ss_df_aux1 = pd.read_csv(inPATH+'aux_eur_sumstat.txt', delimiter=' ', header=None,\
                         names = ['CHR','SNP','GENETIC.DIST','BP','A1','A2','BETA','SE','T','P','N'])
ss_df_aux2 = pd.read_csv(inPATH+'aux_eas_sumstat.txt', delimiter=' ', header=None,\
                         names = ['CHR','SNP','GENETIC.DIST','BP','A1','A2','BETA','SE','T','P','N'])
ss_df_aux3 = pd.read_csv(inPATH+'aux_amr_sumstat.txt', delimiter=' ', header=None,\
                         names = ['CHR','SNP','GENETIC.DIST','BP','A1','A2','BETA','SE','T','P','N'])
ss_df_aux4 = pd.read_csv(inPATH+'aux_afr_sumstat.txt', delimiter=' ', header=None,\
                         names = ['CHR','SNP','GENETIC.DIST','BP','A1','A2','BETA','SE','T','P','N'])

#TAR Summary Statistics
ss_df_tar = pd.read_csv(inPATH+'target_sas_sumstat.txt', delimiter=' ', header=None,\
                        names = ['CHR','SNP','GENETIC.DIST','BP','A1','A2','BETA','SE','T','P','N'])

#External LD
G_sas  = read_plink1_bin(inPATH+'sas_chr22_geno_1KG.bed',\
                         inPATH+'sas_chr22_geno_1KG.bim',\
                         inPATH+'sas_chr22_geno_1KG.fam',\
                         verbose = False)
X_sas = G_sas.values
X_sas = np.where(X_sas == 2, 0, np.where(X_sas == 0, 2, X_sas))
X_sas = np.array(X_sas, dtype=float)
all_snps_sas_ld_full = G_sas.snp.values

common_snps = list(set(ss_df_aux1['SNP']) & set(ss_df_aux2['SNP']) &\
                   set(ss_df_aux3['SNP']) & set(ss_df_aux4['SNP']) &\
                   set(ss_df_tar['SNP']) & set(all_snps_sas_ld_full))

#Preprocess Target and Auxiliary Summ. Stats. File
ss_df_tar = ss_df_tar[ss_df_tar['SNP'].isin(common_snps)]
ss_df_aux1 = ss_df_aux1[ss_df_aux1['SNP'].isin(common_snps)]
ss_df_aux2 = ss_df_aux2[ss_df_aux2['SNP'].isin(common_snps)]
ss_df_aux3 = ss_df_aux3[ss_df_aux3['SNP'].isin(common_snps)]
ss_df_aux4 = ss_df_aux4[ss_df_aux4['SNP'].isin(common_snps)]

ss_df_tar = ss_df_tar.reset_index(drop=True)
ss_df_aux1 = ss_df_aux1.reset_index(drop=True)
ss_df_aux2 = ss_df_aux2.reset_index(drop=True)
ss_df_aux3 = ss_df_aux3.reset_index(drop=True)
ss_df_aux4 = ss_df_aux4.reset_index(drop=True)

N_tar = int(ss_df_tar['N'][0])
N_aux1 = int(ss_df_aux1['N'][0])
N_aux2 = int(ss_df_aux2['N'][0])
N_aux3 = int(ss_df_aux3['N'][0])
N_aux4 = int(ss_df_aux4['N'][0])

tar_snps = list(ss_df_tar['SNP'])
tar_a1 = list(ss_df_tar['A1'])
tar_pvals = np.asarray(list(ss_df_tar['P']), dtype='float')
tar_betas = np.asarray(list(ss_df_tar['BETA']), dtype='float')
tar_chroms = list(ss_df_tar['CHR'])
tar_pos = np.asarray(list(ss_df_tar['BP']), dtype='float')
tar_corr = np.asarray(list(ss_df_tar['BETA']), dtype='float')*N_tar

aux1_snps = list(ss_df_aux1['SNP'])
aux1_betas = np.asarray(list(ss_df_aux1['BETA']), dtype='float')

aux2_snps = list(ss_df_aux2['SNP'])
aux2_betas = np.asarray(list(ss_df_aux2['BETA']), dtype='float')

aux3_snps = list(ss_df_aux3['SNP'])
aux3_betas = np.asarray(list(ss_df_aux3['BETA']), dtype='float')

aux4_snps = list(ss_df_aux4['SNP'])
aux4_betas = np.asarray(list(ss_df_aux4['BETA']), dtype='float')

#Preprocess Target Ref. file
temp_set = set(tar_snps)
temp_indices_tar = [i for i, e in enumerate(all_snps_sas_ld_full) if e in temp_set]
all_snps_sas_ld_full = [i for i in all_snps_sas_ld_full if i in common_snps]
X_tar = X_sas[:, temp_indices_tar]

# Assign Weightage to each auxiliary population:

In [5]:
pops_tuple = 'EurEasAmrAfr'

if pops_tuple == 'EurOnly':
    temp = (aux1_betas*1.0) + (aux2_betas*0.0) + (aux3_betas*0.0) + (aux4_betas*0.0)
elif pops_tuple == 'EasOnly':
    temp = (aux1_betas*0.0) + (aux2_betas*1.0) + (aux3_betas*0.0) + (aux4_betas*0.0)
elif pops_tuple == 'AmrOnly':
    temp = (aux1_betas*0.0) + (aux2_betas*0.0) + (aux3_betas*1.0) + (aux4_betas*0.0)
elif pops_tuple == 'AfrOnly':
    temp = (aux1_betas*0.0) + (aux2_betas*0.0) + (aux3_betas*0.0) + (aux4_betas*1.0)

elif pops_tuple == 'EurAmr':
    per_pop_weight = [0.5,0.5]
    temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*0.0) +\
    (aux3_betas*per_pop_weight[1]) + (aux4_betas*0.0)
elif pops_tuple == 'EurEas':
    per_pop_weight = [0.5,0.5]
    temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*per_pop_weight[1]) +\
    (aux3_betas*0.0) + (aux4_betas*0.0)
elif pops_tuple == 'EurAfr':
    per_pop_weight = [0.5,0.5]
    temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*0.0) +\
    (aux3_betas*0.0) + (aux4_betas*per_pop_weight[1])
elif pops_tuple == 'EasAmr':
    per_pop_weight = [0.5,0.5]
    temp = (aux1_betas*0.0) + (aux2_betas*per_pop_weight[0]) +\
    (aux3_betas*per_pop_weight[1]) + (aux4_betas*0.0)
elif pops_tuple == 'EasAfr':
    per_pop_weight = [0.5,0.5]
    temp = (aux1_betas*0.0) + (aux2_betas*per_pop_weight[0]) +\
    (aux3_betas*0.0) + (aux4_betas*per_pop_weight[1])
elif pops_tuple == 'AmrAfr':
    per_pop_weight = [0.5,0.5]
    temp = (aux1_betas*0.0) + (aux2_betas*0.0) +\
    (aux3_betas*per_pop_weight[0]) + (aux4_betas*per_pop_weight[1])

elif pops_tuple == 'EurEasAmr':
    per_pop_weight = [0.33,0.33,0.33]
    temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*per_pop_weight[1]) +\
    (aux3_betas*per_pop_weight[2]) + (aux4_betas*0.0)
elif pops_tuple == 'EurEasAfr':
    per_pop_weight = [0.33,0.33,0.33]
    temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*per_pop_weight[1]) +\
    (aux3_betas*0.0) + (aux4_betas*per_pop_weight[2])
elif pops_tuple == 'EurAmrAfr':
    per_pop_weight = [0.33,0.33,0.33]
    temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*0.0) +\
    (aux3_betas*per_pop_weight[1]) + (aux4_betas*per_pop_weight[2])
elif pops_tuple == 'EasAmrAfr':
    per_pop_weight = [0.33,0.33,0.33]
    temp = (aux1_betas*0.0) + (aux2_betas*per_pop_weight[0]) +\
    (aux3_betas*per_pop_weight[1]) + (aux4_betas*per_pop_weight[2])

elif pops_tuple == 'EurEasAmrAfr':
    per_pop_weight = [0.25,0.25,0.25,0.25]
    temp = (aux1_betas*per_pop_weight[0]) + (aux2_betas*per_pop_weight[1]) + \
    (aux3_betas*per_pop_weight[2]) + (aux4_betas*per_pop_weight[3])
else:
    print("Invalid Aux Pops!")

aux_betas = temp

# Run MultiPopPred:

In [6]:
print('Start time: '+str(datetime.now()))
result_df = doMPP(X_tar, tar_snps, tar_corr, aux_betas, 100, 0.05)
print('End time: '+str(datetime.now()))

Start time: 2025-07-08 14:36:59.986567
End time: 2025-07-08 14:37:01.451206


In [7]:
out = evaluate('CR', result_df)
print('Correlation Ratio Train: '+str(out[0]))
print('Correlation Ratio Val: '+str(out[1]))
print('Correlation Ratio Test: '+str(out[2]))

Correlation Ratio Train: 0.8224401608651833
Correlation Ratio Val: 0.6293851166163114
Correlation Ratio Test: 0.6415652635350249


In [8]:
out = evaluate('R2', result_df)
print('R2 Train: '+str(out[0]))
print('R2 Val: '+str(out[1]))
print('R2 Test: '+str(out[2]))

R2 Train: 0.4545801105715441
R2 Val: 0.28005939003214186
R2 Test: 0.28068516089225237
