In [15]:
import numpy as np

import torch

from torch.utils.data import Dataset
from torch_geometric.data import Data

import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.parquet as pq

def feature_preprocessing(col_name, value) -> np.ndarray:
    """
    Preprocess the input features when creating the dataset.
    
    Args:
    - col_name: The name of the column to preprocess.
    - value: The value array of the column to preprocess [numpy array].

    Returns:
    - The preprocessed value array [numpy array].
    """
  
    if col_name in ['dom_x', 'dom_y', 'dom_z', 'dom_x_rel', 'dom_y_rel', 'dom_z_rel']:
        value = value / 500
    elif col_name in ['rde']:
        value = (value - 1.25) / 0.25
    elif col_name in ['pmt_area']:
        value = value / 0.05
    elif col_name in ['q1', 'q2', 'q3', 'q4', 'q5', 'Q25', 'Q75', 'Qtotal']:
        mask = value > 0
        value[mask] = np.log10(value[mask])
    elif col_name in ['t1', 't2', 't3','t4', 't5']:
        mask = value > 0
        value[mask] = (value[mask] - 1.0e04) / 3.0e04
    elif col_name in ['T10', 'T50', 'sigmaT']:
        mask = value > 0
        value[mask] = value[mask] / 1.0e04

    return value


class PMTfiedDatasetPyArrow(Dataset):
    def __init__(
            self, 
            truth_paths_1,
            truth_paths_2,
            truth_paths_3,
            sample_weights = [1, 1, 1],
            selection=None,
            transform=feature_preprocessing,
    ):
        '''
        Args:
        - truth_paths_1: List of paths to the truth files of type 1
        - truth_paths_2: List of paths to the truth files of type 2
        - truth_paths_3: List of paths to the truth files of type 3
        - sample_weights: Ratio to sample from file type 1, 2, 3 respectively
        - selection: List of event numbers to select from the corresponding truth files
        - transform: Function to apply to the features as preprocessing
        '''

        self.truth_paths_1 = truth_paths_1
        self.truth_paths_2 = truth_paths_2
        self.truth_paths_3 = truth_paths_3
        self.selection = selection
        self.transform = transform

        # Metadata variables
        self.event_counts = [0,0,0]
        self.cumulative_event_counts_1 = []
        self.cumulative_event_counts_2 = []
        self.cumulative_event_counts_3 = []

        self.current_file_idx_1 = None
        self.current_file_idx_2 = None
        self.current_file_idx_3 = None

        self.current_truth_1 = None
        self.current_truth_2 = None
        self.current_truth_3 = None

        self.current_feature_path_1 = None
        self.current_feature_path_2 = None
        self.current_feature_path_3 = None
     
        self.current_features_1 = None
        self.current_features_2 = None
        self.current_features_3 = None

        # Scan the truth files to get the event 
        total_events = 0
        for path in self.truth_paths_1:
            truth = pq.read_table(path)
            if self.selection is not None:
                mask = pc.is_in(truth['event_no'], value_set=pa.array(self.selection))
                truth = truth.filter(mask)
            n_events = len(truth)
            self.event_counts[0] += n_events
            total_events += n_events
            self.cumulative_event_counts_1.append(total_events)


        total_events = 0
        for path in self.truth_paths_2:
            truth = pq.read_table(path)
            if self.selection is not None:
                mask = pc.is_in(truth['event_no'], value_set=pa.array(self.selection))
                truth = truth.filter(mask)
            n_events = len(truth)
            self.event_counts[1] += n_events
            total_events += n_events
            self.cumulative_event_counts_2.append(total_events)


        total_events = 0
        for path in self.truth_paths_3:
            truth = pq.read_table(path)
            if self.selection is not None:
                mask = pc.is_in(truth['event_no'], value_set=pa.array(self.selection))
                truth = truth.filter(mask)
            n_events = len(truth)
            self.event_counts[2] += n_events
            total_events += n_events
            self.cumulative_event_counts_3.append(total_events)

        self.sample_weights = sample_weights
        self.total_weights = sum(self.sample_weights)

        print('Total events:', self.event_counts)
        print('Cumulative event counts 1:', self.cumulative_event_counts_1)
        print('Cumulative event counts 2:', self.cumulative_event_counts_2)
        print('Cumulative event counts 3:', self.cumulative_event_counts_3)

    def __len__(self):
        # Devide the event counts per file type by the sample weights, take the minimum times the sample weights
        return min([count // weight for count, weight in zip(self.event_counts, self.sample_weights)]) * self.total_weights

    def __getitem__(self, idx):
        # Find the file index for the given event index, sampling from different truth lists
        
        set_idx = idx // self.total_weights
        mod_idx = idx % self.total_weights

        if mod_idx < self.sample_weights[0]:
            file_idx = np.searchsorted(self.cumulative_event_counts_1, self.sample_weights[0]*set_idx + mod_idx, side='right')
            local_idx = self.sample_weights[0]*set_idx + mod_idx if file_idx == 0 else self.sample_weights[0]*set_idx + mod_idx - self.cumulative_event_counts_1[file_idx - 1]
            truth_path = self.truth_paths_1[file_idx]

            if file_idx != self.current_file_idx_1:
                self.current_file_idx_1 = file_idx

                truth = pq.read_table(truth_path)
                if self.selection is not None:
                    mask = pc.is_in(truth['event_no'], value_set=pa.array(self.selection))
                    self.current_truth_1 = truth.filter(mask)
                else:
                    self.current_truth_1 = truth

            truth = self.current_truth_1


            # Get the event details
            event_no = torch.tensor(int(truth.column('event_no')[local_idx].as_py()), dtype=torch.long)
            energy = torch.tensor(truth.column('energy')[local_idx].as_py(), dtype=torch.float32)
            # azimuth = torch.tensor(truth.column('azimuth')[local_idx].as_py(), dtype=torch.float32)
            # zenith = torch.tensor(truth.column('zenith')[local_idx].as_py(), dtype=torch.float32)
            pid = torch.tensor(truth.column('pid')[local_idx].as_py(), dtype=torch.float32)
            
            abs_pid = int(torch.abs(pid))
            
            if abs_pid == 12:
                one_hot_pid = torch.tensor([1, 0, 0], dtype=torch.float32)
            elif abs_pid == 14:
                one_hot_pid = torch.tensor([0, 1, 0], dtype=torch.float32)
            elif abs_pid == 16:
                one_hot_pid = torch.tensor([0, 0, 1], dtype=torch.float32)
            else:
                one_hot_pid = torch.tensor([-1, -1, -1], dtype=torch.float32)

            # Calculate a 3D unit-vector from the zenith and azimuth angles
            # x_dir = torch.sin(zenith) * torch.cos(azimuth)
            # y_dir = torch.sin(zenith) * torch.sin(azimuth)
            # z_dir = torch.cos(zenith)


            offset = int(truth.column('offset')[local_idx].as_py())
            n_doms = int(truth.column('N_doms')[local_idx].as_py())
            part_no = int(truth.column('part_no')[local_idx].as_py())
            shard_no = int(truth.column('shard_no')[local_idx].as_py())

            # Define the feature path based on the truth path
            feature_path = truth_path.replace('truth_{}.parquet'.format(part_no), '' + str(part_no) + '/PMTfied_{}.parquet'.format(shard_no))

            # x from rows (offset-n_doms) to offset
            start_row = offset - n_doms

            # Load the features and apply preprocessing
            if feature_path != self.current_feature_path_1:
                self.current_feature_path_1 = feature_path
                self.current_features_1 = pq.read_table(feature_path)

            features = self.current_features_1

            x = features.slice(start_row, n_doms)
            # drop the first two columns (event_no and original_event_no)
            x = x.drop_columns(['event_no', 'original_event_no'])
            num_columns = x.num_columns

            x_tensor = torch.full((n_doms, num_columns), fill_value=torch.nan, dtype=torch.float32)

            for i, col_name in enumerate(x.column_names):
                value = x.column(i).to_numpy()
                value = value.copy()
                value = self.transform(col_name, value)
                # convert to torch tensor
                value_tensor = torch.from_numpy(value)
                x_tensor[:, i] = value_tensor

            return Data(x=x_tensor, n_doms=n_doms, event_no=event_no, feature_path=feature_path, energy=energy, pid=pid, one_hot_pid=one_hot_pid)

        elif mod_idx < self.sample_weights[0] + self.sample_weights[1]:
            file_idx = np.searchsorted(self.cumulative_event_counts_2, self.sample_weights[1]*set_idx + mod_idx - self.sample_weights[0], side='right')
            local_idx = self.sample_weights[1]*set_idx + mod_idx - self.sample_weights[0] if file_idx == 0 else self.sample_weights[1]*set_idx + mod_idx - self.sample_weights[0] - self.cumulative_event_counts_2[file_idx - 1]
            truth_path = self.truth_paths_2[file_idx]

            if file_idx != self.current_file_idx_2:
                self.current_file_idx_2 = file_idx

                truth = pq.read_table(truth_path)
                if self.selection is not None:
                    mask = pc.is_in(truth['event_no'], value_set=pa.array(self.selection))
                    self.current_truth_2 = truth.filter(mask)
                else:
                    self.current_truth_2 = truth

            truth = self.current_truth_2

            # Get the event details
            event_no = torch.tensor(int(truth.column('event_no')[local_idx].as_py()), dtype=torch.long)
            energy = torch.tensor(truth.column('energy')[local_idx].as_py(), dtype=torch.float32)
            # azimuth = torch.tensor(truth.column('azimuth')[local_idx].as_py(), dtype=torch.float32)
            # zenith = torch.tensor(truth.column('zenith')[local_idx].as_py(), dtype=torch.float32)
            pid = torch.tensor(truth.column('pid')[local_idx].as_py(), dtype=torch.float32)

            # Calculate a 3D unit-vector from the zenith and azimuth angles
            # x_dir = torch.sin(zenith) * torch.cos(azimuth)
            # y_dir = torch.sin(zenith) * torch.sin(azimuth)
            # z_dir = torch.cos(zenith)

            # Stack to dir3vec tensor
            # dir3vec = torch.stack([x_dir, y_dir, z_dir], dim=-1)
            
            abs_pid = int(torch.abs(pid))
            
            if abs_pid == 12:
                one_hot_pid = torch.tensor([1, 0, 0], dtype=torch.float32)
            elif abs_pid == 14:
                one_hot_pid = torch.tensor([0, 1, 0], dtype=torch.float32)
            elif abs_pid == 16:
                one_hot_pid = torch.tensor([0, 0, 1], dtype=torch.float32)
            else:
                one_hot_pid = torch.tensor([-1, -1, -1], dtype=torch.float32)

            offset = int(truth.column('offset')[local_idx].as_py())
            n_doms = int(truth.column('N_doms')[local_idx].as_py())
            part_no = int(truth.column('part_no')[local_idx].as_py())
            shard_no = int(truth.column('shard_no')[local_idx].as_py())

            # Define the feature path based on the truth path
            feature_path = truth_path.replace('truth_{}.parquet'.format(part_no), '' + str(part_no) + '/PMTfied_{}.parquet'.format(shard_no))

            # x from rows (offset-n_doms) to offset
            start_row = offset - n_doms

            # Load the features and apply preprocessing
            if feature_path != self.current_feature_path_2:
                self.current_feature_path_2 = feature_path
                self.current_features_2 = pq.read_table(feature_path)

            features = self.current_features_2

            x = features.slice(start_row, n_doms)
            # drop the first two columns (event_no and original_event_no)
            x = x.drop_columns(['event_no', 'original_event_no'])
            num_columns = x.num_columns

            x_tensor = torch.full((n_doms, num_columns), fill_value=torch.nan, dtype=torch.float32)

            for i, col_name in enumerate(x.column_names):
                value = x.column(i).to_numpy()
                value = value.copy()
                value = self.transform(col_name, value)
                # convert to torch tensor
                value_tensor = torch.from_numpy(value)
                x_tensor[:, i] = value_tensor

            return Data(x=x_tensor, n_doms=n_doms, event_no=event_no, feature_path=feature_path, energy=energy, pid=pid, one_hot_pid=one_hot_pid)

        elif mod_idx < self.sample_weights[0] + self.sample_weights[1] + self.sample_weights[2]:
            file_idx = np.searchsorted(self.cumulative_event_counts_3, set_idx*self.sample_weights[2] + mod_idx - self.sample_weights[0] - self.sample_weights[1], side='right')
            local_idx = self.sample_weights[2]*set_idx + mod_idx - self.sample_weights[0] - self.sample_weights[1] if file_idx == 0 else self.sample_weights[2]*set_idx + mod_idx - self.sample_weights[0] - self.sample_weights[1] - self.cumulative_event_counts_3[file_idx - 1]
            truth_path = self.truth_paths_3[file_idx]

            if file_idx != self.current_file_idx_3:
                self.current_file_idx_3 = file_idx

                truth = pq.read_table(truth_path)
                if self.selection is not None:
                    mask = pc.is_in(truth['event_no'], value_set=pa.array(self.selection))
                    self.current_truth_3 = truth.filter(mask)
                else:
                    self.current_truth_3 = truth

            truth = self.current_truth_3

            # Get the event details
            event_no = torch.tensor(int(truth.column('event_no')[local_idx].as_py()), dtype=torch.long)
            energy = torch.tensor(truth.column('energy')[local_idx].as_py(), dtype=torch.float32)
            # azimuth = torch.tensor(truth.column('azimuth')[local_idx].as_py(), dtype=torch.float32)
            # zenith = torch.tensor(truth.column('zenith')[local_idx].as_py(), dtype=torch.float32)
            pid = torch.tensor(truth.column('pid')[local_idx].as_py(), dtype=torch.float32)

            # Calculate a 3D unit-vector from the zenith and azimuth angles
            # x_dir = torch.sin(zenith) * torch.cos(azimuth)
            # y_dir = torch.sin(zenith) * torch.sin(azimuth)
            # z_dir = torch.cos(zenith)

            # Stack to dir3vec tensor
            # dir3vec = torch.stack([x_dir, y_dir, z_dir], dim=-1)
            abs_pid = int(torch.abs(pid))
            
            if abs_pid == 12:
                one_hot_pid = torch.tensor([1, 0, 0], dtype=torch.float32)
            elif abs_pid == 14:
                one_hot_pid = torch.tensor([0, 1, 0], dtype=torch.float32)
            elif abs_pid == 16:
                one_hot_pid = torch.tensor([0, 0, 1], dtype=torch.float32)
            else:
                one_hot_pid = torch.tensor([-1, -1, -1], dtype=torch.float32)

            offset = int(truth.column('offset')[local_idx].as_py())
            n_doms = int(truth.column('N_doms')[local_idx].as_py())
            part_no = int(truth.column('part_no')[local_idx].as_py())
            shard_no = int(truth.column('shard_no')[local_idx].as_py())

            # Define the feature path based on the truth path
            feature_path = truth_path.replace('truth_{}.parquet'.format(part_no), '' + str(part_no) + '/PMTfied_{}.parquet'.format(shard_no))

            # x from rows (offset-n_doms) to offset
            start_row = offset - n_doms

            # Load the features and apply preprocessing
            if feature_path != self.current_feature_path_3:
                self.current_feature_path_3 = feature_path
                self.current_features_3 = pq.read_table(feature_path)

            features = self.current_features_3

            x = features.slice(start_row, n_doms)
            # drop the first two columns (event_no and original_event_no)

            x = x.drop_columns(['event_no', 'original_event_no'])
            num_columns = x.num_columns

            x_tensor = torch.full((n_doms, num_columns), fill_value=torch.nan, dtype=torch.float32)

            for i, col_name in enumerate(x.column_names):
                value = x.column(i).to_numpy()
                value = value.copy()
                value = self.transform(col_name, value)
                # convert to torch tensor
                value_tensor = torch.from_numpy(value)
                x_tensor[:, i] = value_tensor

            return Data(x=x_tensor, n_doms=n_doms, event_no=event_no, feature_path=feature_path, energy=energy, pid=pid, one_hot_pid=one_hot_pid)

In [16]:
truth_1 = ["/lustre/hpc/project/icecube/HE_Nu_Aske_Oct2024/PMTfied/Snowstorm/22011/truth_1.parquet"]
truth_2 = ["/lustre/hpc/project/icecube/HE_Nu_Aske_Oct2024/PMTfied/Snowstorm/22014/truth_1.parquet"]
truth_3 = ["/lustre/hpc/project/icecube/HE_Nu_Aske_Oct2024/PMTfied/Snowstorm/22017/truth_1.parquet"]

In [17]:
dataset = PMTfiedDatasetPyArrow(
    truth_paths_1=truth_1,
    truth_paths_2=truth_2,
    truth_paths_3=truth_3,
    sample_weights=[1, 1, 1],
)

Total events: [400561, 222514, 331040]
Cumulative event counts 1: [400561]
Cumulative event counts 2: [222514]
Cumulative event counts 3: [331040]


In [18]:
print(f" len(dataset) = {len(dataset)}")

 len(dataset) = 667542


In [19]:
for i in range(len(dataset)):
    # print(dataset[i])
    print(f"{dataset[i].pid} {dataset[i].one_hot_pid}")
    if i > 20:
        break
    
    

-14.0 tensor([0., 1., 0.])
12.0 tensor([1., 0., 0.])
16.0 tensor([0., 0., 1.])
14.0 tensor([0., 1., 0.])
-12.0 tensor([1., 0., 0.])
-16.0 tensor([0., 0., 1.])
14.0 tensor([0., 1., 0.])
12.0 tensor([1., 0., 0.])
-16.0 tensor([0., 0., 1.])
14.0 tensor([0., 1., 0.])
12.0 tensor([1., 0., 0.])
16.0 tensor([0., 0., 1.])
14.0 tensor([0., 1., 0.])
12.0 tensor([1., 0., 0.])
-16.0 tensor([0., 0., 1.])
14.0 tensor([0., 1., 0.])
-12.0 tensor([1., 0., 0.])
16.0 tensor([0., 0., 1.])
14.0 tensor([0., 1., 0.])
12.0 tensor([1., 0., 0.])
16.0 tensor([0., 0., 1.])
-14.0 tensor([0., 1., 0.])


In [20]:
import tqdm

for i in tqdm.tqdm(range(len(dataset))):
    dataset[i]
    if i > 100_000:
        break

  7%|▋         | 45045/667542 [01:28<20:28, 506.66it/s]


KeyboardInterrupt: 

In [None]:
from dataloader import custom_collate_fn

In [24]:
from torch.utils.data import DataLoader

In [25]:
train_dataloader = DataLoader(
    dataset=dataset,
    collate_fn = custom_collate_fn,
    batch_size=64,
    shuffle=False,
    num_workers=1,
    persistent_workers=True, # necessary for caching 
    pin_memory=True, # necessary for caching
    )

In [None]:
# for i, batch in enumerate(train_dataloader):
#     print(f"{batch[0].shape}, {batch[1].shape}, {batch[2]}")
#     if i > 10:
#         break
# # 

torch.Size([64, 256, 32]), torch.Size([64, 3]), tensor([105,  92,  29,  22,  98,  42,  32,  44,  70, 114,  40,  34,  56,  14,
         19,  19, 256,  65, 222,  62,  38,  24, 208,  46, 104,  33,  34,  25,
         89,  34, 128, 256, 163,  28,  42,  29, 179,  19,  25,  44, 241,  37,
         48,  46,  57,  19,  91,  14, 232, 242, 117,  16,  90, 232,  89,  85,
         12,  41,  17, 111, 128, 140,  33,  55])
torch.Size([64, 256, 32]), torch.Size([64, 3]), tensor([ 83,  28,  73,  37,  39, 177,  60, 164,  17,  30,  30,  36,  93,  68,
         41, 256, 127, 152,  92,  10,  18,  46,  24,  69, 256,  24, 128,  23,
        180,  39,  58, 103,  16, 256,  72, 256,  37, 256,  14,  49,  12, 101,
         41,  17, 256,  14,  14,  48,  78,  96,  18,  78, 256,  17,  48,  27,
        148,  66, 256, 201,  32,  29,  95,  56])
torch.Size([64, 256, 32]), torch.Size([64, 3]), tensor([  8,  75,  30,   9,  46,  66,  57, 256,  59, 167,   7,  71,  44,  93,
         44, 101, 256,  44,  53, 204,  48,  89,  65,  26