In [None]:
import os
import sys
import pandas as pd
import pyarrow as pa
import pyarrow.csv as pc
import pyarrow.parquet as pq
from abc import ABC, abstractmethod
import logging
from tqdm import tqdm

In [3]:
PMTfied_dir = "/lustre/hpc/project/icecube/HE_Nu_Aske_Oct2024/PMTfied/"
truth_18_1 = PMTfied_dir + "Snowstorm/22018/truth_1.parquet"
pmtfied_18_1_1 = PMTfied_dir + "Snowstorm/22018/1/PMTfied_1.parquet"

truth_11_2 = PMTfied_dir + "Snowstorm/22011/truth_2.parquet"

truth_14_1 = PMTfied_dir + "Snowstorm/22014/truth_1.parquet"

In [4]:

source_root = "/lustre/hpc/project/icecube/HE_Nu_Aske_Oct2024/PMTfied/"
dest_root = "/lustre/hpc/project/icecube/HE_Nu_Aske_Oct2024/PMTfied_filtered/"
snowstorm_source_dir = source_root + "Snowstorm/"
snowstorm_dest_dir = dest_root + "Snowstorm/"

corsika_source_dir = source_root + "Corsika/"
corsika_dest_dir = dest_root + "Corsika/"
dir_99999 = snowstorm_dest_dir + "99999/"
dir_99999_98 = dir_99999 + "98/"
dir_99999_99 = dir_99999 + "99/"

dir_99999_Corsika = corsika_dest_dir+"9999999-9999999/"
dir_99999_Corsika_96 = dir_99999_Corsika + "96/"
dir_99999_Corsika_97 = dir_99999_Corsika + "97/"

In [None]:
pure_nu_specifier_dir = "/lustre/hpc/project/icecube/HE_Nu_Aske_Oct2024/clean_events_dict/"
pure_nu_specifier_18_1 = pure_nu_specifier_dir + "2018/22018/clean_event_ids_0000000-0000999.csv"
pure_nu_specifier_11_1 = pure_nu_specifier_dir + "2011/22011/clean_event_ids_0000000-0000999.csv"


In [6]:
filtered_dir = "/lustre/hpc/project/icecube/HE_Nu_Aske_Oct2024/PMTfied_filtered/"
selected_events_dir = filtered_dir + "PureNeutrinos/"

In [7]:
def convertParquetToDF(file:str) -> pd.DataFrame:
    table = pq.read_table(file)
    df = table.to_pandas()
    return df

In [8]:
df_truth_18_1 = convertParquetToDF(truth_18_1)
df_truth_11_2 = convertParquetToDF(truth_11_2)
df_truth_14_1 = convertParquetToDF(truth_14_1)

In [9]:
df_truth_11_2.columns   

Index(['event_no', 'original_event_no', 'subdirectory_no', 'part_no',
       'shard_no', 'N_doms', 'offset', 'energy', 'azimuth', 'zenith', 'pid',
       'event_time', 'interaction_type', 'elasticity', 'RunID', 'SubrunID',
       'EventID', 'SubEventID', 'dbang_decay_length', 'track_length',
       'stopped_muon', 'energy_track', 'energy_cascade', 'inelasticity',
       'DeepCoreFilter_13', 'CascadeFilter_13', 'MuonFilter_13',
       'OnlineL2Filter_17', 'L3_oscNext_bool', 'L4_oscNext_bool',
       'L5_oscNext_bool', 'L6_oscNext_bool', 'L7_oscNext_bool',
       'Homogenized_QTot', 'MCLabelClassification', 'MCLabelCoincidentMuons',
       'MCLabelBgMuonMCPE', 'MCLabelBgMuonMCPECharge',
       'GNLabelTrackEnergyDeposited', 'GNLabelTrackEnergyOnEntrance',
       'GNLabelTrackEnergyOnEntrancePrimary',
       'GNLabelTrackEnergyDepositedPrimary', 'GNLabelEnergyPrimary',
       'GNLabelCascadeEnergyDepositedPrimary', 'GNLabelCascadeEnergyDeposited',
       'GNLabelEnergyDepositedTotal', 'GN

In [10]:
df_truth_11_2['event_no']

0         111000200400561
1         111000200400562
2         111000200400563
3         111000200400564
4         111000200400565
               ...       
402680    111000200803241
402681    111000200803242
402682    111000200803243
402683    111000200803244
402684    111000200803245
Name: event_no, Length: 402685, dtype: int64

In [11]:
df_truth_11_2['RunID'].sort_values().unique()

array([2201100014, 2201100015, 2201100016, 2201100017, 2201100018,
       2201100019, 2201100020, 2201100024, 2201100045, 2201100049,
       2201100050, 2201100054, 2201100055, 2201100056, 2201100057,
       2201100058, 2201100084, 2201100085, 2201100086, 2201100087,
       2201100088, 2201100089, 2201100090, 2201100094, 2201100095,
       2201100096, 2201100097, 2201100098, 2201100099, 2201100100,
       2201100104, 2201100105, 2201100118, 2201100139, 2201100140,
       2201100144, 2201100145, 2201100146, 2201100147, 2201100148,
       2201100149, 2201100150, 2201100154, 2201100155, 2201100156,
       2201100157, 2201100159, 2201100160, 2201100196, 2201100197,
       2201100198, 2201100199, 2201100200, 2201100204, 2201100205,
       2201100206, 2201100207, 2201100208, 2201100209, 2201100210,
       2201100214, 2201100215, 2201100216, 2201100217, 2201100218,
       2201100219, 2201100220, 2201100254, 2201100255, 2201100256,
       2201100257, 2201100258, 2201100259, 2201100260, 2201100

In [12]:
df_truth_11_2[['RunID', 'EventID']].nunique()

RunID       343
EventID    8000
dtype: int64

In [13]:
df_truth_11_2['MuonFilter_13'].value_counts()

MuonFilter_13
1    321550
0     81135
Name: count, dtype: int64

In [14]:
class EventFilter(ABC):
    @abstractmethod
    def sort(self, pa_table: pa.Table) -> pa.Table:
        pass
    

In [15]:
class MuonFilter13(EventFilter):
    def sort(self, pa_table: pa.Table) -> pa.Table:
        return pa_table.filter(pa_table.column('MuonFilter_13') == 1)

In [16]:
class RunIdEventIdFilter(EventFilter):
    def __init__(self, file:str):
        self.df = pd.read_csv(file).set_index(['RunID', 'EventID'])
        
    def sort(self, pa_table: pa.Table) -> pa.Table:
        return pa_table.filter(pa_table.column('RunID').isin(self.df['RunID']) & pa_table.column('EventID').isin(self.df['EventID']))

In [17]:
# def read_pure_neutrino_event_specifiers(pure_nu_specifier_dir:str,
#                                         subdir_no:int):
#     dir_path = os.path.join(pure_nu_specifier_dir, str(subdir_no), str(subdir_no), 'reduced')
#     files = os.listdir(dir_path)
#     print(len(files))
#     print(files)
    

In [18]:
# read_pure_neutrino_event_specifiers(pure_nu_specifier_dir, 22018)

In [19]:
# def build_new_parquet(source_dir:str,
#                     subdir_no: int, 
#                     part_no: int,
#                     dest_dir:str,
#                     sorting_hat: RunIdEventIdFilter,
#                     ) -> None:
#     source_truth_file = source_dir + str(subdir_no) + "truth_" + str(part_no) + ".parquet"
#     source_pmtfied_dir = source_dir + str(subdir_no) + "/" + str(part_no) + "/"
#     source_pmtfied_files = os.listdir(source_pmtfied_dir)
#     dest_pmtfied_dir = dest_dir + str(subdir_no) + "/" + str(part_no) + "/"
#     print(f"source_truth_file: {source_truth_file}")
#     if not os.path.exists(dest_pmtfied_dir):
#         os.makedirs(dest_pmtfied_dir)
#     for file in source_pmtfied_files:
#         print(file)
    

In [20]:
# build_new_parquet(source_dir=snowstorm_source_dir,
#                 subdir_no=22018, 
#                 part_no=1,
#                 dest_dir=snowstorm_dest_dir,
#                 sorting_hat=None,
#                 )

In [27]:
def get_event_specifier_files(csv_path: str) -> list:
    files = os.listdir(csv_path)
    
    file_ranges = []
    
    for file in files:
        if file.startswith("clean_event_ids_") and file.endswith(".csv"):
            groups = file.replace("clean_event_ids_", "").replace(".csv", "").split("-")
            min_id, max_id = map(int, groups)
            file_ranges.append((min_id, max_id, os.path.join(csv_path, file)))
    file_ranges.sort(key=lambda x: x[0])

    return file_ranges


In [None]:
def get_relevant_csv_files(truth_table: pa.Table, 
                           file_ranges: list) -> list:
    run_ids = truth_table.column("RunID").to_pylist()
    effective_run_ids = [rid % 100000 for rid in run_ids]

    min_id, max_id = min(effective_run_ids), max(effective_run_ids)

    # Select relevant files
    return [file for file_min, file_max, file in file_ranges if file_min <= max_id and file_max >= min_id]

In [None]:
def load_event_ids_from_csvs(csv_files: list) -> set:
    valid_events = set()

    for csv_file in csv_files:
        table = pc.read_csv(csv_file)
        run_ids = table.column("RunID").to_pylist()
        event_ids = table.column("EventID").to_pylist()

        valid_events.update(zip(run_ids, event_ids))

    return valid_events

In [30]:
def get_event_no_range(pmt_file: str) -> tuple:
    """
    Get the minimum and maximum 'event_no' from a PMTfied parquet file.

    Args:
        pmt_file (str): Path to the PMTfied parquet file.

    Returns:
        tuple: (min_event_no, max_event_no)
    """
    table = pq.read_table(pmt_file)
    event_nos = table.column("event_no")

    if event_nos.num_rows == 0:
        return None  # Empty file

    min_event_no = pc.min(event_nos).as_py()
    max_event_no = pc.max(event_nos).as_py()

    return min_event_no, max_event_no


In [31]:
def filter_pmtfied_files(source_pmtfied_dir: str,
                         dest_pmtfied_dir: str,
                         valid_event_nos: set) -> None:
    """
    Filter PMTfied files based on valid 'event_no' values from the truth file.

    Args:
        source_pmtfied_dir (str): Directory containing original PMTfied files.
        dest_pmtfied_dir (str): Directory where filtered PMTfied files will be saved.
        valid_event_nos (set): Set of valid 'event_no' values.
    """
    os.makedirs(dest_pmtfied_dir, exist_ok=True)

    for file in os.listdir(source_pmtfied_dir):
        source_pmtfied_file = os.path.join(source_pmtfied_dir, file)
        dest_pmtfied_file = os.path.join(dest_pmtfied_dir, file)

        # Query min/max 'event_no' for fast rejection
        event_no_range = get_event_no_range(source_pmtfied_file)
        if event_no_range is None:
            print(f"Skipping empty PMTfied file: {file}")
            continue
        
        min_event_no, max_event_no = event_no_range

        # Quick rejection if there's no overlap
        if min_event_no > max(valid_event_nos) or max_event_no < min(valid_event_nos):
            print(f"Skipping {file}: No overlap with valid event numbers.")
            continue

        # Load PMTfied parquet file
        pmt_table = pq.read_table(source_pmtfied_file)
        pmt_event_nos = pmt_table.column("event_no").to_pylist()

        # Create filtering mask
        valid_indices = [i for i, eno in enumerate(pmt_event_nos) if eno in valid_event_nos]

        if valid_indices:
            filtered_pmt_table = pmt_table.take(valid_indices)
            pq.write_table(filtered_pmt_table, dest_pmtfied_file)
            print(f"Filtered PMTfied file saved to: {dest_pmtfied_file}")


In [None]:
def filter_truth_file(source_dir: str,
                      subdir_no: int,
                      part_no: int,
                      pure_nu_specifier_dir: str,
                      output_dir: str) -> None:
    """
    Filter the truth file and corresponding PMTfied files based on event IDs.

    Args:
        source_dir (str): Base directory.
        subdir_no (int): Subdirectory number.
        part_no (int): Part number.
        pure_nu_specifier_dir (str): Directory containing CSV event specifiers.
        output_dir (str): Directory to store the filtered files.
        file_ranges (list): Preloaded list of CSV files with min-max event ID ranges.
    """
    csv_path = os.path.join(pure_nu_specifier_dir, str(subdir_no), str(subdir_no), 'reduced')
    file_ranges = get_event_specifier_files(csv_path)
    
    source_truth_file = os.path.join(source_dir, str(subdir_no), f"truth_{part_no}.parquet")
    truth_table = pq.read_table(source_truth_file)

    # Get the relevant CSV files
    relevant_csvs = get_relevant_csv_files(truth_table, file_ranges)

    if not relevant_csvs:
        print(f"No matching CSV files found for {subdir_no}/{part_no}")
        return

    # Load valid event (Run ID, Event ID) pairs
    pure_neutrino_events = load_event_ids_from_csvs(relevant_csvs)

    # Convert columns to lists once
    run_ids = truth_table.column("RunID").to_pylist()
    event_ids = truth_table.column("EventID").to_pylist()

    # Create filtering mask
    valid_indices = [i for i, (rid, eid) in enumerate(zip(run_ids, event_ids)) if (rid, eid) in pure_neutrino_events]

    if not valid_indices:
        print(f"No valid events in {subdir_no}/{part_no}. Skipping truth file saving.")
        return

    # Apply mask and extract valid event_no
    filtered_truth_table = truth_table.take(valid_indices)
    valid_event_nos = set(filtered_truth_table.column("event_no").to_pylist())

    output_truth_file = os.path.join(output_dir, str(subdir_no), f"truth_{part_no}.parquet")
    os.makedirs(os.path.dirname(output_truth_file), exist_ok=True)
    pq.write_table(filtered_truth_table, output_truth_file)

    print(f"Filtered truth file saved to: {output_truth_file}")

    # Now filter PMTfied files
    source_pmtfied_dir = os.path.join(source_dir, str(subdir_no), str(part_no))
    dest_pmtfied_dir = os.path.join(output_dir, str(subdir_no), str(part_no))
    filter_pmtfied_files(source_pmtfied_dir, dest_pmtfied_dir, valid_event_nos)


In [33]:
filter_truth_file(source_dir=snowstorm_source_dir,
                    subdir_no=22011,
                    part_no=1,
                    pure_nu_specifier_dir=pure_nu_specifier_dir,
                    output_dir=snowstorm_dest_dir)

KeyError: 'Field "Run ID" does not exist in schema'