In [1]:
import numpy as np
import os
import math
import sys
import matplotlib.pyplot as plt
# 20 sec

In [2]:
import sqlite3 as sql
import pandas as pd
import pyarrow.parquet as pq

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer, LightningDataModule
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
# 10 sec

RuntimeError: operator torchvision::nms does not exist

In [None]:
sys.path.append('/groups/icecube/cyan/Utils')
from PlotUtils import setMplParam, getColour, getHistoParam 
# getHistoParam:
# Nbins, binwidth, bins, counts, bin_centers  = 
from DB_lister import list_content, list_tables
from ExternalFunctions import nice_string_output, add_text_to_ax
setMplParam()

In [None]:
PMTfied_data = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/PMTfiedParquet/PMTfied/Level2_NuMu_NuGenCCNC.022012.000111_PMTfied.parquet"
truth_data = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/PMTfiedParquet/truth/Level2_NuMu_NuGenCCNC.022012.000111_truth.parquet"

In [None]:
def get_table_event_count(conn: sql.Connection, table: str) -> int:
    cursor = conn.cursor()
    cursor.execute(f"SELECT COUNT(DISTINCT event_no) FROM {table}")
    event_count = cursor.fetchone()[0]
    return event_count

In [None]:
def convertDBtoDF(file:str, table:str, N_events_total:int, N_events:int = None) -> pd.DataFrame:
    con = sql.connect(file)
    if N_events is None or N_events > N_events_total:
        N_events = N_events_total
    event_no_query = f'SELECT DISTINCT event_no FROM {table} LIMIT {N_events}'
    event_nos = pd.read_sql_query(event_no_query, con)['event_no'].tolist()
    
    event_filter = ','.join(map(str, event_nos))  # Convert to comma-separated string for SQL IN clause
    query = f'SELECT * FROM {table} WHERE event_no IN ({event_filter})'
    
    # Read data and close the connection
    df = pd.read_sql_query(query, con)
    con.close()
    
    return df

In [None]:
def convertParquetToDF(file: str, 
                    N_events_total: int,
                    N_events: int = None) -> pd.DataFrame:
    # Load the full Parquet file into a DataFrame
    df = pd.read_parquet(file)
    
    if N_events is None or N_events > N_events_total:
        N_events = N_events_total
    
    # Filter DataFrame to include only the first N_events unique event_no values
    unique_events = df['event_no'].unique()[:N_events]
    df_filtered = df[df['event_no'].isin(unique_events)]
    
    return df_filtered


In [None]:
pmtfied_df = convertParquetToDF(PMTfied_data)
truth_df = convertParquetToDF(truth_data)

In [22]:
pmtfied_df.columns

Index(['event_no', 'dom_string', 'dom_number', 'dom_x', 'dom_y', 'dom_z',
       'dom_x_rel', 'dom_y_rel', 'dom_z_rel', 'pmt_area', 'rde',
       'saturation_status', 'q1', 'q2', 'q3', 'Q25', 'Q75', 'Qtotal', 'hlc1',
       'hlc2', 'hlc3', 't1', 't2', 't3', 'T10', 'T50', 'sigmaT'],
      dtype='object')

In [24]:
truth_df.columns

Index(['energy', 'position_x', 'position_y', 'position_z', 'azimuth', 'zenith',
       'pid', 'event_time', 'interaction_type', 'elasticity', 'RunID',
       'SubrunID', 'EventID', 'SubEventID', 'dbang_decay_length',
       'track_length', 'stopped_muon', 'energy_track', 'energy_cascade',
       'inelasticity', 'DeepCoreFilter_13', 'CascadeFilter_13',
       'MuonFilter_13', 'OnlineL2Filter_17', 'L3_oscNext_bool',
       'L4_oscNext_bool', 'L5_oscNext_bool', 'L6_oscNext_bool',
       'L7_oscNext_bool', 'Homogenized_QTot', 'MCLabelClassification',
       'MCLabelCoincidentMuons', 'MCLabelBgMuonMCPE',
       'MCLabelBgMuonMCPECharge', 'GNLabelTrackEnergyDeposited',
       'GNLabelTrackEnergyOnEntrance', 'GNLabelTrackEnergyOnEntrancePrimary',
       'GNLabelTrackEnergyDepositedPrimary', 'GNLabelEnergyPrimary',
       'GNLabelCascadeEnergyDepositedPrimary', 'GNLabelCascadeEnergyDeposited',
       'GNLabelEnergyDepositedTotal', 'GNLabelEnergyDepositedPrimary',
       'GNHighestEInIceParticl

In [37]:
# primitive slicer function that will be used in DataLoader
def get_sliced_event(dataframe, group_col='event_no'):
    grouped_data = dataframe.groupby(group_col)
    event_sliced_data = {event_no: group.copy() for event_no, group in grouped_data}
    events = list(event_sliced_data.keys())
    return event_sliced_data, events

In [61]:
event_wise_feature, events_feature = get_sliced_event(pmtfied_df)
event_wise_truth, events_truth = get_sliced_event(truth_df)

In [62]:
event_wise_feature[events_feature[0]].columns

Index(['event_no', 'dom_string', 'dom_number', 'dom_x', 'dom_y', 'dom_z',
       'dom_x_rel', 'dom_y_rel', 'dom_z_rel', 'pmt_area', 'rde',
       'saturation_status', 'q1', 'q2', 'q3', 'Q25', 'Q75', 'Qtotal', 'hlc1',
       'hlc2', 'hlc3', 't1', 't2', 't3', 'T10', 'T50', 'sigmaT'],
      dtype='object')

In [59]:
event_wise_truth[events_truth[0]].columns

Index(['energy', 'position_x', 'position_y', 'position_z', 'azimuth', 'zenith',
       'pid', 'event_time', 'interaction_type', 'elasticity', 'RunID',
       'SubrunID', 'EventID', 'SubEventID', 'dbang_decay_length',
       'track_length', 'stopped_muon', 'energy_track', 'energy_cascade',
       'inelasticity', 'DeepCoreFilter_13', 'CascadeFilter_13',
       'MuonFilter_13', 'OnlineL2Filter_17', 'L3_oscNext_bool',
       'L4_oscNext_bool', 'L5_oscNext_bool', 'L6_oscNext_bool',
       'L7_oscNext_bool', 'Homogenized_QTot', 'MCLabelClassification',
       'MCLabelCoincidentMuons', 'MCLabelBgMuonMCPE',
       'MCLabelBgMuonMCPECharge', 'GNLabelTrackEnergyDeposited',
       'GNLabelTrackEnergyOnEntrance', 'GNLabelTrackEnergyOnEntrancePrimary',
       'GNLabelTrackEnergyDepositedPrimary', 'GNLabelEnergyPrimary',
       'GNLabelCascadeEnergyDepositedPrimary', 'GNLabelCascadeEnergyDeposited',
       'GNLabelEnergyDepositedTotal', 'GNLabelEnergyDepositedPrimary',
       'GNHighestEInIceParticl

In [53]:
class PMTfiedDataset(Dataset):
    def __init__(self, df_features, df_truth):
        self.dataframe = df_features
        self.event_wise_features, self.events_features = self.__get_event_wise_dic__(df_features)
        self.event_wise_truth, self.events_truth = self.__get_event_wise_dic__(df_truth)
        self.__check_events__()
        self.feature_columns = ['dom_string', 'dom_number', 
                                'dom_x', 'dom_y', 'dom_z',
                                'dom_x_rel', 'dom_y_rel', 'dom_z_rel', 
                                'pmt_area', 'rde', 'saturation_status', 
                                'q1', 'q2', 'q3', 
                                'Q25', 'Q75', 'Qtotal', 
                                'hlc1','hlc2', 'hlc3', 
                                't1', 't2', 't3', 
                                'T10', 'T50', 'sigmaT']
        self.truth_columns = ['energy', 
                            'position_x', 'position_y', 'position_z', 
                            'azimuth', 'zenith',
                            'pid', 'event_time', 'interaction_type', 'elasticity']
        
        # self.features = dataframe.drop(columns=['event_no']).values
        # self.labels = dataframe['event_no'].values

    def __getitem__(self, idx):
        # Get the event number for the current index
        event_no = self.events_features[idx]
        
        # Get features and truth data for the event
        features_df = self.event_wise_features[event_no][self.feature_columns]
        truth_df = self.event_wise_truth[event_no][self.truth_columns]

        # Convert features to numpy arrays and flatten them
        x = features_df.values.flatten()
        y = truth_df.values.flatten()  # Assuming each event has one set of truth values
        
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)
    
    def __get_event_wise_dic__(self, df: pd.DataFrame, group_col: str = 'event_no') -> tuple[dict, list]:
        """
        Groups the input DataFrame by a specified column (default is 'event_no')
        and returns a dictionary of DataFrames for each unique event, as well as a
        list of unique event identifiers.

        Parameters:
        - df (pd.DataFrame): The input DataFrame to group.
        - group_col (str): The column name to group by (typically 'event_no').
                            Default is 'event_no'.

        Returns:
        - dict: A dictionary where each key is a unique event identifier, and the 
                corresponding value is a DataFrame containing data specific to that event.
        - list: A list of unique event identifiers found in the specified column.
        """
        grouped_data = df.groupby(group_col)
        event_wise_data = {event_no: group.copy() for event_no, group in grouped_data}
        events = list(event_wise_data.keys())
        return event_wise_data, events
    
    def __check_events__(self):
        if self.events_features != self.events_truth:
            print("[PMTfiedDataset]Events are not identical")
            raise AssertionError("[PMTfiedDataset]Events are not identical")
        else:
            print(f"[PMTfiedDataset]Events lists are identical, {len(self.events_features)} events")
                

In [54]:
d = PMTfiedDataset(pmtfied_df, truth_df)

[PMTfiedDataset]Events lists are identical, 73 events


In [None]:
class PMTfiedDataModule(LightningDataModule):
    def __init__(self, dataframe, batch_size=32, num_workers=4):
        super().__init__()
        self.dataframe = dataframe
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        self.dataset = PMTfiedDataset(self.dataframe)

    def train_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size, num_workers=self.num_workers)
