In [None]:
# Standard library imports
import json
import csv
import os
import sys
from pprint import pprint
import time
import gc 

# Third party imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch_geometric
from torch.nn.utils import clip_grad_norm_
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import roc_auc_score
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import matplotlib.pyplot as plt

# Local application imports
sys.path.append("/pbs/home/e/erodrigu/TesisPhDEzequielRodriguez/Code")
from UHECRs_gnn import(SD433UMDatasetHomogeneous,
                       GNNWithAttentionDiscriminator3Heads,
                       GNNWithAttentionDiscriminator3HeadsDualInput,
                       MaskNodes,
                       MaskMdCounters,
                       SilentPrunner,
                       MaskRandomNodes,
)

from my_utils.my_basic_utils import (
    create_bins,
    filter_dataframe,
)

from my_utils.my_style import MyStyle

# set PATHS
code_PATH = os.path.abspath(os.path.join(".."))
project_PATH = os.path.abspath(os.path.join(code_PATH, ".."))
data_PATH = os.path.join(project_PATH, "data")

In [None]:
def compute_weighted_histogram(df, target_col, score_col, weight_col, bins=100, density=False):
    # Extract weights for each subset based on the target value
    weights = df[weight_col]

    # Normalize the weights
    normalized_weights = weights / np.sum(weights)

    # Compute weighted histogram with normalized weights
    histogram, bin_edges = np.histogram(
        df[score_col], 
        bins=bins, 
        weights=normalized_weights, 
        density=density
    )

    # Compute Poisson errors for each bin
    def compute_poisson_errors(data, weights, bins):
        bin_indices = np.digitize(data, bins) - 1  # Get bin index for each data point
        bin_errors = np.array([np.sqrt(np.sum(weights[bin_indices == i] ** 2)) for i in range(len(bins) - 1)])
        return bin_errors

    errors = compute_poisson_errors(df[score_col], normalized_weights, bin_edges)

    # Compute bin centers for plotting
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

    return bin_centers, histogram, errors

def weighted_median(values, weights):
    # Convert inputs to lists if they are not already lists
    if not isinstance(values, list):
        values = list(values)
    if not isinstance(weights, list):
        weights = list(weights)

    # Ensure values and weights have the same length
    if len(values) != len(weights):
        raise ValueError("Values and weights must have the same length.")

    # Normalize weights
    total_weight = sum(weights)
    normalized_weights = [w / total_weight for w in weights]

    # Combine values and weights into tuples
    data = [(values[i], normalized_weights[i]) for i in range(len(values))]

    # Sort the data based on values
    data.sort(key=lambda x: x[0])

    # Calculate the total normalized weight (should be 1.0 after normalization)
    total_normalized_weight = sum(w for v, w in data)
    # Find the position of the weighted median
    target_weight = total_normalized_weight / 2.0
    cumulative_weight = 0.0
    for value, weight in data:
        cumulative_weight += weight
        if cumulative_weight >= target_weight:
            return value

### Version check

In [None]:
print(f"Torch version: {torch.__version__}")
print(f"Torch CUDA version: {torch.version.cuda}")
print(f"Cuda available: {torch.cuda.is_available()}")
print(f"Torch geometric version: {torch_geometric.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device is: {device}")

In [None]:
!nvidia-smi

### Dataset Index Loading

In [None]:
folder_path = "/sps/pauger/users/erodriguez/PhotonDiscrimination/"
dir_path = "/sps/pauger/users/erodriguez/PhotonDiscrimination/JSONfiles/"
index = pd.DataFrame()

# indexes
primaries = ["Proton", "Photon"]
energy_bins = ["16.5_17.0", "17.0_17.5"]
atms = ["01", "03", "08", "09"]
indexes = [
    f"index_hadron_rec_{x}_{y}_{z}.csv"
    for x in primaries
    for y in energy_bins
    for z in atms
]

# create the index by appending
for index_name in indexes:
    proton_index = pd.read_csv(folder_path + index_name, on_bad_lines="skip")
    photon_rec_index = pd.read_csv(
        folder_path + index_name.replace("hadron", "photon"), on_bad_lines="skip"
    )
    
    index_cols = ["filename", "atm_model", "shower_id", "use_id"]
    proton_index[index_cols] = proton_index[index_cols].astype('string')
    photon_rec_index[index_cols] = photon_rec_index[index_cols].astype('string')
    
    index_ = pd.merge(
        proton_index,
        photon_rec_index,
        on=index_cols,
        how="left",
    )
    index = pd.concat([index, index_], ignore_index=True)

index = index.drop_duplicates()
index = index.drop_duplicates(subset=["filename"])
# we won't train using iron
index["mass_group"] = index["filename"].str.split(pat="_", expand=True)[0]
index = index[(index["mass_group"] != "Iron") & (index["mass_group"] != "Helium")]
print(f"Events before quality cuts: {len(index)}")

### Quality Cuts and Binning

In [None]:
index = index.sample(frac=1)
index.loc[index["filename"].str.contains("Photon"), "isPhoton"] = 1
index.loc[index["filename"].str.contains("Proton"), "isPhoton"] = 0

index["sin2zenith"] = np.sin(index["zenithMC"]) ** 2

# photon efficiency from fit from simulations
#index["est_efficiency"] = (
#    15.4074
#    + 17.4996 * (np.log10(index["energyMC"]) - 17)
#    - 12.7485 * index["sin2zenith"]
#    - 20.7650 * index["sin2zenith"] ** 2
#    - 13.1239 * (np.log10(index["energyMC"]) - 17) * index["sin2zenith"]
#)
#index["est_efficiency"] = expit(index["est_efficiency"])

feature_filters = {
    "zenithMC": {"filter_type": "range", "max_cut": np.deg2rad(45)},
    "photon_energy": {"filter_type": "range", "min_cut": 1},
    #"est_efficiency": {"filter_type": "range", "min_cut": 0.9},
    "isT5": {"filter_type": "value", "value_to_keep": 1}
}
index = filter_dataframe(index, feature_filters)

index, e_bin_centers, e_bin_edges, e_labels = create_bins(
    index,
    lower_val=10**16.5,
    upper_val=10**17.5,
    num=6,
    unbinned_col="energyMC",
    bin_column_name="e_bin",
    bin_width="equal_logarithmic",
)

index, z_bin_centers, z_bin_edges, z_labels = create_bins(
    index,
    lower_val=0,
    upper_val=np.sin(np.deg2rad(45)) ** 2,
    num=4,
    unbinned_col="sin2zenith",
    bin_column_name="z_bin",
    bin_width="equal",
)

index = index.loc[~index["e_bin"].isnull()]

# corrupted or problematic ADSTs
exclude_list = [
"Photon_17.0_17.5_011102_11",
"Photon_17.0_17.5_080595_20"
]
index = index[~index['filename'].isin(exclude_list)]

print(f"Events after quality cuts: {len(index)}")

### Balanced Dataset Division

In [None]:
# Combine label and the two categorical variables for stratified sampling
index["categorical_balance"] = (
    index["isPhoton"].astype(str)
    + "_"
    + index["e_bin"].astype(str)
    + "_"
    + index["z_bin"].astype(str)
)

random_seed = 42
stratified_split = StratifiedShuffleSplit(
    n_splits=1, test_size=0.25, random_state=random_seed
)

for dev_index_, test_index_ in stratified_split.split(
    index, index["categorical_balance"]
):
    # Original Training set
    dev_index = index.iloc[dev_index_]

    # Testing set
    test_index = index.iloc[test_index_]

# Further split the original training set into train and validation sets
validation_size = 0.25  # Adjust as needed
split = StratifiedShuffleSplit(
    n_splits=1, test_size=validation_size, random_state=random_seed
)

for train_index_, validation_index_ in split.split(
    dev_index, dev_index["categorical_balance"]
):
    train_index = dev_index.iloc[train_index_]
    validation_index = dev_index.iloc[validation_index_]

# Print the size of each dataset
print("Train dataset size:", train_index.shape[0])
print("Validation dataset size:", validation_index.shape[0])
print("Test dataset size:", test_index.shape[0])

In [None]:
index.describe()

### Heavy background df

In [None]:
index_bg = pd.DataFrame()

# indexes
primaries = ["Helium", "Iron"]
energy_bins = ["16.5_17.0", "17.0_17.5"]
atms = ["01", "03", "08", "09"]
indexes = [
    f"index_hadron_rec_{x}_{y}_{z}.csv"
    for x in primaries
    for y in energy_bins
    for z in atms
]

# create the index by appending
for index_name in indexes:
    proton_index = pd.read_csv(folder_path + index_name, on_bad_lines="skip")
    photon_rec_index = pd.read_csv(
        folder_path + index_name.replace("hadron", "photon"), on_bad_lines="skip"
    )
    index_bg_ = pd.merge(
        proton_index,
        photon_rec_index,
        on=["filename", "atm_model", "shower_id", "use_id"],
        how="left",
    )
    index_bg = pd.concat([index_bg, index_bg_], ignore_index=True)

index_bg = index_bg.drop_duplicates()
index_bg = index_bg.drop_duplicates(subset=["filename"])
# we won't train using iron
index_bg["mass_group"] = index_bg["filename"].str.split(pat="_", expand=True)[0]

feature_filters = {
    "zenithMC": {"filter_type": "range", "max_cut": np.deg2rad(45)},
    "photon_energy": {"filter_type": "range", "min_cut": 1},
    #"est_efficiency": {"filter_type": "range", "min_cut": 0.9},
    "isT5": {"filter_type": "value", "value_to_keep": 1}

}

print(f"Events before quality cuts: {len(index_bg)}")

index_bg["isPhoton"] = 0
index_bg["sin2zenith"] = np.sin(index_bg["zenithMC"]) ** 2

index_bg = filter_dataframe(index_bg, feature_filters)

index_bg, e_bin_centers, e_bin_edges, e_labels = create_bins(
    index_bg,
    lower_val=10**16.5,
    upper_val=10**17.5,
    num=6,
    unbinned_col="energyMC",
    bin_column_name="e_bin",
    bin_width="equal_logarithmic",
)

index_bg, z_bin_centers, z_bin_edges, z_labels = create_bins(
    index_bg,
    lower_val=0,
    upper_val=np.sin(np.deg2rad(45)) ** 2,
    num=4,
    unbinned_col="sin2zenith",
    bin_column_name="z_bin",
    bin_width="equal",
)

index_bg = index_bg.loc[~index_bg["e_bin"].isnull()]

# problems in ADSTs
index_bg = index_bg.sample(frac=1)
index_bg = index_bg[(index_bg["mass_group"]=="Helium") | (index_bg["mass_group"]=="Iron")]
print(f"Events after quality cuts: {len(index_bg)}")

# problem in ADSTs
exclude_list_bg = [
"Helium_16.5_17.0_010628_20",
"Helium_16.5_17.0_080138_01",
"Helium_17.0_17.5_011087_07",
"Helium_16.5_17.0_010355_19"
]
index_bg = index_bg[~index_bg['filename'].isin(exclude_list_bg)]

### Generation of Normalization Dictionary

In [None]:
# Define paths and other parameters
dir_path = "/sps/pauger/users/erodriguez/PhotonDiscrimination/JSONfiles/"
root_path = "/sps/pauger/users/erodriguez/PhotonDiscrimination/root/"

# Function to construct paths based on DataFrame columns
def construct_path(row, base_path):
    return f"{base_path}{row['mass_group']}/{row['filename']}.json"

# Set paths according to index
#train_paths = train_index.apply(lambda row: construct_path(row, dir_path), axis=1).tolist()
#val_paths = validation_index.apply(lambda row: construct_path(row, dir_path), axis=1).tolist()
test_paths = test_index.apply(lambda row: construct_path(row, dir_path), axis=1).tolist()
test_helium_paths = index_bg[index_bg["mass_group"]=="Helium"].apply(lambda row: construct_path(row, dir_path), axis=1).tolist()
test_iron_paths = index_bg.loc[index_bg["mass_group"]=="Iron"].apply(lambda row: construct_path(row, dir_path), axis=1).tolist()

In [None]:
augmentation_and_normalization_options={"mask_PMTs":True,
                                        "AoP_and_saturation":True,
                                        "log_normalize_traces":True,
                                        "log_normalize_signals":True,
                                        "mask_MD_mods":True}

normalization_dict = {
                     # with silent
                     #'deltaTimeHottest': {'mean': -13.743452072143555,
                     #                     'method': 'standardization',
                     #                     'std': 446.0340270996094},
                     # without silent
                     'deltaTimeHottest': {'mean': -47.27531568592806,
                                          'method': 'standardization',
                                          'std': 603.6062483849028},
                     'effective_area': {'max': 31.380000000000003,
                                        'method': 'min_max_scaling',
                                        'min': 0},
                     'pmt_number': {'max': 3, 'method': 'min_max_scaling', 'min': 1},
                     #'rho_mu': {'min': -2.0,
                     #           'method': 'min_max_scaling',
                     #           'max': (64*3)/(3*10*np.cos(np.deg2rad(45)))},
                     'rho_mu':{'mean': 0.21510971141596952,
                               'method': 'standardization',
                               'std': 0.9057438827119321},
                     # with silent
                     #'x': {'mean': -2.081512212753296,
                     #      'method': 'standardization',
                     #      'std': 779.2147827148438},
                     #'y': {'mean': 0.3692401945590973,
                     #      'method': 'standardization',
                     #      'std': 779.95458984375},
                     #'z': {'mean': -0.17404018342494965,
                     #      'method': 'standardization',
                     #      'std': 9.372444152832031},
                     # without silent
                     'x': {'mean': -1.01887806305201,
                           'method': 'standardization',
                           'std': 364.5171135403138},
                     'y': {'mean': -0.06542848666299515,
                           'method': 'standardization',
                           'std': 356.8056580363697},
                     'z': {'mean': -0.16543275617686548,
                           'method': 'standardization',
                           'std': 6.410170534866934}
                    }


pprint(normalization_dict)

### Datasets and Loaders

In [None]:
# Divide the dataset paths into subsets
n_loaders = 12
n_time_bins = 60

dataset_args = {"root":root_path,
                "augmentation_options":augmentation_and_normalization_options,
                "normalization_dict":normalization_dict,
                "include_silent":False,
                "n_time_bins":60}

from typing import Callable

def get_loader_list(paths: list, dataset_class: Callable, dataset_args: dict = {}, n_loaders:int =10, batch_size: int =32, num_workers: int = 8):
    subset_size = len(paths) // n_loaders
    subset_paths = [paths[i * subset_size: (i + 1) * subset_size] for i in range(n_loaders)]
    loaders = []
    for subset_paths_ in subset_paths:
        loader = DataLoader(dataset_class(file_paths=subset_paths_,
                                  **dataset_args),
                                  batch_size=batch_size, num_workers=num_workers)
        loaders.append(loader)
    return loaders

# Create DataLoader lists
test_loaders = get_loader_list(
    paths=test_paths,
    dataset_class=SD433UMDatasetHomogeneous,
    dataset_args=dataset_args,
    n_loaders=n_loaders,
    batch_size=32,
    num_workers=8,
)

test_helium_loaders = get_loader_list(
    paths=test_helium_paths,
    dataset_class=SD433UMDatasetHomogeneous,
    dataset_args=dataset_args,
    n_loaders=n_loaders,
    batch_size=32,
    num_workers=8,
)

test_iron_loaders = get_loader_list(
    paths=test_iron_paths,
    dataset_class=SD433UMDatasetHomogeneous,
    dataset_args=dataset_args,
    n_loaders=n_loaders,
    batch_size=32,
    num_workers=8,
)

In [None]:
#print("Parameters: ", sum(p.numel() for p in triheaded_model.parameters()))

# Model ID and file paths
test_run = False
include_silent = False
dual_input = True
loss_strategy = "equal"  # Options: "equal", "prioritize_sdmd", "alternate"
# Model ID and file paths
model_id = f"dual_input_{str(dual_input)}_silent_{str(include_silent)}_loss_strategy_{str(loss_strategy)}_no_TA_BN"
trained_model = torch.load(f"{model_id}.pth")
trained_model.to("cpu")
trained_model.eval() # add check for TA batch norm mean and var

### Overal performance

In [None]:
# Initialize an empty list to store dictionaries for DataFrame rows
df_data = []

# transformations
device = "cpu"

# transformations for individual input
masking_transformation_sd =  MaskNodes(max_nodes2prune=2)
masking_transformation_md =  MaskMdCounters(rho_mu_column_idx=7, effective_area_column_idx=6, silent_value=1/((64*3)/(3*10*np.cos(np.deg2rad(45)))-(-2.0)))

# transformations for dual input
silent_prunner = SilentPrunner(silent_col_index = 4, silent_value=0.0)
random_masking = MaskRandomNodes(max_nodes2prune=2)

for idx, test_loader in enumerate(test_loaders): 
    with torch.no_grad():
        for batch_idx, data in enumerate(tqdm(test_loader)):
            try:
                # Apply transformations
                data_sd = masking_transformation_sd.collate(data)
                if dual_input:
                    # dual input
                    #data_sd_no_silent = silent_prunner.collate(data_sd)
                    #data_sdmd = random_masking.collate(data_sd_no_silent)
                    data_sdmd = random_masking.collate(data_sd)
                    #sd_graph = data_sd.to(device)
                    #sdmd_graph = data_sdmd.to(device)
                    inputs, targets = [data_sd, data_sdmd], data_sd.y
                else:
                    # single input
                    data_sdmd = masking_transformation_md.collate(data_sd)
                    #graph = data_sdmd.to(device)
                    inputs, targets = data_sdmd, data_sdmd.y
    
                # Sanity checks on inputs and targets
                assert torch.all(torch.isfinite(data_sd.x)), "NaN detected in data_sd.x"
                assert torch.all(torch.isfinite(data_sdmd.x)), "NaN detected in data_sdmd.x"
                assert torch.all(torch.isfinite(data_sd.x_traces)), "NaN detected in data_sd.x_traces"
                assert torch.all(torch.isfinite(data_sdmd.x_traces)), "NaN detected in data_sd.x_traces"
                assert torch.all(torch.isfinite(targets)), "NaN detected in targets"

            except AssertionError as e:
                # Print the error message and skip this batch
                print(f"Skipping batch {batch_idx} in subset {idx+1} due to assertion error: {e}")
                continue
            
            # saving preds and targets
            output1, output2, output3 = trained_model(inputs)

            # Iterate over each graph in the batch
            for idx, (graph_id, target, prediction1, prediction2, prediction3) in enumerate(zip(data.id, targets, output1, output2, output3)):
                # Create a dictionary for each graph's data
                graph_sdmd_mask = idx == data_sdmd.batch
                graph_data = {
                    "graph_id": graph_id,
                    "target": target.item(),
                    "prediction_sd": prediction1.item(),
                    "prediction_sdmd": prediction2.item(),
                    "prediction_md": prediction3.item(),
                    "station_list": data_sdmd.station_list[idx],
                    "PMT_list": data_sdmd.x[graph_sdmd_mask,4].numpy(),
                    "MD_area": data_sdmd.x[graph_sdmd_mask,6].numpy()
                }
                # Append the dictionary to the list
                
                df_data.append(graph_data)

gc.collect()
# Convert the list of dictionaries into a DataFrame
df = pd.DataFrame(df_data)

test_df = pd.merge(index, df, how='inner', left_on='filename', right_on='graph_id')
test_df["spectrum_weight_gamma"] = test_df["energyMC"]**-2
test_df["spectrum_weight_hadron"] =test_df["energyMC"]**-3
test_df["no_weight"] = 1
test_df.to_csv(f"{model_id}_test_set.csv")

In [None]:
# Initialize an empty list to store dictionaries for DataFrame rows
df_data = []

for idx, test_loader in enumerate(test_helium_loaders): 
    with torch.no_grad():
        for batch_idx, data in enumerate(tqdm(test_loader)):
            try:
                # Apply transformations
                data_sd = masking_transformation_sd.collate(data)
                if dual_input:
                    # dual input
                    #data_sd_no_silent = silent_prunner.collate(data_sd)
                    #data_sdmd = random_masking.collate(data_sd_no_silent)
                    data_sdmd = random_masking.collate(data_sd)
                    #sd_graph = data_sd.to(device)
                    #sdmd_graph = data_sdmd.to(device)
                    inputs, targets = [data_sd, data_sdmd], data_sd.y
                else:
                    # single input
                    data_sdmd = masking_transformation_md.collate(data_sd)
                    #graph = data_sdmd.to(device)
                    inputs, targets = data_sdmd, data_sdmd.y
    
                # Sanity checks on inputs and targets
                assert torch.all(torch.isfinite(data_sd.x)), "NaN detected in data_sd.x"
                assert torch.all(torch.isfinite(data_sdmd.x)), "NaN detected in data_sdmd.x"
                assert torch.all(torch.isfinite(data_sd.x_traces)), "NaN detected in data_sd.x_traces"
                assert torch.all(torch.isfinite(data_sdmd.x_traces)), "NaN detected in data_sd.x_traces"
                assert torch.all(torch.isfinite(targets)), "NaN detected in targets"

            except AssertionError as e:
                # Print the error message and skip this batch
                print(f"Skipping batch {batch_idx} in subset {idx+1} due to assertion error: {e}")
                continue
            
            # saving preds and targets
            output1, output2, output3 = trained_model(inputs)

            # Iterate over each graph in the batch
            for idx, (graph_id, target, prediction1, prediction2, prediction3) in enumerate(zip(data.id, targets, output1, output2, output3)):
                # Create a dictionary for each graph's data
                graph_sdmd_mask = idx == data_sdmd.batch
                graph_data = {
                    "graph_id": graph_id,
                    "target": target.item(),
                    "prediction_sd": prediction1.item(),
                    "prediction_sdmd": prediction2.item(),
                    "prediction_md": prediction3.item(),
                    "station_list": data_sdmd.station_list[idx],
                    "PMT_list": data_sdmd.x[graph_sdmd_mask,4].numpy(),
                    "MD_area": data_sdmd.x[graph_sdmd_mask,6].numpy()
                }
                # Append the dictionary to the list
                df_data.append(graph_data)

gc.collect()
# Convert the list of dictionaries into a DataFrame
df = pd.DataFrame(df_data)

test_helium_df = pd.merge(index_bg, df, how='inner', left_on='filename', right_on='graph_id')
test_helium_df["spectrum_weight_hadron"] = test_helium_df["energyMC"]**-3
test_helium_df["no_weight"] = 1
test_helium_df.to_csv(f"{model_id}_test_set_helium.csv")

In [None]:
# Initialize an empty list to store dictionaries for DataFrame rows
df_data = []

for idx, test_loader in enumerate(test_iron_loaders): 
    with torch.no_grad():
        for batch_idx, data in enumerate(tqdm(test_loader)):
            try:
                # Apply transformations
                data_sd = masking_transformation_sd.collate(data)
                if dual_input:
                    # dual input
                    #data_sd_no_silent = silent_prunner.collate(data_sd)
                    #data_sdmd = random_masking.collate(data_sd_no_silent)
                    data_sdmd = random_masking.collate(data_sd)
                    #sd_graph = data_sd.to(device)
                    #sdmd_graph = data_sdmd.to(device)
                    inputs, targets = [data_sd, data_sdmd], data_sd.y
                else:
                    # single input
                    data_sdmd = masking_transformation_md.collate(data_sd)
                    #graph = data_sdmd.to(device)
                    inputs, targets = data_sdmd, data_sdmd.y
    
                # Sanity checks on inputs and targets
                assert torch.all(torch.isfinite(data_sd.x)), "NaN detected in data_sd.x"
                assert torch.all(torch.isfinite(data_sdmd.x)), "NaN detected in data_sdmd.x"
                assert torch.all(torch.isfinite(data_sd.x_traces)), "NaN detected in data_sd.x_traces"
                assert torch.all(torch.isfinite(data_sdmd.x_traces)), "NaN detected in data_sd.x_traces"
                assert torch.all(torch.isfinite(targets)), "NaN detected in targets"

            except AssertionError as e:
                # Print the error message and skip this batch
                print(f"Skipping batch {batch_idx} in subset {idx+1} due to assertion error: {e}")
                continue
            
            # saving preds and targets
            output1, output2, output3 = trained_model(inputs)

            # Iterate over each graph in the batch
            for idx, (graph_id, target, prediction1, prediction2, prediction3) in enumerate(zip(data.id, targets, output1, output2, output3)):
                # Create a dictionary for each graph's data
                graph_sdmd_mask = idx == data_sdmd.batch
                graph_data = {
                    "graph_id": graph_id,
                    "target": target.item(),
                    "prediction_sd": prediction1.item(),
                    "prediction_sdmd": prediction2.item(),
                    "prediction_md": prediction3.item(),
                    "station_list": data_sdmd.station_list[idx],
                    "PMT_list": data_sdmd.x[graph_sdmd_mask,4].numpy(),
                    "MD_area": data_sdmd.x[graph_sdmd_mask,6].numpy()
                }
                # Append the dictionary to the list
                df_data.append(graph_data)

gc.collect()
# Convert the list of dictionaries into a DataFrame
df = pd.DataFrame(df_data)

test_iron_df = pd.merge(index_bg, df, how='inner', left_on='filename', right_on='graph_id')
test_iron_df["spectrum_weight_hadron"] = test_iron_df["energyMC"]**-3
test_iron_df["no_weight"] = 1
test_iron_df.to_csv(f"{model_id}_test_set_iron.csv")

In [None]:
test_df = pd.read_csv("last_dual_input_True_silent_False_loss_strategy_equal_new_60_test_set.csv")
test_helium_df = pd.read_csv("last_dual_input_True_silent_False_loss_strategy_equal_new_60_test_set_helium.csv")
test_iron_df = pd.read_csv("last_dual_input_True_silent_False_loss_strategy_equal_new_60_test_set_iron.csv")

In [None]:
#myStyle = MyStyle('1fig-long', markers=None)
from matplotlib.gridspec import GridSpec
fig = plt.figure(figsize=(18,6))
gs = GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1])  # Define the layout

# Extract weights for each subset based on the target value
ph_df = test_df.loc[test_df["target"] == 1, :]
pr_df = test_df.loc[test_df["target"] == 0, :]

n_bins = 30
fontsize = 17
fontsize_label = 17
label_list = [r"$\gamma$", "p", "He", "Fe"]
color_list = ["blue", "#FF6666", "#CC3333",'#4B0000']

# Function to compute and plot histogram
def plot_histogram(ax, df_list, color_list, label_list, score_col, xlabel, weights=False):
    for idx, df in enumerate(df_list):
        if idx==0:
            if weights:
                weight_col="spectrum_weight_gamma"
            else:
                weight_col="no_weight"
                
            weighted_median_ph = weighted_median(np.array(df[score_col]), df[weight_col])
            bin_centers, norm_counts, errors = compute_weighted_histogram(df, target_col="target",
                                                                                   score_col=score_col,
                                                                                   weight_col="spectrum_weight_gamma",
                                                                                   bins=n_bins)
            ax.errorbar(bin_centers, norm_counts, yerr=errors, color=color_list[idx], label=label_list[idx], fmt='o', linestyle='')
    
        else:
            if weights:
                weight_col="spectrum_weight_hadron"
            else:
                weight_col="no_weight"

            bin_centers, norm_counts, errors = compute_weighted_histogram(df, target_col="target",
                                                                               score_col=score_col,
                                                                               weight_col="spectrum_weight_hadron",
                                                                               bins=n_bins)
            if idx==2:
                bin_centers = bin_centers[:-5]
                norm_counts = norm_counts[:-5]
                errors = errors[:-5]
                
            
            ax.errorbar(bin_centers, norm_counts, yerr=errors, color=color_list[idx], label=label_list[idx], fmt="v", linestyle='')
        
    
    ax.axvline(x=weighted_median_ph, label="$\gamma$ median", color="black", linestyle="dashed")
    ax.set_xlabel(xlabel, fontsize=fontsize)
    ax.set_yscale("log")
    ax.set_ylim(10**-5, 1)

# Plot SD
ax1 = fig.add_subplot(gs[0, 1])
plot_histogram(ax1, [ph_df, pr_df, test_helium_df, test_iron_df], score_col="prediction_sd", label_list=label_list, color_list = color_list, weights=True, xlabel="SD score")
ax1.set_xlim(-10, 12)

# Plot MD
ax2 = fig.add_subplot(gs[1, 1])
plot_histogram(ax2, [ph_df, pr_df, test_helium_df, test_iron_df], score_col="prediction_md", label_list=label_list, color_list = color_list, weights=True, xlabel="UMD score")
ax2.set_xlim(-10, 12)

# Plot SD+MD
ax3 = fig.add_subplot(gs[:, 0])
plot_histogram(ax3, [ph_df, pr_df, test_helium_df, test_iron_df], score_col="prediction_sdmd", label_list=label_list, color_list = color_list, weights=True, xlabel="SD-UMD score")
#ax3.set_ylabel(r"$\text{spectrum-weighted } \, 1/N \times \mathrm{d}N/\mathrm{d}\text{score}$", fontsize=fontsize)
ax3.set_ylabel(r"$E^{-\alpha}$ weighted norm. counts", fontsize=fontsize)
ax3.set_xlim(-10, 15)

plt.legend(fontsize=fontsize_label, loc="lower right")
fig.savefig('hist_separation.png', dpi=300)
plt.tight_layout()
plt.show()

In [None]:
x_feature = "showerSize"
x_label = r"$\lg{ \left( S_{300} \right) }$"'o'

#myStyle = MyStyle('1fig-long', markers=None)
fig = plt.figure(figsize=(18,6))
gs = GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1])  # Define the layout

ph_df = test_df.loc[test_df["target"]==1, :]
pr_df = test_df.loc[test_df["target"]==0, :]
he_df = test_helium_df[test_helium_df["prediction_sdmd"]<5]
fe_df = test_iron_df

# plot the histogram
ax1 = fig.add_subplot(gs[0, 0])
ax1.scatter(np.log10(ph_df[x_feature]), ph_df["prediction_sd"], alpha=0.45, color="blue", label='$\gamma$', s=4)
ax1.scatter(np.log10(pr_df[x_feature]), pr_df["prediction_sd"], alpha=0.65, color="#FF6666", label='p', s=4)
ax1.scatter(np.log10(he_df[x_feature]), he_df["prediction_sd"], alpha=0.65, color="#CC3333", label='He', s=4)
ax1.scatter(np.log10(fe_df[x_feature]), fe_df["prediction_sd"], alpha=0.65, color='#4B0000', label='Fe', s=4)
ax1.set_ylabel("SD score")
ax1.set_ylim(-9.0, 12)

ax2 = fig.add_subplot(gs[1, 0])
ax2.scatter(np.log10(ph_df[x_feature]), ph_df["prediction_md"], alpha=0.45, color="blue", label='$\gamma$', s=4)
ax2.scatter(np.log10(pr_df[x_feature]), pr_df["prediction_md"], alpha=0.65, color="#FF6666", label='p', s=4)
ax2.scatter(np.log10(he_df[x_feature]), he_df["prediction_md"], alpha=0.65, color="#CC3333", label='He', s=4)
ax2.scatter(np.log10(fe_df[x_feature]), fe_df["prediction_md"], alpha=0.65, color='#4B0000', label='Fe', s=4)
ax2.set_ylabel("UMD score")
ax2.set_xlabel(x_label)
ax2.set_ylim(-9.0, 12)

ax3 = fig.add_subplot(gs[:, 1])
ax3.scatter(np.log10(ph_df[x_feature]), ph_df["prediction_sdmd"], alpha=0.45, color="blue", label='$\gamma$', s=4)
ax3.scatter(np.log10(pr_df[x_feature]), pr_df["prediction_sdmd"], alpha=0.65, color="#FF6666", label='p', s=4)
ax3.scatter(np.log10(he_df[x_feature]), he_df["prediction_sdmd"], alpha=0.65, color="#CC3333", label='He', s=4)
ax3.scatter(np.log10(fe_df[x_feature]), fe_df["prediction_sdmd"], alpha=0.65, color='#4B0000', label='Fe', s=4)
ax3.set_ylabel("SD-UMD score")
ax3.set_xlabel(x_label)
ax3.set_ylim(-9.0, 12)
ax3.legend(framealpha=1, ncol=4)

#ax.set_ylim(-200, 350)
for ax in [ax1, ax2, ax3]:
    ax.set_xlim(0.0, 2.1)

plt.tight_layout()
plt.show()

In [None]:
x_feature = "photon_energy"
x_label = r"lg($E_{\gamma}$ / eV)"
fontsize = 18

#myStyle = MyStyle('1fig-long', markers=None)
fig = plt.figure(figsize=(18,6))
gs = GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1])  # Define the layout

ph_df = test_df.loc[test_df["target"]==1, :]
pr_df = test_df.loc[test_df["target"]==0, :]
he_df = test_helium_df[test_helium_df["prediction_sdmd"]<5]
fe_df = test_iron_df

# plot the histogram
ax1 = fig.add_subplot(gs[0, 0])
ax1.scatter(np.log10(ph_df[x_feature]), ph_df["prediction_sd"], alpha=0.45, color="blue", label='$\gamma$', s=4)
ax1.scatter(np.log10(pr_df[x_feature]), pr_df["prediction_sd"], alpha=0.65, color="#FF6666", label='p', s=4)
ax1.scatter(np.log10(he_df[x_feature]), he_df["prediction_sd"], alpha=0.65, color="#CC3333", label='He', s=4)
ax1.scatter(np.log10(fe_df[x_feature]), fe_df["prediction_sd"], alpha=0.65, color='#4B0000', label='Fe', s=4)
ax1.set_ylabel("SD score", fontsize=18)
ax1.set_ylim(-9.0, 12)

ax2 = fig.add_subplot(gs[1, 0])
ax2.scatter(np.log10(ph_df[x_feature]), ph_df["prediction_md"], alpha=0.45, color="blue", label='$\gamma$', s=4)
ax2.scatter(np.log10(pr_df[x_feature]), pr_df["prediction_md"], alpha=0.65, color="#FF6666", label='p', s=4)
ax2.scatter(np.log10(he_df[x_feature]), he_df["prediction_md"], alpha=0.65, color="#CC3333", label='He', s=4)
ax2.scatter(np.log10(fe_df[x_feature]), fe_df["prediction_md"], alpha=0.65, color='#4B0000', label='Fe', s=4)
ax2.set_ylabel("UMD score", fontsize=18)
ax2.set_xlabel(x_label, fontsize=18)
ax2.set_ylim(-9.0, 12)

ax3 = fig.add_subplot(gs[:, 1])
ax3.scatter(np.log10(ph_df[x_feature]), ph_df["prediction_sdmd"], alpha=0.45, color="blue", label='$\gamma$', s=4)
ax3.scatter(np.log10(pr_df[x_feature]), pr_df["prediction_sdmd"], alpha=0.65, color="#FF6666", label='p', s=4)
ax3.scatter(np.log10(he_df[x_feature]), he_df["prediction_sdmd"], alpha=0.65, color="#CC3333", label='He', s=4)
ax3.scatter(np.log10(fe_df[x_feature]), fe_df["prediction_sdmd"], alpha=0.65, color='#4B0000', label='Fe', s=4)
ax3.set_ylabel("SD-UMD score", fontsize=18)
ax3.set_xlabel(x_label, fontsize=18)
ax3.set_ylim(-9.0, 12)
ax3.legend(framealpha=1, ncol=4, fontsize=17)

#ax.set_ylim(-200, 350)
for ax in [ax1, ax2, ax3]:
    ax.set_xlim(16.5, 17.5)

plt.tight_layout()
fig.savefig('scores_vs_energy.png', dpi=300)
plt.show()

In [None]:
x_feature = "sin2zenith"
x_label = r"$\sin^2 \theta$"

#myStyle = MyStyle('1fig-long', markers=None)
fig = plt.figure(figsize=(18,6))
gs = GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1])  # Define the layout

ph_df = test_df.loc[test_df["target"]==1, :]
pr_df = test_df.loc[test_df["target"]==0, :]
he_df = test_helium_df[test_helium_df["prediction_sdmd"]<5]
fe_df = test_iron_df

# plot the histogram
ax1 = fig.add_subplot(gs[0, 0])
ax1.scatter(ph_df[x_feature], ph_df["prediction_sd"], alpha=0.45, color="blue", label='$\gamma$', s=4)
ax1.scatter(pr_df[x_feature], pr_df["prediction_sd"], alpha=0.65, color="#FF6666", label='p', s=4)
ax1.scatter(he_df[x_feature], he_df["prediction_sd"], alpha=0.65, color="#CC3333", label='He', s=4)
ax1.scatter(fe_df[x_feature], fe_df["prediction_sd"], alpha=0.65, color='#4B0000', label='Fe', s=4)
ax1.set_ylabel("SD score", fontsize=fontsize)
ax1.set_ylim(-9.0, 12)

ax2 = fig.add_subplot(gs[1, 0])
ax2.scatter(ph_df[x_feature], ph_df["prediction_md"], alpha=0.45, color="blue", label='$\gamma$', s=4)
ax2.scatter(pr_df[x_feature], pr_df["prediction_md"], alpha=0.65, color="#FF6666", label='p', s=4)
ax2.scatter(he_df[x_feature], he_df["prediction_md"], alpha=0.65, color="#CC3333", label='He', s=4)
ax2.scatter(fe_df[x_feature], fe_df["prediction_md"], alpha=0.65, color='#4B0000', label='Fe', s=4)
ax2.set_ylabel("UMD score", fontsize=fontsize)
ax2.set_xlabel(x_label, fontsize=fontsize)
ax2.set_ylim(-9.0, 12)

ax3 = fig.add_subplot(gs[:, 1])
ax3.scatter(ph_df[x_feature], ph_df["prediction_sdmd"], alpha=0.45, color="blue", label='$\gamma$', s=4)
ax3.scatter(pr_df[x_feature], pr_df["prediction_sdmd"], alpha=0.65, color="#FF6666", label='p', s=4)
ax3.scatter(he_df[x_feature], he_df["prediction_sdmd"], alpha=0.65, color="#CC3333", label='He', s=4)
ax3.scatter(fe_df[x_feature], fe_df["prediction_sdmd"], alpha=0.65, color='#4B0000', label='Fe', s=4)
ax3.set_ylabel("SD-UMD score", fontsize=fontsize)
ax3.set_xlabel(x_label, fontsize=fontsize)
ax3.set_ylim(-9.0, 12)
ax3.legend(framealpha=1, ncol=4, fontsize=fontsize)

#ax.set_ylim(-200, 350)
for ax in [ax1, ax2, ax3]:
    ax.set_xlim(0, 0.5)

plt.tight_layout()
fig.savefig('scores_vs_zenith.png', dpi=300)
plt.show()

In [None]:
x_feature = "Xmax"
x_label = r"$X_{max}$"

#myStyle = MyStyle('1fig-long', markers=None)
fig = plt.figure(figsize=(18,6))
gs = GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1])  # Define the layout

ph_df = test_df.loc[test_df["target"]==1, :]
pr_df = test_df.loc[test_df["target"]==0, :]
he_df = test_helium_df[test_helium_df["prediction_sdmd"]<5]
fe_df = test_iron_df

# plot the histogram
ax1 = fig.add_subplot(gs[0, 0])
ax1.scatter(ph_df[x_feature], ph_df["prediction_sd"], alpha=0.45, color="blue", label='$\gamma$', s=4)
ax1.scatter(pr_df[x_feature], pr_df["prediction_sd"], alpha=0.65, color="#FF6666", label='p', s=4)
ax1.scatter(he_df[x_feature], he_df["prediction_sd"], alpha=0.65, color="#CC3333", label='He', s=4)
ax1.scatter(fe_df[x_feature], fe_df["prediction_sd"], alpha=0.65, color='#4B0000', label='Fe', s=4)
ax1.set_ylabel("SD score", fontsize=fontsize)
ax1.set_ylim(-9.0, 12)

ax2 = fig.add_subplot(gs[1, 0])
ax2.scatter(ph_df[x_feature], ph_df["prediction_md"], alpha=0.45, color="blue", label='$\gamma$', s=4)
ax2.scatter(pr_df[x_feature], pr_df["prediction_md"], alpha=0.65, color="#FF6666", label='p', s=4)
ax2.scatter(he_df[x_feature], he_df["prediction_md"], alpha=0.65, color="#CC3333", label='He', s=4)
ax2.scatter(fe_df[x_feature], fe_df["prediction_md"], alpha=0.65, color='#4B0000', label='Fe', s=4)
ax2.set_ylabel("UMD score", fontsize=fontsize)
ax2.set_xlabel(x_label, fontsize=fontsize)
ax2.set_ylim(-9.0, 12)

ax3 = fig.add_subplot(gs[:, 1])
ax3.scatter(ph_df[x_feature], ph_df["prediction_sdmd"], alpha=0.45, color="blue", label='$\gamma$', s=4)
ax3.scatter(pr_df[x_feature], pr_df["prediction_sdmd"], alpha=0.65, color="#FF6666", label='p', s=4)
ax3.scatter(he_df[x_feature], he_df["prediction_sdmd"], alpha=0.65, color="#CC3333", label='He', s=4)
ax3.scatter(fe_df[x_feature], fe_df["prediction_sdmd"], alpha=0.65, color='#4B0000', label='Fe', s=4)
ax3.set_ylabel("SD-UMD score", fontsize=fontsize)
ax3.set_xlabel(x_label, fontsize=fontsize)
ax3.set_ylim(-9.0, 15)
ax3.legend(framealpha=1, ncol=4, fontsize=fontsize)

#ax.set_ylim(-200, 350)
for ax in [ax1, ax2, ax3]:
    ax.set_xlim(500, 1000)

plt.tight_layout()
fig.savefig('scores_vs_Xmax.png', dpi=300)
plt.show()

In [None]:
x_feature = "muonNumber"
x_label = r"lg($N_{\mu}$)"

#myStyle = MyStyle('1fig-long', markers=None)
fig = plt.figure(figsize=(18,6))
gs = GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1])  # Define the layout

ph_df = test_df.loc[test_df["target"]==1, :]
pr_df = test_df.loc[test_df["target"]==0, :]
he_df = test_helium_df[test_helium_df["prediction_sdmd"]<5]
fe_df = test_iron_df

# plot the histogram
ax1 = fig.add_subplot(gs[0, 0])
ax1.scatter(np.log10(ph_df[x_feature]), ph_df["prediction_sd"], alpha=0.45, color="blue", label='$\gamma$', s=4)
ax1.scatter(np.log10(pr_df[x_feature]), pr_df["prediction_sd"], alpha=0.65, color="#FF6666", label='p', s=4)
ax1.scatter(np.log10(he_df[x_feature]), he_df["prediction_sd"], alpha=0.65, color="#CC3333", label='He', s=4)
ax1.scatter(np.log10(fe_df[x_feature]), fe_df["prediction_sd"], alpha=0.65, color='#4B0000', label='Fe', s=4)
ax1.set_ylabel("SD score", fontsize=fontsize)
ax1.set_ylim(-9.0, 12)

ax2 = fig.add_subplot(gs[1, 0])
ax2.scatter(np.log10(ph_df[x_feature]), ph_df["prediction_md"], alpha=0.45, color="blue", label='$\gamma$', s=4)
ax2.scatter(np.log10(pr_df[x_feature]), pr_df["prediction_md"], alpha=0.65, color="#FF6666", label='p', s=4)
ax2.scatter(np.log10(he_df[x_feature]), he_df["prediction_md"], alpha=0.65, color="#CC3333", label='He', s=4)
ax2.scatter(np.log10(fe_df[x_feature]), fe_df["prediction_md"], alpha=0.65, color='#4B0000', label='Fe', s=4)
ax2.set_ylabel("UMD score", fontsize=fontsize)
ax2.set_xlabel(x_label, fontsize=fontsize)
ax2.set_ylim(-9.0, 12)

ax3 = fig.add_subplot(gs[:, 1])
ax3.scatter(np.log10(ph_df[x_feature]), ph_df["prediction_sdmd"], alpha=0.45, color="blue", label='$\gamma$', s=4)
ax3.scatter(np.log10(pr_df[x_feature]), pr_df["prediction_sdmd"], alpha=0.65, color="#FF6666", label='p', s=4)
ax3.scatter(np.log10(he_df[x_feature]), he_df["prediction_sdmd"], alpha=0.65, color="#CC3333", label='He', s=4)
ax3.scatter(np.log10(fe_df[x_feature]), fe_df["prediction_sdmd"], alpha=0.65, color='#4B0000', label='Fe', s=4)
ax3.set_ylabel("SD-UMD score", fontsize=fontsize)
ax3.set_xlabel(x_label, fontsize=fontsize)
ax3.set_ylim(-9.0, 15)
ax3.legend(framealpha=1, ncol=4, fontsize=fontsize)

#ax.set_ylim(-200, 350)
#for ax in [ax1, ax2, ax3]:
#    ax.set_xlim(500, 1000)

plt.tight_layout()
fig.savefig('scores_vs_muon.png', dpi=300)
plt.show()

In [None]:
#myStyle = MyStyle('1fig-long', markers=None)
fig, axs = plt.subplots(1, 2, figsize=(18,6))
ph_df = test_df.loc[test_df["target"]==1, :]
pr_df = test_df.loc[test_df["target"]==0, :]
he_df = test_helium_df[test_helium_df["prediction_sdmd"]<5]
fe_df = test_iron_df

xmin, xmax = -10, 25
ymin, ymax = -10, 15

axs[0].scatter(ph_df["prediction_sdmd"], ph_df["prediction_sd"], alpha=0.45, color="blue", label='$\gamma$', s=4)
axs[0].scatter(pr_df["prediction_sdmd"], pr_df["prediction_sd"], alpha=0.65, color="r", label='p', s=4)
axs[0].scatter(test_helium_df["prediction_sdmd"], test_helium_df["prediction_sd"], alpha=0.65, color="gold", label='He', s=4)
axs[0].scatter(test_iron_df["prediction_sdmd"], test_iron_df["prediction_sd"], alpha=0.65, color="darkblue", label='Fe', s=4)
axs[0].set_ylabel("SD score")
axs[0].vlines(x=5, ymin=ymin, ymax=ymax, linestyle="dotted", color="black")
axs[0].hlines(y=5, xmin=xmin, xmax=xmax, linestyle="dotted", color="black")

axs[1].scatter(ph_df["prediction_sdmd"], ph_df["prediction_md"], alpha=0.45, color="blue", label='$\gamma$', s=4)
axs[1].scatter(pr_df["prediction_sdmd"], pr_df["prediction_md"], alpha=0.65, color="r", label='p', s=4)
axs[1].scatter(test_helium_df["prediction_sdmd"], test_helium_df["prediction_md"], alpha=0.65, color="gold", label='He', s=4)
axs[1].scatter(test_iron_df["prediction_sdmd"], test_iron_df["prediction_md"], alpha=0.65, color="darkblue", label='Fe', s=4)
axs[1].set_ylabel("MD score")

for ax in axs:
    ax.set_xlabel("SD-MD score")
    ax.set_xlim(xmin, xmax)
    
axs[0].set_ylim(ymin, ymax)

plt.tight_layout()
plt.show()

In [None]:
myStyle = MyStyle('1fig-long', markers=None)
fig, axs = plt.subplots(1, 2)

ph_df = merged_test_df.loc[merged_test_df["target"]==1, :]
pr_df = merged_test_df.loc[merged_test_df["target"]==0, :]
he_df = merged_test_df_bg

xmin, xmax = -10, 25
ymin, ymax = -10, 15

axs[0].scatter(ph_df["prediction_sdmd"], ph_df["M1"], alpha=0.45, color="blue", label='$\gamma$', s=4)
axs[0].scatter(pr_df["prediction_sdmd"], pr_df["M1"], alpha=0.65, color="r", label='p', s=4)
axs[0].scatter(he_df["prediction_sdmd"], he_df["M1"], alpha=0.65, color="gold", label='p', s=4)
axs[0].set_ylabel("SD score")
#axs[0].vlines(x=5, ymin=ymin, ymax=ymax, linestyle="dotted", color="black")
#axs[0].hlines(y=5, xmin=xmin, xmax=xmax, linestyle="dotted", color="black")

axs[1].scatter(ph_df["prediction_md"], ph_df["M1"], alpha=0.45, color="blue", label='$\gamma$', s=4)
axs[1].scatter(pr_df["prediction_md"], pr_df["M1"], alpha=0.65, color="r", label='p', s=4)
axs[1].scatter(he_df["prediction_md"], he_df["M1"], alpha=0.65, color="gold", label='p', s=4)
axs[1].set_ylabel("MD score")

for ax in axs:
    ax.set_xlabel("SD-MD score")
    #ax.set_xlim(xmin, xmax)
    
#axs[0].set_ylim(ymin, ymax)

plt.tight_layout()
plt.show()

In [None]:
myStyle = MyStyle('1fig-long', markers=None)
fig, axs = plt.subplots(1, 2)

ph_df = merged_test_df.loc[merged_test_df["target"]==1, :]
pr_df = merged_test_df.loc[merged_test_df["target"]==0, :]
he_df = merged_test_df_bg

xmin, xmax = -10, 25
ymin, ymax = -10, 15

axs[0].scatter(ph_df["showerSize"], -1*ph_df["M1"], alpha=0.45, color="blue", label='$\gamma$', s=4)
axs[0].scatter(pr_df["showerSize"], -1*pr_df["M1"], alpha=0.65, color="r", label='p', s=4)
axs[0].scatter(he_df["showerSize"], -1*he_df["M1"], alpha=0.65, color="gold", label='p', s=4)
axs[0].set_xlabel(r"$\lg{ \left( S_{300} \right) }$")

axs[1].scatter(ph_df["sin2zenith"], -1*ph_df["M1"], alpha=0.45, color="blue", label='$\gamma$', s=4)
axs[1].scatter(pr_df["sin2zenith"], -1*pr_df["M1"], alpha=0.65, color="r", label='p', s=4)
axs[1].scatter(he_df["sin2zenith"], -1*he_df["M1"], alpha=0.65, color="gold", label='p', s=4)
axs[1].set_xlabel(r"$\sin^2{\theta_{\text{MC}}}$")

for ax in axs:
    ax.set_ylabel("M1")
    #ax.set_xlim(xmin, xmax)
    
#axs[0].set_ylim(ymin, ymax)

plt.tight_layout()
plt.show()

In [None]:
FP = merged_test_df.loc[(merged_test_df["prediction_sdmd"]>5) & (merged_test_df["isPhoton"]==0), :]
print(f"{FP.shower_id.unique()}")
FP.describe()

## Missing hottest Counter

In [None]:
from torch_geometric.data import Dataset, Data, Batch
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import subgraph, dropout_node

class MaskHottest(BaseTransform):
    def __init__(self):
        super(MaskHottest, self).__init__()

    def __call__(self, input_data):
        """
        Prune hottest station from graph.

        Returns:
            torch_geometric.data.Data: The data object without the hottest station.
        """
        
        num_nodes = input_data.num_nodes
        data = input_data.clone()
        nodes2keep = torch.arange(1, num_nodes)
        # Update the feature matrix, x_traces, total_signal, station_list, and pos to keep only selected nodes
        data.x = data.x[nodes2keep, :]
        data.x_traces = data.x_traces[nodes2keep, :]
        #data.total_signal = data.total_signal[nodes2keep, :]
        data.station_list = [data.station_list[node] for node in nodes2keep]
        data.distance2hottest_list = [data.distance2hottest_list[node] for node in nodes2keep]
        data.pos = data.pos[nodes2keep]

        # Use subgraph method to get the new edge_index
        new_edge_index, _ = subgraph(nodes2keep, edge_index=data.edge_index.long(), relabel_nodes=True, num_nodes=num_nodes)
        data.edge_index = new_edge_index

        # Update the number of nodes in the data object
        data.num_nodes = data.x.size(0)

        return data

    def collate(self, batch):
        
        new_graphs = []
        for graph_batch_id in np.unique(batch.batch):

            # get features
            graph_mask = graph_batch_id == batch.batch
            _graph_x = batch.x[graph_mask]
            _graph_y = batch.y[graph_batch_id]
            _graph_x_traces = batch.x_traces[graph_mask]
            #_total_signal =  batch.total_signal[graph_mask]
            _station_list = batch.station_list[graph_batch_id]
            _distance2hottest_list = batch.distance2hottest_list[graph_batch_id]
            _pos = batch.pos[graph_mask]

            # get graph edges from batch
            nodes2keep = torch.tensor([x for x in range(batch.ptr[graph_batch_id], batch.ptr[graph_batch_id+1])])
            _graph_x_edge_index, _ = subgraph(nodes2keep, edge_index=batch.edge_index.long(), relabel_nodes=True, num_nodes=len(batch.batch))
            
            # the new graph or "subgraph"
            new_graph = Data(x=_graph_x, edge_index = _graph_x_edge_index, pos=_pos)
            new_graph.x_traces = _graph_x_traces
            #new_graph.total_signal = _total_signal
            new_graph.station_list = _station_list
            new_graph.distance2hottest_list = _distance2hottest_list
            new_graph.y = _graph_y
            
            new_graphs.append(new_graph)
        
        # Now, use the Batch class to collect the transformed graphs into a batch
        return Batch.from_data_list([self(graph) for graph in new_graphs])
    
masking_transformation = MaskHottest()

In [None]:
# Initialize an empty list to store dictionaries for DataFrame rows
df_data = []

# transformations
masking_transformation = MaskHottest()

for idx, test_loader in enumerate(test_loaders): 
    with torch.no_grad():
        for batch_idx, data in enumerate(tqdm(test_loader)):
            sd_graph = masking_transformation_sd.collate(data)
            sdmd_graph = masking_transformation.collate(sd_graph)
            sd_graph.x = sd_graph.x[:, :sd_node_features]
            inputs, targets = [sd_graph, sdmd_graph], sd_graph.y
            # saving preds and targets
            output1, output2, output3 = trained_model(inputs)

            # Iterate over each graph in the batch
            for idx, (graph_id, target, prediction1, prediction2, prediction3) in enumerate(zip(data.id, targets, output1, output2, output3)):
                # Create a dictionary for each graph's data
                graph_mask_sd = idx == sd_graph.batch
                graph_mask_sdmd = idx == sdmd_graph.batch
                graph_data = {
                    "graph_id": graph_id,
                    "target": target.item(),
                    "prediction_sd": prediction1.item(),
                    "prediction_sdmd": prediction2.item(),
                    "prediction_md": prediction3.item(),
                    "original_station_list": data.station_list[idx],
                    "sd_station_list": sd_graph.station_list[idx],
                    "sdmd_station_list": sdmd_graph.station_list[idx],
                    "PMT_list_sd": sd_graph.x[graph_mask_sd,4].numpy(),
                    "PMT_list_sdmd": sdmd_graph.x[graph_mask_sdmd,4].numpy(),
                    "MD_area_list": sdmd_graph.x[graph_mask_sdmd,5].numpy(),
                    "sdmd_distance2hottest": sdmd_graph.distance2hottest_list[idx],
                }
                # Append the dictionary to the list
                df_data.append(graph_data)

gc.collect()
# Convert the list of dictionaries into a DataFrame
df = pd.DataFrame(df_data)

# Merge the two DataFrames on the "filename" column
df = pd.DataFrame(df_data)
df["orig_stations"] = df["original_station_list"].apply(lambda x: len(x))
df["sd_stations"] = df["sd_station_list"].apply(lambda x: len(x))
df["sdmd_stations"] = df["sdmd_station_list"].apply(lambda x: len(x))

merged_test_df_wohottest = pd.merge(index, df, how='inner', left_on='filename', right_on='graph_id')
merged_test_df_wohottest["spectrum_weight_gamma"] = merged_test_df_wohottest["energyMC"]**-2
merged_test_df_wohottest["spectrum_weight_hadron"] = merged_test_df_wohottest["energyMC"]**-3
merged_test_df_wohottest["no_weight"] = 1
merged_test_df_wohottest['n_1stcrown_stations'] = merged_test_df_wohottest['sdmd_distance2hottest'].apply(count_below_cutoff, cutoff=500)

In [None]:
# Initialize an empty list to store dictionaries for DataFrame rows
df_data_bg = []

# transformations
masking_transformation = MaskHottest()

for idx, test_loader in enumerate(test_bg_loaders): 
    with torch.no_grad():
        for batch_idx, data in enumerate(tqdm(test_loader)):
            sd_graph = masking_transformation_sd.collate(data)
            sdmd_graph = masking_transformation.collate(sd_graph)
            sd_graph.x = sd_graph.x[:, :sd_node_features]
            inputs, targets = [sd_graph, sdmd_graph], sd_graph.y
            # saving preds and targets
            output1, output2, output3 = trained_model(inputs)

            # Iterate over each graph in the batch
            for idx, (graph_id, target, prediction1, prediction2, prediction3) in enumerate(zip(data.id, targets, output1, output2, output3)):
                # Create a dictionary for each graph's data
                graph_mask_sd = idx == sd_graph.batch
                graph_mask_sdmd = idx == sdmd_graph.batch
                graph_data = {
                    "graph_id": graph_id,
                    "target": target.item(),
                    "prediction_sd": prediction1.item(),
                    "prediction_sdmd": prediction2.item(),
                    "prediction_md": prediction3.item(),
                    "original_station_list": data.station_list[idx],
                    "sd_station_list": sd_graph.station_list[idx],
                    "sdmd_station_list": sdmd_graph.station_list[idx],
                    "PMT_list_sd": sd_graph.x[graph_mask_sd,4].numpy(),
                    "PMT_list_sdmd": sdmd_graph.x[graph_mask_sdmd,4].numpy(),
                    "MD_area_list": sdmd_graph.x[graph_mask_sdmd,5].numpy(),
                    "sdmd_distance2hottest": sdmd_graph.distance2hottest_list[idx],
                }
                # Append the dictionary to the list
                df_data_bg.append(graph_data)

gc.collect()
# Convert the list of dictionaries into a DataFrame
df_bg = pd.DataFrame(df_data_bg)

# Merge the two DataFrames on the "filename" column
df_bg = pd.DataFrame(df_data_bg)
df_bg["orig_stations"] = df_bg["original_station_list"].apply(lambda x: len(x))
df_bg["sd_stations"] = df_bg["sd_station_list"].apply(lambda x: len(x))
df_bg["sdmd_stations"] = df_bg["sdmd_station_list"].apply(lambda x: len(x))

merged_test_df_bg_wohottest = pd.merge(index_bg, df_bg, how='inner', left_on='filename', right_on='graph_id')
merged_test_df_bg_wohottest["spectrum_weight_gamma"] = merged_test_df_bg_wohottest["energyMC"]**-2
merged_test_df_bg_wohottest["spectrum_weight_hadron"] = merged_test_df_bg_wohottest["energyMC"]**-3
merged_test_df_bg_wohottest["no_weight"] = 1
merged_test_df_bg_wohottest['n_1stcrown_stations'] = merged_test_df_bg_wohottest['sdmd_distance2hottest'].apply(count_below_cutoff, cutoff=500)

In [None]:
myStyle = MyStyle('1fig-long', markers=None)
fig, axs = plt.subplots(1, 2)


linestyles = ["solid","dotted"]
labels = ["w. hottest counter", "w.o. hottest counter"]

for idx, (dataset, dataset_bg) in enumerate([(merged_test_df, merged_test_df_bg), (merged_test_df_wohottest, merged_test_df_bg_wohottest)]):

    # Extract weights for each subset based on the target value
    ph_df = dataset.loc[dataset["target"] == 1, :]
    pr_df = dataset.loc[dataset["target"] == 0, :]
    he_df = dataset_bg

    n_bins = 30
    fontsize = 20
    fontsize_label = 18
    linestyle = linestyles[idx]
    
    for idy, score_col in enumerate(["prediction_sdmd", "prediction_md"]):

        ph_bin_centers, ph_norm_counts, ph_errors = compute_weighted_histogram(ph_df, target_col="target",
                                                                               score_col=score_col,
                                                                               weight_col="spectrum_weight_gamma",
                                                                               bins=n_bins)
        pr_bin_centers, pr_norm_counts, pr_errors = compute_weighted_histogram(pr_df, target_col="target",
                                                                               score_col=score_col,
                                                                               weight_col="spectrum_weight_hadron",
                                                                               bins=n_bins)
        he_bin_centers, he_norm_counts, he_errors = compute_weighted_histogram(he_df, target_col="target",
                                                                               score_col=score_col,
                                                                               weight_col="spectrum_weight_hadron",
                                                                               bins=n_bins)

        weighted_median_ph = weighted_median(np.array(ph_df[score_col]), ph_df[weight_cols[0]])
        axs[idy].step(pr_bin_centers, pr_norm_counts, where='mid', color="red", label=f'p - {labels[idx]}', linestyle=linestyle)
        axs[idy].step(ph_bin_centers, ph_norm_counts, where='mid', color="blue", label=f'$\gamma$ - {labels[idx]}', linestyle=linestyle)
        #axs[idy].step(he_bin_centers, he_norm_counts, where='mid', color="gold", label='He', linestyle=linestyle)
        #axs[idy].axvline(x=weighted_median_ph, label="$\gamma$ median", color="black", linestyle=linestyle)

        
axs[0].set_xlabel("SD-MD score"); axs[1].set_xlabel("MD score")

for ax in axs:
    ax.set_yscale("log")
    ax.set_ylim(10**-5, 1)
    ax.legend(fontsize=fontsize_label, ncol=2, framealpha=1)

plt.tight_layout()
plt.show()