In [2]:
import pandas as pd
import sys
import os
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats
import scipy.optimize as opt
import math

from tqdm import tqdm

import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.compute as pc
import pyarrow.dataset as ds
import pyarrow.csv as csv
import pyarrow.json as json

from enum import Enum
import abc
from typing import List, Dict, Tuple, Any, Union

In [3]:
sys.path.append('/groups/icecube/cyan/Utils')
from PlotUtils import setMplParam, getColour, getHistoParam 
# getHistoParam:
# Nbins, binwidth, bins, counts, bin_centers = 
from DB_lister import list_content, list_tables
from ExternalFunctions import nice_string_output, add_text_to_ax
setMplParam()

In [4]:
sys.path.append('/groups/icecube/cyan/factory/DOMification')
from Enum.Flavour import Flavour
from Enum.EnergyRange import EnergyRange

In [5]:
root_dir = "/lustre/hpc/project/icecube/HE_Nu_Aske_Oct2024/PMTfied_filtered_second_round/Snowstorm/CC_CRclean_IntraTravelDistance_250/"

root_dir_11 = root_dir + "22011/"
root_dir_12 = root_dir + "22012/"

root_dir_14 = root_dir + "22014/"
root_dir_15 = root_dir + "22015/"

root_dir_17 = root_dir + "22017/"
root_dir_18 = root_dir + "22018/"

root_dir_31 = root_dir + "22031/" # new destination

# pmt_dir_11_1 = root_dir_11 + "1/"
pmt_dir_12_1 = root_dir_12 + "1/"
pmt_dir_12_2 = root_dir_12 + "2/"

truth_11_1 = root_dir_11 + "truth_1.parquet"
truth_12_1 = root_dir_12 + "truth_1.parquet"

truth_14_1 = root_dir_14 + "truth_1.parquet"
truth_15_1 = root_dir_15 + "truth_1.parquet"

truth_17_1 = root_dir_17 + "truth_1.parquet"
truth_18_1 = root_dir_18 + "truth_1.parquet"

In [6]:
np.sqrt(2000*20000)

6324.555320336759

N_events_per_part = 30000  
N_events_per_shard = 3000  
n_combined  = min(1PeV-100PeV x2, 100TeV-1PeV)
| Flavour | 100TeV-1PeV | 1PeV-100PeV | total   | n_combined | n_parts | n_shards |
|---------|-------------|-------------|---------|-----------|-----------|-----|
| νₑ      | 27,498      | 24.857      | 52,355  | 49,714    | 49,714 // 30,000 = 2 | 49,714 // 3000 = 17 |
| νμ      | 165,226     | 96,138      | 261,364 | 165,226   | 165,226 // 30,000 = 5 | 165,226 // 3000 = 55 |
| ντ      | 31,473      | 30.143      | 51,536  | 60,286    | 60,286 // 30,000 = 2 | 60,286 // 3000 = 20 |

22011: 165,226
22012: 96,138
22014: 27,498
22015: 24,857
22017: 31,413
22018: 30,123

nu_e = 27,498 + 24,857  = 52,355
nu_mu = 96,138 + 165,226 = 261,364
nu_tau = 31,413 + 30,123 = 61,536

m ratio e: 27,498 / 52,355 = 0.524
m ratio mu: 96,138 / 261,364 = 0.368
m ratio tau: 31,413 / 61,536 = 0.511

h ratio e: 24,857 / 52,355 = 0.475
h ratio mu: 165,226 / 261,364 = 0.632
h ratio tau: 30,123 / 61,536 = 0.489

In [7]:
print([d for d in os.listdir(root_dir_11) if os.path.isdir(os.path.join(root_dir_11, d))])

['1', '2', '8', '5', '7', '4', '10', '9', '3']


In [8]:
print([f for f in os.listdir(root_dir_11) if f.endswith(".parquet")])

['truth_8.parquet', 'truth_4.parquet', 'truth_10.parquet', 'truth_1.parquet', 'truth_5.parquet', 'truth_9.parquet', 'truth_3.parquet', 'truth_7.parquet', 'truth_2.parquet']


In [9]:
def convertParquetToDF(file:str) -> pd.DataFrame:
    table = pq.read_table(file)
    df = table.to_pandas()
    return df

In [10]:
df_12_truth_1 = convertParquetToDF(truth_12_1)
df_12_truth_1

Unnamed: 0,event_no,original_event_no,subdirectory_no,part_no,shard_no,N_doms,offset,energy,azimuth,zenith,...,TotalColumnDepthCGS,TotalPrimaryWeight,TotalWeight,TotalXsectionCGS,TrueActiveLengthAfter,TrueActiveLengthBefore,TypeWeight,max_interPMT_distance,isWithinIceCube,lepton_intra_distance
0,112000100000006,6,12,1,1,1071,1071,3.670788e+06,4.457279,0.701368,...,3.653102e+05,1.0,3.750276e-04,1.717568e-33,1105.433838,2687.593018,0.5,1023.409241,1,383.836761
1,112000100000010,10,12,1,1,252,1323,2.928983e+07,2.781825,1.309062,...,6.062522e+05,1.0,1.493867e-03,4.122190e-33,863.243591,5847.835449,0.5,1026.477051,0,736.449341
2,112000100000017,17,12,1,1,334,1657,4.015983e+07,3.086867,1.549579,...,5.323734e+06,1.0,1.470618e-02,4.676198e-33,943.018860,56868.835938,0.5,1145.506836,0,1018.752136
3,112000100000044,44,12,1,1,298,1955,6.502050e+06,1.396962,1.382517,...,9.476354e+05,1.0,1.249633e-03,2.206342e-33,867.015686,9598.517578,0.5,1074.671997,0,801.945312
4,112000100000054,54,12,1,1,26,1981,1.615565e+07,0.233710,1.616700,...,4.881714e+06,1.0,9.355243e-03,3.225694e-33,935.913879,38541.757812,0.5,619.220886,0,859.519714
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3766,112000100029357,29357,12,1,15,1073,96540,1.014017e+07,2.067528,0.235023,...,2.870003e+05,1.0,4.543297e-04,2.648076e-33,992.672058,1987.235962,0.5,1243.462524,0,1066.704224
3767,112000100029362,29362,12,1,15,392,96932,1.794552e+06,2.906128,2.910149,...,2.903096e+06,1.0,2.065588e-08,4.170369e-34,1018.672180,11150.628906,0.5,1179.754517,0,1065.375122
3768,112000100029369,29369,12,1,15,124,97056,3.091663e+06,2.050184,1.184187,...,4.249452e+05,1.0,4.037663e-04,1.589538e-33,752.795044,3949.443604,0.5,698.764526,0,338.311218
3769,112000100029384,29384,12,1,15,32,97088,1.297350e+07,2.180361,2.612760,...,2.340163e+06,1.0,2.999776e-09,2.323477e-34,1032.625610,9167.044922,0.5,915.793213,0,926.694275


In [11]:
def get_n_selected_events(parquet_file:str) -> int:
    energy_cutoff = 1e5
    table = pq.read_table(parquet_file, columns=["energy"])
    mask = pc.greater(table["energy"], energy_cutoff)
    filtered = table.filter(mask)
    return filtered.num_rows
    

In [12]:
def get_n_selected_events_in_dir(dir:str) -> int:
    n_events = 0
    for filename in os.listdir(dir):
        if filename.endswith(".parquet"):
            parquet_file = os.path.join(dir, filename)
            n_events_truth = get_n_selected_events(parquet_file)
            n_events += n_events_truth
            # print(f"File: {filename}, Number of events: {n_events_truth}")
    return n_events

In [13]:
print(f"22011: {get_n_selected_events_in_dir(root_dir_11):,}")
print(f"22012: {get_n_selected_events_in_dir(root_dir_12):,}")
print(f"22014: {get_n_selected_events_in_dir(root_dir_14):,}")
print(f"22015: {get_n_selected_events_in_dir(root_dir_15):,}")
print(f"22017: {get_n_selected_events_in_dir(root_dir_17):,}")
print(f"22018: {get_n_selected_events_in_dir(root_dir_18):,}")

22011: 165,226
22012: 96,138
22014: 27,498
22015: 24,857
22017: 31,413
22018: 30,123


In [14]:
# dont use pandas dataframe, use pyarrow table
class DataBlender:
    def __init__(self, 
                source_dir: str, 
                energy_range_low: EnergyRange,
                energy_range_high: EnergyRange,
                energy_range_combined: EnergyRange,
                flavour: Flavour,
                energy_cutoff: float = 1e5):
        self.source_dir = source_dir
        self.energy_range_low = energy_range_low
        self.energy_range_high = energy_range_high
        self.energy_range_combined = energy_range_combined
        self.flavour = flavour
        self.energy_cutoff = energy_cutoff
        

In [15]:
# class TruthBlender(abc.ABC):
#     def __init__(self, 
#                 source_dir: str, 
#                 energy_range_low: EnergyRange,
#                 energy_range_high: EnergyRange,
#                 energy_range_combined: EnergyRange,
#                 flavour: Flavour,
#                 energy_cutoff: float = 1e5):
#         self.source_dir = source_dir
#         self.energy_range_low = energy_range_low
#         self.energy_range_high = energy_range_high
#         self.energy_range_combined = energy_range_combined
#         self.flavour = flavour
#         self.energy_cutoff = energy_cutoff

#     @abc.abstractmethod
#     def _get_n_events_low(self)

In [16]:
class TruthBlender:
    def __init__(self, 
                source_dir: str, 
                energy_range_low: EnergyRange,
                energy_range_high: EnergyRange,
                energy_range_combined: EnergyRange,
                flavour: Flavour,
                n_events_per_part: int = 30_000,
                # n_events_per_shard: int = 3_000,
                energy_cutoff: float = 1e5,
                ):
        self.source_dir = source_dir
        self.flavour = flavour
        self.energy_cutoff = energy_cutoff
        self.subdir_low = os.path.join(source_dir, EnergyRange.get_subdir(energy_range_low, flavour))
        self.subdir_high = os.path.join(source_dir, EnergyRange.get_subdir(energy_range_high, flavour))
        self.subdir_combined = os.path.join(source_dir, EnergyRange.get_subdir(energy_range_combined, flavour))
        self.n_events_per_part = n_events_per_part
        # self.n_events_per_shard = n_events_per_shard
        self.n_events_combined = None  # will be set later
        self.n_parts = None
    
    def __call__(self):
        self.blend()

    def blend(self) -> None:
        truth_files_low = self._get_truth_file_list(self.subdir_low)
        truth_files_high = self._get_truth_file_list(self.subdir_high)

        # 10Tev-1PeV needs filtering
        tables_low = [self._filter_truth_table(self._get_truth_table(f)) for f in truth_files_low]
        tables_high = [self._get_truth_table(f) for f in truth_files_high]

        # Concatenate all truth tables from each group
        full_table_low = pa.concat_tables(tables_low)
        full_table_high = pa.concat_tables(tables_high)

        # set the number of events to be combined
        # This determines the total event count of the resulting table
        self._set_balanced_event_limit(full_table_low, full_table_high)
        combined_table = self._update_subdir_and_combine(full_table_low, full_table_high)
        table_parts = self._split_table_into_parts(combined_table)
        
        for part_no, table_part in table_parts:
            self._write_combined_table(table_part, part_no)
    
    def _get_truth_file_list(self, subdir: str) -> List[str]:
        return [os.path.join(subdir, f) for f in os.listdir(subdir) if f.endswith(".parquet")]
    
    def _get_truth_table(self, truth_file: str) -> pa.Table:
        return pq.read_table(truth_file)
    
    def _filter_truth_table(self, table: pa.Table) -> pa.Table:
        return table.filter(pc.greater(table["energy"], self.energy_cutoff))
    
    def _get_n_events_truth(self, table: pa.Table) -> int:
        return table.num_rows
    
    def _set_balanced_event_limit(self, table_low: pa.Table, table_high: pa.Table) -> None:
        n_low = self._get_n_events_truth(table_low)
        n_high = self._get_n_events_truth(table_high)
        n_combined = 2* min(n_low, n_high)
        self.n_events_combined = n_combined

    # concatenate the two pyarrow tables
    def _update_subdir_and_combine(self, table_low: pa.Table, table_high: pa.Table) -> pa.Table:
        n = self.n_events_combined // 2
        sliced_low = table_low.slice(0, n)
        sliced_high = table_high.slice(0, n)

        subdir_tag = int(os.path.basename(self.subdir_combined)[-2:])

        arrays = []
        for col in tqdm(table_low.schema.names, desc="Interleaving truth columns"):
            if col == "subdirectory_no":
                col_low = pa.array([subdir_tag] * n)
                col_high = pa.array([subdir_tag] * n)
            else:
                col_low = sliced_low[col].combine_chunks()
                col_high = sliced_high[col].combine_chunks()

            interleaved = pa.chunked_array([
                pa.concat_arrays([col_low[i:i+1], col_high[i:i+1]]) for i in range(n)
            ])
            arrays.append(interleaved)

        return pa.Table.from_arrays(arrays, schema=table_low.schema)

    def _split_table_into_parts(self, table: pa.Table) -> List[Tuple[int, pa.Table]]:
        parts = []
        total = table.num_rows
        for i, start in enumerate(range(0, total, self.n_events_per_part), start=1):
            part = table.slice(start, self.n_events_per_part)
            # replace part_no column here
            part_no_array = pa.array([i] * part.num_rows)
            arrays = []
            for col in part.schema.names:
                if col == "part_no":
                    arrays.append(part_no_array)
                else:
                    arrays.append(part[col])
            part_with_correct_partno = pa.Table.from_arrays(arrays, schema=part.schema)
            parts.append((i, part_with_correct_partno))
        return parts

    def _write_combined_table(self, combined_table: pa.Table, part_no: int) -> None:
        output_dir = os.path.join(self.source_dir, self.subdir_combined)
        os.makedirs(output_dir, exist_ok=True)

        output_file = os.path.join(output_dir, f"truth_{part_no}.parquet")
        pq.write_table(combined_table, output_file)
        print(f"[✔] Wrote part {part_no}: {output_file}")

In [17]:
# TruthBlender(
#     source_dir=root_dir,
#     energy_range_low=EnergyRange.ER_10_TEV_1_PEV,
#     energy_range_high=EnergyRange.ER_1_PEV_100_PEV,
#     energy_range_combined=EnergyRange.ER_100_TEV_100_PEV,
#     flavour=Flavour.E,
# )()

In [18]:
check_this_out=convertParquetToDF("/lustre/hpc/project/icecube/HE_Nu_Aske_Oct2024/PMTfied_filtered_second_round/Snowstorm/CC_CRclean_IntraTravelDistance_250/22032/truth_1.parquet")
check_this_out

Unnamed: 0,event_no,original_event_no,subdirectory_no,part_no,shard_no,N_doms,offset,energy,azimuth,zenith,...,TotalColumnDepthCGS,TotalPrimaryWeight,TotalWeight,TotalXsectionCGS,TrueActiveLengthAfter,TrueActiveLengthBefore,TypeWeight,max_interPMT_distance,isWithinIceCube,lepton_intra_distance
0,114000400665696,665696,32,1,1,170,264,1.312122e+05,0.045589,2.148992,...,1.844175e+06,1.0,2.700705e-05,8.891057e-35,1286.661743,7780.790527,0.5,456.220154,0,1018.384155
1,115000600138754,138754,32,1,1,93,93,2.677757e+06,3.056381,1.249584,...,1.828938e+05,1.0,1.629861e-04,1.490586e-33,1012.475159,972.081787,0.5,681.626709,0,1059.699829
2,114000400665728,665728,32,1,1,215,713,3.105628e+05,4.643663,2.557549,...,2.281205e+05,1.0,1.880685e-05,5.151188e-34,1215.146362,1062.376099,0.5,672.177246,0,1158.643921
3,115000600138769,138769,32,1,1,394,487,7.291569e+07,1.376329,1.323334,...,1.762801e+05,1.0,6.241383e-04,5.923431e-33,858.934814,1053.857788,0.5,1042.778198,1,793.928467
4,114000400665735,665735,32,1,1,64,777,1.391510e+05,2.550350,1.309242,...,1.812613e+05,1.0,3.725267e-05,3.437636e-34,837.095276,1129.747925,0.5,503.852081,1,291.993256
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29995,115000100009693,9693,32,1,5,104,50008,1.748889e+06,0.177174,2.352001,...,1.459540e+05,1.0,1.973203e-06,3.076522e-34,806.530273,777.197937,0.5,570.022400,0,413.552460
29996,114000501073150,1073150,32,1,10,40,53231,3.005953e+05,4.963473,0.332415,...,2.952687e+05,1.0,8.519232e-05,4.826198e-34,920.709412,2145.045410,0.5,410.646149,1,458.765808
29997,115000100009705,9705,32,1,5,459,50467,2.131569e+06,4.971455,1.715046,...,1.765080e+05,1.0,3.550451e-05,1.232124e-33,976.119202,939.145996,0.5,792.833679,0,989.811951
29998,114000501073211,1073211,32,1,10,73,53942,2.214562e+05,1.991865,2.969661,...,1.931556e+05,1.0,8.019437e-08,3.930128e-35,977.781128,950.651611,0.5,408.476654,0,1009.899963


In [19]:
check_that_out=convertParquetToDF("/lustre/hpc/project/icecube/HE_Nu_Aske_Oct2024/PMTfied_filtered_second_round/Snowstorm/CC_CRclean_IntraTravelDistance_250/22032/truth_2.parquet")
check_that_out

Unnamed: 0,event_no,original_event_no,subdirectory_no,part_no,shard_no,N_doms,offset,energy,azimuth,zenith,...,TotalColumnDepthCGS,TotalPrimaryWeight,TotalWeight,TotalXsectionCGS,TrueActiveLengthAfter,TrueActiveLengthBefore,TypeWeight,max_interPMT_distance,isWithinIceCube,lepton_intra_distance
0,114000501073264,1073264,32,2,10,17,54461,2.834743e+05,2.282027,0.516831,...,218898.625000,1.0,0.000066,5.034060e-34,945.175781,1240.266602,0.5,369.885345,0,630.471191
1,115000100009730,9730,32,2,5,420,51009,2.675333e+07,5.297457,1.750105,...,173427.281250,1.0,0.000014,1.802094e-33,986.350708,895.486450,0.5,1059.482910,0,884.719482
2,114000501073314,1073314,32,2,10,129,54845,1.204046e+05,0.772038,0.325529,...,294574.843750,1.0,0.000048,2.750934e-34,924.179382,2134.377686,0.5,549.604492,1,856.163269
3,115000100009734,9734,32,2,5,441,51450,1.970146e+07,1.448613,1.814812,...,165956.421875,1.0,0.000024,3.112299e-33,928.880920,871.890747,0.5,751.713745,0,822.114319
4,114000501073321,1073321,32,2,10,76,54921,1.597037e+05,3.924697,0.531129,...,323689.625000,1.0,0.000064,3.289485e-34,1242.887939,2118.064453,0.5,478.327332,1,262.240906
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
19709,115001100305006,305006,32,2,14,498,50743,2.097274e+07,1.342468,1.875471,...,173964.093750,1.0,0.000026,3.185825e-33,943.437805,944.224182,0.5,894.077026,0,887.045410
19710,114000200371206,371206,32,2,8,990,107760,2.570231e+05,2.880535,2.515745,...,213615.656250,1.0,0.000014,4.144147e-34,1175.330566,938.955994,0.5,972.591125,0,782.181641
19711,115001100305096,305096,32,2,14,32,50775,2.049648e+06,0.876865,1.011502,...,358686.625000,1.0,0.000305,1.420908e-33,878.156494,3078.761475,0.5,270.964844,0,765.137695
19712,114000200371210,371210,32,2,8,202,107962,6.073934e+05,0.327570,0.898622,...,181530.000000,1.0,0.000080,7.397876e-34,862.950500,1106.808105,0.5,574.502625,1,439.953613


In [20]:
class ShardBlender:
    def __init__(self, 
                 source_dir: str, 
                 energy_range_low: EnergyRange,
                 energy_range_high: EnergyRange,
                 energy_range_combined: EnergyRange,
                 flavour: Flavour,
                 n_events_per_shard: int = 3000):
        self.source_dir = source_dir
        self.flavour = flavour

        # Directories holding the original PMTfied files (low/high energy)
        self.subdir_low = os.path.join(source_dir, EnergyRange.get_subdir(energy_range_low, flavour))
        self.subdir_high = os.path.join(source_dir, EnergyRange.get_subdir(energy_range_high, flavour))
        # Combined output directory from truth blender
        self.subdir_combined = os.path.join(source_dir, EnergyRange.get_subdir(energy_range_combined, flavour))

        self.n_events_per_shard = n_events_per_shard
        self._build_event_identification()
        # This dictionary will track usage state for each PMTfied file 
        # e.g., {"PMTfied_1.parquet": {"used": 0, "total": 1200}, ...}
        self.file_usage: Dict[str, Dict[str, int]] = {}

    def __call__(self):
        self.blend()

    def blend(self) -> None:
        # Get the list of combined truth files (output from TruthBlender)
        truth_files = self._get_truth_file_list(self.subdir_combined)
        for truth_file in tqdm(truth_files, desc="Processing truth parts"):
            truth_table = self._get_truth_table(truth_file)
            event_nos = truth_table.column("event_no").to_pylist()
            
            # Build PMTfied shard for the given truth part.
            new_shard_tables = self._assemble_shard_from_truth(event_nos)
            
            # Write the new PMTfied shards
            for shard_no, shard_table in new_shard_tables:
                self._write_new_pmtfied(shard_table, truth_file, shard_no)

    def _get_truth_file_list(self, subdir: str) -> List[str]:
        return [os.path.join(subdir, f) for f in os.listdir(subdir) if f.startswith("truth_") and f.endswith(".parquet")]

    def _get_truth_table(self, truth_file: str) -> pa.Table:
        return pq.read_table(truth_file)

    def _assemble_shard_from_truth(self, event_nos: List[int]) -> List[Tuple[int, pa.Table]]:
        """
        For a given list of event numbers (from one truth part),
        extract corresponding rows from PMTfied files and reassemble into new shards.
        Returns a list of (new_shard_no, shard_table) tuples.
        """
        new_shards = []
        current_rows = []
        current_event_nos = set()
        shard_index = 1
        subdirs = [self.subdir_low, self.subdir_high]

        for event_no in tqdm(event_nos, desc="Processing events for shard assembly"):
            part_no, shard_no, offset, n_dom = self.event_identification.get(event_no, (None, None, None, None))            
            start_index = offset - n_dom
            # PMTfied files may exist in either low or high subdir – check both
            found = False
            for subdir in subdirs:
                candidate_file = os.path.join(subdir, str(part_no), f"PMTfied_{shard_no}.parquet")
                if os.path.exists(candidate_file):
                    pmt_table = pq.read_table(candidate_file)
                    event_rows = pmt_table.slice(start_index, n_dom)
                    current_rows.append(event_rows)
                    current_event_nos.add(event_no)
                    found = True
                    break

            if not found:
                print(f"[!] PMTfied file not found for event {event_no}: part {part_no}, shard {shard_no}")
                continue

            if len(current_event_nos) >= self.n_events_per_shard:
                shard_table = pa.concat_tables(current_rows)
                shard_table = self._recalculate_offsets(shard_table)
                new_shards.append((shard_index, shard_table))
                shard_index += 1
                current_rows = []
                current_event_nos = set()
                if shard_index > 2:
                    break

        # Final shard (if leftover)
        if current_rows:
            shard_table = pa.concat_tables(current_rows)
            shard_table = self._recalculate_offsets(shard_table)
            new_shards.append((shard_index, shard_table))
            

        return new_shards


    def _build_event_identification(self):
        """
        Build a mapping from event_no to (part_no, shard_no, offset, N_doms)
        using the `truth_{part}.parquet` files found in subdir_low and subdir_high.
        """
        self.event_identification = {}

        for subdir in [self.subdir_low, self.subdir_high]:
            for item in os.listdir(subdir):
                if not item.endswith(".parquet"):
                    continue
                truth_path = os.path.join(subdir, item)

                # Extract part number from filename (e.g., "truth_2.parquet" → 2)
                part_no = int(item.split("_")[1].split(".")[0])
                table = pq.read_table(truth_path)

                event_nos = table["event_no"].to_pylist()
                offsets = table["offset"].to_pylist()
                n_doms = table["N_doms"].to_pylist()
                shard_nos = table["shard_no"].to_pylist()

                for event_no, offset, n_dom, shard_no in zip(event_nos, offsets, n_doms, shard_nos):
                    self.event_identification[event_no] = (part_no, shard_no, offset, n_dom)

    def _recalculate_offsets(self, shard_table: pa.Table) -> pa.Table:
        """
        Given a new PMTfied shard table, recalculate the offset column.
        For each event, this could be computed as a cumulative sum of N_doms.
        This function should group by event_no if necessary, then flatten out.
        """
        # Simplistic implementation — you may need to adjust for your data format.
        # Assuming "N_doms" exists and is numeric.
        n_doms = shard_table.column("N_doms")
        cum_offsets = pc.cumulative_sum(n_doms)
        # Replace the offset column in the table:
        table_dict = shard_table.to_pydict()
        table_dict["offset"] = cum_offsets.to_pylist()
        return pa.Table.from_pydict(table_dict)

    def _write_new_pmtfied(self, shard_table: pa.Table, truth_file: str, shard_no: int) -> None:
        """
        Write a newly assembled PMTfied shard to disk with correct directory nesting: subdir_no/part_no/PMTfied_<n>.parquet
        """
        # 🔍 Extract part_no from file name
        part_no = int(os.path.basename(truth_file).split("_")[1].split(".")[0])
        
        # 🔍 Extract subdirectory number from the `subdir_combined` path
        subdir_no = os.path.basename(self.subdir_combined)
        
        # 🗂️ Construct nested output directory
        output_dir = os.path.join(self.source_dir, subdir_no, str(part_no))
        os.makedirs(output_dir, exist_ok=True)

        # 💾 Save the file
        output_file = os.path.join(output_dir, f"PMTfied_{shard_no}.parquet")
        pq.write_table(shard_table, output_file)
        print(f"[✔] Wrote PMTfied shard {shard_no} to {output_file}")



In [21]:
# df_12_1_1 = convertParquetToDF(pmt_dir_12_1 + "PMTfied_1.parquet")
# df_12_1_2 = convertParquetToDF(pmt_dir_12_1 + "PMTfied_2.parquet")
# df_12_truth_1 = convertParquetToDF(truth_12_1)


In [22]:
# ShardBlender(
#     source_dir=root_dir,
#     energy_range_low=EnergyRange.ER_10_TEV_1_PEV,
#     energy_range_high=EnergyRange.ER_1_PEV_100_PEV,
#     energy_range_combined=EnergyRange.ER_100_TEV_100_PEV,
#     flavour=Flavour.E,
#     n_events_per_shard=3000
# )()