In [None]:
import sqlite3 as sql
import pandas as pd
import sys
import os
import matplotlib.pyplot as plt
import numpy as np

import scipy.optimize as optimize
from scipy import stats as sci
import math
from pprint import pprint

from collections import defaultdict
from typing import List
from tqdm import tqdm

import pyarrow as pa
import pyarrow.parquet as pq
# 2 min 22 sec in HEP04

In [2]:
sys.path.append('/groups/icecube/cyan/Utils')
from PlotUtils import setMplParam, getColour, getHistoParam 
# getHistoParam:
# Nbins, binwidth, bins, counts, bin_centers  = 
from DB_lister import list_content, list_tables
from ExternalFunctions import nice_string_output, add_text_to_ax
setMplParam()

In [3]:
def get_table_event_count(conn: sql.Connection, table: str) -> int:
    cursor = conn.cursor()
    cursor.execute(f"SELECT COUNT(DISTINCT event_no) FROM {table}")
    event_count = cursor.fetchone()[0]
    return event_count

In [4]:
def convertDBtoDF(file:str, table:str, N_events_total:int, N_events:int = None) -> pd.DataFrame:
    con = sql.connect(file)
    if N_events is None or N_events > N_events_total:
        N_events = N_events_total
    # Query to fetch the first `N_events` unique event_no values
    event_no_query = f'SELECT DISTINCT event_no FROM {table} LIMIT {N_events}'
    event_nos = pd.read_sql_query(event_no_query, con)['event_no'].tolist()
    
    # Use the selected event_no values to filter the main data
    event_filter = ','.join(map(str, event_nos))  # Convert to comma-separated string for SQL IN clause
    query = f'SELECT * FROM {table} WHERE event_no IN ({event_filter})'
    
    # Read data and close the connection
    df = pd.read_sql_query(query, con)
    con.close()
    
    return df

In [5]:
def convertDFtoDB(file:str, table:str, df: pd.DataFrame) -> None:
    con = sql.connect(file)
    df.to_sql(table, con, if_exists='replace', index=False)
    con.close()

In [6]:
def load_reference_data(filepath: str) -> np.ndarray:
    df = pd.read_csv(filepath)
    return df.values  # Convert the DataFrame to a NumPy array

In [7]:
def addStringAndDOMtoDB(con_source: sql.Connection, 
                        source_table: str,
                        reference_data: np.ndarray,
                        tolerance_xy: float = 10,
                        tolerance_z: float = 2) -> None:
    cur_source = con_source.cursor()
    cur_source.execute(f"PRAGMA table_info({source_table})")
    existing_columns = [col[1] for col in cur_source.fetchall()]
    
    # Add `string` and `dom_number` columns if they don’t exist
    if 'string' not in existing_columns:
        cur_source.execute(f"ALTER TABLE {source_table} ADD COLUMN string INTEGER")
    if 'dom_number' not in existing_columns:
        cur_source.execute(f"ALTER TABLE {source_table} ADD COLUMN dom_number INTEGER")
    
    # Select rows where `string` or `dom_number` is NULL
    cur_source.execute(f"SELECT rowid, dom_x, dom_y, dom_z FROM {source_table} WHERE string IS NULL OR dom_number IS NULL")
    rows_to_update = cur_source.fetchall()
    
    # Update rows based on tolerance matching with reference data
    for row in rows_to_update:
        row_id, dom_x, dom_y, dom_z = row
        
        # Match `dom_x` and `dom_y` within specified tolerance
        matches_xy = reference_data[
            (np.abs(reference_data[:, 2] - dom_x) <= tolerance_xy) &
            (np.abs(reference_data[:, 3] - dom_y) <= tolerance_xy)
        ]
        
        # If any matches on x and y, proceed to check z
        if len(matches_xy) > 0:
            match_z = matches_xy[np.abs(matches_xy[:, 4] - dom_z) <= tolerance_z]
            
            if len(match_z) > 0:
                string_val = int(match_z[0, 0])
                dom_number_val = int(match_z[0, 1])
                
                # Update the row with matching `string` and `dom_number`
                cur_source.execute(f"UPDATE {source_table} SET string = ?, dom_number = ? WHERE rowid = ?", (string_val, dom_number_val, row_id))

    # Commit all updates
    con_source.commit()

In [8]:
def getTruthTableNameDB(con_source: sql.Connection) -> str:
    cur_source = con_source.cursor()
    cur_source.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = [row[0] for row in cur_source.fetchall()]
    if 'truth' in tables:
        truth_table = 'truth'
    elif 'Truth' in tables:
        truth_table = 'Truth'
    else:
        raise ValueError("Neither 'truth' nor 'Truth' table exists in the source database.")
    return truth_table

In [9]:
def addTruthTableDB(con_in: sql.Connection, 
                    con_out: sql.Connection, 
                    N_events: int) -> None:
    # Get the truth table name
    truth_table_name = getTruthTableNameDB(con_in)
    cur_in = con_in.cursor()
    cur_out = con_out.cursor()

    # Select the first N_events unique event numbers
    event_no_query = f"SELECT DISTINCT event_no FROM {truth_table_name} LIMIT {N_events}"
    cur_in.execute(event_no_query)
    event_nos = [row[0] for row in cur_in.fetchall()]

    # Use selected event numbers to retrieve rows from truth table
    event_filter = ','.join(map(str, event_nos))  # Convert list to comma-separated string for SQL IN clause
    query = f"SELECT * FROM {truth_table_name} WHERE event_no IN ({event_filter})"
    cur_in.execute(query)
    rows = cur_in.fetchall()

    # Copy the schema to the output connection
    cur_in.execute(f"PRAGMA table_info({truth_table_name})")
    schema_info = cur_in.fetchall()
    create_table_query = f"CREATE TABLE IF NOT EXISTS {truth_table_name} ({', '.join([f'{col[1]} {col[2]}' for col in schema_info])})"
    cur_out.execute(create_table_query)

    # Insert the selected rows into the output database
    placeholders = ', '.join(['?'] * len(schema_info))
    insert_query = f"INSERT INTO {truth_table_name} VALUES ({placeholders})"
    cur_out.executemany(insert_query, rows)

    con_out.commit()


In [10]:
test_dir = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/"
test_source_DB_dir = test_dir + "testSource/"
test_PMTfied_DB_dir = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/PMTfiedDB/"
test_PMTfied_parquet_dir = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/PMTfiedParquet/"

QA_DB = test_source_DB_dir + "Level2_NuE_NuGenCCNC.022015.000110.db"
QA_DB_PMTfied = test_PMTfied_DB_dir + "voici.db"

ref_str_dom_pos = "/groups/icecube/cyan/factory/DOMification/unique_string_dom_completed.csv"

In [11]:
list_tables(QA_DB)

['truth', 'SRTInIcePulses', 'event_no_SRTInIcePulses', 'OnlineL2_BestFit', 'OnlineL2_SplineMPE', 'LineFit', 'GNHighestEInIceParticle', 'GNHighestEDaughter', 'MCWeightDict', 'SnowStormParameters']


In [12]:
def getTruthPA(con_source: sql.Connection, N_events: int) -> pa.Table:
    # Get the name of the truth table
    truth_table_name = getTruthTableNameDB(con_source)
    cur_source = con_source.cursor()
    
    # Fetch the schema info for column names
    cur_source.execute(f"PRAGMA table_info({truth_table_name})")
    schema_info = cur_source.fetchall()
    
    # Select the first N_events unique event_no values
    event_no_query = f"SELECT DISTINCT event_no FROM {truth_table_name} LIMIT {N_events}"
    cur_source.execute(event_no_query)
    event_nos = [row[0] for row in cur_source.fetchall()]

    # Use these event numbers to filter the main query
    event_filter = ','.join(map(str, event_nos))  # Convert to comma-separated list for SQL IN clause
    query = f"SELECT * FROM {truth_table_name} WHERE event_no IN ({event_filter})"
    cur_source.execute(query)
    rows = cur_source.fetchall()
    
    # Convert the result to a dictionary compatible with PyArrow
    truth_data_dict = {col[1]: [row[i] for row in rows] for i, col in enumerate(schema_info)}
    truth_table_pa = pa.Table.from_pydict(truth_data_dict)
    
    return truth_table_pa

* this conversion from db to parquet inevitably requires use of pandas dataframe.
* it would be much desirable if the data is directly converted from I3 to PMTfied parquet.

In [13]:
def getPMTfiedPA(con_source: sql.Connection, 
                source_table: str,
                dom_ref_pos_file: str,
                N_events: int) -> pa.Table:
    dom_ref_pos = load_reference_data(dom_ref_pos_file)
    addStringAndDOMtoDB(con_source, source_table, dom_ref_pos, N_events)
    
    # Select unique event numbers to use
    cur_source = con_source.cursor()
    event_no_query = f'SELECT DISTINCT event_no FROM {source_table} LIMIT {N_events}'
    cur_source.execute(event_no_query)
    event_nos = [row[0] for row in cur_source.fetchall()]

    # Query rows where event_no matches the selected event numbers
    event_filter = ','.join(map(str, event_nos))
    query = f'SELECT * FROM {source_table} WHERE event_no IN ({event_filter})'
    cur_source.execute(query)
    rows = cur_source.fetchall()

    # Get column names for indexing
    columns = [description[0] for description in cur_source.description]
    
    event_no_idx = columns.index('event_no')
    dom_string_idx = columns.index('string')
    dom_number_idx = columns.index('dom_number')
    dom_x_idx = columns.index('dom_x')
    dom_y_idx = columns.index('dom_y')
    dom_z_idx = columns.index('dom_z')
    dom_time_idx = columns.index('dom_time')
    dom_hlc_idx = columns.index('hlc')
    dom_charge_idx = columns.index('charge')
    pmt_area_idx = columns.index('pmt_area')
    rde_idx = columns.index('rde')
    saturation_status_idx = columns.index('is_saturated_dom')
    
    def getMaxQtotal(all_pulses_event: List[List[List[float]]]) -> float:
        Qsums = [sum([pulse[dom_charge_idx] for pulse in pulses]) for pulses in all_pulses_event]
        return max(Qsums)
    
    def getQweightedAverageDOMposition(all_pulses_event: List[List[List[float]]], maxQtotal: float) -> List[float]:
        dom_x = [pulse[dom_x_idx] for pulses_dom in all_pulses_event for pulse in pulses_dom]
        dom_y = [pulse[dom_y_idx] for pulses_dom in all_pulses_event for pulse in pulses_dom]
        dom_z = [pulse[dom_z_idx] for pulses_dom in all_pulses_event for pulse in pulses_dom]
        charge_sums = [pulse[dom_charge_idx] for pulses_dom in all_pulses_event for pulse in pulses_dom]

        weighted_x = np.mean([x * charge / maxQtotal for x, charge in zip(dom_x, charge_sums)])
        weighted_y = np.mean([y * charge / maxQtotal for y, charge in zip(dom_y, charge_sums)])
        weighted_z = np.mean([z * charge / maxQtotal for z, charge in zip(dom_z, charge_sums)])

        return [weighted_x, weighted_y, weighted_z]
        
    def getRelativeDOMposition(dom_x: float, dom_y: float, dom_z: float, avg_dom_position: List[float]) -> List[float]:
        return [dom_x - avg_dom_position[0], dom_y - avg_dom_position[1], dom_z - avg_dom_position[2]]
    
    # NOTE pulses_dom: [pulse, ...]
    def getDOMposition(pulses_dom: List[List[float]]) -> List[float]:
        return [pulses_dom[0][dom_x_idx], pulses_dom[0][dom_y_idx], pulses_dom[0][dom_z_idx]]
    
    def getDOMstring(pulses_dom: List[List[float]]) -> int:
        return pulses_dom[0][dom_string_idx]
    
    def getDOMnumber(pulses_dom: List[List[float]]) -> int:
        return pulses_dom[0][dom_number_idx]
    
    def getPmtArea(pulses_dom: List[List[float]]) -> float:
        return pulses_dom[0][pmt_area_idx]
    
    def getRDE(pulses_dom: List[List[float]]) -> float:
        return pulses_dom[0][rde_idx]
    
    def getSaturationStatus(pulses_dom: List[List[float]]) -> int:
        return pulses_dom[0][saturation_status_idx]
    
    def getFirstHlc(pulses_dom: List[List[float]]) -> List[int]:
        n = 3
        _fillIncomplete = -1
        if len(pulses_dom) < n:
            hlc = [pulse[dom_hlc_idx] for pulse in pulses_dom]
            hlc.extend([_fillIncomplete] * (n - len(hlc)))
        else:
            hlc = [pulse[dom_hlc_idx] for pulse in pulses_dom[:n]]
        return hlc
    
    def getFirstPulseTime(pulses_dom: List[List[float]], saturationStatus: int) -> List[float]:
        n = 3
        # HACK consider changing the fill values
        _fillSaturated = -1
        _fillIncomplete = -1
        
        if saturationStatus == 1:
            pulse_times = [_fillSaturated] * n
        elif len(pulses_dom) < n:
            pulse_times = [pulse[dom_time_idx] for pulse in pulses_dom]
            pulse_times.extend([_fillIncomplete] * (n - len(pulse_times)))
        else:
            pulse_times = [pulse[dom_time_idx] for pulse in pulses_dom[:n]]
        return pulse_times
    
    # HACK necessary?
    def getFirstHlcPulseTime(pulses_dom: List[List[float]], saturationStatus: int) -> List[float]:
        n = 3
        _fillSaturated = -1
        _fillIncomplete = -1
        if saturationStatus == 1:
            pulse_times = [_fillSaturated] * n
        elif len(pulses_dom) < n:
            pulse_times = [pulse[dom_time_idx] for pulse in pulses_dom if pulse[dom_hlc_idx] == 1]
            pulse_times.extend([_fillIncomplete] * (n - len(pulse_times)))
        else:
            pulse_times = [pulse[dom_time_idx] for pulse in pulses_dom if pulse[dom_hlc_idx] == 1][:n]
        return pulse_times
        
    def getElapsedTimeUntilChargeFraction(pulses_dom: List[List[float]], saturationStatus: int, percentile1 = 10, percentile2 = 50) -> List[float]:
        # HACK consider changing the fill values
        _fillSaturated = -1
        _fillIncomplete = -1
        if saturationStatus == 1:
            times = [_fillSaturated] * 2
        elif len(pulses_dom) < 2:
            times = [_fillIncomplete] * 2
        else:
            Qtotal = sum([pulse[dom_charge_idx] for pulse in pulses_dom])
            t_0 = pulses_dom[0][dom_time_idx]
            Qcum = 0
            T_first, T_second = -1, -1 # if these are not -1, then they are assigned
            for pulse in pulses_dom:
                Qcum += pulse[dom_charge_idx]
                if Qcum > percentile1 / 100 * Qtotal and T_first == -1:
                    T_first = pulse[dom_time_idx] - t_0
                if Qcum > percentile2 / 100 * Qtotal:
                    T_second = pulse[dom_time_idx] - t_0
                    break
            times = [T_first, T_second]
        return times
    
    def getStandardDeviation(pulse_times: List[float], saturationStatus: int) -> float:
        # HACK consider changing the fill values
        _fillSaturated = 0
        _fillIncomplete = 0
        if saturationStatus == 1:
            sigmaT = _fillSaturated
        elif len(pulse_times) < 2:
            sigmaT = _fillIncomplete
        else:
            sigmaT = np.std(pulse_times)
        return sigmaT
    
    def getFirstChargeReadout(pulses: List[List[float]], saturationStatus: int) -> List[float]:
        # HACK consider changing the fill values
        _fillSaturated = -1
        _fillIncomplete = -1
        n = 3
        if saturationStatus == 1:
            charge_readouts = [_fillSaturated] * n
        elif len(pulses) < n:
            charge_readouts = [pulse[dom_charge_idx] for pulse in pulses]
            charge_readouts.extend([_fillIncomplete] * (n - len(charge_readouts)))
        else:
            charge_readouts = [pulse[dom_charge_idx] for pulse in pulses[:n]]
        return charge_readouts
    
    def getAccumulatedChargeAfterNanoSec(pulses: List[List[float]], saturationStatus: int, interval1 = 25, interval2 = 75) -> List[float]:
        # HACK consider changing the fill values
        _fillSaturated = -1
        _fillIncomplete = -1
        if saturationStatus == 1:
            Qs = [_fillSaturated] * 3
        elif len(pulses) < 1:
            Qs = [_fillIncomplete] * 3
        else:
            Qtotal = sum([pulse[dom_charge_idx] for pulse in pulses])
            t_0 = pulses[0][dom_time_idx]
            Qinterval1 = sum([pulse[dom_charge_idx] for pulse in pulses if pulse[dom_time_idx] - t_0 < interval1])
            Qinterval2 = sum([pulse[dom_charge_idx] for pulse in pulses if pulse[dom_time_idx] - t_0 < interval2])
            Qs = [Qinterval1, Qinterval2, Qtotal]
        return Qs
    
    def processDOM(pulses: List[List[float]], avg_dom_position: List[float]):
        dom_string = getDOMstring(pulses)
        dom_number = getDOMnumber(pulses)
        dom_x, dom_y, dom_z = getDOMposition(pulses)
        dom_x_rel, dom_y_rel, dom_z_rel = getRelativeDOMposition(dom_x, dom_y, dom_z, avg_dom_position)
        pmt_area = getPmtArea(pulses)
        rde = getRDE(pulses)
        saturation_status = getSaturationStatus(pulses)
        
        # Get remaining features
        first_three_charge_readout = getFirstChargeReadout(pulses, saturation_status)
        accumulated_charge_after_nano_sec = getAccumulatedChargeAfterNanoSec(pulses, saturation_status)
        first_three_pulse_time = getFirstPulseTime(pulses, saturation_status)
        # first_three_hlc_pulse_time = getFirstHlcPulseTime(pulses, saturation_status)
        first_three_hlc = getFirstHlc(pulses)
        elapsed_time_until_charge_fraction = getElapsedTimeUntilChargeFraction(pulses, saturation_status)
        standard_deviation = getStandardDeviation([pulse[dom_time_idx] for pulse in pulses], saturation_status)
        
        data_dom = ([dom_string, dom_number]            # dom_number
                    + [dom_x, dom_y, dom_z]             # dom_x, dom_y, dom_z
                    + [dom_x_rel, dom_y_rel, dom_z_rel] # dom_x_rel, dom_y_rel, dom_z_rel
                    + [pmt_area, rde, saturation_status]# pmt_area, rde, saturationStatus
                    + first_three_charge_readout        # q1, q2, q3
                    + accumulated_charge_after_nano_sec # Q25, Q75, Qtotal
                    + first_three_hlc                   # hlc1, hlc2, hlc3
                    + first_three_pulse_time            # t1, t2, t3
                    # + first_three_hlc_pulse_time        # t1_hlc, t2_hlc, t3_hlc
                    + elapsed_time_until_charge_fraction# T10, T50
                    + [standard_deviation]              # sigmaT
                    )
        return data_dom            
    # original data
    events_doms_pulses = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    # new data
    processed_data = []
    for row in rows:
        event_no = row[event_no_idx]
        string = row[dom_string_idx]
        dom_number = row[dom_number_idx]
        events_doms_pulses[event_no][string][dom_number].append(row)
    
    # NOTE data structure
    # events_doms_pulses  : {event_no: {string: {dom_number: [pulse, ...], ...}, ...}, ...}
    # strings_doms_pulses :            {string: {dom_number: [pulse, ...], ...}, ...}
    # doms_pulses         :                     {dom_number: [pulse, ...], ...}
    # pulses              :                                  [pulse, ...]
    for event_no, strings_doms_pulses in events_doms_pulses.items():
        for doms_pulses in strings_doms_pulses.values():
            # Convert the values to a list of pulses (rows)
            all_pulses_event = list(doms_pulses.values())
            maxQtotal = getMaxQtotal(all_pulses_event)
            avg_dom_position = getQweightedAverageDOMposition(all_pulses_event, maxQtotal)
            for pulses in doms_pulses.values():
                dom_data = [event_no] + processDOM(pulses, avg_dom_position)
                processed_data.append(dom_data)

    # Convert the processed data into a DataFrame for easier handling
    df_processed = pd.DataFrame(processed_data, columns=[
        'event_no', 'dom_string', 'dom_number', # indices
        'dom_x', 'dom_y', 'dom_z',  
        'dom_x_rel', 'dom_y_rel', 'dom_z_rel', 
        'pmt_area', 'rde', 'saturation_status', 
        'q1', 'q2', 'q3', 
        'Q25', 'Q75', 'Qtotal',
        'hlc1', 'hlc2', 'hlc3', 
        't1', 't2', 't3', 
        'T10', 'T50', 'sigmaT'
    ])
    pa_processed = pa.Table.from_pandas(df_processed)
    return pa_processed  

* `runPMTfication_DB_Parquet` layer is intended to intervene the process of sqlite connection and close

In [20]:
def runPMTfication_DB_Parquet(source_file: str, 
                            source_table: str, 
                            dom_ref_pos_file: str,
                            N_events: int = None) -> (pa.Table, pa.Table):
    con_source = None
    try:
        base_name = os.path.splitext(os.path.basename(source_file))[0]
        
        con_source = sql.connect(source_file)
        N_events_total = get_table_event_count(con_source, source_table)
        
        # Determine N_events if it is not provided or exceeds N_events_total
        if N_events is None or N_events > N_events_total:
            N_events = N_events_total
        
        # Get the PyArrow tables for PMTfied and truth data based on N_events
        pa_pmtfied = getPMTfiedPA(con_source, source_table, dom_ref_pos_file, N_events)
        pa_truth = getTruthPA(con_source, N_events)
        
    except Exception as e:
        print(f"An error occurred during the PMTfication: {e}")
        raise e
    
    finally:
        if con_source:
            con_source.close()

    return pa_pmtfied, pa_truth


In [21]:
def writeParquet(source_file: str,
                source_table: str,
                dom_ref_pos_file: str,
                N_events: int = None,
                pmtfied_file: str = None,
                truth_file: str = None) -> None:
    pa_pmtfied, pa_truth = runPMTfication_DB_Parquet(source_file, source_table, dom_ref_pos_file, N_events)
    if pmtfied_file is None:
        pmtfied_file = f"{source_file}_PMTfied.parquet"
    pq.write_table(pa_pmtfied, pmtfied_file)
    pq.write_table(pa_truth, truth_file)

```python
test_dir = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/"
test_source_DB_dir = test_dir + "testSource/"
test_PMTfied_DB_dir = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/PMTfiedDB/"
test_PMTfied_parquet_dir = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/PMTfiedParquet/"

QA_DB = test_source_DB_dir + "Level2_NuE_NuGenCCNC.022015.000110.db"
QA_DB_PMTfied = test_PMTfied_DB_dir + "voila.db"

ref_str_dom_pos = "/groups/icecube/cyan/factory/DOMification/unique_string_dom_completed.csv"
```

In [22]:
def print_table_event_count(file: str, table: str):
    conn = sql.connect(file)
    event_count = get_table_event_count(conn, table)
    print(f"Table {table} has {event_count} unique events")
    conn.close()

In [23]:
print_table_event_count(QA_DB, "SRTInIcePulses")

Table SRTInIcePulses has 22 unique events


In [24]:
writeParquet(source_file=QA_DB,
            source_table="SRTInIcePulses",
            dom_ref_pos_file=ref_str_dom_pos,
            N_events=20,
            pmtfied_file="test_eventwise_pmtfied_features.parquet",
            truth_file="test_eventwise_truth.parquet")

In [25]:
seeThisFeatureParquet = "/groups/icecube/cyan/factory/DOMification/test_eventwise_pmtfied_features.parquet"
seeThisTruthParquet = "/groups/icecube/cyan/factory/DOMification/test_eventwise_truth.parquet"

In [26]:
seeThisFeature_df = pq.read_table(seeThisFeatureParquet).to_pandas()
seeThisTruth_df = pq.read_table(seeThisTruthParquet).to_pandas()

In [30]:
print(f"N events of feature: {seeThisFeature_df['event_no'].nunique()}")
print(f"N events of truth: {seeThisTruth_df['event_no'].nunique()}")

N events of feature: 20
N events of truth: 20


In [None]:
def batch_PMTfication_DB_parquet(source_DB_dir: str, target_dir: str, N_events: int = None) -> None:
    db_files = [f for f in os.listdir(source_DB_dir) if f.endswith('.db')]
    
    pmtfied_tables = []
    truth_tables = []
    
    for db_file in tqdm(db_files, desc="Processing..."):
        source_DB = os.path.join(source_DB_dir, db_file)
        
        # Connect to each database to get N_events_total
        with sql.connect(source_DB) as con:
            cur = con.cursor()
            cur.execute("SELECT COUNT(DISTINCT event_no) FROM SRTInIcePulses")
            N_events_total = cur.fetchone()[0]
        
        # Run PMTfication for each database file with N_events_total and optional N_events
        pa_pmtfied, pa_truth = runPMTfication_DB_Parquet(
            source_file=source_DB, 
            source_table="SRTInIcePulses",
            dom_ref_pos_file=ref_str_dom_pos,
            N_events_total=N_events_total,
            N_events=N_events
        )
        
        pmtfied_tables.append(pa_pmtfied)
        truth_tables.append(pa_truth)
    
    # Combine all PMTfied tables and write to a single Parquet file
    combined_pmtfied_table = pa.concat_tables(pmtfied_tables)
    combined_pmtfied_file = os.path.join(target_dir, "combined_PMTfied.parquet")
    pq.write_table(combined_pmtfied_table, combined_pmtfied_file)

    # Combine all truth tables and write to a single Parquet file
    combined_truth_table = pa.concat_tables(truth_tables)
    combined_truth_file = os.path.join(target_dir, "combined_truth.parquet")
    pq.write_table(combined_truth_table, combined_truth_file)

In [39]:
# batch_PMTfication_DB_parquet(test_source_DB_dir, test_PMTfied_parquet_dir)
# individual: 1 min 58 sec
# combining: 5 min 40 sec

Processing...:   0%|          | 0/30 [00:00<?, ?it/s]

Processing...: 100%|██████████| 30/30 [05:38<00:00, 11.28s/it]


In [59]:
seeThis = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/testSource/PMTfiedParquet/event_no=0/276c6d50570445a984e54548bb4dcbe5-0.parquet"
seeThat = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/testSource/PMTfiedParquet/event_no=0/c0e0546be44940d9bd6dc39391a627b8-0.parquet"
compareThis = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/PMTfiedDB/Level2_NuE_NuGenCCNC.022015.000110.db"
# df_seeThis = pq.read_table(seeThis).to_pandas()
# df_seeThat = pq.read_table(seeThat).to_pandas()
df_compareThis = convertDBtoDF(compareThis, "PMTsummarised")

In [49]:
df_seeThis.columns

Index(['energy', 'position_x', 'position_y', 'position_z', 'azimuth', 'zenith',
       'pid', 'event_time', 'interaction_type', 'elasticity', 'RunID',
       'SubrunID', 'EventID', 'SubEventID', 'dbang_decay_length',
       'track_length', 'stopped_muon', 'energy_track', 'energy_cascade',
       'inelasticity', 'DeepCoreFilter_13', 'CascadeFilter_13',
       'MuonFilter_13', 'OnlineL2Filter_17', 'L3_oscNext_bool',
       'L4_oscNext_bool', 'L5_oscNext_bool', 'L6_oscNext_bool',
       'L7_oscNext_bool', 'Homogenized_QTot', 'MCLabelClassification',
       'MCLabelCoincidentMuons', 'MCLabelBgMuonMCPE',
       'MCLabelBgMuonMCPECharge', 'GNLabelTrackEnergyDeposited',
       'GNLabelTrackEnergyOnEntrance', 'GNLabelTrackEnergyOnEntrancePrimary',
       'GNLabelTrackEnergyDepositedPrimary', 'GNLabelEnergyPrimary',
       'GNLabelCascadeEnergyDepositedPrimary', 'GNLabelCascadeEnergyDeposited',
       'GNLabelEnergyDepositedTotal', 'GNLabelEnergyDepositedPrimary',
       'GNHighestEInIceParticl

In [52]:
df_seeThat.columns

Index(['dom_string', 'dom_number', 'dom_x', 'dom_y', 'dom_z', 'dom_x_rel',
       'dom_y_rel', 'dom_z_rel', 'pmt_area', 'rde', 'saturation_status', 'q1',
       'q2', 'q3', 'Q25', 'Q75', 'Qtotal', 'hlc1', 'hlc2', 'hlc3', 't1', 't2',
       't3', 'T10', 'T50', 'sigmaT'],
      dtype='object')

In [60]:
df_compareThis.columns

Index(['event_no', 'dom_string', 'dom_number', 'dom_x', 'dom_y', 'dom_z',
       'dom_x_rel', 'dom_y_rel', 'dom_z_rel', 'pmt_area', 'rde',
       'saturation_status', 'q1', 'q2', 'q3', 'Q25', 'Q75', 'Qtotal', 'hlc1',
       'hlc2', 'hlc3', 't1', 't2', 't3', 'T10', 'T50', 'sigmaT'],
      dtype='object')

In [None]:
class BatchPMTficationProcessor:
    def __init__(self, source_dir, target_dir, dom_ref_pos_file, events_per_file):
        self.source_dir = source_dir
        self.target_dir = target_dir
        self.dom_ref_pos_file = dom_ref_pos_file
        self.events_per_file = events_per_file

    def shift_event_no(self, event_no: int, subdirectory: int, shift_bits: int =24) -> int:
        return (subdirectory << shift_bits) | event_no

    def generate_receipt(self, original_file, event_mappings, target_file):
        """Creates a receipt summarising the processed data."""
        receipt_df = pd.DataFrame(event_mappings, columns=['original_event_no', 'event_no'])
        receipt_df['source_file'] = original_file
        receipt_path = target_file.replace('.parquet', '_receipt.csv')
        receipt_df.to_csv(receipt_path, index=False)

    def process_file_in_batches(self, db_file, subdirectory):
        source_file = os.path.join(self.source_dir, db_file)
        target_subdir = os.path.join(self.target_dir, os.path.dirname(db_file))
        os.makedirs(target_subdir, exist_ok=True)
        
        # Open the source DB file
        con_source = sql.connect(source_file)
        N_events_total = get_table_event_count(con_source, "SRTInIcePulses")

        event_offset = 0
        for i in range(0, N_events_total, self.events_per_file):
            N_events = min(self.events_per_file, N_events_total - i)
            event_mappings = []
            
            # Generate target file name
            target_file = os.path.join(target_subdir, f"{db_file}_{i//self.events_per_file}.parquet")
            
            # Fetch and process data
            pa_pmtfied, pa_truth = runPMTfication_DB_Parquet(source_file, "SRTInIcePulses", self.dom_ref_pos_file, N_events)
            
            # Shift event numbers and log original-to-new mappings
            for idx in range(len(pa_pmtfied.column('event_no'))):
                original_event_no = pa_pmtfied.column('event_no')[idx].as_py()
                new_event_no = self.shift_event_no(original_event_no, subdirectory)
                event_mappings.append({'original_event_no': original_event_no, 'event_no': new_event_no})
                pa_pmtfied = pa_pmtfied.set_column(
                    pa_pmtfied.schema.get_field_index('event_no'), 'event_no', pa.array([new_event_no])
                )
            
            # Save processed data and generate receipt
            pq.write_table(pa_pmtfied, target_file)
            self.generate_receipt(db_file, event_mappings, target_file)
            
            event_offset += N_events

        con_source.close()

    def process_all_files(self):
        db_files = []
        for root, _, files in os.walk(self.source_dir):
            for file in files:
                if file.endswith('.db'):
                    db_files.append(os.path.relpath(os.path.join(root, file), self.source_dir))

        for subdirectory, db_file in enumerate(tqdm(db_files, desc="Processing database files")):
            self.process_file_in_batches(db_file, subdirectory)

# Usage
processor = BatchPMTficationProcessor(
    source_dir='/path/to/source',
    target_dir='/path/to/target',
    dom_ref_pos_file='/path/to/ref_str_dom_pos',
    events_per_file=2000
)
processor.process_all_files()
