In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import yaml
import pandas as pd
import numpy as np
from functools import partial
import multiprocessing
import tensorflow as tf
from pathlib import Path
from time import strftime
from shutil import rmtree
import matplotlib.pyplot as plt
from tqdm import tqdm
import matplotlib
matplotlib.use("TKAgg", force=True)
%matplotlib inline
#pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

In [2]:
from rtapipe.lib.utils.misc import dotdict
from rtapipe.lib.dataset.data_manager import DataManager
from rtapipe.lib.datasource.Photometry3 import OnlinePhotometry, SimulationParams
from rtapipe.lib.plotting.plotting import plot_sequences
from rtapipe.lib.models.anomaly_detector_builder import AnomalyDetectorBuilder
from rtapipe.lib.evaluation.pval import get_pval_table, get_threshold_for_sigma
from rtapipe.lib.standardanalysis.li_ma import LiMa
from rtapipe.lib.evaluation.pval import get_pval_table, get_threshold_for_sigma, get_sigma_from_pvalue, get_sigma_for_ts_array, get_sigma_from_ts

# Loading the models

In [3]:
def load_model(model_id):
    with open("./trained_models.yaml", "r") as f:
        try:
            configs = yaml.safe_load(f)
        except yaml.YAMLError as exc:
            print(exc)
    model_config = [c for c in configs["models"] if c["id"] == model_id]
    model_config = dotdict(model_config.pop())
    model_config.ad = AnomalyDetectorBuilder.getAnomalyDetector(name=model_config.name, timesteps=model_config.timesteps, nfeatures=model_config.nfeatures, load_model="True", training_epoch_dir=model_config.path, training=False)
    model_config.pvalue_table = get_pval_table(model_config.pval_path) 
    return model_config

In [4]:
#model_config_cnn = load_model(0)
#model_config_cnn

In [5]:
model_config_rnn = load_model(1)
model_config_rnn

AnomalyDetector_rnn_l2_u32 - input shape: (5,3)


{'id': 1,
 'patience': 5,
 'integrationtime': 5,
 'timesteps': 5,
 'nfeatures': 3,
 'scaler': 'minmax',
 'name': 'AnomalyDetector_rnn_l2_u32',
 'path': '/data01/homes/baroncelli/phd/rtapipe/notebooks/run_20221116-101109_mr_patience_5/model_AnomalyDetector_rnn_l2_u32_dataset_train_itime_5_a_tsl_5_nbins_3_tsl_3600/epochs/epoch_10',
 'epoch': 10,
 'pval_path': '/data01/homes/baroncelli/phd/rtapipe/notebooks/run_20221116-101109_mr_patience_5/model_AnomalyDetector_rnn_l2_u32_dataset_train_itime_5_a_tsl_5_nbins_3_tsl_3600/epochs/epoch_10/pvalues/pval_20221121-132236/pvalue_bins_100_0.numpy.txt',
 'scaler_path': '/data01/homes/baroncelli/phd/rtapipe/notebooks/run_20221116-101109_mr_patience_5/model_AnomalyDetector_rnn_l2_u32_dataset_train_itime_5_a_tsl_5_nbins_3_tsl_3600/fitted_scaler.pickle',
 'ad': <rtapipe.lib.models.anomaly_detector_rnn.AnomalyDetector_rnn_l2_u32 at 0x2af318d23f10>,
 'pvalue_table':     threshold  threshold_err        pvalue    pvalue_err  sigma
 0    0.000227       0.000

# Loading the test dataset

In [6]:
output_dir = "./logs/check_recostructions_for_templates_out"

In [7]:
features_names = ["EB_0.04-0.117","EB_2-0.117-0.342","EB_0.342-1"]

In [8]:
data_manager = DataManager(output_dir)

In [9]:
dataset_id="test_itime_5_c_tsl_5_nbins_3"

In [10]:
dataset_folder = "/data01/homes/baroncelli/phd/rtapipe/scripts/ml/dataset_generation/test/itime_5_h/fits_data"
fits_files = DataManager.load_fits_data(dataset_folder, limit=500)
test_set_size = len(fits_files)

Loaded 419 files


In [11]:
sim_params = SimulationParams(runid=None, onset=250, emin=0.04, emax=1, tmin=0, tobs=500, offset=0.5, irf="North_z40_5h_LST", roi=2.5, caldb="prod5-v0.1", simtype="grb")
multiple_templates = True
add_target_region = True
integration_time = 5
number_of_energy_bins = 3
tsl = 100
threads = 30
normalize = True
#data_manager.transform_to_timeseries(fits_files, sim_params, add_target_region, integration_time=integration_time, number_of_energy_bins=number_of_energy_bins, tsl=tsl, normalize=normalize, threads=threads, multiple_templates=multiple_templates)
data_manager.load_saved_data(5, 100)

Loading cached data from run0002_ID000044_it_5_tsl_100.npy
Loading cached data from run0005_ID000225_it_5_tsl_100.npy
Loading cached data from run0009_ID000191_it_5_tsl_100.npy
Loading cached data from run0011_ID000139_it_5_tsl_100.npy
Loading cached data from run0013_ID000321_it_5_tsl_100.npy
Loading cached data from run0016_ID000340_it_5_tsl_100.npy
Loading cached data from run0017_ID000132_it_5_tsl_100.npy
Loading cached data from run0017_ID000261_it_5_tsl_100.npy
Loading cached data from run0017_ID000302_it_5_tsl_100.npy
Loading cached data from run0019_ID000070_it_5_tsl_100.npy
Loading cached data from run0019_ID000268_it_5_tsl_100.npy
Loading cached data from run0019_ID000338_it_5_tsl_100.npy
Loading cached data from run0021_ID000373_it_5_tsl_100.npy
Loading cached data from run0023_ID000267_it_5_tsl_100.npy
Loading cached data from run0025_ID000471_it_5_tsl_100.npy
Loading cached data from run0026_ID000127_it_5_tsl_100.npy
Loading cached data from run0027_ID000288_it_5_tsl_100.n

Loading cached data from run0285_ID000014_it_5_tsl_100.npy
Loading cached data from run0286_ID000021_it_5_tsl_100.npy
Loading cached data from run0286_ID000099_it_5_tsl_100.npy
Loading cached data from run0286_ID000206_it_5_tsl_100.npy
Loading cached data from run0290_ID000056_it_5_tsl_100.npy
Loading cached data from run0291_ID000001_it_5_tsl_100.npy
Loading cached data from run0291_ID000193_it_5_tsl_100.npy
Loading cached data from run0292_ID000156_it_5_tsl_100.npy
Loading cached data from run0294_ID000203_it_5_tsl_100.npy
Loading cached data from run0298_ID000339_it_5_tsl_100.npy
Loading cached data from run0301_ID000376_it_5_tsl_100.npy
Loading cached data from run0302_ID000382_it_5_tsl_100.npy
Loading cached data from run0309_ID000061_it_5_tsl_100.npy
Loading cached data from run0314_ID000223_it_5_tsl_100.npy
Loading cached data from run0315_ID000129_it_5_tsl_100.npy
Loading cached data from run0317_ID000219_it_5_tsl_100.npy
Loading cached data from run0318_ID000143_it_5_tsl_100.n

Loading cached data from run0668_ID000294_it_5_tsl_100.npy
Loading cached data from run0672_ID000265_it_5_tsl_100.npy
Loading cached data from run0677_ID000457_it_5_tsl_100.npy
Loading cached data from run0684_ID000011_it_5_tsl_100.npy
Loading cached data from run0684_ID000346_it_5_tsl_100.npy
Loading cached data from run0687_ID000058_it_5_tsl_100.npy
Loading cached data from run0688_ID000044_it_5_tsl_100.npy
Loading cached data from run0688_ID000432_it_5_tsl_100.npy
Loading cached data from run0689_ID000454_it_5_tsl_100.npy
Loading cached data from run0689_ID000513_it_5_tsl_100.npy
Loading cached data from run0697_ID000167_it_5_tsl_100.npy
Loading cached data from run0700_ID000206_it_5_tsl_100.npy
Loading cached data from run0709_ID000364_it_5_tsl_100.npy
Loading cached data from run0710_ID000198_it_5_tsl_100.npy
Loading cached data from run0711_ID000203_it_5_tsl_100.npy
Loading cached data from run0712_ID000163_it_5_tsl_100.npy
Loading cached data from run0719_ID000018_it_5_tsl_100.n

## Plot some samples

In [None]:
#for template in list(data_manager.data.keys())[0:5]:
#    data_manager.plot_timeseries(template, data_manager.data[template], 1, sim_params, output_dir, labels=features_names)

# TESTING

In [None]:
data_manager.load_scaler(model_config_rnn.scaler_path) # model_config_rnn.scaler_path or model_config_cnn.scaler_path

In [None]:
test_all_x, test_all_y = data_manager.get_test_set_all_templates(verbose=False, onset=250, integration_time=integration_time, sub_window_size=5, stride=1)

# RNN

## Standard metrics - 3 sigma threshold

In [None]:
model_config_rnn.pvalue_table

In [None]:
SIGMA_THRESHOLD = 3

In [None]:
sigma_threshold = get_threshold_for_sigma(model_config_rnn.pvalue_table, SIGMA_THRESHOLD)
model_config_rnn.ad.threshold = sigma_threshold
print(f"Threshold: {model_config_rnn.ad.threshold} corresponding to {get_sigma_from_ts(model_config_rnn.pvalue_table, model_config_rnn.ad.threshold)} sigma")

In [None]:
#print(f"************** Evaluating {model_config_rnn.name} patience={model_config_rnn.patience} **************")
#metrics = model_config_rnn.ad.evaluate(test_all_x, test_all_y)
#for k,v in metrics.items():
#    print(k,v)
#print("detection_delay:",model_config_rnn.ad.detection_delay(test_all_y, model_config_rnn.ad.predict(test_all_x), test_set_size, model_config_rnn.timesteps)#)

## Standard metrics - 5 sigma threshold

In [None]:
from rtapipe.lib.plotting.PlotConfig import PlotConfig

pc = PlotConfig()


In [None]:
(pc.fig_size[0]*2,pc.fig_size[1]*2)

In [None]:
SIGMA_THRESHOLD = 5

In [None]:
sigma_threshold = get_threshold_for_sigma(model_config_rnn.pvalue_table, SIGMA_THRESHOLD)
model_config_rnn.ad.threshold = sigma_threshold
print(f"Threshold: {model_config_rnn.ad.threshold} corresponding to {get_sigma_from_ts(model_config_rnn.pvalue_table, model_config_rnn.ad.threshold)} sigma")

In [None]:
%matplotlib inline

def plot_predictions2(samples, samplesLabels, c_threshold, recostructions, mse_per_sample, mse_per_sample_features, features_names=[], integration_time=5, epoch="", max_plots=5, showFig=False, saveFig=True, outputDir="./", figName="predictions.png"):

    pc = PlotConfig()

    total_samples = samples.shape[0]
    
    max_samples = 5
    n_features = samples.shape[2]
    if len(features_names) != n_features:
        features_names = [f"Feature {i}" for i in range(n_features)]

    num_plots = total_samples // max_samples

    mask = (mse_per_sample > c_threshold)

    start = 0
    for p in tqdm(range(num_plots)):

        if p == max_plots:
            print("Max plots reached")
            break



                
        annotations = [f"{i*integration_time}-{i*integration_time+samples.shape[1]}" for i in range(0,96)]
        xticks = [i for i in range(0,96)]

        current_samples = samples[start:start+max_samples, :, :]
        current_samplesLabels = samplesLabels[start:start+max_samples]
        current_samples_annotations = annotations[start:start+max_samples]
        current_samples_xticks = xticks[start:start+max_samples]
        current_recostructions = recostructions[start:start+max_samples, :, :]
        current_mask = mask[start:start+max_samples]
        current_mse_per_sample_features = mse_per_sample_features[start:start+max_samples]
        current_mse_per_sample = mse_per_sample[start:start+max_samples]

        #print("current_samples:",current_samples)
        #print("current_mse_per_sample_features: ", current_mse_per_sample_features)
        start += max_samples

        ymax, ymin = 1, 0
        
        #print(f"Plot {p}. \nNumber of predictions: {len(current_samples)}. \nSample shape: {current_samples.shape} \n Number of features: {n_features}")

        real_labels = ["grb" if lab==1 else "bkg" for lab in current_samplesLabels ]
        pred_labels = ["grb" if lab==1 else "bkg" for lab in current_mask          ]

        fig, ax = plt.subplots(n_features, max_samples, figsize=(pc.fig_size[0]*2,pc.fig_size[1]*2))
        fig.suptitle(f"5σ threshold={round(c_threshold, 3)}")
        #fig.suptitle(f"Predictions (using threshold={round(c_threshold, 3)})")
        fig.supylabel('Energy bins (TeV)')

      

        # For each feature..
        for f in range(n_features):

            for i in range(max_samples):

                # Get a sample and its recostruction
                sample = current_samples[i][:,f]
                recoSample = current_recostructions[i][:,f]

                # And plot them                
                ax[f, i].plot(recoSample, color='red',  marker='o', markersize=6, linestyle='dashed', label="reconstruction")
                ax[f, i].plot(sample,     color="blue", marker='o', markersize=6, linestyle='dashed', label="ground truth")
                ax[f, i].set_ylim(ymin, ymax)

                
                ax[f, i].set_xticks(current_samples_xticks, current_samples_annotations, fontsize=10)

                if real_labels[i] != pred_labels[i]:
                    ax[f, i].set_facecolor('#e6e6e6')

                # Only the first column will show the Y labels
                if i == 0:
                    ax[f, i].set_ylabel(features_names[f].split("EB_")[1].split(" TeV")[0])


                ax[f, i].set_xlabel("mse={:.6f}".format(current_mse_per_sample_features[i, f]))

                # Only the first row will show the TN/FP/FN/TP labels and the averaged mse
                if f == 0:
                    title = "Sample {}\nW. avg mse={:.3f}\nClassification=".format(start+i-max_samples, current_mse_per_sample[i])
                    if real_labels[i] == "grb" and real_labels[i] == pred_labels[i]:
                        title += "TP"
                    elif real_labels[i] == "grb" and real_labels[i] != pred_labels[i]:
                        title += "FN"
                    elif real_labels[i] == "bkg" and real_labels[i] == pred_labels[i]:  
                        title += "TN"
                    elif real_labels[i] == "bkg" and real_labels[i] != pred_labels[i]:
                        title += "FP"
                    ax[f, i].set_title(title)

        handles, labels = ax[f, i].get_legend_handles_labels()
        fig.legend(handles, labels, loc='upper left')

        plt.tight_layout()

        if showFig:
            plt.show()

        if saveFig:
            Path(outputDir).mkdir(parents=True, exist_ok=True)
            outputPath = Path(outputDir).joinpath(f"{figName}_epoch_{epoch}_plot_{p}.png")
            fig.savefig(outputPath, dpi=200)

        plt.close()


In [None]:
outputDir = Path(output_dir).joinpath("predictions")
outputDir.mkdir(exist_ok=True)

In [None]:
metrics = model_config_rnn.ad.evaluate(test_all_x, test_all_y)

In [None]:
showFig = True
saveFig = True 
plot_predictions2(test_all_x, test_all_y, model_config_rnn.ad.threshold, model_config_rnn.ad.reconstruct(test_all_x), model_config_rnn.ad.loss_f.mse_per_sample.numpy(), model_config_rnn.ad.loss_f.mse_per_sample_features.numpy(), features_names=features_names, epoch=model_config_rnn.epoch, max_plots=999, showFig=showFig, saveFig=saveFig, outputDir=outputDir, figName=f"{model_config_rnn.name}_patience_{model_config_rnn.patience}")