In [1]:
import torch
import sys
import os
import os.path as osp
import pandas as pd
import numpy as np
import itertools
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import ticker
import matplotlib as mpl
from matplotlib.colors import LogNorm
import matplotlib.patches as mpatches
from sklearn.metrics import precision_recall_curve, recall_score, precision_score, \
balanced_accuracy_score, accuracy_score
from sklearn import metrics
import math

In [2]:
os.environ['USER_PATH']='/home/users/richras/Ge2Net_Repo'
os.environ['USER_SCRATCH_PATH']="/scratch/users/richras"
os.environ['IN_PATH']='/scratch/groups/cdbustam/richras/data_in'
os.environ['OUT_PATH']='/scratch/groups/cdbustam/richras/data_out'
os.environ['LOG_PATH']='/scratch/groups/cdbustam/richras/logs/'

In [3]:
os.chdir(os.environ.get('USER_PATH'))

In [17]:
%load_ext autoreload
%autoreload 2
from src.utils.dataUtil import load_path, save_file, vcf2npy
from src.utils.modelUtil import Params, load_model, convert_coordinates
from src.utils.labelUtil import getSuperpopBins, repeat_pop_arr
from src.utils.decorators import timer
from src.models.modelSelection import modelSelect
from src.models.modelParamsSelection import Selections
from src.models import Model_A, Model_B, Model_C, BOCD
from src.models.distributions import Multivariate_Gaussian
from src.main.evaluation import eval_cp_batch, reportChangePointMetrics, t_prMetrics, cpMethod, eval_cp_matrix, \
getCpPred
from src.main.settings_model import parse_args
from src.models.Ge3Net import Ge3NetBase
import test

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# load the model and use valid data 

In [26]:
# Specify the dataset to be evaluated
# chm22 pca full dataset and model
labels_path = osp.join(os.environ['OUT_PATH'],'humans/labels/data_id_1_geo')
data_path = osp.join(os.environ['OUT_PATH'],'humans/labels/data_id_1_geo')
models_path=osp.join(os.environ['OUT_PATH'],'humans/training/Model_B_exp_id_18_data_id_1_geo/') 
dataset_type='valid'

In [27]:
config={}
config['data.labels']=labels_path 
config['data.dir']=data_path 
config['models.dir']=models_path
config['data.dataset_type']=dataset_type
config['cuda']='cuda'
config['model.loadBest']=True
json_path = osp.join(config['models.dir'], 'params.json')
assert osp.isfile(json_path), "No json configuration file found at {}".format(json_path)
params = Params(json_path)
params.rtnOuts=True
params.mc_dropout=False
params.mc_samples=100
params.cp_tol=0
results, test_dataset, model=test.main(config, params)

 device used: cuda
Loading the datasets...
Finished '_geoConvertLatLong2nVec' in 0.0750 secs
Finished 'mapping_func' in 0.3573 secs
Finished 'pop_mapping' in 0.1592 secs
Finished 'pop_mapping' in 0.1578 secs
Finished 'transform_data' in 40.2787 secs
Finished '__init__' in 111.2466 secs
Parameter count for model AuxNetwork:31747503
Parameter count for model BiRNN:60355
Parameter count for model logits_Block:2289
Total parameters:2289
best val loss metrics : {'l1_loss': 0.2579494377707586, 'mse': 0.059578730873031095, 'smooth_l1': 0.02974117295702163, 'weighted_loss': 0.2579494377707586, 'loss_main': tensor(1207.7690, device='cuda:0'), 'loss_aux': tensor(2491.7239, device='cuda:0')}
at epoch : 71
train loss metrics: {'l1_loss': 0.17245596720737696, 'mse': 0.0362445987646263, 'smooth_l1': 0.01810576623101305, 'weighted_loss': 0.17245596720737696, 'loss_main': 768.0802761591885, 'loss_aux': 2262.628921121907}
best val cp metrics : {'loss_cp': tensor(0.0896, device='cuda:0'), 'prMetrics': O


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



Finished '_evaluateAccuracy' in 0.0277 secs
Finished 'getBalancedClassGcd' in 0.0288 secs
Finished 'getExtraGcdMetrics' in 0.0292 secs
Finished '_evaluate' in 0.0573 secs
Finished '_evaluateAccuracy' in 0.0049 secs
Finished 'getBalancedClassGcd' in 0.0284 secs
Finished 'getExtraGcdMetrics' in 0.0287 secs
Finished '_evaluate' in 0.0339 secs
Finished '_evaluateAccuracy' in 0.0050 secs
Finished 'getBalancedClassGcd' in 0.0300 secs
Finished 'getExtraGcdMetrics' in 0.0303 secs
Finished '_evaluate' in 0.0358 secs
Finished '_evaluateAccuracy' in 0.0049 secs
Finished 'getBalancedClassGcd' in 0.0338 secs
Finished 'getExtraGcdMetrics' in 0.0341 secs
Finished '_evaluate' in 0.0394 secs
Finished '_evaluateAccuracy' in 0.0049 secs
Finished 'getBalancedClassGcd' in 0.0593 secs
Finished 'getExtraGcdMetrics' in 0.0596 secs
Finished '_evaluate' in 0.0651 secs
Finished '_evaluateAccuracy' in 0.0051 secs
Finished 'getBalancedClassGcd' in 0.0595 secs
Finished 'getExtraGcdMetrics' in 0.0598 secs
Finished '

In [28]:
def load_model(model_path, model_init, optimizer=None):
    if not osp.exists(model_path):
        # ToDo look into the raise exception error not
        # coming from BaseException
        print(f'{model_path} does not exist')
        raise (f'{model_path} does not exist')
        
    checkpoint = torch.load(model_path)

    print(f"best val loss metrics : {checkpoint['val_accr']['t_accr']}")
    print(f"at epoch : {checkpoint['epoch']}")
    print(f"train loss metrics: {checkpoint['train_accr']['t_accr']}")

    print(f"best val cp metrics : {checkpoint['val_accr']['t_cp_accr']}")
    print(f"train cp metrics: {checkpoint['train_accr']['t_cp_accr']}")

    print(f"best val sp metrics : {checkpoint['val_accr']['t_sp_accr']}")
    print(f"train sp metrics: {checkpoint['train_accr']['t_sp_accr']}")

    print(f"best val balanced gcd metrics : {checkpoint['val_accr']['t_balanced_gcd']}")
    print(f"train balanced gcd metrics: {checkpoint['train_accr']['t_balanced_gcd']}")
    
    print(f"checkpoint['model_state_dict']:{checkpoint['model_state_dict']}")
    model_init.load_state_dict(checkpoint['model_state_dict'])
         
    return model_init

In [None]:
# model_path = osp.join(config['models.dir'], 'models_dir')
# modelOption=modelSelect.get_selection()
# option = Selections.get_selection()
# criterion = option['loss'][params.criteria](reduction='sum', alpha=params.criteria_alpha, geography=params.geography)
# cp_criterion=option['cpMetrics']['loss_cp']
# model_init = modelOption['models'][params.model](params, criterion, cp_criterion)
# if config['model.loadBest']:
#     model = load_model(''.join([str(model_path),'/best.pt']), model_init)
# else:
#     model= load_model(''.join([str(model_path),'/last.pt']), model_init)


In [None]:
model.to(params.device)
next(model.parameters()).is_cuda

In [None]:
results.t_accr, results.t_cp_accr

In [None]:
results.t_balanced_gcd

In [25]:
results.t_out.coord_main=results.t_out.coord_main.mean(0)

In [None]:
y_pred = results.t_out.coord_main
n_vec_dim=y_pred.shape[-1]
data_tensor = torch.tensor(y_pred).float()
batch_size_cpd = data_tensor.shape[0]
mu_prior = torch.zeros((batch_size_cpd, 1,n_vec_dim))
mean_var=torch.mean(torch.var(data_tensor, dim =1),dim=0).unsqueeze(0)
cov_prior = (mean_var.repeat(batch_size_cpd,1).unsqueeze(1)* torch.eye(n_vec_dim)).reshape(batch_size_cpd,1,n_vec_dim,n_vec_dim)
cov_x = cov_prior
likelihood_model = Multivariate_Gaussian(mu_prior, cov_prior, cov_x)
T = params.n_win
model_cpd = BOCD.BOCD(None, T, likelihood_model, batch_size_cpd)
_,_,_,_=model_cpd.run_recursive(data_tensor, 'cpu')

In [None]:
granular_pop_dict = load_path(osp.join(labels_path, 'granular_pop.pkl'), en_pickle=True)
superop_dict=load_path(osp.join(labels_path, 'superpop.pkl'), en_pickle=True)
pop_sample_map=pd.read_csv(osp.join(labels_path, 'pop_sample_map.tsv'),sep="\t")
pop_arr=repeat_pop_arr(pop_sample_map)
cp_target=test_dataset.data['cps']
seqlen=cp_target.shape[1]
rev_pop_dict={v:k for k,v in granular_pop_dict.items()}

In [None]:
# index = 2500
# index=2620
# index=2650
# index=3000
# index=3100
# index=2300
# index=2320
# index=2338 #interesting Mansi pop
# index=2339 # nice example showing Karitiana
# index=2344 #interesting clusters of Mongolia
# index=2346 # interesting African diversity
# index=2355#int. Iranian sample
# index=2366# Piapioco and Pima both
# index=2388#int example
# index=2398#Biaka/Luhya
# index=2403#Quechua
# index=2414#incorrect Khomani San
# index=2421#uyugur
# index=2465#balochi similar to MAkrani slave trade
# index=2482# interesting Indo euro
# index=np.random.choice(idxOfInterest)
# index=830
index=510
print(index)
true_cpsSample=cp_target[index,:].detach().cpu().numpy()
y_predSample=results.t_out.coord_main[index,:]
y_trueSample=test_dataset.data['y'][index,:].detach().cpu().numpy()
granularpopSample=test_dataset.data['granular_pop'][index,:].detach().cpu().numpy()
namesSample=[rev_pop_dict[i] for i in granularpopSample.astype(int)]

In [None]:
y_predSample.shape

In [None]:
pred_cps=getCpPred(cpMethod.gradient.name, y_predSample, 0.1, 1, len(true_cpsSample))
pred_cps=pred_cps.squeeze(0).detach().cpu().numpy()# squeeze the batch dimension of 1

In [None]:
predBOCDSample=model_cpd.cp[index,:]
pred_cps_BOCD=getCpPred(cpMethod.BOCD.name, predBOCDSample, 5.0, 1, len(true_cpsSample))
pred_cps_BOCD=pred_cps_BOCD.squeeze(0).detach().cpu().numpy()# squeeze the batch dimension of 1

In [None]:
cpIdx=np.nonzero(pred_cps_BOCD)[0]
remWin=3
for i in cpIdx:
    low=max(0,i-remWin)
    high=min(seqlen, i+remWin)
    pred_cps_BOCD[low:high]=1

In [None]:
mappedSpArr=getSuperpopBins(labels_path, y_predSample.reshape(-1,3))
# mappedSpArr=mappedSpArr.squeeze(1)

In [None]:
y_predsTrue=test_dataset.data['y'].detach().cpu().numpy().reshape(-1,3)
superpopsTrue=test_dataset.data['superpop'].detach().cpu().numpy().reshape(-1,)

In [None]:
@timer
def plot_sample(granularPopSample, y_predSample, y_trueSample, **kwargs):
    backgroundAxis=kwargs.get('backgroundAxis')
    y_predCps=kwargs.get('y_predCps')
    y_preds=kwargs.get('y_preds')
    superpops=kwargs.get('superpops')
    cpShow=kwargs.get('cpShow')
    pred_cps=kwargs.get('pred_cps')
    mappedSpArr=kwargs.get('mappedSpArr')
    fig, ax = plt.subplots(figsize=(12,14))
    
    gs1 = fig.add_gridspec(nrows=3, ncols=1, height_ratios=[13,1,1])
    ax1=fig.add_subplot(gs1[0],projection='3d')
    ax2=fig.add_subplot(gs1[1])
    ax3=fig.add_subplot(gs1[2])
    for axis in [ax, ax2, ax3]:
        axis.set_yticks([])
        axis.spines['top'].set_color('none')
        axis.spines['left'].set_color('none')
        axis.spines['right'].set_color('none')
    ax.set_xticks([])
    plt.subplots_adjust(hspace=0.01)
    
    if backgroundAxis is not None:
        lgnd, colorsPop_sp_dict =plot_all(ax1, y_preds, superpops, cpShow=False)
    else:
        continentaPops=list(superop_dict.values())
        colors_Sp = sns.color_palette("bright", 10)
        del colors_Sp[1]
        del colors_Sp[4]
        del colors_Sp[5]
        colorsPop_sp_dict = {k:v for k,v in zip(continentaPops, colors_Sp)}
        colorsPop_sp_dict[-1]=(0.7,0.7,0.7) # grey color
        patches=[]
        for k, val in superop_dict.items():
            patches.append(mpatches.Patch(color = colorsPop_sp_dict[val], label = k))
        lgnd = ax1.legend(handles=patches, loc="upper right",fontsize=15)
    ax1.add_artist(lgnd)    
    colors_pop = sns.color_palette("rainbow", len(np.unique(granularPopSample)))
    colors_pop_dict = {k:v for k,v in zip(np.unique(granularPopSample), colors_pop)}

    alpha=[1]*len(y_predSample)
    if cpShow is None:
#         alpha=(pred_cps==0).astype(float)
        unknownIdx=np.nonzero(pred_cps)[0]
        granularPopNames=["UNK" if i in unknownIdx else granularPopSample[i] for i in range(len(granularPopSample))]
        colors_pop_dict["UNK"]=(0.9,0.9,0.9) # grey color
        mappedSpArr[unknownIdx]=-1
    
    ax1.scatter(y_predSample[:,0], y_predSample[:,1], y_predSample[:,2], \
               color = [colors_pop_dict[x]+(y,) for x,y in zip(granularPopNames, alpha)], s=50, zorder=0) 
    ax1.scatter(y_trueSample[:,0], y_trueSample[:,1], y_trueSample[:,2], \
               color = [colors_pop_dict[x] for x in granularPopSample], marker='X', s=200, zorder=0)
    
    for axis in [ax1.xaxis, ax1.yaxis, ax1.zaxis]:
        axis.set_pane_color((1.0, 1.0, 1.0, 0.0))
        axis._axinfo["grid"]['color'] = (1,1,1,0)
    patches=[]
    for i, val in enumerate(np.unique(granularPopSample)):
        patches.append(mpatches.Patch(color = colors_pop_dict[val], label = val))
    patches.append(mpatches.Patch(color = (0.9,0.9,0.9), label = "UNK"))
    ax1.legend(handles=patches, loc="upper left", fontsize=20)
    
    #chm plot ground truth
    ax2.scatter(np.arange(len(y_trueSample)),np.ones(len(y_trueSample)),\
                color=[colors_pop_dict[x] for x in granularPopSample], marker='s')
    ax2.set_title('Labeled Chromosome22', fontsize=30, y=0.3)
    #chm plot of predictions
    ax3.scatter(np.arange(len(mappedSpArr)),np.ones(len(mappedSpArr)),\
                color=[colorsPop_sp_dict[x]+(y,) for x,y in zip(mappedSpArr,alpha)], marker='s')
    ax3.set_title('Predicted Chromosome22', fontsize=30, y=0.3)
    
    for ax in [ax2, ax3]:
        ax.xaxis.set_major_locator(ticker.MultipleLocator(1.0))
        ax.xaxis.set_minor_locator(ticker.MultipleLocator(4))
        ax.xaxis.set_ticks_position('bottom')
        ax.tick_params(which='major', width=2, length=10, labelsize=25)
        ax.tick_params(which='minor', width=2, length=10, labelsize=10)
        ax.set_xlim(0, 316)
        ax.set_ylim(0.99,1.09)
        positions = [0, 160, 300]
        x_labels = [0, 150000, 317000]
        ax.xaxis.set_major_locator(ticker.FixedLocator(positions))
        ax.xaxis.set_major_formatter(ticker.FixedFormatter(x_labels))
    
    fig.tight_layout()
    plt.show()
    plt.close('all')

In [None]:
plot_sample(namesSample, y_predSample, y_trueSample, pred_cps=pred_cps_BOCD, mappedSpArr=mappedSpArr)

In [None]:
def plot_coordinates_map(granularPopNames, pred_coord, lbl_coord, **kwargs):
    """
    plotly plot for coordinates on a map
    label: target label vector (int) for the specific sample
    data_coordinates: lat, long for the specific sample
    rev_pop_order: dict with keys as target label ints and values
    as granular population name
    """
    cpShow=kwargs.get('cpShow')
    pred_cps=kwargs.get('pred_cps') 
    alpha=[1.0]*len(lbl_coord)
    if cpShow:
        unknownIdx=np.nonzero(pred_cps)[0]
        granularPopNames=["UNK" if i in unknownIdx else granularPopNames[i] for i in range(len(granularPopNames))]
    fig = go.Figure(go.Scattergeo())
    colors_pop = sns.color_palette("rainbow", len(np.unique(granularPopNames)))
    
    colors_pop_dict = {k:v for k,v in zip(np.unique(granularPopNames), colors_pop)}
    colors_pop_dict['UNK']=(188,188,188) # grey color
    fig.add_trace(go.Scattergeo(lon=pred_coord[:,1], lat=pred_coord[:,0], text = granularPopNames,\
    marker_color=['rgba'+str(colors_pop_dict[x]+(y,)) for x,y in zip(granularPopNames, alpha)]))
    fig.update_traces(marker_size = 5)
    fig.add_trace(go.Scattergeo(lon=lbl_coord[:,1], lat=lbl_coord[:,0], marker = dict(symbol = 'square'), \
                                text = granularPopNames))
    fig.update_traces(marker_size = 5)

    fig.show()
    plt.show()
    plt.close('all')

In [None]:
y_trueSample.shape

In [None]:
lbl_coord=convert_coordinates(y_trueSample[:,0], y_trueSample[:,1], y_trueSample[:,2])
pred_coord=convert_coordinates(y_predSample[:,0], y_predSample[:,1], y_predSample[:,2])
granularPopNames=namesSample
plot_coordinates_map(granularPopNames, pred_coord, lbl_coord, pred_cps=pred_cps_BOCD, cpShow=True)

In [None]:
idxOfInterest=torch.unique(torch.nonzero(test_dataset.data['granular_pop']==granular_pop_dict["San"])[:,0])

In [None]:
idxOfInterest

In [None]:
@timer
def plot_all(ax, y_preds, superpops, **kwargs):
    cpShow=kwargs.get('cpShow')
    pred_cps=kwargs.get('pred_cps')    
    continentaPops=list(superop_dict.values())
    colors_pop = sns.color_palette("bright", 10)
    del colors_pop[1]
    del colors_pop[4]
    del colors_pop[5]
    colors_pop_dict = {k:v for k,v in zip(continentaPops, colors_pop)}
    
    alpha=[0.03]*len(y_preds)
    if cpShow is None:
        alpha=(pred_cps==0).astype(float)
    
    ax.scatter(y_preds[:,0], y_preds[:,1], y_preds[:,2], \
               color = [colors_pop_dict[x]+(y,) for x,y in zip(superpops, alpha)], marker=".", s=0.05,zorder=-1) 
    
    patches=[]
    for k, val in superop_dict.items():
        patches.append(mpatches.Patch(color = colors_pop_dict[val], label = k))
    lgnd = ax.legend(handles=patches, loc="upper right", fontsize=15)
    return lgnd, colors_pop_dict


In [None]:
y_preds=results.t_out.coord_main
y_predCps=getCpPred(cpMethod.gradient.name, y_predSample, 0.1, y_preds.shape[0], y_preds.shape[1])
y_predCps=y_predCps.detach().cpu().numpy().reshape(-1,)
y_preds=y_preds.reshape(-1,3)
superpops=test_dataset.data['superpop'].detach().cpu().numpy().reshape(-1,)