In [1]:
import torch
import sys
import os
import os.path as osp
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import ticker
import matplotlib as mpl
from sklearn.metrics import precision_recall_curve

In [2]:
root_dir = "/home/users/richras/Ge2Net_Repo"
os.chdir(root_dir)

In [3]:
!./ini.sh

set environment variables
All done


In [4]:
os.environ['USER_PATH']='/home/users/richras/Ge2Net_Repo'
os.environ['USER_SCRATCH_PATH']="/scratch/users/richras"
os.environ['IN_PATH']='/scratch/groups/cdbustam/richras/data_in'
os.environ['OUT_PATH']='/scratch/groups/cdbustam/richras/data_out'
os.environ['LOG_PATH']='/scratch/groups/cdbustam/richras/logs/'

In [5]:
%load_ext autoreload
%autoreload 2
from src.utils.dataUtil import load_path, save_file, vcf2npy, get_recomb_rate, interpolate_genetic_pos, form_windows,\
getValueBySelection
from src.utils.modelUtil import Params, load_model
from src.utils.decorators import timer
from src.utils.labelUtil import repeat_pop_arr, getSuperpopBins
from src.models import AuxiliaryTask, LSTM, Attention, BasicBlock, Model_A, Model_B, Model_C
from src.models.distributions import Multivariate_Gaussian
from src.main.evaluation import eval_cp_batch, reportChangePointMetrics, t_prMetrics, cpMethod
from src.main.settings_model import parse_args, MODEL_CLASS
import test

In [19]:
# Specify the dataset to be evaluated
# chm22 pca full dataset and model
labels_path = osp.join(os.environ['OUT_PATH'],'humans/data/data_id_3_pca/labels')
data_path = osp.join(os.environ['OUT_PATH'],'humans/data/data_id_3_pca')
models_path=osp.join(os.environ['OUT_PATH'],'humans/training/Model_B_exp_id_4_data_id_3_pca/') 
dataset_type='test'

In [20]:
# load the params file and run test.py

In [21]:
config={}
config['data.labels']=labels_path 
config['data.dir']=data_path 
config['models.dir']=models_path
config['data.dataset_type']=dataset_type
config['cuda']='cuda'
json_path = osp.join(config['models.dir'], 'params.json')
assert osp.isfile(json_path), "No json configuration file found at {}".format(json_path)
params = Params(json_path)
params.rtnOuts=True
params.mc_dropout=True
params.mc_samples=100
results, test_dataset=test.main(config, params)

 device used: cuda
Loading the datasets...
Finished 'transform_data' in 17.0328 secs
Finished '__init__' in 78.6597 secs
model ['Model_B.model_B'] : AuxiliaryTask.AuxNetwork
model ['Model_B.model_B'] : LSTM.BiRNN
model ['Model_B.model_B'] : BasicBlock.logits_Block
best val loss metrics : {'gcd': None, 'mse': 0.20966243415029062, 'smooth_l1': 0.09892145020944967, 'weighted_loss': 0.3524936370038568, 'loss_main': 0.3524936370038568, 'loss_aux': 0.8679943420903375}
at epoch : 94
train loss metrics: {'gcd': None, 'mse': 0.09787766041908866, 'smooth_l1': 0.046797061544382804, 'weighted_loss': 0.2386510676702706, 'loss_main': 0.21970770299296133, 'loss_aux': 0.8715258090012372}
best val cp metrics : {'loss_cp': 1.3071895907791774e-06, 'Precision': 1.0, 'Recall': 0.4043835616438356, 'Accuracy': 0.9863497688086083, 'A_major': 1.0, 'BalancedAccuracy': 0.7021917808219178}
train cp metrics: {'loss_cp': 2.140980156204126e-06, 'Precision': 1.0, 'Recall': 0.6669736034376919, 'Accuracy': 0.9923624061

In [22]:
test_dataset.data['cps'].shape

torch.Size([2964, 317])

In [23]:
results.t_out.cp_logits.shape, results.t_out.coord_main.shape, results.t_out.y_var.shape

((2964, 317, 1), (2964, 317, 3), (2964, 317, 3))

In [24]:
%%HTML
<style type="text/css">
table.dataframe td, table.dataframe th {
    border: 1px  black solid !important;
  color: black !important;
}
</style>

In [25]:
@timer
def prMetricsByThresh(method_name, cp_pred_raw, cp_target, steps):
    num_samples = cp_target.shape[0]
    seqlen = cp_target.shape[1]
    min_prob = 0.0
    max_prob = 1.0
    increment = (max_prob - min_prob)/steps
    df=pd.DataFrame(columns=list(t_prMetrics._fields)+['thresh'])
    for thresh in np.arange(min_prob, max_prob + increment, increment):
        prMetrics = reportChangePointMetrics(method_name, cp_pred_raw, cp_target, seqlen, thresh)
        prMetrics['thresh']=thresh
        df=df.append(prMetrics, ignore_index=True)
    return df

In [26]:
df_nn=prMetricsByThresh(cpMethod.neural_network.name, torch.tensor(results.t_out.cp_logits).float(), \
                        test_dataset.data['cps'].unsqueeze(2).float(), 20)
df_nn

Finished 'prMetricsByThresh' in 15.0383 secs


Unnamed: 0,Precision,Recall,Accuracy,A_major,BalancedAccuracy,thresh
0,0.017822,1.0,0.048657,0.032968,0.516484,0.0
1,0.042087,0.92218,0.655469,0.651624,0.786902,0.05
2,0.046379,0.884937,0.706812,0.704686,0.794812,0.1
3,0.046747,0.879263,0.711363,0.709418,0.79434,0.15
4,0.046794,0.877395,0.712335,0.710438,0.793917,0.2
5,0.047094,0.872191,0.716549,0.714854,0.793523,0.25
6,0.059157,0.749204,0.819975,0.822863,0.786033,0.3
7,0.071257,0.594242,0.89556,0.903193,0.748717,0.35
8,0.067497,0.543478,0.90392,0.912849,0.728163,0.4
9,0.996626,0.269286,0.983281,0.999989,0.634638,0.45


In [27]:
precision, recall, thresholds = precision_recall_curve(test_dataset.data['cps'].detach().cpu().numpy().flatten(), \
                                                       results.t_out.cp_logits.flatten())

In [28]:
precision, recall, thresholds

(array([0.01673608, 0.01673503, 0.01673505, ..., 0.5       , 0.        ,
        1.        ]),
 array([1.00000000e+00, 9.99936407e-01, 9.99936407e-01, ...,
        6.35930048e-05, 0.00000000e+00, 0.00000000e+00]),
 array([-4.738904  , -4.73328   , -4.7315416 , ..., -0.13503137,
        -0.12965311, -0.09035808], dtype=float32))

In [29]:
print(prMetricsByThresh(cpMethod.gradient.name, torch.tensor(results.t_out.coord_main).float(), \
                        test_dataset.data['cps'].unsqueeze(2).float(), 20))

Finished 'prMetricsByThresh' in 9.0654 secs
    Precision    Recall  Accuracy   A_major  BalancedAccuracy  thresh
0    0.017881  1.000000  0.051772  0.036136          0.518068    0.00
1    0.334019  0.839978  0.951593  0.954127          0.897053    0.05
2    0.529031  0.806928  0.976245  0.980137          0.893532    0.10
3    0.649115  0.788282  0.983920  0.988459          0.888371    0.15
4    0.717262  0.771951  0.987406  0.992411          0.882181    0.20
5    0.770964  0.759673  0.989247  0.994637          0.877155    0.25
6    0.807340  0.748504  0.990277  0.995981          0.872242    0.30
7    0.836436  0.738577  0.990903  0.996868          0.867722    0.35
8    0.860834  0.727823  0.991345  0.997576          0.862700    0.40
9    0.880875  0.719159  0.991674  0.998132          0.858646    0.45
10   0.893181  0.710182  0.991756  0.998453          0.854318    0.50
11   0.904334  0.700741  0.991811  0.998731          0.849736    0.55
12   0.914051  0.688568  0.991746  0.998944   

In [30]:
print(prMetricsByThresh(cpMethod.mc_dropout.name, torch.tensor(results.t_out.y_var).float(), \
                        test_dataset.data['cps'].unsqueeze(2).float(), 20))

Finished 'prMetricsByThresh' in 7.4925 secs
    Precision    Recall  Accuracy   A_major  BalancedAccuracy  thresh
0    0.017822  1.000000  0.048657  0.032968          0.516484    0.00
1    0.438972  0.710644  0.920012  0.924702          0.817673    0.05
2    0.525303  0.566634  0.951765  0.960339          0.763486    0.10
3    0.583514  0.457536  0.963347  0.974736          0.716136    0.15
4    0.637952  0.390869  0.970277  0.983471          0.687170    0.20
5    0.705055  0.345043  0.974572  0.989003          0.667023    0.25
6    0.762848  0.315954  0.977040  0.992311          0.654132    0.30
7    0.811474  0.297258  0.978710  0.994507          0.645882    0.35
8    0.853438  0.284836  0.979782  0.995925          0.640381    0.40
9    0.883706  0.278528  0.980584  0.996940          0.637734    0.45
10   0.908890  0.274868  0.981208  0.997689          0.636278    0.50
11   0.924831  0.272223  0.981636  0.998203          0.635213    0.55
12   0.940855  0.271001  0.981983  0.998603   

In [31]:
results.t_out.y_var

array([[[3.20023799e-04, 5.42979408e-03, 9.68458597e-03],
        [1.42561854e-04, 1.34131662e-03, 5.84353600e-03],
        [8.83527828e-05, 1.64349936e-03, 9.06498730e-03],
        ...,
        [9.61053593e-05, 1.60546741e-04, 1.61960884e-03],
        [1.12768328e-04, 2.53641745e-04, 1.13692658e-03],
        [2.74925027e-04, 2.00151629e-03, 4.89353249e-03]],

       [[4.45431302e-04, 5.46904630e-04, 2.40227149e-04],
        [2.16418575e-05, 2.57266602e-05, 1.19130855e-04],
        [9.73592432e-06, 1.25187053e-05, 1.06627405e-04],
        ...,
        [4.51691267e-05, 3.00417469e-05, 2.72930221e-04],
        [4.64569603e-05, 2.62130779e-05, 3.05910653e-04],
        [6.64300896e-05, 1.31366367e-04, 5.32812264e-04]],

       [[2.90593161e-04, 6.66218519e-04, 6.20053441e-04],
        [7.95598462e-05, 2.86212366e-04, 5.55668317e-04],
        [7.90640770e-05, 2.30348422e-04, 2.72534584e-04],
        ...,
        [1.38824980e-04, 4.00882069e-04, 7.13199915e-05],
        [1.38684321e-04, 4.05