In [45]:
import numpy as np
import h5py
import glob
from pathlib import Path
import json
import matplotlib.pyplot as plt
import seaborn as sns

# Deep learning imports
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

# Scikit-learn for preprocessing
from sklearn.preprocessing import StandardScaler, RobustScaler, PowerTransformer
from sklearn.preprocessing import MinMaxScaler, QuantileTransformer
from sklearn.model_selection import train_test_split

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

In [70]:
def count_objects_vectorized(e_feats, p_feats):
    # Count electrons with energy > 0
    electron_count = np.sum(e_feats[:, :, 0] > 0, axis=1)
    
    # Count photons with energy > 0
    photon_count = np.sum(p_feats[:, :, 0] > 0, axis=1)
    
    # Total count combines electron and photon counts
    total_count = electron_count + photon_count
    
    # Create mask for events with at least 2 objects
    valid_events = total_count >= 2
    
    return valid_events

In [71]:
class SignalDataReader:
    """Handles reading and preprocessing of signal physics event data."""
    
    def __init__(self, data_dir, scaler_path):
        self.data_dir = Path(data_dir)
        self.scaler_path = Path(scaler_path)

        # Load feature lists and scaler parameters
        with open(self.scaler_path, 'r') as f:
            self.scaler_params = json.load(f)

        # Define feature lists based on README
        self.electron_features_list = [
            'electron_E', 'electron_pt', 'electron_eta', 'electron_phi',
            'electron_time',
            'electron_d0', 'electron_z0', 'electron_dpt',
            'electron_nPIX', 'electron_nMissingLayers',
            'electron_chi2', 'electron_numberDoF',  # Will need to handle ratio
            'electron_f1', 'electron_f3', 'electron_z'
        ]
        
        self.photon_features_list = [
            'photon_E', 'photon_pt', 'photon_eta', 'photon_phi',
            'photon_time',
            'photon_maxEcell_E',
            'photon_f1', 'photon_f3', 'photon_r1', 'photon_r2',
            'photon_etas1', 'photon_phis1', 'photon_z'
        ]
        
        # Initialize and load scalers
        self._initialize_and_load_scalers()

        # Load and preprocess all data
        self.load_all_data()

    def _initialize_and_load_scalers(self):
        """Initialize specialized scalers for each feature group based on saved parameters."""
        print(f"Loading specialized scalers from {self.scaler_path}...")
        
        # Initialize electron scalers and feature groups
        self.electron_scalers = {}
        self.electron_feature_groups = {}
        
        for group_name, params in self.scaler_params['electron'].items():
            # Get feature indices
            self.electron_feature_groups[group_name] = params.get('feature_indices', [])
            
            # Create the appropriate scaler type based on the saved parameters
            scaler_type = params.get('type', 'StandardScaler')
            
            # Initialize the scaler based on type and load its parameters
            self._initialize_scaler(self.electron_scalers, group_name, scaler_type, params)
        
        # Initialize photon scalers and feature groups
        self.photon_scalers = {}
        self.photon_feature_groups = {}
        
        for group_name, params in self.scaler_params['photon'].items():
            # Get feature indices
            self.photon_feature_groups[group_name] = params.get('feature_indices', [])
            
            # Create the appropriate scaler type based on the saved parameters
            scaler_type = params.get('type', 'StandardScaler')
            
            # Initialize the scaler based on type and load its parameters
            self._initialize_scaler(self.photon_scalers, group_name, scaler_type, params)
        
        # Initialize vertex scaler (simple case)
        vertex_params = self.scaler_params['vertex']
        self.vertex_scaler = StandardScaler()
        if 'mean' in vertex_params and 'scale' in vertex_params:
            self.vertex_scaler.mean_ = np.array(vertex_params['mean'])
            self.vertex_scaler.scale_ = np.array(vertex_params['scale'])
            self.vertex_scaler.var_ = np.square(self.vertex_scaler.scale_)
        
        print("Scalers loaded successfully.")

    def _initialize_scaler(self, scalers_dict, group_name, scaler_type, params):
        """Initialize a specific scaler based on its type and parameters."""
        if scaler_type == 'StandardScaler':
            scaler = StandardScaler()
            if 'mean' in params and 'scale' in params:
                scaler.mean_ = np.array(params['mean'])
                scaler.scale_ = np.array(params['scale'])
                scaler.var_ = np.square(scaler.scale_)
        
        elif scaler_type == 'RobustScaler':
            scaler = RobustScaler()
            if 'center' in params and 'scale' in params:
                scaler.center_ = np.array(params['center'])
                scaler.scale_ = np.array(params['scale'])
        
        elif scaler_type == 'MinMaxScaler':
            scaler = MinMaxScaler(feature_range=(-1, 1))
            scaler.min_ = np.array(params["min_"])
            scaler.scale_ = np.array(params["scale_"])
        
        elif scaler_type == 'PowerTransformer':
            method = params.get('method', 'yeo-johnson')
            standardize = params.get('standardize', True)
            scaler = PowerTransformer(method=method, standardize=standardize)
            scaler.lambdas_ = np.array(params['lambdas'])

            # Restore internal StandardScaler for standardization
            if standardize:
                scaler._scaler = StandardScaler()
                scaler._scaler.mean_ = np.array(params['mean'])
                scaler._scaler.scale_ = np.array(params['scale'])
            
            # Restore internal scaler manually if standardization was used
            if standardize and 'mean' in params and 'scale' in params:
                scaler._scaler = StandardScaler()
                scaler._scaler.mean_ = np.array(params['mean'])
                scaler._scaler.scale_ = np.array(params['scale'])
        
        else:
            # Default to StandardScaler if unknown type
            print(f"Warning: Unknown scaler type '{scaler_type}' for {group_name}, using StandardScaler")
        
        # Store the initialized scaler
        scalers_dict[group_name] = scaler

    def load_all_data(self):
        """Load and preprocess signal data from HDF5 files."""
        print("Loading all signal data...")
        
        # Initialize as None for first file
        self.electron_features = None
        self.photon_features = None
        self.vertex_features = None
        
        file_count = 0
        
        for file_path in self.data_dir.glob("*.h5"):
            with h5py.File(file_path, 'r', rdcc_nbytes=10*1024*1024) as f:
                n_events = len(f['events/PV_x'])
                print(f"Processing {file_path.name}: {n_events} events")
                
                # Load all data at once
                electrons = {feat: f[f'events/electrons/{feat}'][:] for feat in self.electron_features_list}
                photons = {feat: f[f'events/photons/{feat}'][:] for feat in self.photon_features_list}
                vertices = np.stack([
                    f['events/PV_x'][:],
                    f['events/PV_y'][:],
                    f['events/PV_z'][:]
                ], axis=1)
                
                # Process all events at once - for signal, only filter based on energy
                e_mask = (electrons['electron_E'] > 0)
                
                # Initialize arrays for all events
                e_feats = np.zeros((n_events, 4, len(self.electron_features_list)))
                p_feats = np.zeros((n_events, 4, len(self.photon_features_list)))
                
                # Process all events at once
                for feat_idx, feat in enumerate(self.electron_features_list):
                    e_feats[..., feat_idx] = electrons[feat]
                    e_feats[..., feat_idx] = np.where(e_mask, e_feats[..., feat_idx], 0)  # Zero out electrons failing selection
                
                for feat_idx, feat in enumerate(self.photon_features_list):
                    p_feats[..., feat_idx] = photons[feat]

                # Apply event filtering: Require at least two objects
                electron_count = np.sum(e_feats[:, :, 0] > 0, axis=1)
                photon_count = np.sum(p_feats[:, :, 0] > 0, axis=1)
                total_count = electron_count + photon_count

                # Create mask for events with at least 2 objects
                valid_events = total_count >= 2
                
                # Apply the filter
                e_feats = e_feats[valid_events]
                p_feats = p_feats[valid_events]
                vertices = vertices[valid_events]

                # Add to main arrays
                if self.electron_features is None:
                    self.electron_features = e_feats
                    self.photon_features = p_feats
                    self.vertex_features = vertices
                else:
                    self.electron_features = np.concatenate([self.electron_features, e_feats])
                    self.photon_features = np.concatenate([self.photon_features, p_feats])
                    self.vertex_features = np.concatenate([self.vertex_features, vertices])
                
                file_count += 1
                print(f"Processed {file_count} files, total events: {len(self.electron_features):,}")
        
        print(f"\nFinal dataset:")
        print(f"Total files processed: {file_count}")
        print(f"Total events: {len(self.electron_features):,}")
        print(f"Shapes: electrons {self.electron_features.shape}, photons {self.photon_features.shape}, vertices {self.vertex_features.shape}")
        
        # Apply saved scalers
        print("\nApplying saved scalers to signal data...")
        self._transform_features()
        
        print(f"Final dataset: {len(self.electron_features):,} events")
        print(f"Shapes: electrons {self.electron_features.shape}, photons {self.photon_features.shape}, vertices {self.vertex_features.shape}")    
    
    def _transform_features(self):
        """Transform features using loaded scalers."""
        # Create working copies
        e_feats_transformed = self.electron_features.copy()
        p_feats_transformed = self.photon_features.copy()
        
        # Process electron features by group
        for group_name, feature_indices in self.electron_feature_groups.items():
            # Get all feature data for this group at once
            group_values = np.column_stack([
                self.electron_features[:, :, idx].reshape(-1, 1) 
                for idx in feature_indices
            ])
            
            # Transform all features in the group together
            transformed_values = self.electron_scalers[group_name].transform(group_values)
            
            # Split back into individual features and update
            for i, feat_idx in enumerate(feature_indices):
                feat_transformed = transformed_values[:, i].reshape(
                    self.electron_features.shape[0], self.electron_features.shape[1]
                )
                e_feats_transformed[:, :, feat_idx] = feat_transformed
        
        # Process photon features by group
        for group_name, feature_indices in self.photon_feature_groups.items():
            # Get all feature data for this group at once
            group_values = np.column_stack([
                self.photon_features[:, :, idx].reshape(-1, 1) 
                for idx in feature_indices
            ])
            
            # Transform all features in the group together
            transformed_values = self.photon_scalers[group_name].transform(group_values)
            
            # Split back into individual features and update
            for i, feat_idx in enumerate(feature_indices):
                feat_transformed = transformed_values[:, i].reshape(
                    self.photon_features.shape[0], self.photon_features.shape[1]
                )
                p_feats_transformed[:, :, feat_idx] = feat_transformed
        
        # For vertices, simple standard scaling
        self.vertex_features = self.vertex_scaler.transform(self.vertex_features)
        
        # Update features with transformed versions
        self.electron_features = e_feats_transformed
        self.photon_features = p_feats_transformed 
    
    def transform_new_data(self, electron_features, photon_features, vertex_features):
        """Transform new data using loaded scalers."""
        # Create working copies
        e_feats_transformed = electron_features.copy()
        p_feats_transformed = photon_features.copy()
        
        # Process electron features by group
        for group_name, feature_indices in self.electron_feature_groups.items():
            # Get all feature data for this group at once
            group_values = np.column_stack([
                electron_features[:, :, idx].reshape(-1, 1) 
                for idx in feature_indices
            ])
            
            # Transform all features in the group together
            transformed_values = self.electron_scalers[group_name].transform(group_values)
            
            # Split back into individual features and update
            for i, feat_idx in enumerate(feature_indices):
                feat_transformed = transformed_values[:, i].reshape(
                    electron_features.shape[0], electron_features.shape[1]
                )
                e_feats_transformed[:, :, feat_idx] = feat_transformed
        
        # Process photon features by group
        for group_name, feature_indices in self.photon_feature_groups.items():
            # Get all feature data for this group at once
            group_values = np.column_stack([
                photon_features[:, :, idx].reshape(-1, 1) 
                for idx in feature_indices
            ])
            
            # Transform all features in the group together
            transformed_values = self.photon_scalers[group_name].transform(group_values)
            
            # Split back into individual features and update
            for i, feat_idx in enumerate(feature_indices):
                feat_transformed = transformed_values[:, i].reshape(
                    photon_features.shape[0], photon_features.shape[1]
                )
                p_feats_transformed[:, :, feat_idx] = feat_transformed
        
        # For vertices, simple standard scaling
        v_feats_transformed = self.vertex_scaler.transform(vertex_features)
        
        return e_feats_transformed, p_feats_transformed, v_feats_transformed

    def get_all(self):
        """Return all processed data as a tuple."""
        all_data = (
            self.electron_features,
            self.photon_features,
            self.vertex_features
        )
        return all_data

In [72]:
# Test data loading and preprocessing
data_dir = "/fs/ddn/sdf/group/atlas/d/hjia625/VLL-DP/VLL_classifier/hdf5_signal_output"
scalar_path = "/fs/ddn/sdf/group/atlas/d/hjia625/VLL-DP/VLL_classifier/src/output/scaler_params.json"
reader = SignalDataReader(data_dir, scalar_path)

Loading specialized scalers from /fs/ddn/sdf/group/atlas/d/hjia625/VLL-DP/VLL_classifier/src/output/scaler_params.json...
Scalers loaded successfully.
Loading all signal data...
Processing signal_543784.e8564_e8528_s4277_s4114_r15530_r15514_p6069.43297402._000001.trees.h5: 50000 events
Processed 1 files, total events: 38,720
Processing signal_543785.e8564_e8528_s4277_s4114_r15530_r15514_p6069.43297406._000001.trees.h5: 50000 events
Processed 2 files, total events: 46,987
Processing signal_543828.e8564_e8528_s4277_s4114_r15530_r15514_p6069.43297651._000001.trees.h5: 50000 events
Processed 3 files, total events: 59,845
Processing signal_543822.e8564_e8528_s4277_s4114_r15530_r15514_p6069.43297639._000001.trees.h5: 50000 events
Processed 4 files, total events: 61,561
Processing signal_543832.e8564_e8528_s4237_s4114_r15540_r15516_p6069.43297658._000001.trees.h5: 50000 events
Processed 5 files, total events: 110,277
Processing signal_543790.e8564_e8528_s4277_s4114_r15530_r15514_p6069.4329743

In [73]:
# Get train/val/test splits
all_signal = reader.get_all()

In [74]:
electron_features, photon_features, vertices = all_signal
np.array(electron_features).shape

(2080280, 4, 15)

In [75]:
np.save('signal_data_e.npy', electron_features)
np.save('signal_data_p.npy', photon_features)
np.save('signal_data_v.npy', vertices)

In [32]:
print(electron_features[3])

[[ 1.54841698e+00  1.54757093e+00 -9.64014450e-01 -1.32190589e+00
  -1.40644550e-01  5.41521502e+00 -2.67910218e+00 -2.40094867e-02
   1.00000000e+00  0.00000000e+00  2.19650166e-01  1.59134593e+00
   5.54905102e+01  2.28443714e-01  1.97551516e-05]
 [ 1.60289098e+00  1.51562423e+00 -5.71524059e-01 -1.22031674e+00
  -3.16543430e-01 -6.61583138e+00 -6.54915314e+01 -5.82080841e-01
   1.00000000e+00  0.00000000e+00  2.19650166e-01  1.60533610e+00
   4.97479642e+01 -8.80099228e-02 -2.45832784e-03]
 [ 1.56390265e+00  1.47006406e+00 -1.39032884e+00 -5.30572935e-03
  -3.53752375e-02 -2.10090446e+01 -6.62908249e+01 -1.31678909e-01
   0.00000000e+00  0.00000000e+00  6.92112628e-01  1.59134593e+00
   4.39527342e+01  7.63784909e-01 -4.06860080e-03]
 [ 1.49462312e+00  1.46598289e+00 -1.18399247e+00 -6.81469054e-01
   4.22006011e-01  1.84539413e+01 -7.22690353e+01 -1.24363974e-01
   1.00000000e+00  0.00000000e+00  6.92112628e-01  1.58928858e+00
   1.75476819e+01  2.94126450e+00 -6.55712443e-03]]


In [67]:
# Electrons - shape (1, 4, 15)
e_feats = np.array([
    [  # Event 0
        # Electron 0 (15 features)
        [65.2, 55.3, 1.2, 0.5, 1.1, 0.05, 0.15, 0.08, 6, 1, 1.5, 7, 0.45, 0.15, 35.0],
        # Electron 1 
        [48.7, 42.1, -0.8, 2.1, 0.9, -0.03, 0.08, 0.07, 5, 0, 0.9, 6, 0.52, 0.18, -42.0], 
        # Electron 2
        [37.1, 35.6, 0.5, -1.8, 1.3, 0.02, 0.11, 0.09, 7, 1, 1.2, 8, 0.38, 0.22, 15.0],
        # Electron 3 (padding - all zeros)
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0, 0, 0.0, 0, 0.0, 0.0, 0.0]
    ],
    [  # Event 0
        # Electron 0 (15 features)
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0, 0, 0.0, 0, 0.0, 0.0, 0.0],
        # Electron 1 
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0, 0, 0.0, 0, 0.0, 0.0, 0.0], 
        # Electron 2
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0, 0, 0.0, 0, 0.0, 0.0, 0.0],
        # Electron 3 (padding - all zeros)
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0, 0, 0.0, 0, 0.0, 0.0, 0.0]
    ]
])

# Photons - shape (1, 4, 13)
p_feats = np.array([
    [  # Event 0
        # Photon 0 (13 features)
        [72.5, 68.3, 0.7, 1.2, 0.8, 12.5, 0.55, 0.12, 0.83, 0.91, 0.025, 0.015, 24.0],
        # Photon 1
        [63.1, 61.2, -1.1, -0.5, 0.9, 10.2, 0.48, 0.15, 0.78, 0.87, 0.031, 0.018, -18.0],
        # Photon 2 (padding - all zeros)
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        # Photon 3 (padding - all zeros)
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
    ],
    [  # Event 0
        # Photon 0 (13 features)
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        # Photon 1
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        # Photon 2 (padding - all zeros)
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        # Photon 3 (padding - all zeros)
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
    ]
])

# Vertex - shape (1, 3)
v_feats = np.array([
    [0.05, -0.03, 12.4], [0.0, 0.0, 0.0]  # x, y, z coordinates
])

In [68]:
e_feats_transformed, p_feats_transformed, v_feats_transformed = reader.transform_new_data(e_feats, p_feats, v_feats)

In [69]:
print(e_feats_transformed, p_feats_transformed, v_feats_transformed)

[[[ 1.59097474e+00  1.62705790e+00  4.78493591e-01  1.59154887e-01
    1.10000000e+00  5.00000000e-02  1.50000000e-01  8.00000000e-02
    2.77477172e+00  2.26772628e+00  1.66822786e+00  1.24570654e+00
    2.58125923e+00  2.49537196e+01  4.48486359e-02]
  [ 1.56164954e+00  1.60462253e+00 -3.21643342e-01  6.68450768e-01
    9.00000000e-01 -3.00000000e-02  8.00000000e-02  7.00000000e-02
    2.21405888e+00 -3.00438871e-01  1.51182443e+00  1.19477551e+00
    3.07491851e+00  3.00278616e+01 -5.17347392e-02]
  [ 1.52884413e+00  1.58833992e+00  1.98445664e-01 -5.72957942e-01
    1.30000000e+00  2.00000000e-02  1.10000000e-01  9.00000000e-02
    3.33548457e+00  2.26772628e+00  1.61353262e+00  1.28644253e+00
    2.08759994e+00  3.67933842e+01  1.97620449e-02]
  [-6.34503288e-01 -6.34618939e-01 -1.58856875e-03 -7.58910003e-08
    0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
   -5.89505351e-01 -3.00438871e-01 -6.33984016e-01 -6.34877193e-01
   -5.92264747e-01 -4.16990332e-01  9.47