In [35]:
import numpy as onp
import pandas as pd
import csv

In [36]:
ethnic_dic=pd.read_csv('sampleID.csv',usecols=['Sample (Male/Female/Unknown)','Population(s)'])

In [37]:
def process_dict(data):
    data=data.rename(columns={'Sample (Male/Female/Unknown)':'ID','Population(s)':'eth'})
    data['eth']=data['eth'].apply(lambda x: x.split(',')[-1][1:])
    data['ID']=data['ID'].apply(lambda x: x.split(' ')[0])
    data.index=data['ID']
    data=data.drop('ID',axis=1)
    data=data.drop('NA18498')
    return data    

In [38]:
eth_ID=process_dict(ethnic_dic)

In [39]:
eth=eth_ID['eth'].unique()

In [40]:
removal_list=['HG00104',
 'HG00134',
 'HG00135',
 'HG00152',
 'HG00156',
 'HG00249',
 'HG00270',
 'HG00302',
 'HG00303',
 'HG00312',
 'HG00359',
 'HG00377',
 'HG01471',
 'HG02168',
 'HG02169',
 'HG02170',
 'HG02173',
 'HG02176',
 'HG02358',
 'HG02405',
 'HG02436',
 'HG03171',
 'HG03393',
 'HG03398',
 'HG03431',
 'HG03462',
 'HG03549',
 'HG04301',
 'HG04302',
 'HG04303',
 'NA18527',
 'NA18576',
 'NA18791',
 'NA18955',
 'NA19044',
 'NA19359',
 'NA19371',
 'NA19398',
 'NA20537',
 'NA20816',
 'NA20829',
 'NA20831',
 'NA20873',
 'NA20883',
 'NA21121']

In [41]:
def process_data(data,num_causal_snps=200):
    data=data[data['ALT'].isin(['A','C','G','T'])] #select SNPs with single ALT allele
    data['INFO']=data['INFO'].apply(lambda x: float(x.split(';')[3].split('=')[-1])) #extract allele freq information
    data=data[data['INFO']>0.05] #choose SNPs with allele freq more than 0.05
    data.index=data['POS'] #set ID col as index
    data=data.drop(['ID','#CHROM','POS','REF','ALT','QUAL','FILTER','INFO','FORMAT'],axis=1) #drop columns other than individual data
    data=data.applymap(lambda x: 2 if x=='1|1' else(0 if x=='0|0' else 1)) #sets 0|0 to 0 ...
    data=data.drop(removal_list,axis=1)
    data=data.T
    causal_snps=onp.arange(0,len(data.columns),len(data.columns)//num_causal_snps)
    data=data[data.columns[causal_snps]]
    data['eth']=eth_ID['eth']
    data['ID']=data.index
    data=data.set_index(['eth','ID'])
    return data
    

In [42]:
#in final form we will read all rows
num_rows=20000
header_line=42
data=pd.read_csv('ALL.autosomes.shapeit2_integrated_v1a.GRCh38.20181129.phased.vcf.gz',sep='\t',
                header=header_line,nrows=num_rows)

In [43]:
data1=process_data(data)

In [45]:
data1.to_csv('processed_data.csv')

In [48]:
data2=pd.read_csv('processed_data.csv',index_col=[0,1])

In [51]:
eth_ID.to_csv('eth_ID.csv')

In [55]:
eth_ID1=pd.read_csv('eth_ID.csv',index_col=0)

In [59]:
eth_ID

Unnamed: 0_level_0,eth
ID,Unnamed: 1_level_1
HG00096,GBR
HG00097,GBR
HG00099,GBR
HG00100,GBR
HG00101,GBR
...,...
NA21137,SAS
NA21141,SAS
NA21142,SAS
NA21143,SAS


In [10]:
eth_coef={}
eth_errors={}
eth_errors_sigma2=onp.random.randint(1,5, size=len(eth))
eth_errors.update([(ethnicity,errors/100) for ethnicity,errors in zip(eth,eth_errors_sigma2)])

for i in eth:
    eth_coef[i]=onp.random.uniform(-1,1,(data1.shape[1]))


In [79]:
eth_errors

{'GBR': 0.02,
 'FIN': 0.04,
 'EAS': 0.04,
 'PUR': 0.01,
 'CLM': 0.03,
 'IBS': 0.03,
 'PEL': 0.02,
 'SAS': 0.01,
 'KHV': 0.01,
 'ACB': 0.01,
 'GWD': 0.04,
 'ESN': 0.03,
 'MSL': 0.02,
 'STU': 0.03,
 'EUR': 0.04,
 'YRI': 0.03,
 'JPT': 0.02,
 'LWK': 0.04,
 'ASW': 0.03,
 'MXL': 0.04,
 'TSI': 0.02}

In [125]:
data1.iloc[0].name[0]

'GBR'

In [135]:
y1=data1.apply(lambda a: a@eth_coef[a.name[0]]+onp.random.normal(scale=eth_errors[a.name[0]]),axis=1)

In [140]:
data1.loc['GBR'].iloc[1]@eth_coef['GBR']+onp.random.normal(scale=eth_errors['GBR'])

1.9246659226313187

In [97]:
data1.loc['SAS'].iloc[-1]@eth_coef['SAS']

-2.3412896055127295

In [138]:
y1

eth  ID     
GBR  HG00096    2.171494
     HG00097    1.967864
     HG00099   -2.985690
     HG00100   -0.473637
     HG00101    0.934244
                  ...   
SAS  NA21137   -1.158354
     NA21141   -1.882105
     NA21142   -1.018477
     NA21143   -3.316611
     NA21144   -2.335341
Length: 2503, dtype: float64

In [98]:
x

Unnamed: 0_level_0,POS,51479,55545,77961,84002,87190,275654,631490,734210,779968,792275,...,1325753,1329159,1331096,1334949,1337898,1341079,1343267,1346453,1351587,1353091
eth,ID,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
GBR,HG00096,1,0,0,1,0,0,0,0,2,0,...,1,2,2,2,1,2,0,0,2,2
GBR,HG00097,0,0,0,0,0,2,0,0,2,0,...,0,2,2,2,2,2,0,0,2,2
GBR,HG00099,0,2,0,0,0,0,0,0,2,0,...,0,2,2,2,2,2,0,0,2,2
GBR,HG00100,0,0,0,1,0,0,0,0,2,0,...,0,2,2,2,2,2,0,0,2,2
GBR,HG00101,0,1,1,0,1,0,0,0,2,0,...,0,2,2,2,2,2,0,0,2,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SAS,NA21137,0,0,0,0,0,0,0,0,2,0,...,0,2,2,2,2,2,0,0,2,2
SAS,NA21141,1,0,0,1,0,0,0,0,2,0,...,1,2,2,2,1,2,0,0,2,2
SAS,NA21142,0,0,0,0,0,1,0,0,2,0,...,0,2,2,2,2,2,0,0,2,2
SAS,NA21143,0,0,0,1,0,1,0,0,2,0,...,0,2,2,2,2,2,0,0,2,2


In [66]:
eth_errors_sigma2=onp.random.randint(1,5, size=len(eth))

In [70]:
eth_errors={}
eth_errors.update([(ethnicity,errors/100) for ethnicity,errors in zip(eth,eth_errors_sigma2)])


In [71]:
eth_errors

{'GBR': 0.02,
 'FIN': 0.04,
 'EAS': 0.04,
 'PUR': 0.01,
 'CLM': 0.03,
 'IBS': 0.03,
 'PEL': 0.02,
 'SAS': 0.01,
 'KHV': 0.01,
 'ACB': 0.01,
 'GWD': 0.04,
 'ESN': 0.03,
 'MSL': 0.02,
 'STU': 0.03,
 'EUR': 0.04,
 'YRI': 0.03,
 'JPT': 0.02,
 'LWK': 0.04,
 'ASW': 0.03,
 'MXL': 0.04,
 'TSI': 0.02}

In [32]:
x

Unnamed: 0_level_0,POS,51479,55545,77961,84002,87190,275654,631490,734210,779968,792275,...,1325753,1329159,1331096,1334949,1337898,1341079,1343267,1346453,1351587,1353091
eth,ID,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
GBR,HG00096,1,0,0,1,0,0,0,0,2,0,...,1,2,2,2,1,2,0,0,2,2
GBR,HG00097,0,0,0,0,0,2,0,0,2,0,...,0,2,2,2,2,2,0,0,2,2
GBR,HG00099,0,2,0,0,0,0,0,0,2,0,...,0,2,2,2,2,2,0,0,2,2
GBR,HG00100,0,0,0,1,0,0,0,0,2,0,...,0,2,2,2,2,2,0,0,2,2
GBR,HG00101,0,1,1,0,1,0,0,0,2,0,...,0,2,2,2,2,2,0,0,2,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SAS,NA21137,0,0,0,0,0,0,0,0,2,0,...,0,2,2,2,2,2,0,0,2,2
SAS,NA21141,1,0,0,1,0,0,0,0,2,0,...,1,2,2,2,1,2,0,0,2,2
SAS,NA21142,0,0,0,0,0,1,0,0,2,0,...,0,2,2,2,2,2,0,0,2,2
SAS,NA21143,0,0,0,1,0,1,0,0,2,0,...,0,2,2,2,2,2,0,0,2,2


In [13]:
task='MXL' #the ethnic group we want to test
x_train=x.drop('MXL',axis=0)
y_train=y.drop('MXL',axis=0)
eth_train=onp.delete(eth, onp.where(eth == 'MXL'))
x_test=x.loc['MXL'].to_numpy()
y_test=y.loc['MXL'].to_numpy()
y_test=onp.reshape(y_test,(-1,1))

  x_train=x.drop('MXL',axis=0)
  y_train=y.drop('MXL',axis=0)


In [14]:
eth_train=onp.delete(eth, onp.where(eth == 'MXL'))

In [15]:
import jax
import jax.numpy as np
from jax import vmap,grad
from functools import partial
from jax.example_libraries import stax
from jax.example_libraries.stax import Dense,Relu,Flatten
#from jax import random
import random
from jax import jit
from keras.datasets import mnist,fashion_mnist
import numpy as onp
from jax.example_libraries import optimizers
from jax.tree_util import tree_multimap

In [16]:
num_features=x.shape[1] #number of causal snps
reg_weight=0.1
epochs=20000
in_shape=(-1,num_features)
ethnic_grp_min_pop=20 #min population among all subpopulatiosn
batch_size=20 #inner batch size for inner loop
#K=20 #K-shot learning
num_task_sample= 4 #number of tasks to sample to meta train
lr=0.001
rng=jax.random.PRNGKey(1)



In [17]:
net_init,net_apply=Dense(1)
out_shape, net_params=net_init(rng,input_shape=in_shape)

In [18]:
def loss(params,inputs,targets):
    predictions=net_apply(params,inputs)
    for i in range(len(net_params)):
        l1_params=np.linalg.norm(net_params[i],1)
   # print('##################')
   # print(targets.shape)
   # print(predictions.shape)
   # print(reg_weight*np.linalg.norm(net_params[0][0],1))
   # print('##################')
    return np.mean((targets-predictions)**2) + reg_weight*np.linalg.norm(net_params[0],1)

def accuracy(params, inputs, targets):
    predictions = net_apply(params,inputs)
    return np.mean((targets-predictions)**2)


In [19]:
def inner_update(p,x1,y1):
    grads= grad(loss)(p,x1,y1)
    inner_sgd_fn= lambda g, state: (state - lr*g)
    #return tree_multimap(inner_sgd_fn,grads,p)
    return [(w - lr * dw)
          for w,dw in zip(p, grads)]

def maml_loss(p,x1,y1,x2,y2):
    p2= inner_update(p,x1,y1)
    return loss(p2,x2,y2)

In [20]:
#need ethnic_grp_pop list which contains population of each ethinc group
#training data is a list of arrays each of which corresponds to an ethinic groups 
#inner batch size < ethnic_grp_min_pop
def sample_tasks(outer_batch_size, inner_batch_size):
    # Select amplitude and phase for the task
    ethnic_grp_sample=random.sample(list(eth_train), k=outer_batch_size)

    def get_batch():
        xs, ys = [], []
        for j in ethnic_grp_sample:
            indices = onp.random.randint(ethnic_grp_min_pop,size=inner_batch_size)
            x= x_train.loc[j].iloc[indices].to_numpy()
            y= y_train.loc[j].iloc[indices].to_numpy()
            xs.append(x)
            ys.append(y)
        return np.stack(xs), np.stack(ys)
    x1, y1 = get_batch()
    x2, y2 = get_batch()
    return x1, y1, x2, y2

In [21]:
#meta training
opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3)
out_shape, net_params = net_init(rng,in_shape)
opt_state = opt_init(net_params)

# vmapped version of maml loss.
# returns scalar for all tasks.
def batch_maml_loss(p, x1_b, y1_b, x2_b, y2_b):
    task_losses = vmap(partial(maml_loss, p))(x1_b, y1_b, x2_b, y2_b)
    return np.mean(task_losses)

@jit
def step(i, opt_state, x1, y1, x2, y2):
    p = get_params(opt_state)
    g = grad(batch_maml_loss)(p, x1, y1, x2, y2)
    l = batch_maml_loss(p, x1, y1, x2, y2)
    return opt_update(i, g, opt_state), l

np_batched_maml_loss = []

for i in range(epochs):
    x1_b, y1_b, x2_b, y2_b = sample_tasks(num_task_sample, batch_size)
    opt_state, l = step(i, opt_state, x1_b, y1_b, x2_b, y2_b)
    np_batched_maml_loss.append(l)
    if i % 1000 == 0:
        print(i,'maml_loss',l)
net_params = get_params(opt_state)

0 maml_loss 16.025219
1000 maml_loss 18.555271
2000 maml_loss 17.865814
3000 maml_loss 17.448845
4000 maml_loss 17.072521
5000 maml_loss 19.350431
6000 maml_loss 20.385563
7000 maml_loss 12.411094
8000 maml_loss 15.943439
9000 maml_loss 19.711412
10000 maml_loss 14.990958
11000 maml_loss 12.222893
12000 maml_loss 14.535436
13000 maml_loss 10.182553
14000 maml_loss 14.971959
15000 maml_loss 17.603727
16000 maml_loss 14.175284
17000 maml_loss 19.570267
18000 maml_loss 14.906768
19000 maml_loss 14.148574


In [30]:
#meta testing
#meta test; train with batch_size many examples from validation set on desired task

#pre update prediction
pre_predictions = vmap(partial(net_apply, net_params))(x_test)
pre_error= loss(net_params,x_test,y_test)
print('pre update MSE='+str(pre_error))
#post-update prediction
indx=onp.random.randint(x_test.shape[0],size=batch_size)
test_indx=onp.delete(onp.arange(x_test.shape[0]),indx)
x1, y1 = x_test[indx] , y_test[indx]
for i in range(batch_size):
    
    net_params = inner_update(net_params,x1,y1)
    #print('training loss '+str(l))
    #train_accuracy= accuracy(net_params,x1,y1)
    #print('train accuracy',train_accuracy)
    #post_error= loss(net_params,x_test[test_indx],y_test[test_indx])
    #print('Post step ' + str(i)+' update test MSE='+str(post_error))
   
#post_predictions = vmap(partial(net_apply, net_params))(x_test)
maml_error= accuracy(net_params,x_test[test_indx],y_test[test_indx])
print('Test Error on Task: MSE = ',maml_error)




pre update MSE=12.301481
Test Error on Task: MSE =  10.882493


In [29]:
#default regression model
basenet_init,basenet_apply=Dense(1)
out_shape, basenet_params=basenet_init(rng,input_shape=in_shape)

for i in range(epochs):
    basenet_params=inner_update(basenet_params,x_train.to_numpy(),y_train.to_numpy())
    if i%1000==0:
        train_loss=loss(basenet_params,x_train.to_numpy(),y_train.to_numpy())
        print(i,'training loss',train_loss)
        
indx=onp.random.randint(x_test.shape[0],size=batch_size)
test_indx=onp.delete(onp.arange(x_test.shape[0]),indx)
x1, y1 = x_test[indx] , y_test[indx]
for i in range(batch_size):
    basenet_params=inner_update(basenet_params,x1,y1)
    
test_error=accuracy(basenet_params,x_test,y_test)

print('test error MSE', test_error)
    

0 training loss 20.57894
1000 training loss 20.0394
2000 training loss 20.003063
3000 training loss 19.986656
4000 training loss 19.977318
5000 training loss 19.971405
6000 training loss 19.967402
7000 training loss 19.964552
8000 training loss 19.962437
9000 training loss 19.960823
10000 training loss 19.959547
11000 training loss 19.958519
12000 training loss 19.957678
13000 training loss 19.956974
14000 training loss 19.956385
15000 training loss 19.955873
16000 training loss 19.955433
17000 training loss 19.955048
18000 training loss 19.954708
19000 training loss 19.954405
test error MSE 13.447605


In [31]:
indx=onp.random.randint(x_test.shape[0],size=batch_size)
test_indx=onp.delete(onp.arange(x_test.shape[0]),indx)
x1, y1 = x_test[indx] , y_test[indx]
for i in range(batch_size):
    basenet_params=inner_update(basenet_params,x1,y1)
    
lin_test_error=accuracy(basenet_params,x_test,y_test)

print('test error MSE', lin_test_error)
    

test error MSE 11.684541


In [25]:
L=['maml_error ',str(maml_error)+'\n','lin_error ',str(lin_test_error)]
f=open('results.txt','w')
f.writelines(L)
f.close()

In [28]:
eth_ID.value_counts()

eth
SAS    387
EAS    301
GWD    113
YRI    107
TSI    107
IBS    107
PUR    104
JPT    104
STU    102
FIN     99
KHV     99
LWK     99
EUR     99
ESN     99
ACB     96
CLM     94
GBR     91
MSL     85
PEL     85
MXL     64
ASW     61
dtype: int64