In [1]:

# Imports
import matplotlib.pyplot as plt
import numpy as np
from toolbox import get_pred_energy_diff_data, common_dir, calculate_percentage_interval
import sys
import argparse
import os
import time
import pickle
from NuRadioReco.utilities import units
from scipy import stats
from radiotools import stats as rtSTATS
from mpl_toolkits.mplot3d import Axes3D
from itertools import product, combinations
from radiotools import plthelpers as php
from tensorflow import keras
from radiotools import helper as hp
import datasets
# -------

file_ids_to_load = [1, 2, 3, 4, 5]

for dataset_to_use in ["ALVAREZ-HAD", "ARZ-HAD", "ARZ-EM"]:

    if dataset_to_use == "ALVAREZ-HAD":
        dataset_title = "Alvarez2009 (had.)"
        dataset_name = "ALVAREZ"
        dataset_em = False
        dataset_noise = True
    if dataset_to_use == "ARZ-HAD":
        dataset_title = "ARZ2020 (had.)"
        dataset_name = "ARZ"
        dataset_em = False
        dataset_noise = True
    if dataset_to_use == "ARZ-EM":
        dataset_title = "ARZ2020 (had. + EM)"
        dataset_name = "ARZ"
        dataset_em = True
        dataset_noise = True

    dataset = datasets.Dataset(dataset_name, dataset_em, dataset_noise)

    # Loading data and label files
    def load_file(i_file, norm=1e-6):
        # Load 500 MHz filter
        filt = np.load(f"{common_dir()}/bandpass_filters/500MHz_filter.npy")

        t0 = time.time()
        print(f"loading file {i_file}", flush=True)
        # data = np.load(os.path.join(dataset.datapath, f"{dataset.data_filename}{i_file:04d}.npy"), allow_pickle=True)
        # data = np.fft.irfft(np.fft.rfft(data, axis=-1) * filt, axis=-1)
        # data = data[:, :, :, np.newaxis]

        labels_tmp = np.load(os.path.join(dataset.datapath, f"{dataset.label_filename}{i_file:04d}.npy"), allow_pickle=True)
        print(f"finished loading file {i_file} in {time.time() - t0}s")
        
        nu_energy_data = np.array(labels_tmp.item()["nu_energy"])

        # # check for nans and remove them
        # idx = ~(np.isnan(data))
        # idx = np.all(idx, axis=1)
        # idx = np.all(idx, axis=1)
        # idx = np.all(idx, axis=1)
        # data = data[idx, :, :, :]
        # nu_energy_data = nu_energy_data[idx]
        # data /= norm

        

        # return data, nu_energy_data
        return None, nu_energy_data


    # Load test file data
        # Load first file
    _, nu_energy = load_file(file_ids_to_load[0])

        # Then load rest of files
    if len(file_ids_to_load) > 1:
        for test_file_id in file_ids_to_load:
            if test_file_id != file_ids_to_load[0]:
                _, nu_energy_tmp = load_file(test_file_id)

                # data = np.concatenate((data, data_tmp))
                nu_energy = np.concatenate((nu_energy, nu_energy_tmp))


    # Create figure
    fig = plt.figure()

    # Calculate binned statistics
    ax = fig.add_subplot(1, 1, 1)
    nu_energy_bins = np.logspace(np.log10(1e16),np.log10(10**19), 30)
    nu_energy_bins_with_one_extra = np.append(np.logspace(np.log10(1e16),np.log10(10**19), 30), [1e20])
    binned_resolution_nu_energy_count = stats.binned_statistic(nu_energy, nu_energy, bins = nu_energy_bins_with_one_extra, statistic = "count")[0]
    
    ax.plot(nu_energy_bins, binned_resolution_nu_energy_count, "*")
    # ax.set_ylim(0, 0.4)
    ax.set_xlabel(r"$\nu$ energy (eV)")
    ax.set_ylabel("Events")
    ax.set_xscale('log')

    plt.title(f"Count of events inside neutrino energy bins\nfor dataset {dataset_title}")
    plt.tight_layout()
    plt.savefig(f"{dataset_to_use}_counts_plot_NU_ENERGY_from1e16.png", dpi=300)
    plt.close("all")

    # ___________________________________


    # # Calculate the weight data
    # max_count = max(binned_resolution_nu_energy_count)
    # weight_vector = [max_count / count_in_bin if count_in_bin != 0 else 1 for count_in_bin in binned_resolution_nu_energy_count ]

    # energy_vector_log10 = np.log10(nu_energy_bins)
    # # Plot the weights so that we can se eit
    # plt.plot(energy_vector_log10, weight_vector)
    # plt.xlabel("log10 shower energy")
    # plt.ylabel("weight")
    # plt.yscale("log")
    # plt.title(f"Weight as a function of shower energy for {dataset_title}")
    # plt.tight_layout()
    # plt.savefig(f"{dataset_to_use}_weights_plot.png", dpi=300)

    # energy_weight_count_tuple = [(energy_vector_log10[i], weight_vector[i], binned_resolution_nu_energy_count[i]) for i in range(len(weight_vector))]

    # with open(f"{dataset_to_use}_weights.npy", "wb") as f:  
    #     np.save(f, energy_weight_count_tuple)

loading file 1
finished loading file 1 in 0.5706546306610107s
loading file 2
finished loading file 2 in 0.563478946685791s
loading file 3
finished loading file 3 in 0.5648424625396729s
loading file 4
finished loading file 4 in 0.5572545528411865s
loading file 5
finished loading file 5 in 0.5971784591674805s
loading file 1
finished loading file 1 in 0.5113587379455566s
loading file 2
finished loading file 2 in 0.515911340713501s
loading file 3
finished loading file 3 in 0.5329535007476807s
loading file 4
finished loading file 4 in 0.5111770629882812s
loading file 5
finished loading file 5 in 0.5447676181793213s
loading file 1
finished loading file 1 in 0.5347864627838135s
loading file 2
finished loading file 2 in 0.5613644123077393s
loading file 3
finished loading file 3 in 0.5367591381072998s
loading file 4
finished loading file 4 in 0.5348429679870605s
loading file 5
finished loading file 5 in 0.5596091747283936s
