In [1]:
import numpy as np
import matplotlib.pyplot as plt
import h5py
import torch
import jetnet

In [2]:
from utils.plotting import HighLevelFeatures as HLF

def extract_shower_and_energy(given_file, which):
    """ reads .hdf5 file and returns samples and their energy """
    print("Extracting showers from {} file ...".format(which))
    if which == 0.:
        shower = given_file['showers'][:]
        energy = given_file['incident_energies'][:]
    else:
        shower = given_file['showers'][:]
        energy = given_file['incidence energy'][:]
    print("Extracting showers from {} file: DONE.\n".format(which))
    return shower, energy

def prepare_high_data_for_classifier(test, e_inc, hlf_class, label):
    """ takes hdf5_file, extracts high-level features, appends label, returns array """
    # voxel, E_inc = extract_shower_and_energy(hdf5_file, label)
    voxel, E_inc = test, e_inc
    E_tot = hlf_class.GetEtot()
    E_layer = []
    for layer_id in hlf_class.GetElayers():
        E_layer.append(hlf_class.GetElayers()[layer_id].reshape(-1, 1))
    EC_etas = []
    EC_phis = []
    Width_etas = []
    Width_phis = []
    for layer_id in hlf_class.layersBinnedInAlpha:
        EC_etas.append(hlf_class.GetECEtas()[layer_id].reshape(-1, 1))
        EC_phis.append(hlf_class.GetECPhis()[layer_id].reshape(-1, 1))
        Width_etas.append(hlf_class.GetWidthEtas()[layer_id].reshape(-1, 1))
        Width_phis.append(hlf_class.GetWidthPhis()[layer_id].reshape(-1, 1))
    E_layer = np.concatenate(E_layer, axis=1)
    EC_etas = np.concatenate(EC_etas, axis=1)
    EC_phis = np.concatenate(EC_phis, axis=1)
    Width_etas = np.concatenate(Width_etas, axis=1)
    Width_phis = np.concatenate(Width_phis, axis=1)
    ret = np.concatenate([np.log10(E_inc), np.log10(E_layer+1e-8), EC_etas/1e2, EC_phis/1e2,
                          Width_etas/1e2, Width_phis/1e2, label*np.ones_like(E_inc)], axis=1)
    return ret

In [3]:
def check_and_replace_nans_infs(data):
    if np.isnan(data).any() or np.isinf(data).any():
        print("Data contains NaNs or Infs. Handling them...")
        # Replace NaNs and Infs with zeros (or you can choose a different strategy)
        data = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0)
    return data

def get_fpd_kpd_metrics(test_data, gen_data, syn_bool, hlf, ref_hlf):
    print("TESTING HELLO")
    if syn_bool == True:
        data_showers = (np.array(test_data['showers']))
        energy = (np.array(test_data['incident_energies']))
        gen_showers = (np.array(gen_data['showers'], dtype=float))
        hlf.Einc = energy
    else:
        data_showers = test_data
        gen_showers = gen_data
    hlf.CalculateFeatures(data_showers)
    ref_hlf.CalculateFeatures(gen_showers)
    hlf_test_data = prepare_high_data_for_classifier(test_data, hlf.Einc, hlf, 0.)[:, :-1]
    hlf_gen_data = prepare_high_data_for_classifier(gen_data, hlf.Einc, ref_hlf, 1.)[:, :-1]
    hlf_test_data = check_and_replace_nans_infs(hlf_test_data)
    hlf_gen_data = check_and_replace_nans_infs(hlf_gen_data)
    fpd_val, fpd_err = jetnet.evaluation.fpd(hlf_test_data, hlf_gen_data)
    kpd_val, kpd_err = jetnet.evaluation.kpd(hlf_test_data, hlf_gen_data)
    
    result_str = (
        f"FPD (x10^3): {fpd_val*1e3:.4f} ± {fpd_err*1e3:.4f}\n" 
        f"KPD (x10^3): {kpd_val*1e3:.4f} ± {kpd_err*1e3:.4f}"
    )
    
    print(result_str)
    return fpd_val, fpd_err, kpd_val, kpd_err

In [4]:
def get_fpd_kpd_metrics_(test_data, gen_data):
    data_showers = (np.array(test_data['showers']))
    energy = (np.array(test_data['incident_energies']))
    gen_showers = (np.array(gen_data['showers'], dtype=float))
    hlf.CalculateFeatures(data_showers)
    ref_hlf.CalculateFeatures(gen_showers)
    hlf.Einc = energy
    hlf_test_data = prepare_high_data_for_classifier(test_data, hlf.Einc, hlf, 0.)[:, :-1]
    hlf_gen_data = prepare_high_data_for_classifier(gen_data, hlf.Einc, ref_hlf, 1.)[:, :-1]
    # hlf_test_data = check_and_replace_nans_infs(hlf_test_data)
    # hlf_gen_data = check_and_replace_nans_infs(hlf_gen_data)
    fpd_val, fpd_err = jetnet.evaluation.fpd(hlf_test_data, hlf_gen_data)
    kpd_val, kpd_err = jetnet.evaluation.kpd(hlf_test_data, hlf_gen_data)
    result_str = (
        f"FPD (x10^3): {fpd_val*1e3:.4f} ± {fpd_err*1e3:.4f}\n"
        f"KPD (x10^3): {kpd_val*1e3:.4f} ± {kpd_err*1e3:.4f}"
    )
    print(result_str)
    return fpd_val, fpd_err, kpd_val, kpd_err



In [8]:
if __name__ == "__main__":
    # test_data = h5py.File('/fast_scratch_1/caloqvae/test_data/dataset_2_2.hdf5', 'r')
    # gen_data = h5py.File("/fast_scratch_1/caloqvae/syn_data/dataset2_synthetic_denim-smoke-166en130.hdf5", 'r')
    hlf = HLF.HighLevelFeatures('electron', filename='/raid/javier/Datasets/CaloVAE/data/atlas_dataset2and3/binning_dataset_2.xml', wandb=False)
    ref_hlf = HLF.HighLevelFeatures('electron', filename='/raid/javier/Datasets/CaloVAE/data/atlas_dataset2and3/binning_dataset_2.xml', wandb=False)

In [6]:
    norm = 0.008
    print(fpd_val - norm)
    print(kpd_val)

NameError: name 'fpd_val' is not defined

In [6]:
    modelname = 'mild-salad-468'
    modelname = 'morning-bush-469'
    # modelname = 'dutiful-gorge-467'
    modelname = 'robust-tree-339'
    modelname='fluent-dawn-488'
    # modelname='dry-galaxy-489'
    modelname='skilled-night-490'
    modelname='giddy-violet-575'
    modelname= 'generous-water-216'
    fpath = f'/raid/javier/Datasets/CaloVAE/data/synData/dataset2_synthetic_{modelname}.hdf5'
    test_data = h5py.File('/raid/javier/Datasets/CaloVAE/data/atlas_dataset2and3/dataset_2_2.hdf5', 'r')
    gen_data = h5py.File(fpath, 'r')
    # get_fpd_kpd_metrics(test_data, gen_data)

In [9]:
    get_fpd_kpd_metrics_(test_data, gen_data)

FPD (x10^3): 445.4664 ± 2.4697
KPD (x10^3): 0.7796 ± 0.0527


(0.4454663827130533,
 0.002469694347808809,
 0.0007795959320133772,
 5.272728554136906e-05)

In [30]:
    get_fpd_kpd_metrics_(test_data, gen_data)

FPD (x10^3): 494.3088 ± 2.5972
KPD (x10^3): 0.8791 ± 0.0732


(0.4943088182279927,
 0.0025972268461431913,
 0.0008791455794538994,
 7.324259237435177e-05)

In [28]:
    #zephyr
    get_fpd_kpd_metrics_(test_data, gen_data)

FPD (x10^3): 479.7403 ± 2.5953
KPD (x10^3): 0.7910 ± 0.0470


(0.47974033657200993,
 0.002595293143427049,
 0.0007910003374627106,
 4.70376286785968e-05)

In [32]:
    get_fpd_kpd_metrics_(test_data, gen_data)

FPD (x10^3): 489.5326 ± 3.9541
KPD (x10^3): 1.0659 ± 0.0854


(0.4895325545689304,
 0.003954118311942111,
 0.001065932720673679,
 8.5405953064085e-05)

In [36]:
    get_fpd_kpd_metrics_(test_data, gen_data)

FPD (x10^3): 489.2272 ± 2.1890
KPD (x10^3): 1.0346 ± 0.1822


(0.48922716378384185,
 0.002188991523305713,
 0.0010346138544368166,
 0.00018216326708355655)

In [38]:
    get_fpd_kpd_metrics_(test_data, gen_data)

FPD (x10^3): 480.9062 ± 2.2294
KPD (x10^3): 0.9094 ± 0.0582


(0.48090616330299885,
 0.0022293699485623733,
 0.0009094030224379157,
 5.818459916231782e-05)

In [6]:
##################

In [7]:
import pickle

In [18]:
with open('/home/javier/Projects/CaloQVAE/figs/skilled-night-490/xtarget_samples.pickle', 'rb') as handle:
    test_data_arr = pickle.load(handle)
    
with open('/home/javier/Projects/CaloQVAE/figs/skilled-night-490/xrecon_samples.pickle', 'rb') as handle:
    gen_data_recon_arr = pickle.load(handle)
    
with open('/home/javier/Projects/CaloQVAE/figs/skilled-night-490/xgen_samples.pickle', 'rb') as handle:
    gen_data_arr = pickle.load(handle)
    
with open('/home/javier/Projects/CaloQVAE/figs/skilled-night-490/xgen_samples_qpu.pickle', 'rb') as handle:
    gen_data_qpu_arr = pickle.load(handle)
    
with open('/home/javier/Projects/CaloQVAE/figs/skilled-night-490/entarget_samples.pickle', 'rb') as handle:
    entarget_samples = pickle.load(handle)    
    

  return torch.load(io.BytesIO(b))


In [19]:

test_data = {'showers':test_data_arr, 'incident_energies':entarget_samples}
gen_data_recon = {'showers':gen_data_recon_arr, 'incident_energies':entarget_samples}
gen_data = {'showers':gen_data_arr, 'incident_energies':entarget_samples}
gen_data_qpu = {'showers':gen_data_qpu_arr, 'incident_energies':entarget_samples}

In [16]:
get_fpd_kpd_metrics_(test_data, test_data)

  fpd_val, fpd_err = jetnet.evaluation.fpd(hlf_test_data, hlf_gen_data)


FPD (x10^3): 0.3475 ± 0.2304
KPD (x10^3): -0.0308 ± 0.0493


(0.00034754066487014713,
 0.0002304036701224789,
 -3.08029726767689e-05,
 4.932830488774131e-05)

In [20]:
get_fpd_kpd_metrics_(test_data, gen_data_recon)
get_fpd_kpd_metrics_(test_data, gen_data)
get_fpd_kpd_metrics_(test_data, gen_data_qpu)

  fpd_val, fpd_err = jetnet.evaluation.fpd(hlf_test_data, hlf_gen_data)


FPD (x10^3): 364.8854 ± 2.5351
KPD (x10^3): 1.2111 ± 0.3002
FPD (x10^3): 538.2631 ± 2.4200
KPD (x10^3): 1.1997 ± 0.1776
FPD (x10^3): 523.6575 ± 1.7569
KPD (x10^3): 1.7247 ± 0.2225


(0.5236574615628374,
 0.0017568697326292401,
 0.0017246968855182399,
 0.00022247322298027405)