# Notebook 02: Pipeline for idealized AD results

This notebook goes through the pipeline for obtaining results using the idealized AD (idealized anomaly detection) method.

In [1]:
import os
import argparse
import numpy as np

from run_ANODE_training import main as train_DE
from run_classifier_data_creation import main as create_data
from run_classifier_training import main as train_classifier
from run_ANODE_evaluation import main as eval_ANODE
from evaluation_utils import full_single_evaluation, classic_ANODE_eval, minimum_val_loss_model_evaluation

In [2]:
mode = 'idealized_AD'
data_dir = '../separated_data'
save_dir = 'idealized_AD_models'
# Shift on jet mass variables to be applied.
datashift = 0.
# Shift is not correlated to the actual mjj but randomized.
random_shift = False
# Whether to apply an (ANODE paper) fiducial cut on the data (and samples).
fiducial_cut = False
# Suppress the processing of the extra signal sample.
no_extra_signal = True
verbose = False

# ANODE model config file (.yml).
DE_config_file = '../DE_MAF_model.yml'
# 'Number of Density Estimation training epochs.'
DE_epochs = 100
# Batch size during density estimation training.
DE_batch_size = 256
# Skips the density estimation (loads existing files instead).
DF_skip = False
# Turns off the logit transform in the density estimator.
DE_no_logit = False

# File name for the density estimator.
DE_file_name = 'my_ANODE_model'

# Classifier model config file (.yml).
cf_config_file = '../classifier.yml'

# Number of classifier training epochs
cf_epochs = 100
# Number of samples to be generated. Currently the samples will be cut down to match data proportion.
cf_n_samples = 130000

# Sample the conditional from a KDE fit rather than a uniform distribution.
cf_realistic_conditional = False
# Bandwith of the KDE fit (used when realistic_conditional is selected)
cf_KDE_bandwidth = 0.01
# Add the full number of samples to the training set rather than mixing it in equal parts with data.
cf_oversampling = True
# Turns off logit tranform in the classifier.
cf_no_logit = True
# Space-separated list of pre-sampled npy files of physical variables if the sampling has been done externally. The format is 
# (mjj, mj1, dmj, tau21_1, tau21_2)
cf_external_samples = ""
# Lower boundary of signal region.
cf_SR_min = 3.3
# Upper boundary of signal region.
cf_SR_max = 3.7
# Number of independent classifier training runs.
cf_n_runs = 10
# Batch size during classifier training.
cf_batch_size = 128
# Use the conditional variable as classifier input during training.
cf_use_mjj = False
# Weight the classes according to their occurence in the training set. 
# Necessary if the training set was intentionally oversampled.'
cf_use_class_weights = True
# Central value of signal region. Must only be given for using CWoLa with weights.
cf_SR_center = 3.5
# Make use of extra background (for supervised and idealized AD).
cf_extra_bkg = True
# Define a separate validation set to pick the classifier epochs.
cf_separate_val_set = True
# Save the tensorflow model after each epoch instead of saving predictions.
cf_save_model = True
# Skips the creation of the classifier dataset (loads existing files instead).
cf_skip_create = False
# Skips the training of the classifier (loads existing files instead).
cf_skip_train = False

The history saving thread hit an unexpected error (DatabaseError('database disk image is malformed')).History will not be written to the database.


In [3]:
data_creation_kwargs = {
    'savedir': save_dir,
    'datashift': datashift,
    'data_dir': data_dir,
    'random_shift': random_shift,
    'config_file': DE_config_file,
    'verbose': verbose,
    'fiducial_cut': fiducial_cut,
    'n_samples': cf_n_samples,
    'realistic_conditional': cf_realistic_conditional,
    'KDE_bandwidth': cf_KDE_bandwidth,
    'oversampling': cf_oversampling,
    'no_extra_signal': no_extra_signal,
    'CWoLa': False,
    'supervised': False,
    'idealized_AD': True,
    'no_logit': cf_no_logit,
    'no_logit_trained': DE_no_logit,
    'external_samples': cf_external_samples,
    'SR_min': cf_SR_min,
    'SR_max': cf_SR_max,
    'extra_bkg': cf_extra_bkg,
    'separate_val_set': cf_separate_val_set,
    'ANODE_models': []
}

In [4]:
import torch
import numpy as np
from data_handler import LHCORD_data_handler, sample_handler, mix_data_samples, plot_data_sample_comparison
from density_estimator import DensityEstimator

def create_data(**kwargs):

    assert not ((not (kwargs['supervised'] or kwargs['idealized_AD'] or kwargs['CWoLa']) and\
                 kwargs['external_samples'] == "") and kwargs['ANODE_models'] == ""), (
                     "ANODE models need to be given unless CWoLa, supervised, idealized_AD or"
                     " external sampling is used.")

    # selecting appropriate device
    CUDA = torch.cuda.is_available()
    print("cuda available:", CUDA)
    device = torch.device("cuda:0" if CUDA else "cpu")

    # checking for data separation
    data_files = os.listdir(kwargs['data_dir'])
    if "innerdata_val.npy" in data_files:
        finer_data_split = True
    else:
        finer_data_split = False

    if finer_data_split:
        innerdata_train_path = [os.path.join(kwargs['data_dir'], 'innerdata_train.npy')]
        innerdata_val_path = [os.path.join(kwargs['data_dir'], 'innerdata_val.npy')]
        innerdata_test_path = [os.path.join(kwargs['data_dir'], 'innerdata_test.npy')]
        if "innerdata_extrabkg_test.npy" in data_files:
            innerdata_test_path.append(os.path.join(kwargs['data_dir'], 'innerdata_extrabkg_test.npy'))
        extrasig_path = None
        if kwargs['supervised']:
            innerdata_train_path = []
            innerdata_val_path = []
            innerdata_train_path.append(os.path.join(kwargs['data_dir'], 'innerdata_extrasig_train.npy'))
            innerdata_val_path.append(os.path.join(kwargs['data_dir'], 'innerdata_extrasig_val.npy'))
            innerdata_train_path.append(os.path.join(kwargs['data_dir'], 'innerdata_extrabkg_train.npy'))
            innerdata_val_path.append(os.path.join(kwargs['data_dir'], 'innerdata_extrabkg_val.npy'))
            extra_bkg = None
        elif kwargs['idealized_AD']:
            extra_bkg = [os.path.join(kwargs['data_dir'], 'innerdata_extrabkg_train.npy'),
                         os.path.join(kwargs['data_dir'], 'innerdata_extrabkg_val.npy')]
        else:
            extra_bkg = None

    else:
        innerdata_train_path = os.path.join(kwargs['data_dir'], 'innerdata_train.npy')
        extrasig_path = os.path.join(kwargs['data_dir'], 'innerdata_extrasig.npy')
        if kwargs['extra_bkg']:
            extra_bkg = os.path.join(kwargs['data_dir'], 'innerdata_extrabkg.npy')
        else:
            extra_bkg = None
        innerdata_val_path = None
        innerdata_test_path = os.path.join(kwargs['data_dir'], 'innerdata_test.npy')

    # data preprocessing
    data = LHCORD_data_handler(innerdata_train_path,
                               innerdata_test_path,
                               os.path.join(kwargs['data_dir'], 'outerdata_train.npy'),
                               os.path.join(kwargs['data_dir'], 'outerdata_test.npy'),
                               extrasig_path,
                               inner_extrabkg_path=extra_bkg,
                               inner_val_path=innerdata_val_path,
                               batch_size=256,
                               device=device)
    if kwargs['datashift'] != 0:
        print("applying a datashift of", kwargs['datashift'])
        data.shift_data(kwargs['datashift'], constant_shift=False, random_shift=kwargs['random_shift'],
                        shift_mj1=True, shift_dm=True, additional_shift=False)

    if kwargs['CWoLa']:
        # data preprocessing
        samples = None
        data.preprocess_CWoLa_data(fiducial_cut=kwargs['fiducial_cut'], no_logit=kwargs['no_logit'],
                                   outer_range=(kwargs['SR_min']-0.2, kwargs['SR_max']+0.2))

    else:
        # data preprocessing
        data.preprocess_ANODE_data(fiducial_cut=kwargs['fiducial_cut'],
                                   no_logit=kwargs['no_logit_trained'],
                                   no_mean_shift=kwargs['no_logit_trained'])

        # model instantiation
        if len(kwargs['external_samples']) > 0:
            model_list = None
            loaded_samples = [np.load(sample_path) for sample_path in kwargs['external_samples']]
            external_sample = np.concatenate(loaded_samples)
        else:
            model_list = []
            for model_path in kwargs['ANODE_models']:
                anode = DensityEstimator(kwargs['config_file'],
                                         eval_mode=True,
                                         load_path=model_path,
                                         device=device, verbose=kwargs['verbose'],
                                         bound=kwargs['no_logit_trained'])
                model_list.append(anode.model)
            external_sample = None

        # generate samples
        if not kwargs['supervised'] and not kwargs['idealized_AD']:
            uniform_cond = not kwargs['realistic_conditional']
            samples = sample_handler(model_list, kwargs['n_samples'], data, cond_min=kwargs['SR_min'],
                                     cond_max=kwargs['SR_max'], uniform_cond=uniform_cond,
                                     external_sample=external_sample,
                                     device=device, no_logit=kwargs['no_logit_trained'],
                                     no_mean_shift=kwargs['no_logit_trained'],
                                     KDE_bandwidth=kwargs['KDE_bandwidth'])
        else:
            samples = None

        # redo data preprocessing if the classifier should not use logit but ANODE did
        data.preprocess_ANODE_data(fiducial_cut=kwargs['fiducial_cut'], no_logit=kwargs['no_logit_trained'],
                                   no_mean_shift=kwargs['no_logit_trained'])

        # sample preprocessing
        if not kwargs['supervised'] and not kwargs['idealized_AD']:
            samples.preprocess_samples(fiducial_cut=kwargs['fiducial_cut'], no_logit=kwargs['no_logit_trained'],
                                       no_mean_shift=kwargs['no_logit_trained'])


    # sample mixing
    X_train, y_train, X_test, y_test, X_extrasig, y_extrasig = mix_data_samples(
        data, samples_handler=samples, oversampling=kwargs['oversampling'],
        savedir=kwargs['savedir'], CWoLa=kwargs['CWoLa'], supervised=kwargs['supervised'],
        idealized_AD=kwargs['idealized_AD'], separate_val_set=kwargs['separate_val_set'] or finer_data_split)

    # sanity checks
    if not kwargs['CWoLa'] and not kwargs['supervised'] and not kwargs['idealized_AD']:
        samples.sanity_check(savefig=os.path.join(kwargs['savedir'], "sanity_check"), suppress_show=True)
        samples.sanity_check_after_cuts(savefig=os.path.join(kwargs['savedir'], "sanity_check_cuts"),
                                        suppress_show=True)

    if kwargs['supervised'] or kwargs['separate_val_set'] or finer_data_split:
        X_val = X_extrasig
        if kwargs['supervised']:
            y_train = X_train[:, -1]
            y_test = X_test[:, -1]
            y_val = X_val[:, -1]
        else:
            y_val = X_val[:, -2]
        plot_data_sample_comparison(X_val, y_val, title="validation set",
                                    savefig=os.path.join(kwargs['savedir'],
                                                         "data_sample_comparison_val"),
                                    suppress_show=True)

    plot_data_sample_comparison(X_train, y_train, title="training set",
                                savefig=os.path.join(kwargs['savedir'], "data_sample_comparison_train"),
                                suppress_show=True)
    plot_data_sample_comparison(X_test, y_test, title="test set",
                                savefig=os.path.join(kwargs['savedir'], "data_sample_comparison_test"),
                                suppress_show=True)

    print("number of training data =", X_train.shape[0])
    print("number of test data =", X_test.shape[0])
    if not kwargs['no_extra_signal']:
        if kwargs['supervised'] or kwargs['separate_val_set'] or finer_data_split:
            print("number of validation data =", X_val.shape[0])
        elif extrasig_path is not None:
            print("number of extra signal data =", X_extrasig.shape[0])

In [5]:
create_data(**data_creation_kwargs)

cuda available: False
Using extra background...
number of training data = 196372
number of test data = 359883


In [7]:
classifier_kwargs = {
    'config_file': cf_config_file,
    'data_dir': save_dir,
    'savedir': save_dir,
    'verbose': verbose,
    'epochs': cf_epochs,
    'n_runs': cf_n_runs,
    'batch_size': cf_batch_size,
    'no_extra_signal': no_extra_signal,
    'use_mjj': cf_use_mjj,
    'supervised': False,
    'use_class_weights': cf_oversampling or cf_use_class_weights,
    'CWoLa': False,
    'SR_center': cf_SR_center,
    'save_model': cf_save_model,
    'separate_val_set': cf_separate_val_set
}

In [9]:
from classifier_training_utils import train_n_models, plot_classifier_losses
from evaluation_utils import minimum_val_loss_model_evaluation
import matplotlib as mpl

def train_classifier(**kwargs):

    # loading the data
    # TODO get rid of the y's since the information is fully included in X
    X_train = np.load(os.path.join(kwargs['data_dir'], 'X_train.npy'))
    X_test = np.load(os.path.join(kwargs['data_dir'], 'X_test.npy'))
    y_train = np.load(os.path.join(kwargs['data_dir'], 'y_train.npy'))
    y_test = np.load(os.path.join(kwargs['data_dir'], 'y_test.npy'))
    if kwargs['no_extra_signal'] or kwargs['supervised']:
        X_extrasig = None
    else:
        X_extrasig = np.load(os.path.join(kwargs['data_dir'], 'X_extrasig.npy'))
    if kwargs['supervised'] or kwargs['separate_val_set']:
        X_val = np.load(os.path.join(kwargs['data_dir'], 'X_validation.npy'))        
    else:
        X_val = None

    if kwargs['save_model']:
        if not os.path.exists(kwargs['savedir']):
            os.makedirs(kwargs['savedir'])
        save_model = os.path.join(kwargs['savedir'], "model")
    else:
        save_model = None

    # actual training
    loss_matris, val_loss_matris = train_n_models(
        kwargs['n_runs'], kwargs['config_file'], kwargs['epochs'], X_train, y_train, X_test, y_test,
        X_extrasig=X_extrasig, X_val=X_val, use_mjj=kwargs['use_mjj'], batch_size=kwargs['batch_size'],
        supervised=kwargs['supervised'], use_class_weights=kwargs['use_class_weights'],
        CWoLa=kwargs['CWoLa'], SR_center=kwargs['SR_center'], verbose=kwargs['verbose'],
        savedir=kwargs['savedir'], save_model=save_model)

    if kwargs['save_model']:
        minimum_val_loss_model_evaluation(kwargs['data_dir'], kwargs['savedir'], n_epochs=10,
                                use_mjj=kwargs['use_mjj'], extra_signal=not kwargs['no_extra_signal'])

    for i in range(loss_matris.shape[0]):
        plot_classifier_losses(
            loss_matris[i], val_loss_matris[i],
            savefig=save_model+"_run"+str(i)+"_loss_plot",
            suppress_show=True
        )

In [10]:
train_classifier(**classifier_kwargs)

Training model nr 0...
training epoch nr 0
training loss: 0.6932780742645264
validation loss: 0.6931359171867371
training epoch nr 1
training loss: 0.6931953430175781
validation loss: 0.6930888891220093
training epoch nr 2
training loss: 0.6930838227272034
validation loss: 0.6933363676071167
training epoch nr 3
training loss: 0.6931409239768982
validation loss: 0.6930994987487793
training epoch nr 4
training loss: 0.6930557489395142
validation loss: 0.6930897235870361
training epoch nr 5
training loss: 0.6930726170539856
validation loss: 0.6930766105651855
training epoch nr 6
training loss: 0.6930475831031799
validation loss: 0.6930747628211975
training epoch nr 7
training loss: 0.6930184364318848
validation loss: 0.6932281851768494
training epoch nr 8
training loss: 0.6930239200592041
validation loss: 0.6930232048034668
training epoch nr 9
training loss: 0.6929574608802795
validation loss: 0.6930456757545471
training epoch nr 10
training loss: 0.6928954124450684
validation loss: 0.693

training loss: 0.69047611951828
validation loss: 0.6955103278160095
training epoch nr 91
training loss: 0.6904341578483582
validation loss: 0.6953031420707703
training epoch nr 92
training loss: 0.6903659105300903
validation loss: 0.69511878490448
training epoch nr 93
training loss: 0.6902026534080505
validation loss: 0.6952337622642517
training epoch nr 94
training loss: 0.6904141902923584
validation loss: 0.6950365304946899
training epoch nr 95
training loss: 0.6902703046798706
validation loss: 0.6953434944152832
training epoch nr 96
training loss: 0.6901812553405762
validation loss: 0.6951842308044434
training epoch nr 97
training loss: 0.6902047991752625
validation loss: 0.6950226426124573
training epoch nr 98
training loss: 0.6900937557220459
validation loss: 0.6955899000167847
training epoch nr 99
training loss: 0.6901094317436218
validation loss: 0.6949779987335205
Training model nr 1...
training epoch nr 0
training loss: 0.6932125687599182
validation loss: 0.693444550037384
tra

validation loss: 0.6946978569030762
training epoch nr 81
training loss: 0.6906936168670654
validation loss: 0.6947641968727112
training epoch nr 82
training loss: 0.6906795501708984
validation loss: 0.6949481964111328
training epoch nr 83
training loss: 0.6906285881996155
validation loss: 0.6949411034584045
training epoch nr 84
training loss: 0.6905679106712341
validation loss: 0.6949799060821533
training epoch nr 85
training loss: 0.6905386447906494
validation loss: 0.6950817108154297
training epoch nr 86
training loss: 0.6906304955482483
validation loss: 0.6955601572990417
training epoch nr 87
training loss: 0.6905013918876648
validation loss: 0.6947298049926758
training epoch nr 88
training loss: 0.6904920339584351
validation loss: 0.695368230342865
training epoch nr 89
training loss: 0.6903680562973022
validation loss: 0.6955840587615967
training epoch nr 90
training loss: 0.6903074383735657
validation loss: 0.6953136920928955
training epoch nr 91
training loss: 0.6904801726341248


training loss: 0.6911621689796448
validation loss: 0.6944054961204529
training epoch nr 72
training loss: 0.691191554069519
validation loss: 0.6944600939750671
training epoch nr 73
training loss: 0.6912329792976379
validation loss: 0.6939171552658081
training epoch nr 74
training loss: 0.6911283731460571
validation loss: 0.694666862487793
training epoch nr 75
training loss: 0.6910624504089355
validation loss: 0.6944803595542908
training epoch nr 76
training loss: 0.6911155581474304
validation loss: 0.6941702961921692
training epoch nr 77
training loss: 0.6912094354629517
validation loss: 0.6943863034248352
training epoch nr 78
training loss: 0.6910942196846008
validation loss: 0.6946521997451782
training epoch nr 79
training loss: 0.6911026835441589
validation loss: 0.6941227912902832
training epoch nr 80
training loss: 0.6909779906272888
validation loss: 0.6943683624267578
training epoch nr 81
training loss: 0.6910178661346436
validation loss: 0.6942960023880005
training epoch nr 82
t

validation loss: 0.6934456825256348
training epoch nr 62
training loss: 0.6916823387145996
validation loss: 0.6935083270072937
training epoch nr 63
training loss: 0.6916695833206177
validation loss: 0.6937739253044128
training epoch nr 64
training loss: 0.6917261481285095
validation loss: 0.6936299800872803
training epoch nr 65
training loss: 0.6917118430137634
validation loss: 0.6935216784477234
training epoch nr 66
training loss: 0.6916903853416443
validation loss: 0.6936642527580261
training epoch nr 67
training loss: 0.6915190815925598
validation loss: 0.6939173340797424
training epoch nr 68
training loss: 0.6916424632072449
validation loss: 0.6939753890037537
training epoch nr 69
training loss: 0.6915319561958313
validation loss: 0.6936342120170593
training epoch nr 70
training loss: 0.6915045380592346
validation loss: 0.6943835020065308
training epoch nr 71
training loss: 0.6915168762207031
validation loss: 0.6941301822662354
training epoch nr 72
training loss: 0.691563069820404


training loss: 0.6921729445457458
validation loss: 0.6931846737861633
training epoch nr 53
training loss: 0.6921829581260681
validation loss: 0.693294107913971
training epoch nr 54
training loss: 0.6920837163925171
validation loss: 0.6931566596031189
training epoch nr 55
training loss: 0.6920790076255798
validation loss: 0.6933624148368835
training epoch nr 56
training loss: 0.692138135433197
validation loss: 0.693336009979248
training epoch nr 57
training loss: 0.6919810771942139
validation loss: 0.6935259699821472
training epoch nr 58
training loss: 0.6919556260108948
validation loss: 0.6933832168579102
training epoch nr 59
training loss: 0.6919387578964233
validation loss: 0.6933124661445618
training epoch nr 60
training loss: 0.6919011473655701
validation loss: 0.6933375000953674
training epoch nr 61
training loss: 0.6919958591461182
validation loss: 0.6933783292770386
training epoch nr 62
training loss: 0.6919989585876465
validation loss: 0.6935216784477234
training epoch nr 63
tr

validation loss: 0.6933934092521667
training epoch nr 43
training loss: 0.6923066973686218
validation loss: 0.6933528780937195
training epoch nr 44
training loss: 0.6921607851982117
validation loss: 0.693278431892395
training epoch nr 45
training loss: 0.6922233700752258
validation loss: 0.6934046745300293
training epoch nr 46
training loss: 0.6921082139015198
validation loss: 0.6935862302780151
training epoch nr 47
training loss: 0.6921659708023071
validation loss: 0.6933176517486572
training epoch nr 48
training loss: 0.6921831965446472
validation loss: 0.6932965517044067
training epoch nr 49
training loss: 0.6920884847640991
validation loss: 0.6934993267059326
training epoch nr 50
training loss: 0.6920489072799683
validation loss: 0.6938296556472778
training epoch nr 51
training loss: 0.6920707821846008
validation loss: 0.6935250163078308
training epoch nr 52
training loss: 0.6920574903488159
validation loss: 0.6934195756912231
training epoch nr 53
training loss: 0.6921036839485168


training loss: 0.6923003196716309
validation loss: 0.6930958032608032
training epoch nr 34
training loss: 0.6923609375953674
validation loss: 0.6932232975959778
training epoch nr 35
training loss: 0.6923403739929199
validation loss: 0.6931245923042297
training epoch nr 36
training loss: 0.6923523545265198
validation loss: 0.6931108832359314
training epoch nr 37
training loss: 0.6923301815986633
validation loss: 0.6933255195617676
training epoch nr 38
training loss: 0.6921852231025696
validation loss: 0.693166196346283
training epoch nr 39
training loss: 0.6921879649162292
validation loss: 0.6933891177177429
training epoch nr 40
training loss: 0.6921969056129456
validation loss: 0.6931802034378052
training epoch nr 41
training loss: 0.6922004818916321
validation loss: 0.6931835412979126
training epoch nr 42
training loss: 0.6922093033790588
validation loss: 0.6932344436645508
training epoch nr 43
training loss: 0.6921924352645874
validation loss: 0.6933586597442627
training epoch nr 44


validation loss: 0.6930369138717651
training epoch nr 24
training loss: 0.692470133304596
validation loss: 0.6931178569793701
training epoch nr 25
training loss: 0.6924535632133484
validation loss: 0.6930464506149292
training epoch nr 26
training loss: 0.6924849152565002
validation loss: 0.6929125785827637
training epoch nr 27
training loss: 0.6923926472663879
validation loss: 0.6930100917816162
training epoch nr 28
training loss: 0.6923136115074158
validation loss: 0.6931660771369934
training epoch nr 29
training loss: 0.6923646926879883
validation loss: 0.6932098865509033
training epoch nr 30
training loss: 0.6923887729644775
validation loss: 0.6930358409881592
training epoch nr 31
training loss: 0.6923261284828186
validation loss: 0.6931883692741394
training epoch nr 32
training loss: 0.6922976970672607
validation loss: 0.6931703686714172
training epoch nr 33
training loss: 0.6921765208244324
validation loss: 0.6932862401008606
training epoch nr 34
training loss: 0.6922177076339722


training loss: 0.6928734183311462
validation loss: 0.6928950548171997
training epoch nr 15
training loss: 0.692744255065918
validation loss: 0.6928768754005432
training epoch nr 16
training loss: 0.6927787661552429
validation loss: 0.6929330229759216
training epoch nr 17
training loss: 0.69269198179245
validation loss: 0.6929281949996948
training epoch nr 18
training loss: 0.6926786303520203
validation loss: 0.6928935050964355
training epoch nr 19
training loss: 0.6926341652870178
validation loss: 0.6930919885635376
training epoch nr 20
training loss: 0.6926212906837463
validation loss: 0.6932684779167175
training epoch nr 21
training loss: 0.6926352977752686
validation loss: 0.6930117607116699
training epoch nr 22
training loss: 0.6925264596939087
validation loss: 0.6930029392242432
training epoch nr 23
training loss: 0.6926083564758301
validation loss: 0.6934126615524292
training epoch nr 24
training loss: 0.69260174036026
validation loss: 0.6929417252540588
training epoch nr 25
trai

validation loss: 0.693078339099884
training epoch nr 5
training loss: 0.6929639577865601
validation loss: 0.6930943131446838
training epoch nr 6
training loss: 0.6929739713668823
validation loss: 0.6930409073829651
training epoch nr 7
training loss: 0.6929610967636108
validation loss: 0.6932111382484436
training epoch nr 8
training loss: 0.6929855346679688
validation loss: 0.6929931044578552
training epoch nr 9
training loss: 0.6929386258125305
validation loss: 0.6929919123649597
training epoch nr 10
training loss: 0.6928797364234924
validation loss: 0.6929677724838257
training epoch nr 11
training loss: 0.692887008190155
validation loss: 0.6930912137031555
training epoch nr 12
training loss: 0.6928954124450684
validation loss: 0.6929931044578552
training epoch nr 13
training loss: 0.6928495168685913
validation loss: 0.6930425763130188
training epoch nr 14
training loss: 0.692820131778717
validation loss: 0.6930117011070251
training epoch nr 15
training loss: 0.6928517818450928
validat

training loss: 0.6906145811080933
validation loss: 0.694676399230957
training epoch nr 96
training loss: 0.6904985308647156
validation loss: 0.6948760747909546
training epoch nr 97
training loss: 0.6904053688049316
validation loss: 0.6950919032096863
training epoch nr 98
training loss: 0.6903219819068909
validation loss: 0.6956607103347778
training epoch nr 99
training loss: 0.6902612447738647
validation loss: 0.6952342391014099
minimum validation loss epochs: [ 8 23 21 20 14 13 12  9 11 17]
minimum validation loss epochs: [13 27 24 23 18 26 11 22 30 16]
minimum validation loss epochs: [18 17 24 16 25 12 20 14 11 19]
minimum validation loss epochs: [16  9 18 17 15 14  6 10 19  8]
minimum validation loss epochs: [15 26 19 23 17 31 16 12 10 29]
minimum validation loss epochs: [11 21 25 18 17 16 15 13 12 29]
minimum validation loss epochs: [21  8 16 15 23 12  9 18 20 19]
minimum validation loss epochs: [16 12 19 17 26 14 18  8 13 11]
minimum validation loss epochs: [10 13 14 15 18  7 17  

In [11]:
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import torch
import sklearn
from sklearn import metrics
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from scipy.interpolate import interp1d

from evaluation_utils import load_predictions, tprs_fprs_sics, minumum_validation_loss_ensemble, compare_on_various_runs
# evalautes the ROCs and SICs of one given training with minimum validation loss epoch picking
def full_single_evaluation(data_dir, prediction_dir, n_ensemble_epochs=10, extra_signal=True,
                           sic_range=(0,20), savefig=None, suppress_show=False, return_all=False):
    X_test, y_test, predictions, val_losses = load_predictions(
        data_dir, prediction_dir, extra_signal=extra_signal)
    if predictions.shape[1]==1: ## check if ensembling done already
        min_val_loss_predictions = predictions
    else:
        min_val_loss_predictions = minumum_validation_loss_ensemble(
            predictions, val_losses, n_epochs=n_ensemble_epochs)
    tprs, fprs, sics = tprs_fprs_sics(min_val_loss_predictions, y_test, X_test)

    return compare_on_various_runs(
        [tprs], [fprs], [np.zeros(min_val_loss_predictions.shape[0])], [""],
        sic_lim=sic_range, savefig=savefig, only_median=False, continuous_colors=False,
        reduced_legend=False, suppress_show=suppress_show, return_all=return_all)

In [12]:
_ = full_single_evaluation(save_dir, save_dir, n_ensemble_epochs=10,
                           extra_signal=not no_extra_signal, sic_range=(0, 20),
                           savefig=os.path.join(save_dir, 'result_SIC'))