In [1]:
import numpy as np
import os
import math
import sys
import matplotlib.pyplot as plt
# 20 sec

In [2]:
import sqlite3 as sql
import pandas as pd
import pyarrow.parquet as pq

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer, LightningDataModule
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
# 10 sec

In [5]:
sys.path.append('/groups/icecube/cyan/Utils')
from PlotUtils import setMplParam, getColour, getHistoParam 
# getHistoParam:
# Nbins, binwidth, bins, counts, bin_centers  = 
from DB_lister import list_content, list_tables
from ExternalFunctions import nice_string_output, add_text_to_ax
setMplParam()

In [None]:
NuMu_PeV_root = "/lustre/hpc/project/icecube/HE_Nu_Aske_Oct2024/PMTfied/Snowstorm/22012/"
truth_1_parquet = NuMu_PeV_root + "truth_1.parquet"
PMTfied_1 = NuMu_PeV_root + "1/"
PMTfied_1_1 = PMTfied_1 + "PMTfied_1.parquet"

In [None]:
def get_files_in_dir(directory, extension='.parquet'):
    return [f for f in os.listdir(directory) if f.endswith(extension)]
def get_subdir_in_dir(directory):
    return [name for name in os.listdir(directory) if os.path.isdir(os.path.join(directory, name))]

In [28]:
def convertParquetToDF(file:str) -> pd.DataFrame:
    table = pq.read_table(file)
    df = table.to_pandas()
    return df

In [None]:
convertParquetToDF(truth_1_parquet)
# 20 sec
# 29386 rows × 11 columns

Unnamed: 0,event_no,subdirectory_no,db_file_no,shard_index,file_no,N_doms,offset,energy,azimuth,zenith,pid
0,0,22012,1,1,1,28,0,8.238094e+06,5.843885,0.917283,-14.0
1,1,22012,1,1,1,132,28,5.375572e+07,5.431356,1.708568,14.0
2,2,22012,1,1,1,313,160,2.576900e+07,2.869423,0.965348,14.0
3,3,22012,1,1,1,279,473,6.547825e+06,5.131130,0.480436,14.0
4,4,22012,1,1,1,741,752,8.446194e+07,0.912671,0.921206,14.0
...,...,...,...,...,...,...,...,...,...,...,...
29381,29381,22012,1,15,15,264,493344,6.344052e+06,2.746214,0.611205,-14.0
29382,29382,22012,1,15,15,365,493608,4.438849e+06,0.573137,0.925718,-14.0
29383,29383,22012,1,15,15,232,493973,1.867177e+07,3.502039,0.844169,14.0
29384,29384,22012,1,15,15,32,494205,1.297350e+07,2.180361,2.612760,14.0


In [None]:
convertParquetToDF(PMTfied_1_1)
# 16 sec
# 695403 rows × 24 columns

Unnamed: 0,dom_x,dom_y,dom_z,dom_x_rel,dom_y_rel,dom_z_rel,pmt_area,rde,saturation_status,q1,...,Qtotal,hlc1,hlc2,hlc3,t1,t2,t3,T10,T50,sigmaT
0,576.369995,170.919998,-271.890015,0.000000,0.000000,0.000000,0.0444,1.00,-1,0.675,...,0.675,0,-1,-1,14581.0,-1.0,-1.0,-1.0,-1.0,6874.020508
1,-234.949997,140.440002,312.839996,-76.781044,45.895424,111.693008,0.0444,1.00,-1,1.275,...,1.275,1,-1,-1,9863.0,-1.0,-1.0,-1.0,-1.0,4649.934082
2,-234.949997,140.440002,295.820007,-76.781044,45.895424,94.673004,0.0444,1.00,-1,0.475,...,0.475,1,-1,-1,10097.0,-1.0,-1.0,-1.0,-1.0,4760.242676
3,-234.949997,140.440002,278.790009,-76.781044,45.895424,77.643005,0.0444,1.00,-1,0.825,...,0.825,1,-1,-1,10338.0,-1.0,-1.0,-1.0,-1.0,4873.851562
4,-111.510002,159.979996,262.239990,-59.034706,84.695297,153.412582,0.0444,1.00,-1,0.775,...,0.775,1,-1,-1,10486.0,-1.0,-1.0,-1.0,-1.0,4943.619141
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
695398,72.370003,-66.599998,-440.170013,4.669030,-4.296772,-11.664500,0.0444,1.00,-1,0.775,...,0.775,0,-1,-1,12685.0,-1.0,-1.0,-1.0,-1.0,5980.237793
695399,72.370003,-66.599998,-461.200012,4.669030,-4.296772,-32.694500,0.0444,1.00,-1,0.725,...,0.725,1,-1,-1,13096.0,-1.0,-1.0,-1.0,-1.0,6173.984863
695400,72.370003,-66.599998,-475.220001,4.669030,-4.296772,-46.714500,0.0444,1.00,-1,0.675,...,0.675,1,-1,-1,12155.0,-1.0,-1.0,-1.0,-1.0,5730.393555
695401,113.190002,-60.470001,-469.480011,2.219410,-1.185685,0.892362,0.0444,1.35,-1,1.275,...,1.275,0,-1,-1,12839.0,-1.0,-1.0,-1.0,-1.0,6052.833984


In [53]:
class PMTfiedDataset(Dataset):
    def __init__(self, df_features, df_truth):
        self.dataframe = df_features
        self.event_wise_features, self.events_features = self.__get_event_wise_dic__(df_features)
        self.event_wise_truth, self.events_truth = self.__get_event_wise_dic__(df_truth)
        self.__check_events__()
        self.feature_columns = ['dom_string', 'dom_number', 
                                'dom_x', 'dom_y', 'dom_z',
                                'dom_x_rel', 'dom_y_rel', 'dom_z_rel', 
                                'pmt_area', 'rde', 'saturation_status', 
                                'q1', 'q2', 'q3', 
                                'Q25', 'Q75', 'Qtotal', 
                                'hlc1','hlc2', 'hlc3', 
                                't1', 't2', 't3', 
                                'T10', 'T50', 'sigmaT']
        self.truth_columns = ['energy', 
                            'position_x', 'position_y', 'position_z', 
                            'azimuth', 'zenith',
                            'pid', 'event_time', 'interaction_type', 'elasticity']
        
        # self.features = dataframe.drop(columns=['event_no']).values
        # self.labels = dataframe['event_no'].values

    def __getitem__(self, idx):
        # Get the event number for the current index
        event_no = self.events_features[idx]
        
        # Get features and truth data for the event
        features_df = self.event_wise_features[event_no][self.feature_columns]
        truth_df = self.event_wise_truth[event_no][self.truth_columns]

        # Convert features to numpy arrays and flatten them
        x = features_df.values.flatten()
        y = truth_df.values.flatten()  # Assuming each event has one set of truth values
        
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)
    
    def __get_event_wise_dic__(self, df: pd.DataFrame, group_col: str = 'event_no') -> tuple[dict, list]:
        """
        Groups the input DataFrame by a specified column (default is 'event_no')
        and returns a dictionary of DataFrames for each unique event, as well as a
        list of unique event identifiers.

        Parameters:
        - df (pd.DataFrame): The input DataFrame to group.
        - group_col (str): The column name to group by (typically 'event_no').
                            Default is 'event_no'.

        Returns:
        - dict: A dictionary where each key is a unique event identifier, and the 
                corresponding value is a DataFrame containing data specific to that event.
        - list: A list of unique event identifiers found in the specified column.
        """
        grouped_data = df.groupby(group_col)
        event_wise_data = {event_no: group.copy() for event_no, group in grouped_data}
        events = list(event_wise_data.keys())
        return event_wise_data, events
    
    def __check_events__(self):
        if self.events_features != self.events_truth:
            print("[PMTfiedDataset]Events are not identical")
            raise AssertionError("[PMTfiedDataset]Events are not identical")
        else:
            print(f"[PMTfiedDataset]Events lists are identical, {len(self.events_features)} events")
                

In [54]:
d = PMTfiedDataset(pmtfied_df, truth_df)

[PMTfiedDataset]Events lists are identical, 73 events


In [None]:
class PMTfiedDataModule(LightningDataModule):
    def __init__(self, dataframe, batch_size=32, num_workers=4):
        super().__init__()
        self.dataframe = dataframe
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        self.dataset = PMTfiedDataset(self.dataframe)

    def train_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.dataset, batch_size=self.batch_size, num_workers=self.num_workers)
