In [5]:
"""
This code is crucial for calculating the class weights (f1, f2, f3) 
that  will  be  used  in the loss function to account for the class 
imbalance  in the training data. Since my training dataset contains 
varying  numbers  of photons (pid = 22), neutrons (pid = 2112), and 
other  particle  types,  applying inverse frequency weights ensures 
that the loss function does not overfit to the more frequent classes.
"""

import sys
sys.path.append("../../src")
import pandas as pd
import glob
import numpy as np
from tqdm import tqdm
import h5py

# Set the project directory and the number of files to read
project_dir = "../../projects/supercell.10.08.2024.18.28/"
Nfiles = 100

# Collect all relevant files in the destination directory
files = glob.glob(f"{project_dir}/training/*.h5")

# Limit the number of files to read based on user input or available files
Nfiles = np.amin([Nfiles, len(files)])
files = files[:Nfiles]

# Initialize counts for each particle type
N1, N2, N3 = 0, 0, 0  # N1: photons (pid = 22), N2: neutrons (pid = 2112), N3: others

# Loop over each file
for i, file in enumerate(tqdm(files, desc="Processing files")):
    # Open the h5 file and get 'X', 'y' and 'misc' from it
    with h5py.File(file, 'r') as f:
        y = f['y']
        strip_pid = np.array(y[:,:,1])
        strip_pid = strip_pid[strip_pid!=0]
        strip_pid = strip_pid[strip_pid!=-1]
        N1 = np.sum(strip_pid==22)
        N2 = np.sum(strip_pid==2112)
        N3 = np.sum((strip_pid!=22)&(strip_pid!=2112))

# Total number of events processed
N_total = N1 + N2 + N3

# Compute inverse frequency weights for each class
f1 = N_total / N1 if N1 > 0 else 1.0  # Handle potential division by zero
f2 = N_total / N2 if N2 > 0 else 1.0
f3 = N_total / N3 if N3 > 0 else 1.0

# Display class counts and the computed inverse frequency weights
print(f"Total photon count (pid=22): {N1}")
print(f"Total neutron count (pid=2112): {N2}")
print(f"Total other particles count: {N3}")
print(f"Inverse frequency weights: f1={f1:.2f}, f2={f2:.2f}, f3={f3:.2f}")


Processing files: 100%|██████████████████████| 100/100 [00:06<00:00, 15.29it/s]

Total photon count (pid=22): 2609
Total neutron count (pid=2112): 612
Total other particles count: 6491
Inverse frequency weights: f1=3.72, f2=15.87, f3=1.50



