# Notebook 02: Pipeline for CWoLa results

This notebook goes through the pipeline for obtaining results using the idealized CWoLa (Classification Without Labels) method.

In [1]:
import os
import argparse
import numpy as np

from run_ANODE_training import main as train_DE
from run_classifier_data_creation import main as create_data
from run_classifier_training import main as train_classifier
from run_ANODE_evaluation import main as eval_ANODE
from evaluation_utils import full_single_evaluation, classic_ANODE_eval, minimum_val_loss_model_evaluation

In [2]:
mode = 'CWoLa'
data_dir = '../separated_data'
save_dir = 'CWoLa_models'
# Shift on jet mass variables to be applied.
datashift = 0.
# Shift is not correlated to the actual mjj but randomized.
random_shift = False
# Whether to apply an (ANODE paper) fiducial cut on the data (and samples).
fiducial_cut = False
# Suppress the processing of the extra signal sample.
no_extra_signal = True
verbose = False

# ANODE model config file (.yml).
DE_config_file = '../DE_MAF_model.yml'
# 'Number of Density Estimation training epochs.'
DE_epochs = 100
# Batch size during density estimation training.
DE_batch_size = 256
# Skips the density estimation (loads existing files instead).
DF_skip = False
# Turns off the logit transform in the density estimator.
DE_no_logit = False

# File name for the density estimator.
DE_file_name = 'my_ANODE_model'

# Classifier model config file (.yml).
cf_config_file = '../classifier.yml'

# Number of classifier training epochs
cf_epochs = 100
# Number of samples to be generated. Currently the samples will be cut down to match data proportion.
cf_n_samples = 130000

# Sample the conditional from a KDE fit rather than a uniform distribution.
cf_realistic_conditional = False
# Bandwith of the KDE fit (used when realistic_conditional is selected)
cf_KDE_bandwidth = 0.01
# Add the full number of samples to the training set rather than mixing it in equal parts with data.
cf_oversampling = True
# Turns off logit tranform in the classifier.
cf_no_logit = True
# Space-separated list of pre-sampled npy files of physical variables if the sampling has been done externally. The format is 
# (mjj, mj1, dmj, tau21_1, tau21_2)
cf_external_samples = ""
# Lower boundary of signal region.
cf_SR_min = 3.3
# Upper boundary of signal region.
cf_SR_max = 3.7
# Number of independent classifier training runs.
cf_n_runs = 10
# Batch size during classifier training.
cf_batch_size = 128
# Use the conditional variable as classifier input during training.
cf_use_mjj = False
# Weight the classes according to their occurence in the training set. 
# Necessary if the training set was intentionally oversampled.'
cf_use_class_weights = True
# Central value of signal region. Must only be given for using CWoLa with weights.
cf_SR_center = 3.5
# Make use of extra background (for supervised and idealized AD).
cf_extra_bkg = False
# Define a separate validation set to pick the classifier epochs.
cf_separate_val_set = True
# Save the tensorflow model after each epoch instead of saving predictions.
cf_save_model = True
# Skips the creation of the classifier dataset (loads existing files instead).
cf_skip_create = False
# Skips the training of the classifier (loads existing files instead).
cf_skip_train = False

In [3]:
data_creation_kwargs = {
    'savedir': save_dir,
    'datashift': datashift,
    'data_dir': data_dir,
    'random_shift': random_shift,
    'config_file': DE_config_file,
    'verbose': verbose,
    'fiducial_cut': fiducial_cut,
    'n_samples': cf_n_samples,
    'realistic_conditional': cf_realistic_conditional,
    'KDE_bandwidth': cf_KDE_bandwidth,
    'oversampling': cf_oversampling,
    'no_extra_signal': no_extra_signal,
    'CWoLa': True,
    'supervised': False,
    'idealized_AD': False,
    'no_logit': cf_no_logit,
    'no_logit_trained': DE_no_logit,
    'external_samples': cf_external_samples,
    'SR_min': cf_SR_min,
    'SR_max': cf_SR_max,
    'extra_bkg': cf_extra_bkg,
    'separate_val_set': cf_separate_val_set,
    'ANODE_models': []
}

In [4]:
import torch
import numpy as np
from data_handler import LHCORD_data_handler, sample_handler, mix_data_samples, plot_data_sample_comparison
from density_estimator import DensityEstimator

def create_data(**kwargs):

    assert not ((not (kwargs['supervised'] or kwargs['idealized_AD'] or kwargs['CWoLa']) and\
                 kwargs['external_samples'] == "") and kwargs['ANODE_models'] == ""), (
                     "ANODE models need to be given unless CWoLa, supervised, idealized_AD or"
                     " external sampling is used.")

    # selecting appropriate device
    CUDA = torch.cuda.is_available()
    print("cuda available:", CUDA)
    device = torch.device("cuda:0" if CUDA else "cpu")

    # checking for data separation
    data_files = os.listdir(kwargs['data_dir'])
    if "innerdata_val.npy" in data_files:
        finer_data_split = True
    else:
        finer_data_split = False

    if finer_data_split:
        innerdata_train_path = [os.path.join(kwargs['data_dir'], 'innerdata_train.npy')]
        innerdata_val_path = [os.path.join(kwargs['data_dir'], 'innerdata_val.npy')]
        innerdata_test_path = [os.path.join(kwargs['data_dir'], 'innerdata_test.npy')]
        if "innerdata_extrabkg_test.npy" in data_files:
            innerdata_test_path.append(os.path.join(kwargs['data_dir'], 'innerdata_extrabkg_test.npy'))
        extrasig_path = None
        if kwargs['supervised']:
            innerdata_train_path = []
            innerdata_val_path = []
            innerdata_train_path.append(os.path.join(kwargs['data_dir'], 'innerdata_extrasig_train.npy'))
            innerdata_val_path.append(os.path.join(kwargs['data_dir'], 'innerdata_extrasig_val.npy'))
            innerdata_train_path.append(os.path.join(kwargs['data_dir'], 'innerdata_extrabkg_train.npy'))
            innerdata_val_path.append(os.path.join(kwargs['data_dir'], 'innerdata_extrabkg_val.npy'))
            extra_bkg = None
        elif kwargs['idealized_AD']:
            extra_bkg = [os.path.join(kwargs['data_dir'], 'innerdata_extrabkg_train.npy'),
                         os.path.join(kwargs['data_dir'], 'innerdata_extrabkg_val.npy')]
        else:
            extra_bkg = None

    else:
        innerdata_train_path = os.path.join(kwargs['data_dir'], 'innerdata_train.npy')
        extrasig_path = os.path.join(kwargs['data_dir'], 'innerdata_extrasig.npy')
        if kwargs['extra_bkg']:
            extra_bkg = os.path.join(kwargs['data_dir'], 'innerdata_extrabkg.npy')
        else:
            extra_bkg = None
        innerdata_val_path = None
        innerdata_test_path = os.path.join(kwargs['data_dir'], 'innerdata_test.npy')

    # data preprocessing
    data = LHCORD_data_handler(innerdata_train_path,
                               innerdata_test_path,
                               os.path.join(kwargs['data_dir'], 'outerdata_train.npy'),
                               os.path.join(kwargs['data_dir'], 'outerdata_test.npy'),
                               extrasig_path,
                               inner_extrabkg_path=extra_bkg,
                               inner_val_path=innerdata_val_path,
                               batch_size=256,
                               device=device)
    if kwargs['datashift'] != 0:
        print("applying a datashift of", kwargs['datashift'])
        data.shift_data(kwargs['datashift'], constant_shift=False, random_shift=kwargs['random_shift'],
                        shift_mj1=True, shift_dm=True, additional_shift=False)

    if kwargs['CWoLa']:
        # data preprocessing
        samples = None
        data.preprocess_CWoLa_data(fiducial_cut=kwargs['fiducial_cut'], no_logit=kwargs['no_logit'],
                                   outer_range=(kwargs['SR_min']-0.2, kwargs['SR_max']+0.2))

    else:
        # data preprocessing
        data.preprocess_ANODE_data(fiducial_cut=kwargs['fiducial_cut'],
                                   no_logit=kwargs['no_logit_trained'],
                                   no_mean_shift=kwargs['no_logit_trained'])

        # model instantiation
        if len(kwargs['external_samples']) > 0:
            model_list = None
            loaded_samples = [np.load(sample_path) for sample_path in kwargs['external_samples']]
            external_sample = np.concatenate(loaded_samples)
        else:
            model_list = []
            for model_path in kwargs['ANODE_models']:
                anode = DensityEstimator(kwargs['config_file'],
                                         eval_mode=True,
                                         load_path=model_path,
                                         device=device, verbose=kwargs['verbose'],
                                         bound=kwargs['no_logit_trained'])
                model_list.append(anode.model)
            external_sample = None

        # generate samples
        if not kwargs['supervised'] and not kwargs['idealized_AD']:
            uniform_cond = not kwargs['realistic_conditional']
            samples = sample_handler(model_list, kwargs['n_samples'], data, cond_min=kwargs['SR_min'],
                                     cond_max=kwargs['SR_max'], uniform_cond=uniform_cond,
                                     external_sample=external_sample,
                                     device=device, no_logit=kwargs['no_logit_trained'],
                                     no_mean_shift=kwargs['no_logit_trained'],
                                     KDE_bandwidth=kwargs['KDE_bandwidth'])
        else:
            samples = None

        # redo data preprocessing if the classifier should not use logit but ANODE did
        data.preprocess_ANODE_data(fiducial_cut=kwargs['fiducial_cut'], no_logit=kwargs['no_logit_trained'],
                                   no_mean_shift=kwargs['no_logit_trained'])

        # sample preprocessing
        if not kwargs['supervised'] and not kwargs['idealized_AD']:
            samples.preprocess_samples(fiducial_cut=kwargs['fiducial_cut'], no_logit=kwargs['no_logit_trained'],
                                       no_mean_shift=kwargs['no_logit_trained'])


    # sample mixing
    X_train, y_train, X_test, y_test, X_extrasig, y_extrasig = mix_data_samples(
        data, samples_handler=samples, oversampling=kwargs['oversampling'],
        savedir=kwargs['savedir'], CWoLa=kwargs['CWoLa'], supervised=kwargs['supervised'],
        idealized_AD=kwargs['idealized_AD'], separate_val_set=kwargs['separate_val_set'] or finer_data_split)

    # sanity checks
    if not kwargs['CWoLa'] and not kwargs['supervised'] and not kwargs['idealized_AD']:
        samples.sanity_check(savefig=os.path.join(kwargs['savedir'], "sanity_check"), suppress_show=True)
        samples.sanity_check_after_cuts(savefig=os.path.join(kwargs['savedir'], "sanity_check_cuts"),
                                        suppress_show=True)

    if kwargs['supervised'] or kwargs['separate_val_set'] or finer_data_split:
        X_val = X_extrasig
        if kwargs['supervised']:
            y_train = X_train[:, -1]
            y_test = X_test[:, -1]
            y_val = X_val[:, -1]
        else:
            y_val = X_val[:, -2]
        plot_data_sample_comparison(X_val, y_val, title="validation set",
                                    savefig=os.path.join(kwargs['savedir'],
                                                         "data_sample_comparison_val"),
                                    suppress_show=True)

    plot_data_sample_comparison(X_train, y_train, title="training set",
                                savefig=os.path.join(kwargs['savedir'], "data_sample_comparison_train"),
                                suppress_show=True)
    plot_data_sample_comparison(X_test, y_test, title="test set",
                                savefig=os.path.join(kwargs['savedir'], "data_sample_comparison_test"),
                                suppress_show=True)

    print("number of training data =", X_train.shape[0])
    print("number of test data =", X_test.shape[0])
    if not kwargs['no_extra_signal']:
        if kwargs['supervised'] or kwargs['separate_val_set'] or finer_data_split:
            print("number of validation data =", X_val.shape[0])
        elif extrasig_path is not None:
            print("number of extra signal data =", X_extrasig.shape[0])

In [5]:
create_data(**data_creation_kwargs)

cuda available: False


  return n/db/n.sum(), bin_edges
  return n/db/n.sum(), bin_edges


number of training data = 125093
number of test data = 359934


In [6]:
classifier_kwargs = {
    'config_file': cf_config_file,
    'data_dir': save_dir,
    'savedir': save_dir,
    'verbose': verbose,
    'epochs': cf_epochs,
    'n_runs': cf_n_runs,
    'batch_size': cf_batch_size,
    'no_extra_signal': no_extra_signal,
    'use_mjj': cf_use_mjj,
    'supervised': False,
    'use_class_weights': cf_oversampling or cf_use_class_weights,
    'CWoLa': True,
    'SR_center': cf_SR_center,
    'save_model': cf_save_model,
    'separate_val_set': cf_separate_val_set
}

In [7]:
from classifier_training_utils import train_n_models, plot_classifier_losses
from evaluation_utils import minimum_val_loss_model_evaluation
import matplotlib as mpl

def train_classifier(**kwargs):

    # loading the data
    # TODO get rid of the y's since the information is fully included in X
    X_train = np.load(os.path.join(kwargs['data_dir'], 'X_train.npy'))
    X_test = np.load(os.path.join(kwargs['data_dir'], 'X_test.npy'))
    y_train = np.load(os.path.join(kwargs['data_dir'], 'y_train.npy'))
    y_test = np.load(os.path.join(kwargs['data_dir'], 'y_test.npy'))
    if kwargs['no_extra_signal'] or kwargs['supervised']:
        X_extrasig = None
    else:
        X_extrasig = np.load(os.path.join(kwargs['data_dir'], 'X_extrasig.npy'))
    if kwargs['supervised'] or kwargs['separate_val_set']:
        X_val = np.load(os.path.join(kwargs['data_dir'], 'X_validation.npy'))        
    else:
        X_val = None

    if kwargs['save_model']:
        if not os.path.exists(kwargs['savedir']):
            os.makedirs(kwargs['savedir'])
        save_model = os.path.join(kwargs['savedir'], "model")
    else:
        save_model = None

    # actual training
    loss_matris, val_loss_matris = train_n_models(
        kwargs['n_runs'], kwargs['config_file'], kwargs['epochs'], X_train, y_train, X_test, y_test,
        X_extrasig=X_extrasig, X_val=X_val, use_mjj=kwargs['use_mjj'], batch_size=kwargs['batch_size'],
        supervised=kwargs['supervised'], use_class_weights=kwargs['use_class_weights'],
        CWoLa=kwargs['CWoLa'], SR_center=kwargs['SR_center'], verbose=kwargs['verbose'],
        savedir=kwargs['savedir'], save_model=save_model)

    if kwargs['save_model']:
        minimum_val_loss_model_evaluation(kwargs['data_dir'], kwargs['savedir'], n_epochs=10,
                                use_mjj=kwargs['use_mjj'], extra_signal=not kwargs['no_extra_signal'])

    for i in range(loss_matris.shape[0]):
        plot_classifier_losses(
            loss_matris[i], val_loss_matris[i],
            savefig=save_model+"_run"+str(i)+"_loss_plot",
            suppress_show=True
        )

In [8]:
train_classifier(**classifier_kwargs)

Training model nr 0...
training epoch nr 0
training loss: 0.6932655572891235
validation loss: 0.6912049651145935
training epoch nr 1
training loss: 0.6931723952293396
validation loss: 0.6924153566360474
training epoch nr 2
training loss: 0.6931125521659851
validation loss: 0.6926141381263733
training epoch nr 3
training loss: 0.6931195259094238
validation loss: 0.6879608035087585
training epoch nr 4
training loss: 0.6931072473526001
validation loss: 0.6912862658500671
training epoch nr 5
training loss: 0.6930382251739502
validation loss: 0.6973282694816589
training epoch nr 6
training loss: 0.6930860280990601
validation loss: 0.6895106434822083
training epoch nr 7
training loss: 0.6929660439491272
validation loss: 0.6931555271148682
training epoch nr 8
training loss: 0.6929461359977722
validation loss: 0.6855359077453613
training epoch nr 9
training loss: 0.6929247379302979
validation loss: 0.6935411095619202
training epoch nr 10
training loss: 0.6928834915161133
validation loss: 0.697

training loss: 0.6895545125007629
validation loss: 0.6979000568389893
training epoch nr 91
training loss: 0.6893881559371948
validation loss: 0.7023988962173462
training epoch nr 92
training loss: 0.6894462704658508
validation loss: 0.6972627639770508
training epoch nr 93
training loss: 0.6894827485084534
validation loss: 0.7077352404594421
training epoch nr 94
training loss: 0.6892298460006714
validation loss: 0.6919747591018677
training epoch nr 95
training loss: 0.6892366409301758
validation loss: 0.6956527829170227
training epoch nr 96
training loss: 0.6891956925392151
validation loss: 0.7000218629837036
training epoch nr 97
training loss: 0.689211368560791
validation loss: 0.6929028630256653
training epoch nr 98
training loss: 0.6891415119171143
validation loss: 0.6985769271850586
training epoch nr 99
training loss: 0.6891866326332092
validation loss: 0.7058608531951904
Training model nr 1...
training epoch nr 0
training loss: 0.693309485912323
validation loss: 0.6934164762496948


validation loss: 0.6949307918548584
training epoch nr 81
training loss: 0.6909645795822144
validation loss: 0.6990415453910828
training epoch nr 82
training loss: 0.6910693645477295
validation loss: 0.6967562437057495
training epoch nr 83
training loss: 0.690881609916687
validation loss: 0.6942293643951416
training epoch nr 84
training loss: 0.6908169388771057
validation loss: 0.697820246219635
training epoch nr 85
training loss: 0.6907696723937988
validation loss: 0.6934583187103271
training epoch nr 86
training loss: 0.6908106803894043
validation loss: 0.6947358846664429
training epoch nr 87
training loss: 0.6907904744148254
validation loss: 0.6913406848907471
training epoch nr 88
training loss: 0.6907577514648438
validation loss: 0.6962429285049438
training epoch nr 89
training loss: 0.690734326839447
validation loss: 0.6979172825813293
training epoch nr 90
training loss: 0.6907042264938354
validation loss: 0.6952803730964661
training epoch nr 91
training loss: 0.6907477974891663
va

training loss: 0.6912575364112854
validation loss: 0.6936166882514954
training epoch nr 72
training loss: 0.6911873817443848
validation loss: 0.6935134530067444
training epoch nr 73
training loss: 0.6911927461624146
validation loss: 0.6909866333007812
training epoch nr 74
training loss: 0.6911748647689819
validation loss: 0.6931139230728149
training epoch nr 75
training loss: 0.6910248398780823
validation loss: 0.6975204944610596
training epoch nr 76
training loss: 0.6909784078598022
validation loss: 0.6927093863487244
training epoch nr 77
training loss: 0.691091001033783
validation loss: 0.6930666565895081
training epoch nr 78
training loss: 0.690910816192627
validation loss: 0.6960469484329224
training epoch nr 79
training loss: 0.6908534169197083
validation loss: 0.6934594511985779
training epoch nr 80
training loss: 0.6908184289932251
validation loss: 0.6966127157211304
training epoch nr 81
training loss: 0.6908301115036011
validation loss: 0.6963210701942444
training epoch nr 82
t

validation loss: 0.7033535838127136
training epoch nr 62
training loss: 0.691527247428894
validation loss: 0.698296308517456
training epoch nr 63
training loss: 0.6915127635002136
validation loss: 0.6950126886367798
training epoch nr 64
training loss: 0.6915889978408813
validation loss: 0.6942485570907593
training epoch nr 65
training loss: 0.6915349960327148
validation loss: 0.6884159445762634
training epoch nr 66
training loss: 0.6915715336799622
validation loss: 0.6961824297904968
training epoch nr 67
training loss: 0.6915377378463745
validation loss: 0.6922828555107117
training epoch nr 68
training loss: 0.6914750337600708
validation loss: 0.6967470049858093
training epoch nr 69
training loss: 0.6914718151092529
validation loss: 0.6999849677085876
training epoch nr 70
training loss: 0.691487729549408
validation loss: 0.6903036236763
training epoch nr 71
training loss: 0.6914499998092651
validation loss: 0.6938517093658447
training epoch nr 72
training loss: 0.691392183303833
valida

training loss: 0.6913809180259705
validation loss: 0.689378023147583
training epoch nr 53
training loss: 0.6913110613822937
validation loss: 0.6956019401550293
training epoch nr 54
training loss: 0.6913222074508667
validation loss: 0.6894998550415039
training epoch nr 55
training loss: 0.691218912601471
validation loss: 0.693843424320221
training epoch nr 56
training loss: 0.6912829279899597
validation loss: 0.6857420206069946
training epoch nr 57
training loss: 0.6911927461624146
validation loss: 0.6889767050743103
training epoch nr 58
training loss: 0.6911465525627136
validation loss: 0.6875969767570496
training epoch nr 59
training loss: 0.6910190582275391
validation loss: 0.6955615878105164
training epoch nr 60
training loss: 0.6911520957946777
validation loss: 0.6932112574577332
training epoch nr 61
training loss: 0.6911537647247314
validation loss: 0.692004382610321
training epoch nr 62
training loss: 0.6910246014595032
validation loss: 0.6983087658882141
training epoch nr 63
tra

validation loss: 0.697100043296814
training epoch nr 43
training loss: 0.6917247176170349
validation loss: 0.6942140460014343
training epoch nr 44
training loss: 0.6917914748191833
validation loss: 0.6973629593849182
training epoch nr 45
training loss: 0.6917471885681152
validation loss: 0.7025663256645203
training epoch nr 46
training loss: 0.6917513012886047
validation loss: 0.6977724432945251
training epoch nr 47
training loss: 0.6916401982307434
validation loss: 0.696074366569519
training epoch nr 48
training loss: 0.6916109919548035
validation loss: 0.6980382204055786
training epoch nr 49
training loss: 0.6915914416313171
validation loss: 0.6949172616004944
training epoch nr 50
training loss: 0.6915772557258606
validation loss: 0.6952934861183167
training epoch nr 51
training loss: 0.6915083527565002
validation loss: 0.7010878324508667
training epoch nr 52
training loss: 0.6914380192756653
validation loss: 0.7014341950416565
training epoch nr 53
training loss: 0.6913827657699585
v

training loss: 0.6919915080070496
validation loss: 0.6957857012748718
training epoch nr 34
training loss: 0.6919534206390381
validation loss: 0.6960039734840393
training epoch nr 35
training loss: 0.6919654607772827
validation loss: 0.6958020925521851
training epoch nr 36
training loss: 0.6918777227401733
validation loss: 0.6965941786766052
training epoch nr 37
training loss: 0.6918999552726746
validation loss: 0.6975207924842834
training epoch nr 38
training loss: 0.6918730139732361
validation loss: 0.6922696232795715
training epoch nr 39
training loss: 0.691680908203125
validation loss: 0.6889004111289978
training epoch nr 40
training loss: 0.6915870904922485
validation loss: 0.7025080919265747
training epoch nr 41
training loss: 0.6915934085845947
validation loss: 0.7026206254959106
training epoch nr 42
training loss: 0.6916223764419556
validation loss: 0.6989195942878723
training epoch nr 43
training loss: 0.6915925741195679
validation loss: 0.6934776306152344
training epoch nr 44


validation loss: 0.6919596195220947
training epoch nr 24
training loss: 0.6924770474433899
validation loss: 0.6911084651947021
training epoch nr 25
training loss: 0.6924301981925964
validation loss: 0.6928812861442566
training epoch nr 26
training loss: 0.6923457980155945
validation loss: 0.6930832862854004
training epoch nr 27
training loss: 0.692326009273529
validation loss: 0.6970106363296509
training epoch nr 28
training loss: 0.6922925114631653
validation loss: 0.6960746049880981
training epoch nr 29
training loss: 0.6922073364257812
validation loss: 0.6950013041496277
training epoch nr 30
training loss: 0.6922035813331604
validation loss: 0.6988515257835388
training epoch nr 31
training loss: 0.6921547651290894
validation loss: 0.6969897150993347
training epoch nr 32
training loss: 0.6920498609542847
validation loss: 0.6969448924064636
training epoch nr 33
training loss: 0.6921267509460449
validation loss: 0.6961122751235962
training epoch nr 34
training loss: 0.6921188235282898


training loss: 0.6927737593650818
validation loss: 0.6909342408180237
training epoch nr 15
training loss: 0.6927642226219177
validation loss: 0.6916611194610596
training epoch nr 16
training loss: 0.6926897764205933
validation loss: 0.691599428653717
training epoch nr 17
training loss: 0.6926308274269104
validation loss: 0.6921172142028809
training epoch nr 18
training loss: 0.6927464604377747
validation loss: 0.6925827860832214
training epoch nr 19
training loss: 0.6926538944244385
validation loss: 0.691990852355957
training epoch nr 20
training loss: 0.6925605535507202
validation loss: 0.6917730569839478
training epoch nr 21
training loss: 0.6926665902137756
validation loss: 0.6977949738502502
training epoch nr 22
training loss: 0.6925734877586365
validation loss: 0.6946200728416443
training epoch nr 23
training loss: 0.6926218271255493
validation loss: 0.6919918060302734
training epoch nr 24
training loss: 0.6925466060638428
validation loss: 0.6945841908454895
training epoch nr 25
t

validation loss: 0.6922244429588318
training epoch nr 5
training loss: 0.6931259632110596
validation loss: 0.6934130191802979
training epoch nr 6
training loss: 0.6930693984031677
validation loss: 0.6935009360313416
training epoch nr 7
training loss: 0.6930195689201355
validation loss: 0.6952805519104004
training epoch nr 8
training loss: 0.6930373907089233
validation loss: 0.691798210144043
training epoch nr 9
training loss: 0.6929882764816284
validation loss: 0.6939482688903809
training epoch nr 10
training loss: 0.6929425597190857
validation loss: 0.6925357580184937
training epoch nr 11
training loss: 0.69284987449646
validation loss: 0.6930391192436218
training epoch nr 12
training loss: 0.6929166913032532
validation loss: 0.6943768858909607
training epoch nr 13
training loss: 0.6928553581237793
validation loss: 0.6920607686042786
training epoch nr 14
training loss: 0.6928407549858093
validation loss: 0.6925715804100037
training epoch nr 15
training loss: 0.6928205490112305
validat

training loss: 0.6896779537200928
validation loss: 0.7006023526191711
training epoch nr 96
training loss: 0.6896700263023376
validation loss: 0.7070900201797485
training epoch nr 97
training loss: 0.6896584033966064
validation loss: 0.6924435496330261
training epoch nr 98
training loss: 0.6897239685058594
validation loss: 0.7102804780006409
training epoch nr 99
training loss: 0.6895762085914612
validation loss: 0.7107054591178894
minimum validation loss epochs: [94 14 28  3  4 24  6  0  8 32]
minimum validation loss epochs: [31 67 44 47 14 22  6 45 87 15]
minimum validation loss epochs: [20 60  2 42 48 73 61 53 52 57]
minimum validation loss epochs: [65 85 25 70 54 87 32  1 23 83]
minimum validation loss epochs: [89 63 96 77 81 66 84 56  1 83]
minimum validation loss epochs: [ 0 18 10 13  7 15  1  9 16 74]
minimum validation loss epochs: [20 73 39 18 19 22 60 59 16 68]
minimum validation loss epochs: [17 10 44 18 21  3 24 13 16 11]
minimum validation loss epochs: [65 53 26 25 46 75 44 

In [13]:
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import torch
import sklearn
from sklearn import metrics
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from scipy.interpolate import interp1d


from evaluation_utils import (load_predictions, tprs_fprs_sics,
                              minumum_validation_loss_ensemble, compare_on_various_runs)
# evalautes the ROCs and SICs of one given training with minimum validation loss epoch picking
def full_single_evaluation(data_dir, prediction_dir, n_ensemble_epochs=10, extra_signal=True,
                           sic_range=(0,20), savefig=None, suppress_show=False, return_all=False):
    X_test, y_test, predictions, val_losses = load_predictions(
        data_dir, prediction_dir, extra_signal=extra_signal)
    if predictions.shape[1]==1: ## check if ensembling done already
        min_val_loss_predictions = predictions
    else:
        min_val_loss_predictions = minumum_validation_loss_ensemble(
            predictions, val_losses, n_epochs=n_ensemble_epochs)
    tprs, fprs, sics = tprs_fprs_sics(min_val_loss_predictions, y_test, X_test)

    return compare_on_various_runs(
        [tprs], [fprs], [np.zeros(min_val_loss_predictions.shape[0])], [""],
        sic_lim=sic_range, savefig=savefig, only_median=False, continuous_colors=False,
        reduced_legend=False, suppress_show=suppress_show, return_all=return_all)

In [14]:
_ = full_single_evaluation(save_dir, save_dir, n_ensemble_epochs=10,
                           extra_signal=not no_extra_signal, sic_range=(0, 20),
                           savefig=os.path.join(save_dir, 'result_SIC'))