In [1]:
import torch
import sys
import os
import os.path as osp
import pandas as pd
import numpy as np
import itertools
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import ticker
import matplotlib as mpl
import matplotlib.patches as mpatches
from IPython.display import display
import math
from dataclasses import dataclass
import allel
import scipy
import copy
import torch.nn.functional as F

In [2]:
os.environ['USER_PATH']='/home/users/richras/Ge2Net_Repo'
os.environ['USER_SCRATCH_PATH']="/scratch/users/richras"
os.environ['IN_PATH']='/scratch/groups/cdbustam/richras/data_in'
os.environ['OUT_PATH']='/scratch/groups/cdbustam/richras/data_out'
os.environ['LOG_PATH']='/scratch/groups/cdbustam/richras/logs/'

In [3]:
os.chdir(os.environ.get('USER_PATH'))

In [4]:
%load_ext autoreload
%autoreload 2
from src.utils.dataUtil import load_path, save_file, vcf2npy, getWinInfo
from src.utils.modelUtil import Params, load_model, convert_coordinates, convert_nVector
from src.utils.labelUtil import nearestNeighbourMapping, repeat_pop_arr,get_sample_map
from src.utils.decorators import timer
from src.main.visualization import plot_coordinates_map, plot_sample, plot_changepoints
from src.main.dataset import Haplotype
from src.models.modelSelection import modelSelect
from src.models.modelParamsSelection import Selections
from src.models import Model_A, Model_B, Model_C, BOCD
from src.models.distributions import Multivariate_Gaussian
from src.main.evaluation import eval_cp_batch, reportChangePointMetrics, t_prMetrics, cpMethod, eval_cp_matrix, \
getCpPred, GcdLoss, balancedMetrics, class_accuracy, Running_Average
from src.main.settings_model import parse_args
from src.models.MCDropout import MC_Dropout
from src.models.Ge3Net import Ge3NetBase
import inference

In [5]:
test_sample_map=pd.read_csv(osp.join(os.environ.get('OUT_PATH'), \
'humans/labels/data_id_4_geo/test_sample_map.tsv'), sep="\t")
test_pop_arr = repeat_pop_arr(test_sample_map)
df_test_pop_arr = pd.DataFrame(test_pop_arr, columns=['Sample', 'vcf_idx', 'granular_pop', 'superpop'])

In [6]:
pop_sample_map=pd.read_csv(osp.join(os.environ.get('OUT_PATH'), \
'humans/labels/data_id_4_geo/pop_sample_map.tsv'), sep="\t")
pop_arr = repeat_pop_arr(pop_sample_map)
gp2sp={k:v for k,v in zip(pop_arr[:,2], pop_arr[:,3])}

In [7]:
# df_test_pop_arr.to_csv(osp.join(os.environ.get('OUT_PATH'), \
#                         'humans/benchmark/data_id_4_geo/test/pop_arr.tsv'), index=None, sep='\t')

In [8]:
def pop_mapping(y_vcf, pop_arr, type='superpop'):

    result = np.zeros((y_vcf.shape[0], y_vcf.shape[1]), dtype=np.int64)
    if type=='superpop': col_num=3
    elif type=='granular_pop': col_num=2

    idx2label_dict={k:v for k,v in zip(pop_arr[:,1], pop_arr[:,col_num])}
    result=np.vectorize(idx2label_dict.get)(y_vcf)

    return result

In [9]:
def geoConvertLatLong2nVec(coord):
    """
    Converts the result from 2 dim Lat/Long to 3 dim n vector
    """
    # ToDo: Need to change this. Too slow!!!
    lat=coord[..., 0]
    lon=coord[..., 1] 
    nVec=convert_nVector(lat,lon)
    return nVec

In [10]:
def mapping_func(arr, labels_dict, dim):
    """
    Inputs:
    arr: 2(d)D array
    labels_dict: dict with 3(d) dim array as values
    d: dimension of the output, could be 3 or more
    return:
    result: 3(d)D array
    """
    if params.geography: dim=2
    result = np.zeros((arr.shape[0], arr.shape[1], dim), dtype=np.float_)

    for d in np.arange(dim):
        labels_dict_dim={k:labels_dict[k][d] for k in labels_dict.keys()}
        result[..., d] = np.vectorize(labels_dict_dim.get)(arr)

    result=geoConvertLatLong2nVec(result)

    return result

In [11]:
labels_path = osp.join(os.environ['OUT_PATH'],'humans/labels/data_id_4_geo')
data_path = osp.join(os.environ['OUT_PATH'],'humans/labels/data_id_4_geo')
models_path=osp.join(os.environ['OUT_PATH'],'humans/training/Model_Q_exp_id_3_data_id_4_geo/') 
dataset_type='test'

In [12]:
config={}
config['data.labels']=labels_path 
config['data.dir']=data_path 
config['models.dir']=models_path
config['data.dataset_type']=dataset_type
config['cuda']='cuda'
config['model.loadBest']=True
yaml_path = osp.join(config['models.dir'], 'params.yaml')
assert osp.isfile(yaml_path), "No json configuration file found at {}".format(yaml_path)
params = Params(yaml_path)
params.rtnOuts=True
params.mc_dropout=False
params.mc_samples=100
params.cp_tol=1
params.pretrained=False
params.optimizer='AdamW'
params.test_gens=[0,2,4,8]

AssertionError: No json configuration file found at /scratch/groups/cdbustam/richras/data_out/humans/training/Model_Q_exp_id_3_data_id_4_geo/params.yaml

In [None]:
def getLabelPops(y_vcf, method, pop_arr, idx2coordinates, params):
    if method=="gnomix":
        params.win_size=857
        params.chmlen, params.n_win = getWinInfo(y_vcf.shape[1], params.win_size)
    if method=="lainet":
        params.win_size=500
        params.chmlen, params.n_win = getWinInfo(y_vcf.shape[1], params.win_size)

    y_tmp = y_vcf[:,:params.chmlen]
    y_tmp = y_tmp.reshape(-1, params.n_win, params.win_size)
    y_vcf_idx = scipy.stats.mode(y_tmp, axis=2)[0].squeeze(2)
    y_coord = mapping_func(y_vcf_idx, idx2coordinates, 2)
    test_superpop = pop_mapping(y_vcf_idx, pop_arr, type='superpop')
    test_granular_pop = pop_mapping(y_vcf_idx, pop_arr, type='granular_pop')
    return y_coord, test_superpop, test_granular_pop

In [None]:
def getLabelPopsPerSnp(y_vcf, method, pop_arr, idx2coordinates, params):
    y_coord = mapping_func(y_vcf, idx2coordinates, 2)
    test_superpop = pop_mapping(y_vcf, pop_arr, type='superpop')
    test_granular_pop = pop_mapping(y_vcf, pop_arr, type='granular_pop')
    return y_coord, test_superpop, test_granular_pop

In [None]:
def gp2Gcd(pop_arr, gpPreds, y_coord, dim):
    idx2gp={k:v for k,v in zip(pop_arr[:,1], pop_arr[:,2])}
    gp2idx={k:v for v,k in idx2gp.items()}
    gp2coordinates={i:idx2coordinates[j] for i,j in gp2idx.items()}
    pred_coord = mapping_func(gpPreds, gp2coordinates, dim)
    gcdObj=GcdLoss()
    gnomix_gcd=np.mean(gcdObj.rawGcd(pred_coord, y_coord))
    return gnomix_gcd

In [52]:
idx2coordinates = load_path(osp.join(os.environ.get('OUT_PATH'),'humans/labels/data_id_1_geo', params.coordinates), en_pickle=True)

gnomix_preds, gnomix_sp_preds, curr_vcf_idx=[], [], []
for i, gen in enumerate([0,2,4,8]):
    curr_vcf_idx.append(load_path(osp.join(os.environ.get('OUT_PATH'), \
                 'humans/labels/data_id_1_geo/test' ,'gen_' + str(gen) ,'mat_map.npy')))
    gnomix_preds.append(load_path(osp.join(os.environ.get('OUT_PATH') ,'humans/benchmark/data_id_1_geo/test/window_size_857',\
                        'gen_' + str(gen) +'.npy')))
    gnomix_sp_preds.append(np.vectorize(gp2sp.get)(gnomix_preds[-1]))
    print(f" Metrics for generation: {gen}")
    print(f' number of samples :{curr_vcf_idx[-1].shape[0]}')
    y_coord, y_superpop, y_granular_pop=getLabelPops(curr_vcf_idx[-1], "gnomix", pop_arr, idx2coordinates, copy.deepcopy(params))
    gnomix_sp_acc = np.mean(gnomix_sp_preds[-1]==y_superpop)
    gnomix_gp_acc = np.mean(gnomix_preds[-1]==y_granular_pop)
    gnomix_gcd = gp2Gcd(pop_arr, gnomix_preds[-1], y_coord, params.dataset_dim)
    print(f" superpop accuracy:{gnomix_sp_acc}")
    print(f" granular pop accuracy:{gnomix_gp_acc}")
    print(f" gcd:{gnomix_gcd}")

gnomix_preds_all = np.vstack(gnomix_preds)
gnomix_sp_all = np.vstack(gnomix_sp_preds)
y_vcf_idx = np.vstack(curr_vcf_idx)
y_coord_all, y_superpop_all, y_granular_pop_all=getLabelPops(y_vcf_idx, "gnomix", pop_arr, idx2coordinates, copy.deepcopy(params))
gnomix_sp_acc_all = np.mean(gnomix_sp_all==y_superpop_all)
gnomix_gp_acc_all = np.mean(gnomix_preds_all==y_granular_pop_all)
gnomix_gcd_all = gp2Gcd(pop_arr, gnomix_preds_all, y_coord_all, params.dataset_dim)
print(f" \n For all the generations:")
print(f" superpop accuracy:{gnomix_sp_acc_all}")
print(f" granular pop accuracy:{gnomix_gp_acc_all}")
print(f" gcd:{gnomix_gcd_all}")


 Metrics for generation: 0
 number of samples :564
 superpop accuracy:0.9649032010734139
 granular pop accuracy:0.6007906843013227
 gcd:731.6905443983123
 Metrics for generation: 2
 number of samples :800
 superpop accuracy:0.9567195945945945
 granular pop accuracy:0.5746587837837838
 gcd:833.7782390661861
 Metrics for generation: 4
 number of samples :800
 superpop accuracy:0.9467804054054054
 granular pop accuracy:0.5528445945945946
 gcd:897.6635521956672
 Metrics for generation: 8
 number of samples :800
 superpop accuracy:0.9251013513513513
 granular pop accuracy:0.5177128378378378
 gcd:1050.6644337416844
 
 For all the generations:
 superpop accuracy:0.9470602181128497
 granular pop accuracy:0.558373454426086
 gcd:890.1344288270847


In [16]:
idx2coordinates = load_path(osp.join(os.environ.get('OUT_PATH'),'humans/labels/data_id_4_geo', params.coordinates), en_pickle=True)

gnomix_preds, gnomix_sp_preds, curr_vcf_idx=[], [], []
for i, gen in enumerate([0,2,4,8]):
    curr_vcf_idx.append(load_path(osp.join(os.environ.get('OUT_PATH'), \
                 'humans/labels/data_id_4_geo/test' ,'gen_' + str(gen) ,'mat_map.npy')))
    gnomix_preds.append(load_path(osp.join(os.environ.get('OUT_PATH') ,'/scratch/groups/cdbustam/ge3net_eval/ge3net_eval/lainet',\
                        'gen_' + str(gen) +'.npy')))
    gnomix_sp_preds.append(np.vectorize(gp2sp.get)(gnomix_preds[-1]))
    print(f" Metrics for generation: {gen}")
    print(f' number of samples :{curr_vcf_idx[-1].shape[0]}')
    y_coord, y_superpop, y_granular_pop=getLabelPopsPerSnp(curr_vcf_idx[-1], "lainet", pop_arr, idx2coordinates, copy.deepcopy(params))
    gnomix_sp_acc = np.mean(gnomix_sp_preds[-1]==y_superpop)
    gnomix_gp_acc = np.mean(gnomix_preds[-1]==y_granular_pop)
    gnomix_gcd = gp2Gcd(pop_arr, gnomix_preds[-1], y_coord, params.dataset_dim)
    print(f" superpop accuracy:{gnomix_sp_acc}")
    print(f" granular pop accuracy:{gnomix_gp_acc}")
    print(f" gcd:{gnomix_gcd}")

gnomix_preds_all = np.vstack(gnomix_preds)
gnomix_sp_all = np.vstack(gnomix_sp_preds)
y_vcf_idx = np.vstack(curr_vcf_idx)
y_coord_all, y_superpop_all, y_granular_pop_all=getLabelPopsPerSnp(y_vcf_idx, "lainet", pop_arr, idx2coordinates, copy.deepcopy(params))
gnomix_sp_acc_all = np.mean(gnomix_sp_all==y_superpop_all)
gnomix_gp_acc_all = np.mean(gnomix_preds_all==y_granular_pop_all)
gnomix_gcd_all = gp2Gcd(pop_arr, gnomix_preds_all, y_coord_all, params.dataset_dim)
print(f" \n For all the generations:")
print(f" superpop accuracy:{gnomix_sp_acc_all}")
print(f" granular pop accuracy:{gnomix_gp_acc_all}")
print(f" gcd:{gnomix_gcd_all}")


 Metrics for generation: 0
 number of samples :540
 superpop accuracy:0.9650200056708989
 granular pop accuracy:0.5341428366396074
 gcd:832.9132711822924
 Metrics for generation: 2
 number of samples :800
 superpop accuracy:0.9445899113134432
 granular pop accuracy:0.45188934264830977
 gcd:1084.4283512918382
 Metrics for generation: 4
 number of samples :800
 superpop accuracy:0.9349370963422703
 granular pop accuracy:0.43499640449261207
 gcd:1154.4467691813595
 Metrics for generation: 8
 number of samples :800
 superpop accuracy:0.9192513665290949
 granular pop accuracy:0.43175730128225326
 gcd:1237.3200480442
 
 For all the generations:
 superpop accuracy:0.9388209191871197
 granular pop accuracy:0.45692230289929525
 gcd:1098.8875174324155


In [None]:
idx2coordinates = load_path(osp.join(os.environ.get('OUT_PATH'),'humans/labels/data_id_4_geo', params.coordinates), en_pickle=True)

gnomix_preds, gnomix_sp_preds, curr_vcf_idx=[], [], []
for i, gen in enumerate([0,2,4,8]):
    curr_vcf_idx.append(load_path(osp.join(os.environ.get('OUT_PATH'), \
                 'humans/labels/data_id_4_geo/test' ,'gen_' + str(gen) ,'mat_map.npy')))
    gnomix_preds.append(load_path(osp.join(os.environ.get('OUT_PATH') ,'/scratch/groups/cdbustam/ge3net_eval_2/ge3net_eval/small',\
                        'gen_' + str(gen) +'.npy')))
    gnomix_sp_preds.append(np.vectorize(gp2sp.get)(gnomix_preds[-1]))
    print(f" Metrics for generation: {gen}")
    print(f' number of samples :{curr_vcf_idx[-1].shape[0]}')
    y_coord, y_superpop, y_granular_pop=getLabelPopsPerSnp(curr_vcf_idx[-1], "small", pop_arr, idx2coordinates, copy.deepcopy(params))
    gnomix_sp_acc = np.mean(gnomix_sp_preds[-1]==y_superpop)
    gnomix_gp_acc = np.mean(gnomix_preds[-1]==y_granular_pop)
    gnomix_gcd = gp2Gcd(pop_arr, gnomix_preds[-1], y_coord, params.dataset_dim)
    print(f" superpop accuracy:{gnomix_sp_acc}")
    print(f" granular pop accuracy:{gnomix_gp_acc}")
    print(f" gcd:{gnomix_gcd}")

gnomix_preds_all = np.vstack(gnomix_preds)
gnomix_sp_all = np.vstack(gnomix_sp_preds)
y_vcf_idx = np.vstack(curr_vcf_idx)
y_coord_all, y_superpop_all, y_granular_pop_all=getLabelPopsPerSnp(y_vcf_idx, "small", pop_arr, idx2coordinates, copy.deepcopy(params))
gnomix_sp_acc_all = np.mean(gnomix_sp_all==y_superpop_all)
gnomix_gp_acc_all = np.mean(gnomix_preds_all==y_granular_pop_all)
gnomix_gcd_all = gp2Gcd(pop_arr, gnomix_preds_all, y_coord_all, params.dataset_dim)
print(f" \n For all the generations:")
print(f" superpop accuracy:{gnomix_sp_acc_all}")
print(f" granular pop accuracy:{gnomix_gp_acc_all}")
print(f" gcd:{gnomix_gcd_all}")


 Metrics for generation: 0
 number of samples :540
 superpop accuracy:0.1962962962962963
 granular pop accuracy:0.03333333333333333
 gcd:6095.2288412066855
 Metrics for generation: 2
 number of samples :800
 superpop accuracy:0.20551390551652438
 granular pop accuracy:0.043309835858983646
 gcd:6046.585485533406
 Metrics for generation: 4
 number of samples :800
 superpop accuracy:0.1917231971267446
 granular pop accuracy:0.030159730317255284
 gcd:6186.379165737158
 Metrics for generation: 8
 number of samples :800
 superpop accuracy:0.20235521880218016
 granular pop accuracy:0.03237570492423049
 gcd:6035.365674518453


In [None]:
# idx2coordinates = load_path(osp.join(os.environ.get('OUT_PATH'),'humans/labels/data_id_4_geo', params.coordinates), en_pickle=True)
# lainet_predLs, curr_vcf_idx=[], []
# for i, gen in enumerate([0,2,4,8]):
#     curr_vcf_idx.append(load_path(osp.join(os.environ.get('OUT_PATH'), \
#                  'humans/labels/data_id_4_geo/test' ,'gen_' + str(gen) ,'mat_map.npy')))
#     # get true labels
#     y_coord, y_superpop, y_granular_pop=getLabelPops(curr_vcf_idx[-1], "lainet", pop_arr, \
#                                                          idx2coordinates, copy.deepcopy(params))
#     #load lainet predictions
#     lainet_raw_pred=load_path(osp.join('/scratch/groups/cdbustam/ge3net_eval/ge3net_eval/small', 'gen_' + str(gen) +'.npy'))
#     print(f'lainet_raw_pred shape{lainet_raw_pred.shape}')
#     lainet_pred = np.swapaxes(lainet_raw_pred,2,1)
#     lainet_pred=lainet_pred.reshape(lainet_pred.shape[0]*lainet_pred.shape[1], -1)
#     print(f"lainet_pred shape:{lainet_pred.shape}")
    
#     lainet_predLs.append(lainet_pred)
#     lainet_gp_acc = np.mean(lainet_predLs[-1]==y_granular_pop)
#     lainet_gcd = gp2Gcd(pop_arr, lainet_predLs[-1], y_coord, params.dataset_dim)
#     print(f"gen {gen} granular pop accuracy:{lainet_gp_acc}")
#     print(f" gcd:{lainet_gcd}")

In [None]:
# idx2coordinates = load_path(osp.join(os.environ.get('OUT_PATH'),'humans/labels/data_id_4_geo', params.coordinates), en_pickle=True)
# lainet_predLs, curr_vcf_idx=[], []
# for i, gen in enumerate([0,2,4,8]):
#     curr_vcf_idx.append(load_path(osp.join(os.environ.get('OUT_PATH'), \
#                  'humans/labels/data_id_4_geo/test' ,'gen_' + str(gen) ,'mat_map.npy')))
#     # get true labels
#     y_coord, y_superpop, y_granular_pop=getLabelPops(curr_vcf_idx[-1], "lainet", pop_arr, \
#                                                          idx2coordinates, copy.deepcopy(params))
#     #load lainet predictions
#     lainet_raw_pred=load_path(osp.join('/scratch/groups/cdbustam/ge3net_eval/ge3net_eval/lainet', f'gen_{i}.npy'))
#     lainet_pred = np.swapaxes(lainet_raw_pred,2,1)
#     lainet_pred=lainet_pred.reshape(lainet_pred.shape[0]*lainet_pred.shape[1], -1)
#     print(f"lainet_pred shape:{lainet_pred.shape}")
    
#     lainet_predLs.append(lainet_pred)
#     lainet_gp_acc = np.mean(lainet_predLs[-1]==y_granular_pop)
#     lainet_gcd = gp2Gcd(pop_arr, lainet_predLs[-1], y_coord, params.dataset_dim)
#     print(f"gen {gen} granular pop accuracy:{lainet_gp_acc}")
#     print(f" gcd:{lainet_gcd}")

In [16]:
config['model.loadBest']

True

In [17]:
params.evalExtraMainLosses=True
params.pretrained=False
params.optimizer='Adam'
_, _, model, _=inference.main(config, params)

INFO: __init__:src.main.dataset: Loading test Dataset
INFO: __init__:src.main.dataset:Loading gen 0
INFO: __init__:src.main.dataset: snps data: (540, 317410)


 device used: cuda
Loading the datasets...


INFO: __init__:src.main.dataset: y_labels data :(540, 317410)
INFO: __init__:src.main.dataset:Loading gen 2
INFO: __init__:src.main.dataset: snps data: (800, 317410)
INFO: __init__:src.main.dataset: y_labels data :(800, 317410)
INFO: __init__:src.main.dataset:Loading gen 4
INFO: __init__:src.main.dataset: snps data: (800, 317410)
INFO: __init__:src.main.dataset: y_labels data :(800, 317410)
INFO: __init__:src.main.dataset:Loading gen 8
INFO: __init__:src.main.dataset: snps data: (800, 317410)
INFO: __init__:src.main.dataset: y_labels data :(800, 317410)
INFO: transform_data:src.main.dataset:Transforming the data


Finished '_geoConvertLatLong2nVec' in 0.0621 secs
Finished 'mapping_func' in 0.3107 secs
Finished 'pop_mapping' in 0.1330 secs
Finished 'pop_mapping' in 0.1351 secs
Finished 'transform_data' in 31.7426 secs
Finished '__init__' in 34.1165 secs
Parameter count for model AuxNetwork:31747503
Parameter count for model TransformerModel:107528
Parameter count for model BiRNN:34499
Parameter count for model logits_Block:2289
Total parameters:31891819
Finished 'load_model' in 1.9346 secs
best val loss metrics : {'loss_main': 866.6805836590546, 'loss_aux': 2160.2513772460106, 'l1_loss': 0.20087590047931303, 'mse': 0.045047710068923504, 'smooth_l1': 0.022466416130238673, 'weighted_loss': 0.20087590047931303}
at epoch : 227
train loss metrics: {'loss_main': 260.368462953392, 'loss_aux': 1561.6972164885574, 'l1_loss': 0.05821169218555197, 'mse': 0.005354038258400175, 'smooth_l1': 0.0026711223615385584, 'weighted_loss': 0.05821169218555197}
best val cp metrics : {'loss_cp': 0.49852718259127005, 'prM

In [18]:
def getWin2Snp(y_pred, win_size, truncChmLen, trueChmLen=317410):
    """
    y_pred is of shape:[Number of samples x win_size x 3]
    """
    y_pred_expanded = np.zeros((y_pred.shape[0],trueChmLen,y_pred.shape[-1]))
    y_pred = y_pred[:,:,None,:]
    y_pred = np.repeat(y_pred,win_size,axis=2).reshape(y_pred.shape[0],-1,3)
    y_pred_expanded[:,:truncChmLen,:] = y_pred
    y_pred_expanded[:,-(trueChmLen-truncChmLen):,:] = np.repeat(y_pred[:,-1,:][:,None,:],(trueChmLen-truncChmLen),axis=1)
    return y_pred_expanded
    

In [19]:
a=np.random.randint(1,10,(3,4,3))
print(f"a:{a}")
a_expanded = getWin2Snp(y_pred=a, win_size=4, truncChmLen=16, trueChmLen=18)
print(f"a_expanded :{a_expanded.shape}, {a_expanded}")

a:[[[4 7 6]
  [5 9 2]
  [8 7 9]
  [1 6 1]]

 [[7 3 1]
  [6 3 7]
  [4 8 1]
  [1 4 3]]

 [[4 2 4]
  [2 4 8]
  [2 8 5]
  [1 6 2]]]
a_expanded :(3, 18, 3), [[[4. 7. 6.]
  [4. 7. 6.]
  [4. 7. 6.]
  [4. 7. 6.]
  [5. 9. 2.]
  [5. 9. 2.]
  [5. 9. 2.]
  [5. 9. 2.]
  [8. 7. 9.]
  [8. 7. 9.]
  [8. 7. 9.]
  [8. 7. 9.]
  [1. 6. 1.]
  [1. 6. 1.]
  [1. 6. 1.]
  [1. 6. 1.]
  [1. 6. 1.]
  [1. 6. 1.]]

 [[7. 3. 1.]
  [7. 3. 1.]
  [7. 3. 1.]
  [7. 3. 1.]
  [6. 3. 7.]
  [6. 3. 7.]
  [6. 3. 7.]
  [6. 3. 7.]
  [4. 8. 1.]
  [4. 8. 1.]
  [4. 8. 1.]
  [4. 8. 1.]
  [1. 4. 3.]
  [1. 4. 3.]
  [1. 4. 3.]
  [1. 4. 3.]
  [1. 4. 3.]
  [1. 4. 3.]]

 [[4. 2. 4.]
  [4. 2. 4.]
  [4. 2. 4.]
  [4. 2. 4.]
  [2. 4. 8.]
  [2. 4. 8.]
  [2. 4. 8.]
  [2. 4. 8.]
  [2. 8. 5.]
  [2. 8. 5.]
  [2. 8. 5.]
  [2. 8. 5.]
  [1. 6. 2.]
  [1. 6. 2.]
  [1. 6. 2.]
  [1. 6. 2.]
  [1. 6. 2.]
  [1. 6. 2.]]]


In [20]:
idx2coordinates = load_path(osp.join(os.environ.get('OUT_PATH'),'humans/labels/data_id_4_geo', params.coordinates), en_pickle=True)
granular_pop_dict = load_path(osp.join(labels_path, 'granular_pop.pkl'), en_pickle=True)
superop_dict=load_path(osp.join(labels_path, 'superpop.pkl'), en_pickle=True)
rev_gp_dict={v:k for k,v in granular_pop_dict.items()}
rev_sp_dict={v:k for k,v in superop_dict.items()}
def getG3Net_Metrics_Classification(idx2coordinates, model, params, rev_gp_dict, rev_sp_dict):
    g3net_preds, g3net_sp_preds, g3net_gp_preds, curr_vcf_idx, curr_snps=[], [], [], [], []
    y_sp_true, y_gp_true = [], []
    gcd_classSuperpop= {}
    gcd_classGranularpop={}
    acc_classSuperpop= {}
    acc_classGranularpop={}
    model.eval()
    gcd_acc_ByGen=pd.DataFrame(columns=('Gen', 'Samples', 'Gcd'))
    gcd_acc_classSuperpop=pd.DataFrame(columns=('Gen', 'Superpop', 'Samples', 'Gcd', 'Acc'))
    gcd_acc_classGranularpop=pd.DataFrame(columns=('Gen', 'Granularpop', 'Samples', 'Gcd', 'Acc'))
    gcdObj=GcdLoss()
    mask=torch.ones(1,params.n_win,1).to(params.device)
    for i, gen in enumerate([0,2,4,8]):
        curr_vcf_idx = load_path(osp.join(os.environ.get('OUT_PATH'), \
                     'humans/labels/data_id_4_geo/test' ,'gen_' + str(gen) ,'mat_map.npy'))
        curr_snps = load_path(osp.join(os.environ.get('OUT_PATH'), \
                     'humans/labels/data_id_4_geo/test' ,'gen_' + str(gen) ,'mat_vcf_2d.npy'))
        x_data = torch.tensor(curr_snps[:,:params.chmlen], device=params.device).float()
        g3net_results = model(x_data, mask=mask)
        
        # convert from window to snp level
        pred_coord = g3net_results.coord_main.detach().cpu().numpy()
        pred_coord = getWin2Snp(pred_coord, params.win_size, params.chmlen, trueChmLen=317410)
        
        # get true labels
        y_coord, y_superpop, y_granular_pop=getLabelPopsPerSnp(curr_vcf_idx, "ge3net", pop_arr, \
                                                             idx2coordinates, copy.deepcopy(params))
        
        
        g3net_gcd = np.mean(gcdObj.rawGcd(pred_coord, y_coord))
        print(f"\n generation:{gen}")
        print(f"ge3net_gcd:{g3net_gcd}")
        
        gcd_acc_ByGen.loc[len(gcd_acc_ByGen)] = [gen, curr_vcf_idx[-1].shape[0], g3net_gcd]
        del g3net_results, x_data
        torch.cuda.empty_cache()
        
        
    return gcd_acc_ByGen


In [21]:
gcd_acc_ByGen_class=\
getG3Net_Metrics_Classification(idx2coordinates, model, params, rev_gp_dict, rev_sp_dict)


 generation:0
ge3net_gcd:588.9539274666732

 generation:2
ge3net_gcd:736.8628045032455

 generation:4
ge3net_gcd:873.8558979001231

 generation:8
ge3net_gcd:1043.3416924241556


In [55]:
idx2coordinates = load_path(osp.join(os.environ.get('OUT_PATH'),'humans/labels/data_id_1_geo', params.coordinates), en_pickle=True)
granular_pop_dict = load_path(osp.join(labels_path, 'granular_pop.pkl'), en_pickle=True)
superop_dict=load_path(osp.join(labels_path, 'superpop.pkl'), en_pickle=True)
rev_gp_dict={v:k for k,v in granular_pop_dict.items()}
rev_sp_dict={v:k for k,v in superop_dict.items()}
def getG3Net_Metrics_Classification(idx2coordinates, model, params, rev_gp_dict, rev_sp_dict):
    g3net_preds, g3net_sp_preds, g3net_gp_preds, curr_vcf_idx, curr_snps=[], [], [], [], []
    y_sp_true, y_gp_true = [], []
    gcd_classSuperpop= {}
    gcd_classGranularpop={}
    acc_classSuperpop= {}
    acc_classGranularpop={}
    model.eval()
    gcd_acc_ByGen=pd.DataFrame(columns=('Gen', 'Samples', 'Gcd', 'SpAcc', 'GpAcc'))
    gcd_acc_classSuperpop=pd.DataFrame(columns=('Gen', 'Superpop', 'Samples', 'Gcd', 'Acc'))
    gcd_acc_classGranularpop=pd.DataFrame(columns=('Gen', 'Granularpop', 'Samples', 'Gcd', 'Acc'))
    
    for i, gen in enumerate([0,2,4,8]):
        curr_vcf_idx.append(load_path(osp.join(os.environ.get('OUT_PATH'), \
                     'humans/labels/data_id_1_geo/test' ,'gen_' + str(gen) ,'mat_map.npy')))
        curr_snps.append(load_path(osp.join(os.environ.get('OUT_PATH'), \
                     'humans/labels/data_id_1_geo/test' ,'gen_' + str(gen) ,'mat_vcf_2d.npy')))
        g3net_results = model(torch.tensor(curr_snps[-1][:,:params.chmlen], device=params.device).float(), \
                              mask=torch.ones(1,params.n_win,1).to(params.device))
        
        y_pred = torch.argmax(F.log_softmax(g3net_results.coord_main, dim=-1), dim=-1)
        y_pred = y_pred.detach().cpu().numpy()
        
        # get true labels
        y_coord, y_superpop, y_granular_pop=getLabelPops(curr_vcf_idx[-1], "ge3net", pop_arr, \
                                                             idx2coordinates, copy.deepcopy(params))
        y_sp_true.append(y_superpop)
        y_gp_true.append(y_granular_pop)
        
        g3net_gp_preds.append(y_pred)
        g3net_sp_preds.append(np.vectorize(gp2sp.get)(y_pred))
        g3net_sp_acc = np.mean(g3net_sp_preds[-1]==y_superpop)
        g3net_gp_acc = np.mean(g3net_gp_preds[-1]==y_granular_pop)
        g3net_gcd = gp2Gcd(pop_arr, y_pred, y_coord, 2)
        loss = gp2Gcd2(pop_arr, y_pred, y_coord, 2)
        
        
        print(f"\n generation:{gen}")
        print(f"granular pop accuracy:{g3net_gp_acc}")
        print(f"superpop accuracy:{g3net_sp_acc}")
        print(f"ge3net_gcd:{g3net_gcd}")
        
        gcd_acc_ByGen.loc[len(gcd_acc_ByGen)] = [gen, curr_vcf_idx[-1].shape[0], g3net_gcd, g3net_sp_acc, g3net_gp_acc]
        
        superpop_num=np.unique(y_superpop).astype(int)
        for k in superpop_num:
            idx=np.nonzero(y_superpop==k)
            num_samples = len(idx[0])
            if num_samples == 0:
                continue
            if gcd_classSuperpop.get(k) is None: 
                gcd_classSuperpop[k]=Running_Average()
                acc_classSuperpop[k]=Running_Average()
            gcd_classSuperpop[k].update(np.sum(loss[idx[0], idx[1]]), num_samples)
            acc_classSuperpop[k].update(np.sum(g3net_sp_preds[-1][idx[0], idx[1]]==k), num_samples)
            gcd_acc_classSuperpop.loc[len(gcd_acc_classSuperpop)] = [gen, rev_sp_dict[k], num_samples,
                                np.mean(loss[idx[0], idx[1]]), np.mean(g3net_sp_preds[-1][idx[0], idx[1]]==k)]
            
        granularpop_num=np.unique(y_granular_pop).astype(int)
        for k in granularpop_num:
            idx=np.nonzero(y_granular_pop==k)
            num_samples = len(idx[0])
            if num_samples == 0:
                continue
            if gcd_classGranularpop.get(k) is None: 
                gcd_classGranularpop[k]=Running_Average()
                acc_classGranularpop[k]=Running_Average()
            gcd_classGranularpop[k].update(np.sum(loss[idx[0], idx[1]]), num_samples)  
            acc_classGranularpop[k].update(np.sum(g3net_gp_preds[-1][idx[0], idx[1]]==k), num_samples)
            gcd_acc_classGranularpop.loc[len(gcd_acc_classGranularpop)] = [gen, rev_gp_dict[k], num_samples, \
                                                               np.mean(loss[idx[0], idx[1]]), \
                                                               np.mean(g3net_gp_preds[-1][idx[0], idx[1]]==k)]
            
            
    y_sp_true_all = np.vstack(y_sp_true)
    y_gp_true_all = np.vstack(y_gp_true)
    g3net_sp_all = np.vstack(g3net_sp_preds)
    g3net_gp_all = np.vstack(g3net_gp_preds)
    y_vcf_idx = np.vstack(curr_vcf_idx)
    y_coord_all, y_superpop_all, y_granular_pop_all=getLabelPops(y_vcf_idx, "ge3net", \
                                                                 pop_arr, idx2coordinates, copy.deepcopy(params))
    g3net_sp_acc_all = np.mean(g3net_sp_all==y_superpop_all)
    g3net_gp_acc_all = np.mean(g3net_gp_all==y_granular_pop_all) 
    g3net_gcd_all = gp2Gcd(pop_arr, g3net_gp_all, y_coord_all, 2)
    print(f" \n For all the generations:")
    print(f" superpop accuracy:{g3net_sp_acc_all}")
    print(f" granular pop accuracy:{g3net_gp_acc_all}")
    print(f" gcd:{g3net_gcd_all}")
    
    for i, key in enumerate(gcd_classSuperpop.keys()):
        gcd_acc_classSuperpop.loc[len(gcd_acc_classSuperpop)] = [-1, rev_sp_dict[key], \
                                                                 gcd_classSuperpop[key].steps, \
                                        gcd_classSuperpop[key](), acc_classSuperpop[key]()]
    for i, key in enumerate(gcd_classGranularpop.keys()):
        gcd_acc_classGranularpop.loc[len(gcd_acc_classGranularpop)] = [-1, rev_gp_dict[key], \
                                                                       gcd_classGranularpop[key].steps, \
                                           gcd_classGranularpop[key](), acc_classGranularpop[key]()]
        
    return gcd_acc_ByGen, gcd_acc_classSuperpop, gcd_acc_classGranularpop


In [130]:
gcd_acc_ByGen_class, gcd_acc_classSuperpop_class, gcd_acc_classGranularpop_class=getG3Net_Metrics_Classification(idx2coordinates, model, params, rev_gp_dict, rev_sp_dict)


 generation:0
granular pop accuracy:0.6357809248942882
superpop accuracy:0.9769895071257578
ge3net_gcd:638.4868143877925

 generation:2
granular pop accuracy:0.5713801261829653
superpop accuracy:0.9574250788643534
ge3net_gcd:836.1259360992749

 generation:4
granular pop accuracy:0.5161790220820189
superpop accuracy:0.9438446372239747
ge3net_gcd:962.1086060850004

 generation:8
granular pop accuracy:0.4674171924290221
superpop accuracy:0.9203509463722398
ge3net_gcd:1170.2223492864266
 
 For all the generations:
 superpop accuracy:0.9474759149754999
 granular pop accuracy:0.5406752746948663
 gcd:922.6963820820773


In [56]:
def getcps(y_pred):
    
    n_vec_dim=y_pred.shape[-1]
    data_tensor = torch.tensor(y_pred).float()
    batch_size_cpd = data_tensor.shape[0]
    mu_prior = torch.zeros((batch_size_cpd, 1,n_vec_dim))
    mean_var=torch.mean(torch.var(data_tensor, dim =1),dim=0).unsqueeze(0)
    cov_prior = (mean_var.repeat(batch_size_cpd,1).unsqueeze(1)* torch.eye(n_vec_dim)).reshape(batch_size_cpd,1,n_vec_dim,n_vec_dim)
    cov_x = cov_prior
    likelihood_model = Multivariate_Gaussian(mu_prior, cov_prior, cov_x)
    T = params.n_win
    model_cpd = BOCD.BOCD(None, T, likelihood_model, batch_size_cpd)
    _,_,_,_=model_cpd.run_recursive(data_tensor, 'cpu')
    predBOCDSample=model_cpd.cp
    pred_cps_BOCD=getCpPred(cpMethod.BOCD.name, predBOCDSample, 10.0, y_pred.shape[0], y_pred.shape[1])
    
    pred_cps_BOCD=pred_cps_BOCD.detach().cpu().numpy()
    return pred_cps_BOCD

In [57]:
def getCpBasedCentroid(cp_pred, coord_main, labels_path, labelType="granular_pop", distType="L2"):
    """
    Given a batch of haplotypes with changepoints, return the centroid of coordinates per segment
    and the option to return the nearest neighbour mapped label by superpop or granular pop
    of that segment
    Input:
        cp_pred: shape (BxW): batch size x window size matrix of 0 and 1 with 1 for a changepoint
        coord_main: shape (BXWX2) : batch_size x window_size x (lat, long) as predictions
    Output:
        centroid_coord: shape (BXWX2) : batch_size x window_size x (lat, long) as predictions with centroid 
        coordinates
        centroid_label: shape (BXW) : batch_size x window_size as superpop/granular_pop as specified from the
        kwarg label_type for the nearest neighbour class from training samples
    """
    
    last_win=cp_pred.shape[1]
    # mark the last window as changepoint
    cp_pred[:,-1]=1.0
    # get the indices of 1 and stack them up
    cp_idx=np.where(cp_pred==1)
    cp_wins=cp_idx[1]
    # transform the cp indices with the last win size. This will enable to keep track of the haplotype 
    # number and enable to transform back from 1d stacked array to 2d array with batch_size x window size
    transformedCp_idx = cp_idx[0] * last_win + cp_wins
    # transformedCp_idx is a stacked array of segments. Now convert this to a stacked array of windows
    centroid_coord_tmp=np.zeros((len(transformedCp_idx),3))
    centroid_coord=np.zeros_like(coord_main)
    centroid_label = np.zeros_like(cp_pred)
    cpWin_prev=0
    for i, win in enumerate(transformedCp_idx):
        # find haplo num by round down division with the max window
        haplo_num = win // last_win
        # find cp_win by remainder
        cp_win = win % last_win
        centroid_coord_tmp[i,0]=np.mean(coord_main[haplo_num, cpWin_prev:cp_win+1,0])
        centroid_coord_tmp[i,1]=np.mean(coord_main[haplo_num, cpWin_prev:cp_win+1,1])
        centroid_coord_tmp[i,2]=np.mean(coord_main[haplo_num, cpWin_prev:cp_win+1,2])
        centroid_coord[haplo_num, cpWin_prev:cp_win+1, :] = centroid_coord_tmp[i,:]
        cpWin_prev = (cp_win+1)% last_win
    
    if labels_path is not None:
        centroid_label_tmp=nearestNeighbourMapping(labels_path, centroid_coord_tmp, labelType=labelType , distType=distType)

        for i, win in enumerate(transformedCp_idx):
            # find haplo num by round down division with the max window
            haplo_num = win // last_win
            # find cp_win by remainder
            cp_win = win % last_win
            centroid_label[haplo_num, cpWin_prev:cp_win+1] = centroid_label_tmp[i]
            cpWin_prev = (cp_win+1)% last_win

    return centroid_coord, centroid_label

In [58]:
# G3Net - optimized for gcd
idx2coordinates = load_path(osp.join(os.environ.get('OUT_PATH'),'humans/labels/data_id_1_geo', params.coordinates), en_pickle=True)
granular_pop_dict = load_path(osp.join(labels_path, 'granular_pop.pkl'), en_pickle=True)
superop_dict=load_path(osp.join(labels_path, 'superpop.pkl'), en_pickle=True)
rev_gp_dict={v:k for k,v in granular_pop_dict.items()}
rev_sp_dict={v:k for k,v in superop_dict.items()}
def getG3Net_Metrics(idx2coordinates, model, params, rev_gp_dict, rev_sp_dict):
    g3net_preds, g3net_sp_preds, g3net_gp_preds, curr_vcf_idx, curr_snps=[], [], [], [], []
    y_sp_true, y_gp_true = [], []
    gcd_classSuperpop= {}
    gcd_classGranularpop={}
    acc_classSuperpop= {}
    acc_classGranularpop={}
    model.eval()
    gcd_acc_ByGen=pd.DataFrame(columns=('Gen', 'Samples', 'Gcd', 'SpAcc', 'GpAcc'))
    gcd_acc_classSuperpop=pd.DataFrame(columns=('Gen', 'Superpop', 'Samples', 'Gcd', 'Acc'))
    gcd_acc_classGranularpop=pd.DataFrame(columns=('Gen', 'Granularpop', 'Samples', 'Gcd', 'Acc'))
    
    for i, gen in enumerate([0,2,4,8]):
        curr_vcf_idx.append(load_path(osp.join(os.environ.get('OUT_PATH'), \
                     'humans/labels/data_id_1_geo/test' ,'gen_' + str(gen) ,'mat_map.npy')))
        curr_snps.append(load_path(osp.join(os.environ.get('OUT_PATH'), \
                     'humans/labels/data_id_1_geo/test' ,'gen_' + str(gen) ,'mat_vcf_2d.npy')))
        enable_mcDroput=False
        if enable_mcDroput:
            test_dataset = Haplotype(config['data.dataset_type'], params, data_path, labels_path=labels_path)
            test_generator = torch.utils.data.DataLoader(test_dataset, batch_size=params.batch_size, num_workers=0, pin_memory=True)
            PredLs=[]
            for i, data_x in enumerate(test_generator):
                test_result = model._batch_validate_1_step(data_x)
                PredLs.append(torch.stack(test_result.coord_mainLs, dim=0).contiguous().detach().cpu().numpy())
            
            test_result.coord_main=np.concatenate((PredLs), axis=1)
            g3net_results=test_result.coord_main
            y_pred = g3net_results.coord_main.mean(0).detach().cpu().numpy()
        
        else:
            print(f"{curr_snps[-1][:,:params.chmlen].shape}, {params.n_win}, {params.win_size}")
            g3net_results = model(torch.tensor(curr_snps[-1][:,:params.chmlen], device=params.device).float(), \
                                  mask=torch.ones(1,params.n_win,1).to(params.device))
            y_pred = g3net_results.coord_main.detach().cpu().numpy()
        print(f"y_pred shape:{y_pred.shape}")
        
        
        # get true labels
        y_coord, y_superpop, y_granular_pop=getLabelPops(curr_vcf_idx[-1], "ge3net", pop_arr, \
                                                             idx2coordinates, copy.deepcopy(params))
        gcdObj=GcdLoss()
        metricsByCp=False
        GcdByPreds=True
        onlyGCd=False
        if metricsByCp:
            cp_pred = getcps(y_pred)
            centroid_coord, mappedGpArr = getCpBasedCentroid(cp_pred, y_pred, labels_path)
            mappedSpArr = np.vectorize(gp2sp.get)(mappedGpArr)
            loss=gcdObj.rawGcd(centroid_coord, y_coord)
            y_pred=centroid_coord
            
        else:
            if not onlyGCd:
                pred_by_win = y_pred.reshape(-1,params.dataset_dim)
                mappedGpArr = nearestNeighbourMapping(labels_path, pred_by_win, labelType="granular_pop")
                mappedGpArr = mappedGpArr.reshape(-1,params.n_win)
                mappedSpArr = nearestNeighbourMapping(labels_path, pred_by_win, labelType="superpop")
                mappedSpArr = mappedSpArr.reshape(-1,params.n_win)

            if GcdByPreds:
                loss=gcdObj.rawGcd(y_pred, y_coord)

            else:
                # this is alt loss based on GP mapped using NN
                loss=gp2Gcd2(pop_arr, mappedGpArr, y_coord, params.dataset_dim) 
        
        y_sp_true.append(y_superpop)
        y_gp_true.append(y_granular_pop)
        g3net_preds.append(y_pred)
        g3net_gcd=np.mean(loss)
        if not onlyGCd:
            g3net_sp_preds.append(mappedSpArr)
            g3net_gp_preds.append(mappedGpArr)
            g3net_sp_acc = np.mean(g3net_sp_preds[-1]==y_superpop)
            g3net_gp_acc = np.mean(g3net_gp_preds[-1]==y_granular_pop)
            gcd_acc_ByGen.loc[len(gcd_acc_ByGen)] = [gen, curr_vcf_idx[-1].shape[0], g3net_gcd, g3net_sp_acc, g3net_gp_acc]
        else:
            gcd_acc_ByGen.loc[len(gcd_acc_ByGen)] = [gen, curr_vcf_idx[-1].shape[0], g3net_gcd, 0, 0]
        
        superpop_num=np.unique(y_superpop).astype(int)
        granularpop_num=np.unique(y_granular_pop).astype(int)
        
        if not onlyGCd:
            for k in superpop_num:
                idx=np.nonzero(y_superpop==k)
                num_samples = len(idx[0])
                if num_samples == 0:
                    continue
                if gcd_classSuperpop.get(k) is None: 
                    gcd_classSuperpop[k]=Running_Average()
                    acc_classSuperpop[k]=Running_Average()
                gcd_classSuperpop[k].update(np.sum(loss[idx[0], idx[1]]), num_samples)
                acc_classSuperpop[k].update(np.sum(g3net_sp_preds[-1][idx[0], idx[1]]==k), num_samples)
                gcd_acc_classSuperpop.loc[len(gcd_acc_classSuperpop)] = [gen, rev_sp_dict[k], num_samples,
                                                                     np.mean(loss[idx[0], idx[1]]),\
                                                                    np.mean(g3net_sp_preds[-1][idx[0], idx[1]]==k)]

            for k in granularpop_num:
                idx=np.nonzero(y_granular_pop==k)
                num_samples = len(idx[0])
                if num_samples == 0:
                    continue
                if gcd_classGranularpop.get(k) is None: 
                    gcd_classGranularpop[k]=Running_Average()
                    acc_classGranularpop[k]=Running_Average()
                gcd_classGranularpop[k].update(np.sum(loss[idx[0], idx[1]]), num_samples)  
                acc_classGranularpop[k].update(np.sum(g3net_gp_preds[-1][idx[0], idx[1]]==k), num_samples)
                gcd_acc_classGranularpop.loc[len(gcd_acc_classGranularpop)] = [gen, rev_gp_dict[k], num_samples, \
                                                                   np.mean(loss[idx[0], idx[1]]), \
                                                                   np.mean(g3net_gp_preds[-1][idx[0], idx[1]]==k)]
            
        del loss
        
    y_sp_true_all = np.vstack(y_sp_true)
    y_gp_true_all = np.vstack(y_gp_true)
    g3net_preds_all = np.vstack(g3net_preds)
    y_vcf_idx = np.vstack(curr_vcf_idx)
    y_coord_all, y_superpop_all, y_granular_pop_all=getLabelPops(y_vcf_idx, "ge3net", \
                                                                 pop_arr, idx2coordinates, copy.deepcopy(params))
    print(f" \n For all the generations:")
    
    if not onlyGCd: 
        g3net_sp_all = np.vstack(g3net_sp_preds)
        g3net_gp_all = np.vstack(g3net_gp_preds)
        g3net_sp_acc_all = np.mean(g3net_sp_all==y_superpop_all)
        g3net_gp_acc_all = np.mean(g3net_gp_all==y_granular_pop_all)
        print(f" superpop accuracy:{g3net_sp_acc_all}")
        print(f" granular pop accuracy:{g3net_gp_acc_all}")
        for i, key in enumerate(gcd_classSuperpop.keys()):
            gcd_acc_classSuperpop.loc[len(gcd_acc_classSuperpop)] = [-1, rev_sp_dict[key], \
                                                                     gcd_classSuperpop[key].steps, \
                                            gcd_classSuperpop[key](), acc_classSuperpop[key]()]

        for i, key in enumerate(gcd_classGranularpop.keys()):
            gcd_acc_classGranularpop.loc[len(gcd_acc_classGranularpop)] = [-1, rev_gp_dict[key], \
                                                                           gcd_classGranularpop[key].steps, \
                                               gcd_classGranularpop[key](), acc_classGranularpop[key]()]
        del gcd_classSuperpop, gcd_classGranularpop, acc_classGranularpop, acc_classSuperpop
    
    gcdObj=GcdLoss()
    g3net_gcd_all = np.mean(gcdObj.rawGcd(g3net_preds_all, y_coord_all))
    print(f" gcd:{g3net_gcd_all}")

    return gcd_acc_ByGen, gcd_acc_classSuperpop, gcd_acc_classGranularpop, g3net_sp_all, g3net_gp_all, \
            y_sp_true_all, y_gp_true_all
    

In [59]:
def gp2Gcd2(pop_arr, gpPreds, y_coord, dim):
    idx2gp={k:v for k,v in zip(pop_arr[:,1], pop_arr[:,2])}
    gp2idx={k:v for v,k in idx2gp.items()}
    gp2coordinates={i:idx2coordinates[j] for i,j in gp2idx.items()}
    pred_coord = mapping_func(gpPreds, gp2coordinates, dim)
    gcdObj=GcdLoss()
    gnomix_gcd=gcdObj.rawGcd(pred_coord, y_coord)
    return gnomix_gcd

In [60]:
gcd_acc_ByGen, gcd_acc_classSuperpop, gcd_acc_classGranularpop, \
    g3net_sp_all, g3net_gp_all,y_superpop, y_granular_pop = \
    getG3Net_Metrics(idx2coordinates, model, params, rev_gp_dict, rev_sp_dict)
print(f" gcd_acc_ByGen accuracy:{gcd_acc_ByGen}")
# print(f" gcd_acc_classSuperpop accuracy:{gcd_acc_classSuperpop}")
# print(f" gcd_acc_classGranularpop accuracy:{gcd_acc_classGranularpop}")

(564, 317000), 634, 500
y_pred shape:(564, 634, 3)
Finished 'nearestNeighbourMapping' in 49.0742 secs
Finished 'nearestNeighbourMapping' in 49.0439 secs
(800, 317000), 634, 500
y_pred shape:(800, 634, 3)
Finished 'nearestNeighbourMapping' in 75.3704 secs
Finished 'nearestNeighbourMapping' in 74.6345 secs
(800, 317000), 634, 500
y_pred shape:(800, 634, 3)
Finished 'nearestNeighbourMapping' in 75.7987 secs
Finished 'nearestNeighbourMapping' in 75.7819 secs
(800, 317000), 634, 500
y_pred shape:(800, 634, 3)
Finished 'nearestNeighbourMapping' in 76.2164 secs
Finished 'nearestNeighbourMapping' in 75.8730 secs
 
 For all the generations:
 superpop accuracy:0.9273607155476655
 granular pop accuracy:0.27359544821772946
 gcd:1036.0143293060185
 gcd_acc_ByGen accuracy:   Gen  Samples          Gcd     SpAcc     GpAcc
0  0.0    564.0   833.230847  0.953669  0.331216
1  2.0    800.0   972.801258  0.938078  0.286088
2  4.0    800.0  1054.150112  0.925235  0.262261
3  8.0    800.0  1224.053973  0.900

In [61]:
pd.set_option('display.max_rows', None)
gcd_acc_classGranularpop

Unnamed: 0,Gen,Granularpop,Samples,Gcd,Acc
0,0,British,11412,761.422327,0.480459
1,0,Finnish,12680,587.887068,0.704968
2,0,Southern Han Chinese,12680,721.143011,0.090615
3,0,Dai Chinese,11412,925.746823,0.260778
4,0,Spanish,13948,1175.835541,0.042013
5,0,Peruvian,3804,2272.710375,0.65694
6,0,Punjabi,12680,1191.912459,0.246924
7,0,Kinh Vietnamese,12680,1326.016406,0.180915
8,0,Gambian Mandinka,13948,680.012556,0.011471
9,0,Esan,12680,617.86645,0.068139


In [121]:
from sklearn.metrics import confusion_matrix
np.set_printoptions(threshold=np.inf)
print(f"y_granular_pop :{y_granular_pop.shape}, y_granular_pop_pred :{g3net_gp_all.shape}")
gp_cm = confusion_matrix(y_granular_pop.reshape(-1,1), g3net_gp_all.reshape(-1,1))

y_granular_pop :(2964, 317), y_granular_pop_pred :(2964, 317)


In [179]:
gp_cm[-1,:].shape, len(g3net_gp_all[g3net_gp_all>71.0]), len(np.unique(y_granular_pop)), len(np.unique(g3net_gp_all))

((125,), 125845, 69, 123)

In [23]:
pd.set_option('display.max_rows', None)
gcd_acc_classGranularpop

Unnamed: 0,Gen,Granularpop,Samples,Gcd,Acc
0,0,British,5706,671.928936,0.451104
1,0,Finnish,6340,563.172199,0.679495
2,0,Southern Han Chinese,6340,539.866644,0.246372
3,0,Dai Chinese,5706,863.095733,0.296179
4,0,Spanish,6974,842.046628,0.049039
5,0,Peruvian,1902,1701.903645,0.821767
6,0,Punjabi,6340,1494.439866,0.130915
7,0,Kinh Vietnamese,6340,1229.183488,0.343691
8,0,Gambian Mandinka,6974,484.901935,0.200029
9,0,Esan,6340,394.661268,0.230915


In [134]:
# Gnomix
idx2coordinates = load_path(osp.join(os.environ.get('OUT_PATH'),'humans/labels/data_id_1_geo', params.coordinates), en_pickle=True)
granular_pop_dict = load_path(osp.join(labels_path, 'granular_pop.pkl'), en_pickle=True)
superop_dict=load_path(osp.join(labels_path, 'superpop.pkl'), en_pickle=True)
rev_gp_dict={v:k for k,v in granular_pop_dict.items()}
rev_sp_dict={v:k for k,v in superop_dict.items()}
gnomix_preds, gnomix_sp_preds, curr_vcf_idx=[], [], []
gcd_classSuperpop_Gnomix= {}
gcd_classGranularpop_Gnomix={}
acc_classSuperpop_Gnomix= {}
acc_classGranularpop_Gnomix={}
gcd_acc_ByGen_Gnomix=pd.DataFrame(columns=('Gen', 'Samples', 'Gcd', 'SpAcc', 'GpAcc'))
gcd_acc_classSuperpop_Gnomix=pd.DataFrame(columns=('Gen', 'Superpop', 'Samples', 'Gcd', 'Acc'))
gcd_acc_classGranularpop_Gnomix=pd.DataFrame(columns=('Gen', 'Granularpop', 'Samples', 'Gcd', 'Acc'))

for i, gen in enumerate([0,2,4,8]):
    curr_vcf_idx.append(load_path(osp.join(os.environ.get('OUT_PATH'), \
                 'humans/labels/data_id_1_geo/test' ,'gen_' + str(gen) ,'mat_map.npy')))
    gnomix_preds.append(load_path(osp.join(os.environ.get('OUT_PATH') ,'humans/benchmark/data_id_1_geo/test/window_size_857',\
                        'gen_' + str(gen) +'.npy')))
    gnomix_sp_preds.append(np.vectorize(gp2sp.get)(gnomix_preds[-1]))
    print(f" Metrics for generation: {gen}")
    print(f' number of samples :{curr_vcf_idx[-1].shape[0]}')
    y_coord, y_superpop, y_granular_pop=getLabelPops(curr_vcf_idx[-1], "gnomix", pop_arr, idx2coordinates, copy.deepcopy(params))
    gnomix_sp_acc = np.mean(gnomix_sp_preds[-1]==y_superpop)
    gnomix_gp_acc = np.mean(gnomix_preds[-1]==y_granular_pop)
    gnomix_gcd = gp2Gcd2(pop_arr, gnomix_preds[-1], y_coord, params.dataset_dim)
    print(f" superpop accuracy:{gnomix_sp_acc}")
    print(f" granular pop accuracy:{gnomix_gp_acc}")
    print(f" gcd:{np.mean(gnomix_gcd)}")

    gcd_acc_ByGen_Gnomix.loc[i] = [gen, curr_vcf_idx[-1].shape[0], np.mean(gnomix_gcd), gnomix_sp_acc, gnomix_gp_acc]
        
    superpop_num=np.unique(y_superpop).astype(int)
    granularpop_num=np.unique(y_granular_pop).astype(int)
        
    for k in superpop_num:
        idx=np.nonzero(y_superpop==k)
        num_samples = len(idx[0])
        if num_samples == 0:
            continue
        if gcd_classSuperpop_Gnomix.get(k) is None: 
            gcd_classSuperpop_Gnomix[k]=Running_Average()
            acc_classSuperpop_Gnomix[k]=Running_Average()
        gcd_classSuperpop_Gnomix[k].update(np.sum(gnomix_gcd[idx[0], idx[1]]), num_samples)
        acc_classSuperpop_Gnomix[k].update(np.sum(gnomix_sp_preds[-1][idx[0], idx[1]]==k), num_samples)
        gcd_acc_classSuperpop_Gnomix.loc[len(gcd_acc_classSuperpop_Gnomix)] = [gen, rev_sp_dict[k], num_samples, \
                                    np.sum(gnomix_gcd[idx[0], idx[1]]), np.sum(gnomix_sp_preds[-1][idx[0], idx[1]]==k)]
        
    for k in granularpop_num:
        idx=np.nonzero(y_granular_pop==k)
        num_samples = len(idx[0])
        if num_samples == 0:
            continue
        if gcd_classGranularpop_Gnomix.get(k) is None: 
            gcd_classGranularpop_Gnomix[k]=Running_Average()
            acc_classGranularpop_Gnomix[k]=Running_Average()
        gcd_classGranularpop_Gnomix[k].update(np.sum(gnomix_gcd[idx[0], idx[1]]), num_samples)  
        acc_classGranularpop_Gnomix[k].update(np.sum(gnomix_preds[-1][idx[0], idx[1]]==k), num_samples)
        gcd_acc_classGranularpop_Gnomix.loc[len(gcd_acc_classGranularpop_Gnomix)] = [gen, rev_gp_dict[k], num_samples, \
                                       np.sum(gnomix_gcd[idx[0], idx[1]]), np.sum(gnomix_preds[-1][idx[0], idx[1]]==k)]
    
gnomix_preds_all = np.vstack(gnomix_preds)
gnomix_sp_all = np.vstack(gnomix_sp_preds)
y_vcf_idx = np.vstack(curr_vcf_idx)
y_coord_all, y_superpop_all, y_granular_pop_all=getLabelPops(y_vcf_idx, "gnomix", pop_arr, idx2coordinates, copy.deepcopy(params))
gnomix_sp_acc_all = np.mean(gnomix_sp_all==y_superpop_all)
gnomix_gp_acc_all = np.mean(gnomix_preds_all==y_granular_pop_all)
gnomix_gcd_all = gp2Gcd(pop_arr, gnomix_preds_all, y_coord_all, params.dataset_dim)
print(f" \n For all the generations:")
print(f" superpop accuracy:{gnomix_sp_acc_all}")
print(f" granular pop accuracy:{gnomix_gp_acc_all}")
print(f" gcd:{gnomix_gcd_all}")
for i, key in enumerate(gcd_classSuperpop_Gnomix.keys()):
    gcd_acc_classSuperpop_Gnomix.loc[len(gcd_acc_classSuperpop_Gnomix)] = [-1, rev_sp_dict[key], gcd_classSuperpop_Gnomix[key].steps, \
                                    gcd_classSuperpop_Gnomix[key](), acc_classSuperpop_Gnomix[key]()]
for i, key in enumerate(gcd_classGranularpop_Gnomix.keys()):
    gcd_acc_classGranularpop_Gnomix.loc[len(gcd_acc_classGranularpop_Gnomix)] = [-1, rev_gp_dict[key], gcd_classGranularpop_Gnomix[key].steps, \
                                       gcd_classGranularpop_Gnomix[key](), acc_classGranularpop_Gnomix[key]()]
del gcd_classSuperpop_Gnomix, gcd_classGranularpop_Gnomix, acc_classGranularpop_Gnomix, acc_classSuperpop_Gnomix

print(f" gcd_acc_ByGen accuracy:{gcd_acc_ByGen_Gnomix}")
print(f" gcd_acc_classSuperpop accuracy:{gcd_acc_classSuperpop_Gnomix}")
print(f" gcd_acc_classGranularpop accuracy:{gcd_acc_classGranularpop_Gnomix}")

 Metrics for generation: 0
 number of samples :564
 superpop accuracy:0.9649032010734139
 granular pop accuracy:0.6007906843013227
 gcd:731.6905443983123
 Metrics for generation: 2
 number of samples :800
 superpop accuracy:0.9567195945945945
 granular pop accuracy:0.5746587837837838
 gcd:833.7782390661861
 Metrics for generation: 4
 number of samples :800
 superpop accuracy:0.9467804054054054
 granular pop accuracy:0.5528445945945946
 gcd:897.6635521956672
 Metrics for generation: 8
 number of samples :800
 superpop accuracy:0.9251013513513513
 granular pop accuracy:0.5177128378378378
 gcd:1050.6644337416844
 
 For all the generations:
 superpop accuracy:0.9470602181128497
 granular pop accuracy:0.558373454426086
 gcd:890.1344288270847
 gcd_acc_ByGen accuracy:   Gen  Samples          Gcd     SpAcc     GpAcc
0  0.0    564.0   731.690544  0.964903  0.600791
1  2.0    800.0   833.778239  0.956720  0.574659
2  4.0    800.0   897.663552  0.946780  0.552845
3  8.0    800.0  1050.664434  0.9

In [135]:
pd.set_option('display.max_rows', None)
gcd_acc_classGranularpop_Gnomix

Unnamed: 0,Gen,Granularpop,Samples,Gcd,Acc
0,0,British,6660,6357049.0,2864.0
1,0,Finnish,7400,3193883.0,6301.0
2,0,Southern Han Chinese,7400,8352919.0,2683.0
3,0,Dai Chinese,6660,4657200.0,3862.0
4,0,Spanish,8140,5222231.0,4966.0
5,0,Peruvian,2220,3528151.0,1834.0
6,0,Punjabi,7400,10549410.0,2982.0
7,0,Kinh Vietnamese,7400,8109880.0,3877.0
8,0,Gambian Mandinka,8140,1548212.0,7026.0
9,0,Esan,7400,2935274.0,3574.0


In [203]:
torch.cuda.empty_cache()

In [36]:
rev_gp_dict

{0: 'British',
 1: 'Finnish',
 2: 'Southern Han Chinese',
 3: 'Dai Chinese',
 4: 'Spanish',
 5: 'Peruvian',
 6: 'Punjabi',
 7: 'Kinh Vietnamese',
 8: 'Gambian Mandinka',
 9: 'Esan',
 10: 'Bengali',
 11: 'Mende',
 12: 'Sri Lankan',
 13: 'Indian Telugu',
 14: 'Yoruba',
 15: 'Han Chinese',
 16: 'Japanese',
 17: 'Luhya',
 18: 'Mexican-American',
 19: 'Tuscan',
 20: 'Gujarati',
 21: 'Brahui',
 22: 'Balochi',
 23: 'Hazara',
 24: 'Makrani',
 25: 'Sindhi',
 26: 'Pathan',
 27: 'Kalash',
 28: 'Burusho',
 29: 'Mbuti',
 30: 'Biaka',
 31: 'Bougainville',
 32: 'French',
 33: 'PapuanSepik',
 34: 'PapuanHighlands',
 35: 'Druze',
 36: 'Bedouin',
 37: 'Sardinian',
 38: 'Palestinian',
 39: 'Colombian',
 40: 'Cambodian',
 41: 'Han',
 42: 'Orcadian',
 43: 'Surui',
 44: 'Maya',
 45: 'Russian',
 46: 'Mandenka',
 47: 'Yakut',
 48: 'San',
 49: 'BantuSouthAfrica',
 50: 'Karitiana',
 51: 'Pima',
 52: 'Tujia',
 53: 'Bergamo_Italian',
 54: 'Yi',
 55: 'Miao',
 56: 'Oroqen',
 57: 'Daur',
 58: 'Mongolian',
 59: 'Hezh