In [1]:
import numpy as np
import copy
import time
import sys
sys.path.append('scripts')
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchinfo import summary
from torch.utils import data
import h5py
import pickle
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score
from sklearn.calibration import calibration_curve
from sklearn.isotonic import IsotonicRegression

import os
from CNNEvaluate import *

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
sys.path.append('../CaloChallenge/code')
sys.path.append('CaloChallenge/code')
from utils import *
from models import *
from XMLHandler import *


Parameters

In [2]:
device='cpu'
binning_file="../CaloChallenge/code/binning_dataset_1_photons.xml"
config = '../configs/config_dataset1_photon.json'
cls_lr = 0.0001
cls_batch_size = 100
cls_n_epochs = 50
INPUT_FILE = 'generated_photons.h5'
REFERENCE_FILE = '../CaloChallenge/Datasets/dataset_1_photons_2.hdf5'
output_dir = 'results'

In [3]:
dataset_config = LoadJson(config)

bins=XMLHandler("photon", binning_file)
NN_embed=NNConverter(bins=bins).to(device=device)
cond_dim = dataset_config['COND_SIZE_UNET']
layer_sizes = [16,16,32,32,32]
mid_attn = dataset_config.get("MID_ATTN", False)
compress_Z = dataset_config.get("COMPRESS_Z", False)
E_embed = dataset_config.get("COND_EMBED", 'sin')

RZ_shape = dataset_config['SHAPE_PAD'][1:]

R_Z_inputs = dataset_config.get('R_Z_INPUT', False)
phi_inputs = dataset_config.get('PHI_INPUT', False)

in_channels = 1

if(R_Z_inputs): in_channels = 3
if(phi_inputs): in_channels += 1

calo_summary_shape = list(copy.copy(RZ_shape))
calo_summary_shape.insert(0, 1)
calo_summary_shape[1] = in_channels

calo_summary_shape[0] = 1

### File Reading

In [4]:
source_file = h5py.File(INPUT_FILE, 'r')
reference_file = h5py.File(REFERENCE_FILE, 'r')

reference_showers = reference_file['showers']
reference_energies = reference_file['incident_energies']

reference_data = np.hstack((reference_showers, reference_energies, np.zeros_like(reference_energies)))

source_showers = source_file['showers']
source_energies = source_file['incident_energies']

source_data = np.hstack((source_showers, source_energies, np.ones_like(source_energies)))

train, test, val = ttv_split(source_data, reference_data)



train_data = data.TensorDataset(torch.tensor(train).float())
test_data = data.TensorDataset(torch.tensor(test).float())
val_data = data.TensorDataset(torch.tensor(val).float())


train_dataloader = data.DataLoader(train_data, batch_size=cls_batch_size, shuffle=True)
test_dataloader = data.DataLoader(test_data, batch_size= cls_batch_size, shuffle=False)
val_dataloader = data.DataLoader(val_data, batch_size=cls_batch_size, shuffle=False)

### Model

In [5]:
model = CNN(cond_dim = cond_dim, out_dim = 1, channels = in_channels, layer_sizes = layer_sizes
            ,cylindrical = dataset_config.get('CYLINDRICAL', False),
            data_shape = calo_summary_shape, NN_embed=NN_embed, RZ_shape = RZ_shape, mid_attn = mid_attn)

total_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Model has {} parameters".format(int(total_parameters)))

optimizer = torch.optim.Adam(model.parameters(), lr=cls_lr)

Model has 311665 parameters


In [14]:
train_and_evaluate_cls(model, train_dataloader, test_dataloader, optimizer, cls_n_epochs, device)

Epoch   1, step    0 / 1452; loss 1.7897
Epoch   1, step  726 / 1452; loss 0.7219
Accuracy on training set is 0.5056404958677686
Accuracy on test set is 0.5121900826446281
AUC on test set is 0.5121900826446281
BCE loss of test set is 0.7023, JSD of the two dists is -0.0131
Epoch   2, step    0 / 1452; loss 0.6819


In [None]:
classifier = load_classifier(classifier, args)
with torch.no_grad():
    print("Now looking at independent dataset:")
    eval_acc, eval_auc, eval_JSD = evaluate_cls(classifier, val_dataloader, device,
                                                final_eval=True,
                                                calibration_data=test_dataloader)
    print("Final result of classifier test (AUC / JSD):")
    print("{:.4f} / {:.4f}".format(eval_auc, eval_JSD))

In [9]:
def train_and_evaluate_cls(model, data_train, data_test, optim, cls_n_epochs, device):
    """ train the model and evaluate along the way"""
    best_eval_acc = float('-inf')
    best_epoch = -1
    try:
        for i in range(cls_n_epochs):
            train_cls(model, data_train, optim, i, device)
            with torch.no_grad():
                eval_acc, _, _ = evaluate_cls(model, data_test, device)
            if eval_acc > best_eval_acc:
                best_eval_acc = eval_acc
                best_epoch = i+1
                filename = 'CNN_1_photons.pt'
                torch.save({'model_state_dict':model.state_dict()},
                           os.path.join(output_dir, filename))
            if eval_acc == 1.:
                break
    except KeyboardInterrupt:
        # training can be cut short with ctrl+c, for example if overfitting between train/test set
        # is clearly visible
        pass

In [10]:
def train_cls(model, data_train, optim, epoch, device):
    """ train one step """
    model.train()
    for i, data_batch in enumerate(data_train):
        
        data_batch = data_batch[0].to(device)
        
        
        #input_vector, target_vector = torch.split(data_batch, [data_batch.size()[1]-1, 1], dim=1)
        input_vector, cond_vector, target_vector = data_batch[:, :-2], data_batch[:,-2], data_batch[:, -1]
        
        output_vector = model(input_vector, cond_vector)
        
        criterion = torch.nn.BCEWithLogitsLoss()
        loss = criterion(output_vector, target_vector.unsqueeze(1))

        optim.zero_grad()
        loss.backward()
        optim.step()

        if i % (len(data_train)//2) == 0:
            print('Epoch {:3d}, step {:4d} / {}; loss {:.4f}'.format(
                epoch+1, i, len(data_train), loss.item()))
        # PREDICTIONS
        pred = torch.round(torch.sigmoid(output_vector.detach()))
        target = torch.round(target_vector.detach())
        if i == 0:
            res_true = target
            res_pred = pred
        else:
            res_true = torch.cat((res_true, target), 0)
            res_pred = torch.cat((res_pred, pred), 0)

    try:
        print("Accuracy on training set is",
          accuracy_score(res_true.cpu(), np.clip(res_pred.cpu(), 0., 1.0)))
    except:
        print("Nans")

In [11]:
def evaluate_cls(model, data_test, device, final_eval=False, calibration_data=None):
    """ evaluate on test set """
    model.eval()
    for j, data_batch in enumerate(data_test):
        
        data_batch = data_batch[0].to(device)
        input_vector, cond_vector, target_vector = data_batch[:, :-2], data_batch[:, -2], data_batch[:, -1]
        
        
        output_vector = model(input_vector, cond_vector)
        pred = output_vector.reshape(-1)
        target = target_vector.double()
        if j == 0:
            result_true = target
            result_pred = pred
        else:
            result_true = torch.cat((result_true, target), 0)
            result_pred = torch.cat((result_pred, pred), 0)
    BCE = torch.nn.BCEWithLogitsLoss()(result_pred, result_true)
    result_pred = torch.round(torch.sigmoid(result_pred)).cpu().numpy()
    result_true = result_true.cpu().numpy()
    result_pred = np.clip(np.round(result_pred), 0., 1.0)
    #print(np.amin(result_pred), np.amax(result_pred), np.sum(np.isnan(result_pred)))
    try:
        eval_acc = accuracy_score(result_true, result_pred)
    except:
        print("Nans")
        result_pred[np.isnan(result_pred)] = 0.5
        eval_acc = accuracy_score(result_true, result_pred)
    print("Accuracy on test set is", eval_acc)
    eval_auc = roc_auc_score(result_true, result_pred)
    print("AUC on test set is", eval_auc)
    JSD = - BCE + np.log(2.)
    print("BCE loss of test set is {:.4f}, JSD of the two dists is {:.4f}".format(BCE,
                                                                                  JSD/np.log(2.)))
    if final_eval:
        prob_true, prob_pred = calibration_curve(result_true, result_pred, n_bins=10)
        print("unrescaled calibration curve:", prob_true, prob_pred)
        calibrator = calibrate_classifier(model, calibration_data, device)
        rescaled_pred = calibrator.predict(result_pred)
        eval_acc = accuracy_score(result_true, np.clip(np.round(rescaled_pred), 0., 1.0))
        print("Rescaled accuracy is", eval_acc)
        eval_auc = roc_auc_score(result_true, rescaled_pred)
        print("rescaled AUC of dataset is", eval_auc)
        prob_true, prob_pred = calibration_curve(result_true, rescaled_pred, n_bins=10)
        print("rescaled calibration curve:", prob_true, prob_pred)
        # calibration was done after sigmoid, therefore only BCELoss() needed here:
        BCE = torch.nn.BCELoss()(torch.tensor(rescaled_pred), torch.tensor(result_true))
        JSD = - BCE.cpu().numpy() + np.log(2.)
        otp_str = "rescaled BCE loss of test set is {:.4f}, "+\
            "rescaled JSD of the two dists is {:.4f}"
        print(otp_str.format(BCE, JSD/np.log(2.)))
    return eval_acc, eval_auc, JSD/np.log(2.)

In [12]:
def calibrate_classifier(model, calibration_data, device):
    
    """ reads in calibration data and performs a calibration with isotonic regression"""
    model.eval()
    assert calibration_data is not None, ("Need calibration data for calibration!")
    for j, data_batch in enumerate(calibration_data):
        
        
        data_batch = data_batch[0].to(device)
        input_vector, target_vector = data_batch[:, :-1], data_batch[:, -1]
        output_vector = model(input_vector)
        pred = torch.sigmoid(output_vector).reshape(-1)
        target = target_vector.to(torch.float64)
        if j == 0:
            result_true = target
            result_pred = pred
        else:
            result_true = torch.cat((result_true, target), 0)
            result_pred = torch.cat((result_pred, pred), 0)
    result_true = result_true.cpu().numpy()
    result_pred = result_pred.cpu().numpy()
    iso_reg = IsotonicRegression(out_of_bounds='clip', y_min=1e-6, y_max=1.-1e-6).fit(result_pred,
                                                                                      result_true)
    return iso_reg

In [13]:
def load_classifier(constructed_model, output_dir, filename, device):
    """ loads a saved model """
    checkpoint = torch.load(os.path.join(output_dir, filename),
                            map_location=device)
    constructed_model.load_state_dict(checkpoint['model_state_dict'])
    constructed_model.to(device)
    constructed_model.eval()
    print('classifier loaded successfully')
    return constructed_model