# Notebook 02: Pipeline for supervised results

This notebook goes through the pipeline for obtaining results using the supervised method.

In [1]:
import os
import argparse
import numpy as np

from run_ANODE_training import main as train_DE
from run_classifier_data_creation import main as create_data
from run_classifier_training import main as train_classifier
from run_ANODE_evaluation import main as eval_ANODE
from evaluation_utils import full_single_evaluation, classic_ANODE_eval, minimum_val_loss_model_evaluation

In [2]:
mode = 'supervised'
data_dir = '../separated_data'
save_dir = 'supervised_models'
# Shift on jet mass variables to be applied.
datashift = 0.
# Shift is not correlated to the actual mjj but randomized.
random_shift = False
# Whether to apply an (ANODE paper) fiducial cut on the data (and samples).
fiducial_cut = False
# Suppress the processing of the extra signal sample.
no_extra_signal = True
verbose = False

# ANODE model config file (.yml).
DE_config_file = '../DE_MAF_model.yml'
# 'Number of Density Estimation training epochs.'
DE_epochs = 100
# Batch size during density estimation training.
DE_batch_size = 256
# Skips the density estimation (loads existing files instead).
DF_skip = False
# Turns off the logit transform in the density estimator.
DE_no_logit = False

# File name for the density estimator.
DE_file_name = 'my_ANODE_model'

# Classifier model config file (.yml).
cf_config_file = '../classifier.yml'

# Number of classifier training epochs
cf_epochs = 100
# Number of samples to be generated. Currently the samples will be cut down to match data proportion.
cf_n_samples = 130000

# Sample the conditional from a KDE fit rather than a uniform distribution.
cf_realistic_conditional = False
# Bandwith of the KDE fit (used when realistic_conditional is selected)
cf_KDE_bandwidth = 0.01
# Add the full number of samples to the training set rather than mixing it in equal parts with data.
cf_oversampling = True
# Turns off logit tranform in the classifier.
cf_no_logit = True
# Space-separated list of pre-sampled npy files of physical variables if the sampling has been done externally. The format is 
# (mjj, mj1, dmj, tau21_1, tau21_2)
cf_external_samples = ""
# Lower boundary of signal region.
cf_SR_min = 3.3
# Upper boundary of signal region.
cf_SR_max = 3.7
# Number of independent classifier training runs.
cf_n_runs = 10
# Batch size during classifier training.
cf_batch_size = 128
# Use the conditional variable as classifier input during training.
cf_use_mjj = False
# Weight the classes according to their occurence in the training set. 
# Necessary if the training set was intentionally oversampled.'
cf_use_class_weights = False
# Central value of signal region. Must only be given for using CWoLa with weights.
cf_SR_center = 3.5
# Make use of extra background (for supervised and idealized AD).
cf_extra_bkg = True
# Define a separate validation set to pick the classifier epochs.
cf_separate_val_set = True
# Save the tensorflow model after each epoch instead of saving predictions.
cf_save_model = True
# Skips the creation of the classifier dataset (loads existing files instead).
cf_skip_create = False
# Skips the training of the classifier (loads existing files instead).
cf_skip_train = False

In [3]:
data_creation_kwargs = {
    'savedir': save_dir,
    'datashift': datashift,
    'data_dir': data_dir,
    'random_shift': random_shift,
    'config_file': DE_config_file,
    'verbose': verbose,
    'fiducial_cut': fiducial_cut,
    'n_samples': cf_n_samples,
    'realistic_conditional': cf_realistic_conditional,
    'KDE_bandwidth': cf_KDE_bandwidth,
    'oversampling': cf_oversampling,
    'no_extra_signal': no_extra_signal,
    'CWoLa': False,
    'supervised': True,
    'idealized_AD': False,
    'no_logit': cf_no_logit,
    'no_logit_trained': DE_no_logit,
    'external_samples': cf_external_samples,
    'SR_min': cf_SR_min,
    'SR_max': cf_SR_max,
    'extra_bkg': cf_extra_bkg,
    'separate_val_set': cf_separate_val_set,
    'ANODE_models': []
}

In [4]:
import torch
import numpy as np
from data_handler import LHCORD_data_handler, sample_handler, mix_data_samples, plot_data_sample_comparison
from density_estimator import DensityEstimator

def create_data(**kwargs):

    assert not ((not (kwargs['supervised'] or kwargs['idealized_AD'] or kwargs['CWoLa']) and\
                 kwargs['external_samples'] == "") and kwargs['ANODE_models'] == ""), (
                     "ANODE models need to be given unless CWoLa, supervised, idealized_AD or"
                     " external sampling is used.")

    # selecting appropriate device
    CUDA = torch.cuda.is_available()
    print("cuda available:", CUDA)
    device = torch.device("cuda:0" if CUDA else "cpu")

    # checking for data separation
    data_files = os.listdir(kwargs['data_dir'])
    if "innerdata_val.npy" in data_files:
        finer_data_split = True
    else:
        finer_data_split = False

    if finer_data_split:
        innerdata_train_path = [os.path.join(kwargs['data_dir'], 'innerdata_train.npy')]
        innerdata_val_path = [os.path.join(kwargs['data_dir'], 'innerdata_val.npy')]
        innerdata_test_path = [os.path.join(kwargs['data_dir'], 'innerdata_test.npy')]
        if "innerdata_extrabkg_test.npy" in data_files:
            innerdata_test_path.append(os.path.join(kwargs['data_dir'], 'innerdata_extrabkg_test.npy'))
        extrasig_path = None
        if kwargs['supervised']:
            innerdata_train_path = []
            innerdata_val_path = []
            innerdata_train_path.append(os.path.join(kwargs['data_dir'], 'innerdata_extrasig_train.npy'))
            innerdata_val_path.append(os.path.join(kwargs['data_dir'], 'innerdata_extrasig_val.npy'))
            innerdata_train_path.append(os.path.join(kwargs['data_dir'], 'innerdata_extrabkg_train.npy'))
            innerdata_val_path.append(os.path.join(kwargs['data_dir'], 'innerdata_extrabkg_val.npy'))
            extra_bkg = None
        elif kwargs['idealized_AD']:
            extra_bkg = [os.path.join(kwargs['data_dir'], 'innerdata_extrabkg_train.npy'),
                         os.path.join(kwargs['data_dir'], 'innerdata_extrabkg_val.npy')]
        else:
            extra_bkg = None

    else:
        innerdata_train_path = os.path.join(kwargs['data_dir'], 'innerdata_train.npy')
        extrasig_path = os.path.join(kwargs['data_dir'], 'innerdata_extrasig.npy')
        if kwargs['extra_bkg']:
            extra_bkg = os.path.join(kwargs['data_dir'], 'innerdata_extrabkg.npy')
        else:
            extra_bkg = None
        innerdata_val_path = None
        innerdata_test_path = os.path.join(kwargs['data_dir'], 'innerdata_test.npy')

    # data preprocessing
    data = LHCORD_data_handler(innerdata_train_path,
                               innerdata_test_path,
                               os.path.join(kwargs['data_dir'], 'outerdata_train.npy'),
                               os.path.join(kwargs['data_dir'], 'outerdata_test.npy'),
                               extrasig_path,
                               inner_extrabkg_path=extra_bkg,
                               inner_val_path=innerdata_val_path,
                               batch_size=256,
                               device=device)
    if kwargs['datashift'] != 0:
        print("applying a datashift of", kwargs['datashift'])
        data.shift_data(kwargs['datashift'], constant_shift=False, random_shift=kwargs['random_shift'],
                        shift_mj1=True, shift_dm=True, additional_shift=False)

    if kwargs['CWoLa']:
        # data preprocessing
        samples = None
        data.preprocess_CWoLa_data(fiducial_cut=kwargs['fiducial_cut'], no_logit=kwargs['no_logit'],
                                   outer_range=(kwargs['SR_min']-0.2, kwargs['SR_max']+0.2))

    else:
        # data preprocessing
        data.preprocess_ANODE_data(fiducial_cut=kwargs['fiducial_cut'],
                                   no_logit=kwargs['no_logit_trained'],
                                   no_mean_shift=kwargs['no_logit_trained'])

        # model instantiation
        if len(kwargs['external_samples']) > 0:
            model_list = None
            loaded_samples = [np.load(sample_path) for sample_path in kwargs['external_samples']]
            external_sample = np.concatenate(loaded_samples)
        else:
            model_list = []
            for model_path in kwargs['ANODE_models']:
                anode = DensityEstimator(kwargs['config_file'],
                                         eval_mode=True,
                                         load_path=model_path,
                                         device=device, verbose=kwargs['verbose'],
                                         bound=kwargs['no_logit_trained'])
                model_list.append(anode.model)
            external_sample = None

        # generate samples
        if not kwargs['supervised'] and not kwargs['idealized_AD']:
            uniform_cond = not kwargs['realistic_conditional']
            samples = sample_handler(model_list, kwargs['n_samples'], data, cond_min=kwargs['SR_min'],
                                     cond_max=kwargs['SR_max'], uniform_cond=uniform_cond,
                                     external_sample=external_sample,
                                     device=device, no_logit=kwargs['no_logit_trained'],
                                     no_mean_shift=kwargs['no_logit_trained'],
                                     KDE_bandwidth=kwargs['KDE_bandwidth'])
        else:
            samples = None

        # redo data preprocessing if the classifier should not use logit but ANODE did
        data.preprocess_ANODE_data(fiducial_cut=kwargs['fiducial_cut'], no_logit=kwargs['no_logit_trained'],
                                   no_mean_shift=kwargs['no_logit_trained'])

        # sample preprocessing
        if not kwargs['supervised'] and not kwargs['idealized_AD']:
            samples.preprocess_samples(fiducial_cut=kwargs['fiducial_cut'], no_logit=kwargs['no_logit_trained'],
                                       no_mean_shift=kwargs['no_logit_trained'])


    # sample mixing
    X_train, y_train, X_test, y_test, X_extrasig, y_extrasig = mix_data_samples(
        data, samples_handler=samples, oversampling=kwargs['oversampling'],
        savedir=kwargs['savedir'], CWoLa=kwargs['CWoLa'], supervised=kwargs['supervised'],
        idealized_AD=kwargs['idealized_AD'], separate_val_set=kwargs['separate_val_set'] or finer_data_split)

    # sanity checks
    if not kwargs['CWoLa'] and not kwargs['supervised'] and not kwargs['idealized_AD']:
        samples.sanity_check(savefig=os.path.join(kwargs['savedir'], "sanity_check"), suppress_show=True)
        samples.sanity_check_after_cuts(savefig=os.path.join(kwargs['savedir'], "sanity_check_cuts"),
                                        suppress_show=True)

    if kwargs['supervised'] or kwargs['separate_val_set'] or finer_data_split:
        X_val = X_extrasig
        if kwargs['supervised']:
            y_train = X_train[:, -1]
            y_test = X_test[:, -1]
            y_val = X_val[:, -1]
        else:
            y_val = X_val[:, -2]
        plot_data_sample_comparison(X_val, y_val, title="validation set",
                                    savefig=os.path.join(kwargs['savedir'],
                                                         "data_sample_comparison_val"),
                                    suppress_show=True)

    plot_data_sample_comparison(X_train, y_train, title="training set",
                                savefig=os.path.join(kwargs['savedir'], "data_sample_comparison_train"),
                                suppress_show=True)
    plot_data_sample_comparison(X_test, y_test, title="test set",
                                savefig=os.path.join(kwargs['savedir'], "data_sample_comparison_test"),
                                suppress_show=True)

    print("number of training data =", X_train.shape[0])
    print("number of test data =", X_test.shape[0])
    if not kwargs['no_extra_signal']:
        if kwargs['supervised'] or kwargs['separate_val_set'] or finer_data_split:
            print("number of validation data =", X_val.shape[0])
        elif extrasig_path is not None:
            print("number of extra signal data =", X_extrasig.shape[0])

In [5]:
create_data(**data_creation_kwargs)

cuda available: False


  return n/db/n.sum(), bin_edges
  return n/db/n.sum(), bin_edges
  return n/db/n.sum(), bin_edges


number of training data = 163668
number of test data = 359949


In [6]:
classifier_kwargs = {
    'config_file': cf_config_file,
    'data_dir': save_dir,
    'savedir': save_dir,
    'verbose': verbose,
    'epochs': cf_epochs,
    'n_runs': cf_n_runs,
    'batch_size': cf_batch_size,
    'no_extra_signal': no_extra_signal,
    'use_mjj': cf_use_mjj,
    'supervised': True,
    'use_class_weights': cf_oversampling or cf_use_class_weights,
    'CWoLa': False,
    'SR_center': cf_SR_center,
    'save_model': cf_save_model,
    'separate_val_set': cf_separate_val_set
}

In [7]:
from classifier_training_utils import train_n_models, plot_classifier_losses
from evaluation_utils import minimum_val_loss_model_evaluation
import matplotlib as mpl

def train_classifier(**kwargs):

    # loading the data
    # TODO get rid of the y's since the information is fully included in X
    X_train = np.load(os.path.join(kwargs['data_dir'], 'X_train.npy'))
    X_test = np.load(os.path.join(kwargs['data_dir'], 'X_test.npy'))
    y_train = np.load(os.path.join(kwargs['data_dir'], 'y_train.npy'))
    y_test = np.load(os.path.join(kwargs['data_dir'], 'y_test.npy'))
    if kwargs['no_extra_signal'] or kwargs['supervised']:
        X_extrasig = None
    else:
        X_extrasig = np.load(os.path.join(kwargs['data_dir'], 'X_extrasig.npy'))
    if kwargs['supervised'] or kwargs['separate_val_set']:
        X_val = np.load(os.path.join(kwargs['data_dir'], 'X_validation.npy'))        
    else:
        X_val = None

    if kwargs['save_model']:
        if not os.path.exists(kwargs['savedir']):
            os.makedirs(kwargs['savedir'])
        save_model = os.path.join(kwargs['savedir'], "model")
    else:
        save_model = None

    # actual training
    loss_matris, val_loss_matris = train_n_models(
        kwargs['n_runs'], kwargs['config_file'], kwargs['epochs'], X_train, y_train, X_test, y_test,
        X_extrasig=X_extrasig, X_val=X_val, use_mjj=kwargs['use_mjj'], batch_size=kwargs['batch_size'],
        supervised=kwargs['supervised'], use_class_weights=kwargs['use_class_weights'],
        CWoLa=kwargs['CWoLa'], SR_center=kwargs['SR_center'], verbose=kwargs['verbose'],
        savedir=kwargs['savedir'], save_model=save_model)

    if kwargs['save_model']:
        minimum_val_loss_model_evaluation(kwargs['data_dir'], kwargs['savedir'], n_epochs=10,
                                use_mjj=kwargs['use_mjj'], extra_signal=not kwargs['no_extra_signal'])

    for i in range(loss_matris.shape[0]):
        plot_classifier_losses(
            loss_matris[i], val_loss_matris[i],
            savefig=save_model+"_run"+str(i)+"_loss_plot",
            suppress_show=True
        )

In [8]:
train_classifier(**classifier_kwargs)

Training model nr 0...
Running a fully supervised training. Sig/bkg labels will be known!
training epoch nr 0
training loss: 0.2777864336967468
validation loss: 0.23900023102760315
training epoch nr 1
training loss: 0.23200424015522003
validation loss: 0.2299966961145401
training epoch nr 2
training loss: 0.22715002298355103
validation loss: 0.22659741342067719
training epoch nr 3
training loss: 0.2241223156452179
validation loss: 0.2269790768623352
training epoch nr 4
training loss: 0.22208960354328156
validation loss: 0.22076523303985596
training epoch nr 5
training loss: 0.22081901133060455
validation loss: 0.2215927392244339
training epoch nr 6
training loss: 0.21962608397006989
validation loss: 0.2197658121585846
training epoch nr 7
training loss: 0.21914184093475342
validation loss: 0.21811790764331818
training epoch nr 8
training loss: 0.2185097187757492
validation loss: 0.21846142411231995
training epoch nr 9
training loss: 0.21782803535461426
validation loss: 0.219537511467933

training loss: 0.20654766261577606
validation loss: 0.21912983059883118
training epoch nr 89
training loss: 0.20662467181682587
validation loss: 0.21938170492649078
training epoch nr 90
training loss: 0.20639172196388245
validation loss: 0.2196572721004486
training epoch nr 91
training loss: 0.20612020790576935
validation loss: 0.22356414794921875
training epoch nr 92
training loss: 0.2061387449502945
validation loss: 0.22078698873519897
training epoch nr 93
training loss: 0.20614902675151825
validation loss: 0.21900416910648346
training epoch nr 94
training loss: 0.20596162974834442
validation loss: 0.22175095975399017
training epoch nr 95
training loss: 0.20608656108379364
validation loss: 0.22279411554336548
training epoch nr 96
training loss: 0.20558755099773407
validation loss: 0.22289323806762695
training epoch nr 97
training loss: 0.2058422565460205
validation loss: 0.21990355849266052
training epoch nr 98
training loss: 0.2055191695690155
validation loss: 0.22080416977405548
tr

training loss: 0.20786093175411224
validation loss: 0.21910080313682556
training epoch nr 77
training loss: 0.20870114862918854
validation loss: 0.22002992033958435
training epoch nr 78
training loss: 0.20790426433086395
validation loss: 0.21824993193149567
training epoch nr 79
training loss: 0.20784704387187958
validation loss: 0.21993288397789001
training epoch nr 80
training loss: 0.2079704999923706
validation loss: 0.21782156825065613
training epoch nr 81
training loss: 0.2077990025281906
validation loss: 0.220712810754776
training epoch nr 82
training loss: 0.20784153044223785
validation loss: 0.21989117562770844
training epoch nr 83
training loss: 0.20735687017440796
validation loss: 0.22185219824314117
training epoch nr 84
training loss: 0.2076929658651352
validation loss: 0.2220170497894287
training epoch nr 85
training loss: 0.20757025480270386
validation loss: 0.22066150605678558
training epoch nr 86
training loss: 0.20745868980884552
validation loss: 0.22109712660312653
trai

training loss: 0.20860816538333893
validation loss: 0.21706876158714294
training epoch nr 65
training loss: 0.20817753672599792
validation loss: 0.21814817190170288
training epoch nr 66
training loss: 0.2086290866136551
validation loss: 0.21706177294254303
training epoch nr 67
training loss: 0.2082018256187439
validation loss: 0.2188032567501068
training epoch nr 68
training loss: 0.2078743427991867
validation loss: 0.21948829293251038
training epoch nr 69
training loss: 0.20790313184261322
validation loss: 0.2184378057718277
training epoch nr 70
training loss: 0.20772239565849304
validation loss: 0.219774529337883
training epoch nr 71
training loss: 0.20763985812664032
validation loss: 0.2191118746995926
training epoch nr 72
training loss: 0.2079009860754013
validation loss: 0.2176922857761383
training epoch nr 73
training loss: 0.20744067430496216
validation loss: 0.21881310641765594
training epoch nr 74
training loss: 0.20741592347621918
validation loss: 0.21930797398090363
training

training loss: 0.20982994139194489
validation loss: 0.2175634205341339
training epoch nr 53
training loss: 0.20980796217918396
validation loss: 0.21670004725456238
training epoch nr 54
training loss: 0.20936807990074158
validation loss: 0.21933333575725555
training epoch nr 55
training loss: 0.20963183045387268
validation loss: 0.21619993448257446
training epoch nr 56
training loss: 0.20958445966243744
validation loss: 0.2187022715806961
training epoch nr 57
training loss: 0.2093384861946106
validation loss: 0.21866875886917114
training epoch nr 58
training loss: 0.20926526188850403
validation loss: 0.2170756608247757
training epoch nr 59
training loss: 0.20905430614948273
validation loss: 0.21874749660491943
training epoch nr 60
training loss: 0.20923812687397003
validation loss: 0.22044776380062103
training epoch nr 61
training loss: 0.20884515345096588
validation loss: 0.22052566707134247
training epoch nr 62
training loss: 0.20900671184062958
validation loss: 0.2173640877008438
tra

training loss: 0.21189184486865997
validation loss: 0.2209012806415558
training epoch nr 41
training loss: 0.21172957122325897
validation loss: 0.2167162001132965
training epoch nr 42
training loss: 0.21158365905284882
validation loss: 0.21828989684581757
training epoch nr 43
training loss: 0.2118532359600067
validation loss: 0.21655015647411346
training epoch nr 44
training loss: 0.21188542246818542
validation loss: 0.21646150946617126
training epoch nr 45
training loss: 0.2117512822151184
validation loss: 0.2159683257341385
training epoch nr 46
training loss: 0.21155601739883423
validation loss: 0.2164163887500763
training epoch nr 47
training loss: 0.2111426740884781
validation loss: 0.21858753263950348
training epoch nr 48
training loss: 0.21110008656978607
validation loss: 0.2165045142173767
training epoch nr 49
training loss: 0.21164579689502716
validation loss: 0.21712768077850342
training epoch nr 50
training loss: 0.2111506164073944
validation loss: 0.21729618310928345
trainin

training loss: 0.21311110258102417
validation loss: 0.21896104514598846
training epoch nr 29
training loss: 0.21303772926330566
validation loss: 0.21822889149188995
training epoch nr 30
training loss: 0.21303631365299225
validation loss: 0.21654409170150757
training epoch nr 31
training loss: 0.21290573477745056
validation loss: 0.21731796860694885
training epoch nr 32
training loss: 0.21283023059368134
validation loss: 0.21817468106746674
training epoch nr 33
training loss: 0.21270258724689484
validation loss: 0.22134211659431458
training epoch nr 34
training loss: 0.21250277757644653
validation loss: 0.21754556894302368
training epoch nr 35
training loss: 0.21260994672775269
validation loss: 0.21627819538116455
training epoch nr 36
training loss: 0.2121584713459015
validation loss: 0.21633462607860565
training epoch nr 37
training loss: 0.2119145393371582
validation loss: 0.21766172349452972
training epoch nr 38
training loss: 0.21246108412742615
validation loss: 0.21560470759868622


training loss: 0.21457016468048096
validation loss: 0.21700550615787506
training epoch nr 17
training loss: 0.21482600271701813
validation loss: 0.2171909064054489
training epoch nr 18
training loss: 0.21438559889793396
validation loss: 0.2178259789943695
training epoch nr 19
training loss: 0.21381118893623352
validation loss: 0.2186523675918579
training epoch nr 20
training loss: 0.2137349545955658
validation loss: 0.21962516009807587
training epoch nr 21
training loss: 0.2141512930393219
validation loss: 0.21829918026924133
training epoch nr 22
training loss: 0.21340319514274597
validation loss: 0.21759635210037231
training epoch nr 23
training loss: 0.21397462487220764
validation loss: 0.21690234541893005
training epoch nr 24
training loss: 0.21389532089233398
validation loss: 0.2153298407793045
training epoch nr 25
training loss: 0.2131940871477127
validation loss: 0.21701224148273468
training epoch nr 26
training loss: 0.2131909877061844
validation loss: 0.21601392328739166
traini

training loss: 0.2214488536119461
validation loss: 0.2219044715166092
training epoch nr 5
training loss: 0.2200639247894287
validation loss: 0.21971647441387177
training epoch nr 6
training loss: 0.2195703685283661
validation loss: 0.21790678799152374
training epoch nr 7
training loss: 0.2186463326215744
validation loss: 0.22052302956581116
training epoch nr 8
training loss: 0.21793444454669952
validation loss: 0.21843716502189636
training epoch nr 9
training loss: 0.21698947250843048
validation loss: 0.21837688982486725
training epoch nr 10
training loss: 0.2168710082769394
validation loss: 0.21888960897922516
training epoch nr 11
training loss: 0.216077983379364
validation loss: 0.21862490475177765
training epoch nr 12
training loss: 0.21610337495803833
validation loss: 0.21789583563804626
training epoch nr 13
training loss: 0.21587032079696655
validation loss: 0.22023963928222656
training epoch nr 14
training loss: 0.21538807451725006
validation loss: 0.2161852866411209
training epo

training loss: 0.2057347446680069
validation loss: 0.2235286384820938
training epoch nr 94
training loss: 0.2060060054063797
validation loss: 0.22063715755939484
training epoch nr 95
training loss: 0.20575805008411407
validation loss: 0.22072423994541168
training epoch nr 96
training loss: 0.2058558166027069
validation loss: 0.22181719541549683
training epoch nr 97
training loss: 0.20571641623973846
validation loss: 0.22086039185523987
training epoch nr 98
training loss: 0.2052917778491974
validation loss: 0.2220257669687271
training epoch nr 99
training loss: 0.20586198568344116
validation loss: 0.22238734364509583
Training model nr 8...
Running a fully supervised training. Sig/bkg labels will be known!
training epoch nr 0
training loss: 0.2683088481426239
validation loss: 0.23525655269622803
training epoch nr 1
training loss: 0.22936506569385529
validation loss: 0.2262234389781952
training epoch nr 2
training loss: 0.2238752394914627
validation loss: 0.22170905768871307
training epoc

training loss: 0.20738515257835388
validation loss: 0.2205781489610672
training epoch nr 82
training loss: 0.20726615190505981
validation loss: 0.22012102603912354
training epoch nr 83
training loss: 0.2067732959985733
validation loss: 0.22185204923152924
training epoch nr 84
training loss: 0.20722928643226624
validation loss: 0.22160471975803375
training epoch nr 85
training loss: 0.2067716419696808
validation loss: 0.2204953283071518
training epoch nr 86
training loss: 0.20669607818126678
validation loss: 0.22109541296958923
training epoch nr 87
training loss: 0.20698295533657074
validation loss: 0.2215975821018219
training epoch nr 88
training loss: 0.20681658387184143
validation loss: 0.2212776094675064
training epoch nr 89
training loss: 0.20656268298625946
validation loss: 0.2221362292766571
training epoch nr 90
training loss: 0.20615017414093018
validation loss: 0.225407674908638
training epoch nr 91
training loss: 0.20641562342643738
validation loss: 0.22149650752544403
trainin

training loss: 0.20866726338863373
validation loss: 0.21826374530792236
training epoch nr 70
training loss: 0.20853741466999054
validation loss: 0.2189447283744812
training epoch nr 71
training loss: 0.20843107998371124
validation loss: 0.22029878199100494
training epoch nr 72
training loss: 0.20848870277404785
validation loss: 0.21855010092258453
training epoch nr 73
training loss: 0.20830926299095154
validation loss: 0.2218886762857437
training epoch nr 74
training loss: 0.2082894891500473
validation loss: 0.2184862643480301
training epoch nr 75
training loss: 0.2084088772535324
validation loss: 0.218919575214386
training epoch nr 76
training loss: 0.20824715495109558
validation loss: 0.21870948374271393
training epoch nr 77
training loss: 0.20800325274467468
validation loss: 0.21969862282276154
training epoch nr 78
training loss: 0.2078389823436737
validation loss: 0.2199505865573883
training epoch nr 79
training loss: 0.2079342007637024
validation loss: 0.2201385796070099
training 

In [5]:
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import torch
import sklearn
from sklearn import metrics
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from scipy.interpolate import interp1d

from evaluation_utils import load_predictions, tprs_fprs_sics, minumum_validation_loss_ensemble, compare_on_various_runs
# evalautes the ROCs and SICs of one given training with minimum validation loss epoch picking
def full_single_evaluation(data_dir, prediction_dir, n_ensemble_epochs=10, extra_signal=True,
                           sic_range=(0,20), savefig=None, suppress_show=False, return_all=False):
    X_test, y_test, predictions, val_losses = load_predictions(
        data_dir, prediction_dir, extra_signal=extra_signal)
    if predictions.shape[1]==1: ## check if ensembling done already
        min_val_loss_predictions = predictions
    else:
        min_val_loss_predictions = minumum_validation_loss_ensemble(
            predictions, val_losses, n_epochs=n_ensemble_epochs)
    tprs, fprs, sics = tprs_fprs_sics(min_val_loss_predictions, y_test, X_test)

    return compare_on_various_runs(
        [tprs], [fprs], [np.zeros(min_val_loss_predictions.shape[0])], [""],
        sic_lim=sic_range, savefig=savefig, only_median=False, continuous_colors=False,
        reduced_legend=False, suppress_show=suppress_show, return_all=return_all)

In [6]:
_ = full_single_evaluation(save_dir, save_dir, n_ensemble_epochs=10,
                           extra_signal=not no_extra_signal, sic_range=(0, 20),
                           savefig=os.path.join(save_dir, 'result_SIC'))