In [1]:
import sqlite3 as sql
import pandas as pd
import sys
import os
import matplotlib.pyplot as plt
import numpy as np

import scipy.optimize as optimize
from scipy import stats as sci
import math
from pprint import pprint

from collections import defaultdict
from typing import List
from tqdm import tqdm

import pyarrow as pa
import pyarrow.parquet as pq

In [2]:
sys.path.append('/groups/icecube/cyan/Utils')
from PlotUtils import setMplParam, getColour, getHistoParam 
# getHistoParam:
# Nbins, binwidth, bins, counts, bin_centers  = 
from DB_lister import list_content, list_tables
from ExternalFunctions import nice_string_output, add_text_to_ax
setMplParam()

In [3]:
def get_table_row_count(conn: sql.Connection, table: str) -> int:
    cursor = conn.cursor()
    cursor.execute(f"SELECT COUNT(*) FROM {table}")
    row_count = cursor.fetchone()[0]
    return row_count

In [4]:
def convertDBtoDF(file:str, table:str, Nlines_model:int = None) -> pd.DataFrame:
    con = sql.connect(file)
    if Nlines_model is None:
        Nlines_model = get_table_row_count(con, table)
    query = f'SELECT * FROM {table} LIMIT {Nlines_model}'
    df = pd.read_sql_query(query, con)
    con.close()
    return df

In [5]:
def convertDFtoDB(file:str, table:str, df: pd.DataFrame) -> None:
    con = sql.connect(file)
    df.to_sql(table, con, if_exists='replace', index=False)
    con.close()

In [6]:
def load_reference_data(filepath: str) -> np.ndarray:
    df = pd.read_csv(filepath)
    return df.values  # Convert the DataFrame to a NumPy array

In [7]:
def addStringAndDOMtoDB(con_source: sql.Connection, 
                        source_table: str,
                        reference_data: np.ndarray,
                        Nlines_model: int,
                        tolerance_xy: float = 10,
                        tolerance_z: float = 2) -> None:
    cur_source = con_source.cursor()
    cur_source.execute(f"PRAGMA table_info({source_table})")
    existing_columns = [col[1] for col in cur_source.fetchall()]
    
    if 'string' not in existing_columns:
        cur_source.execute(f"ALTER TABLE {source_table} ADD COLUMN string INTEGER")
    if 'dom_number' not in existing_columns:
        cur_source.execute(f"ALTER TABLE {source_table} ADD COLUMN dom_number INTEGER")
    
    cur_source.execute(f"SELECT rowid, dom_x, dom_y, dom_z FROM {source_table} WHERE string IS NULL OR dom_number IS NULL LIMIT {Nlines_model}")
    rows_to_update = cur_source.fetchall()
    
    for row in rows_to_update:
        row_id, dom_x, dom_y, dom_z = row
        
        matches_xy = reference_data[
            (np.abs(reference_data[:, 2] - dom_x) <= tolerance_xy) &
            (np.abs(reference_data[:, 3] - dom_y) <= tolerance_xy)
        ]
        
        if len(matches_xy) > 0:
            match_z = matches_xy[np.abs(matches_xy[:, 4] - dom_z) <= tolerance_z]
            
            if len(match_z) > 0:
                string_val = int(match_z[0, 0])
                dom_number_val = int(match_z[0, 1])
                
                cur_source.execute(f"UPDATE {source_table} SET string = ?, dom_number = ? WHERE rowid = ?", (string_val, dom_number_val, row_id))

    con_source.commit()

In [8]:
def getTruthTableNameDB(con_source: sql.Connection) -> str:
    cur_source = con_source.cursor()
    cur_source.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = [row[0] for row in cur_source.fetchall()]
    if 'truth' in tables:
        truth_table = 'truth'
    elif 'Truth' in tables:
        truth_table = 'Truth'
    else:
        raise ValueError("Neither 'truth' nor 'Truth' table exists in the source database.")
    return truth_table

In [9]:
def addTruthTableDB(con_in: sql.Connection, 
                con_out: sql.Connection, 
                Nlines_model: int) -> None:
    truth_table_name = getTruthTableNameDB(con_in)
    cur_in = con_in.cursor()
    cur_out = con_out.cursor()

    query = f"SELECT * FROM {truth_table_name} LIMIT {Nlines_model}"
    cur_in.execute(query)
    rows = cur_in.fetchall()

    cur_in.execute(f"PRAGMA table_info({truth_table_name})")
    schema_info = cur_in.fetchall()
    create_table_query = f"CREATE TABLE IF NOT EXISTS {truth_table_name} ({', '.join([f'{col[1]} {col[2]}' for col in schema_info])})"
    cur_out.execute(create_table_query)

    placeholders = ', '.join(['?'] * len(schema_info))
    insert_query = f"INSERT INTO {truth_table_name} VALUES ({placeholders})"
    cur_out.executemany(insert_query, rows)

    con_out.commit()

* `PMTficationCore_DB_DB` logical PMTfication core
  * get reference dom positiom
  * add string and dom number to the source
  * 

In [21]:
def PMTficationCore_DB_DB(con_source: sql.Connection, con_target: sql.Connection,
                source_table: str, target_table: str, 
                dom_ref_pos_file: str, Nlines_model: int) -> None:
    # if the source misses the string and dom_number columns, add them
    # HACK add if condition to check if the columns are already present
    dom_ref_pos = load_reference_data(dom_ref_pos_file)
    addStringAndDOMtoDB(con_source, source_table, dom_ref_pos, Nlines_model)
    
    query = f'SELECT * FROM {source_table} LIMIT {Nlines_model}'
    cur_source = con_source.cursor()
    cur_source.execute(query)
    rows = cur_source.fetchall()
    columns = [description[0] for description in cur_source.description]

    event_no_idx = columns.index('event_no')
    dom_string_idx = columns.index('string')
    dom_number_idx = columns.index('dom_number')
    dom_x_idx = columns.index('dom_x')
    dom_y_idx = columns.index('dom_y')
    dom_z_idx = columns.index('dom_z')
    dom_time_idx = columns.index('dom_time')
    dom_hlc_idx = columns.index('hlc')
    dom_charge_idx = columns.index('charge')
    pmt_area_idx = columns.index('pmt_area')
    rde_idx = columns.index('rde')
    saturation_status_idx = columns.index('is_saturated_dom')

    cur_target = con_target.cursor()
    cur_target.execute(f"DROP TABLE IF EXISTS {target_table}")
    
    cur_target.execute(f'''
    CREATE TABLE IF NOT EXISTS {target_table} (
        event_no INTEGER,
        dom_string INTEGER,
        dom_number INTEGER,
        dom_x REAL,
        dom_y REAL,
        dom_z REAL,
        dom_x_rel REAL,
        dom_y_rel REAL,
        dom_z_rel REAL,
        pmt_area REAL,
        rde REAL,
        saturation_status INTEGER,
        q1 REAL,
        q2 REAL,
        q3 REAL,
        Q25 REAL,
        Q75 REAL,
        Qtotal REAL,
        hlc1 REAL,
        hlc2 REAL,
        hlc3 REAL,
        t1 REAL,
        t2 REAL,
        t3 REAL,
        T10 REAL,
        T50 REAL,
        sigmaT REAL 
    )
    ''')
    
    def getMaxQtotal(all_pulses_event: List[List[List[float]]]) -> float:
        Qsums = [sum([pulse[dom_charge_idx] for pulse in pulses]) for pulses in all_pulses_event]
        return max(Qsums)
    
    def getQweightedAverageDOMposition(all_pulses_event: List[List[List[float]]], maxQtotal: float) -> List[float]:
        dom_x = [pulse[dom_x_idx] for pulses_dom in all_pulses_event for pulse in pulses_dom]
        dom_y = [pulse[dom_y_idx] for pulses_dom in all_pulses_event for pulse in pulses_dom]
        dom_z = [pulse[dom_z_idx] for pulses_dom in all_pulses_event for pulse in pulses_dom]
        charge_sums = [pulse[dom_charge_idx] for pulses_dom in all_pulses_event for pulse in pulses_dom]

        weighted_x = np.mean([x * charge / maxQtotal for x, charge in zip(dom_x, charge_sums)])
        weighted_y = np.mean([y * charge / maxQtotal for y, charge in zip(dom_y, charge_sums)])
        weighted_z = np.mean([z * charge / maxQtotal for z, charge in zip(dom_z, charge_sums)])

        return [weighted_x, weighted_y, weighted_z]
        
    def getRelativeDOMposition(dom_x: float, dom_y: float, dom_z: float, avg_dom_position: List[float]) -> List[float]:
        return [dom_x - avg_dom_position[0], dom_y - avg_dom_position[1], dom_z - avg_dom_position[2]]
    
    # NOTE pulses_dom: [pulse, ...]
    def getDOMposition(pulses_dom: List[List[float]]) -> List[float]:
        return [pulses_dom[0][dom_x_idx], pulses_dom[0][dom_y_idx], pulses_dom[0][dom_z_idx]]
    
    def getDOMstring(pulses_dom: List[List[float]]) -> int:
        return pulses_dom[0][dom_string_idx]
    
    def getDOMnumber(pulses_dom: List[List[float]]) -> int:
        return pulses_dom[0][dom_number_idx]
    
    def getPmtArea(pulses_dom: List[List[float]]) -> float:
        return pulses_dom[0][pmt_area_idx]
    
    def getRDE(pulses_dom: List[List[float]]) -> float:
        return pulses_dom[0][rde_idx]
    
    def getSaturationStatus(pulses_dom: List[List[float]]) -> int:
        return pulses_dom[0][saturation_status_idx]
    
    def getFirstHlc(pulses_dom: List[List[float]]) -> List[int]:
        n = 3
        _fillIncomplete = -1
        if len(pulses_dom) < n:
            hlc = [pulse[dom_hlc_idx] for pulse in pulses_dom]
            hlc.extend([_fillIncomplete] * (n - len(hlc)))
        else:
            hlc = [pulse[dom_hlc_idx] for pulse in pulses_dom[:n]]
        return hlc
    
    def getFirstPulseTime(pulses_dom: List[List[float]], saturationStatus: int) -> List[float]:
        n = 3
        # HACK consider changing the fill values
        _fillSaturated = -1
        _fillIncomplete = -1
        
        if saturationStatus == 1:
            pulse_times = [_fillSaturated] * n
        elif len(pulses_dom) < n:
            pulse_times = [pulse[dom_time_idx] for pulse in pulses_dom]
            pulse_times.extend([_fillIncomplete] * (n - len(pulse_times)))
        else:
            pulse_times = [pulse[dom_time_idx] for pulse in pulses_dom[:n]]
        return pulse_times
    
    # HACK necessary?
    def getFirstHlcPulseTime(pulses_dom: List[List[float]], saturationStatus: int) -> List[float]:
        n = 3
        _fillSaturated = -1
        _fillIncomplete = -1
        if saturationStatus == 1:
            pulse_times = [_fillSaturated] * n
        elif len(pulses_dom) < n:
            pulse_times = [pulse[dom_time_idx] for pulse in pulses_dom if pulse[dom_hlc_idx] == 1]
            pulse_times.extend([_fillIncomplete] * (n - len(pulse_times)))
        else:
            pulse_times = [pulse[dom_time_idx] for pulse in pulses_dom if pulse[dom_hlc_idx] == 1][:n]
        return pulse_times
        
    def getElapsedTimeUntilChargeFraction(pulses_dom: List[List[float]], saturationStatus: int, percentile1 = 10, percentile2 = 50) -> List[float]:
        # HACK consider changing the fill values
        _fillSaturated = -1
        _fillIncomplete = -1
        if saturationStatus == 1:
            times = [_fillSaturated] * 2
        elif len(pulses_dom) < 2:
            times = [_fillIncomplete] * 2
        else:
            Qtotal = sum([pulse[dom_charge_idx] for pulse in pulses_dom])
            t_0 = pulses_dom[0][dom_time_idx]
            Qcum = 0
            T_first, T_second = -1, -1 # if these are not -1, then they are assigned
            for pulse in pulses_dom:
                Qcum += pulse[dom_charge_idx]
                if Qcum > percentile1 / 100 * Qtotal and T_first == -1:
                    T_first = pulse[dom_time_idx] - t_0
                if Qcum > percentile2 / 100 * Qtotal:
                    T_second = pulse[dom_time_idx] - t_0
                    break
            times = [T_first, T_second]
        return times
    
    def getStandardDeviation(pulse_times: List[float], saturationStatus: int) -> float:
        # HACK consider changing the fill values
        _fillSaturated = 0
        _fillIncomplete = 0
        if saturationStatus == 1:
            sigmaT = _fillSaturated
        elif len(pulse_times) < 2:
            sigmaT = _fillIncomplete
        else:
            sigmaT = np.std(pulse_times)
        return sigmaT
    
    def getFirstChargeReadout(pulses: List[List[float]], saturationStatus: int) -> List[float]:
        # HACK consider changing the fill values
        _fillSaturated = -1
        _fillIncomplete = -1
        n = 3
        if saturationStatus == 1:
            charge_readouts = [_fillSaturated] * n
        elif len(pulses) < n:
            charge_readouts = [pulse[dom_charge_idx] for pulse in pulses]
            charge_readouts.extend([_fillIncomplete] * (n - len(charge_readouts)))
        else:
            charge_readouts = [pulse[dom_charge_idx] for pulse in pulses[:n]]
        return charge_readouts
    
    def getAccumulatedChargeAfterNanoSec(pulses: List[List[float]], saturationStatus: int, interval1 = 25, interval2 = 75) -> List[float]:
        # HACK consider changing the fill values
        _fillSaturated = -1
        _fillIncomplete = -1
        if saturationStatus == 1:
            Qs = [_fillSaturated] * 3
        elif len(pulses) < 1:
            Qs = [_fillIncomplete] * 3
        else:
            Qtotal = sum([pulse[dom_charge_idx] for pulse in pulses])
            t_0 = pulses[0][dom_time_idx]
            Qinterval1 = sum([pulse[dom_charge_idx] for pulse in pulses if pulse[dom_time_idx] - t_0 < interval1])
            Qinterval2 = sum([pulse[dom_charge_idx] for pulse in pulses if pulse[dom_time_idx] - t_0 < interval2])
            Qs = [Qinterval1, Qinterval2, Qtotal]
        return Qs
    
    def processDOM(pulses: List[List[float]], avg_dom_position: List[float]):
        dom_string = getDOMstring(pulses)
        dom_number = getDOMnumber(pulses)
        dom_x, dom_y, dom_z = getDOMposition(pulses)
        dom_x_rel, dom_y_rel, dom_z_rel = getRelativeDOMposition(dom_x, dom_y, dom_z, avg_dom_position)
        pmt_area = getPmtArea(pulses)
        rde = getRDE(pulses)
        saturation_status = getSaturationStatus(pulses)
        
        # Get remaining features
        first_three_charge_readout = getFirstChargeReadout(pulses, saturation_status)
        accumulated_charge_after_nano_sec = getAccumulatedChargeAfterNanoSec(pulses, saturation_status)
        first_three_pulse_time = getFirstPulseTime(pulses, saturation_status)
        # first_three_hlc_pulse_time = getFirstHlcPulseTime(pulses, saturation_status)
        first_three_hlc = getFirstHlc(pulses)
        elapsed_time_until_charge_fraction = getElapsedTimeUntilChargeFraction(pulses, saturation_status)
        standard_deviation = getStandardDeviation([pulse[dom_time_idx] for pulse in pulses], saturation_status)
        
        data_dom = ([dom_string, dom_number]            # dom_number
                    + [dom_x, dom_y, dom_z]             # dom_x, dom_y, dom_z
                    + [dom_x_rel, dom_y_rel, dom_z_rel] # dom_x_rel, dom_y_rel, dom_z_rel 
                    + [pmt_area, rde, saturation_status]# pmt_area, rde, saturationStatus
                    + first_three_charge_readout        # q1, q2, q3
                    + accumulated_charge_after_nano_sec # Q25, Q75, Qtotal
                    + first_three_hlc                   # hlc1, hlc2, hlc3
                    + first_three_pulse_time            # t1, t2, t3
                    # + first_three_hlc_pulse_time        # t1_hlc, t2_hlc, t3_hlc
                    + elapsed_time_until_charge_fraction# T10, T50
                    + [standard_deviation]              # sigmaT
                    )
        return data_dom
    
    events_doms_pulses = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    for row in rows:
        event_no = row[event_no_idx]
        string = row[dom_string_idx]
        dom_number = row[dom_number_idx]
        events_doms_pulses[event_no][string][dom_number].append(row)
    
    # print(f"Processing {len(events_doms_pulses)} unique events")
    # NOTE data structure
    # events_doms_pulses  : {event_no: {string: {dom_number: [pulse, ...], ...}, ...}, ...}
    # strings_doms_pulses :            {string: {dom_number: [pulse, ...], ...}, ...}
    # doms_pulses         :                     {dom_number: [pulse, ...], ...}
    # pulses              :                                  [pulse, ...]
    for event_no, strings_doms_pulses in events_doms_pulses.items():
        for doms_pulses in strings_doms_pulses.values():
            all_pulses_event = list(doms_pulses.values())
            maxQtotal = getMaxQtotal(all_pulses_event)
            avg_dom_position = getQweightedAverageDOMposition(all_pulses_event, maxQtotal)
            for pulses in doms_pulses.values():
                dom_data = [event_no] + processDOM(pulses, avg_dom_position)
                cur_target.execute(f'INSERT INTO {target_table} VALUES ({",".join(["?"]*len(dom_data))})', dom_data)
    con_target.commit()
    addTruthTableDB(con_source, con_target, Nlines_model)

* `runPMTfication_DB_DB`: additional layer for sqlite connection and close

In [22]:
def runPMTfication_DB_DB(source_file: str, target_file: str, 
                        source_table: str, target_table: str, 
                        dom_ref_pos_file: str, 
                        Nlines_model: int = None) -> None:
    con_source = None
    con_target = None
    try:
        con_source = sql.connect(source_file)
        con_target = sql.connect(target_file)
        
        # limit the number of rows
        if Nlines_model is None:
            Nlines_model = get_table_row_count(con_source, source_table)
        
        PMTficationCore_DB_DB(con_source, con_target, source_table, target_table, dom_ref_pos_file, Nlines_model)
        
    except Exception as e:
        print(f"An error occurred during the PMTfication: {e}")
        raise e
    finally:
        if con_source:
            con_source.close()
        if con_target:
            con_target.close()

In [10]:
test_dir = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/"
test_source_DB_dir = test_dir + "testSource/"
test_PMTfied_DB_dir = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/PMTfiedDB/"
test_PMTfied_parquet_dir = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/PMTfiedParquet/"

QA_DB = test_source_DB_dir + "Level2_NuE_NuGenCCNC.022015.000110.db"
QA_DB_PMTfied = test_PMTfied_DB_dir + "voila.db"

ref_str_dom_pos = "/groups/icecube/cyan/factory/DOMification/unique_string_dom_completed.csv"

In [11]:
list_tables(QA_DB)

['truth', 'SRTInIcePulses', 'event_no_SRTInIcePulses', 'OnlineL2_BestFit', 'OnlineL2_SplineMPE', 'LineFit', 'GNHighestEInIceParticle', 'GNHighestEDaughter', 'MCWeightDict', 'SnowStormParameters']


In [89]:
runPMTfication_DB_DB(
    source_file=QA_DB, 
    target_file=QA_DB_PMTfied, 
    source_table="SRTInIcePulses",
    target_table="PMTsummarised",
    dom_ref_pos_file=ref_str_dom_pos,
)

In [90]:
def batch_PMTfication_DB_DB(source_DB_dir, target_DB_dir):
    db_files = [f for f in os.listdir(source_DB_dir) if f.endswith('.db')]
    for db_file in tqdm(db_files, desc="Processing..."):
        source_DB = os.path.join(source_DB_dir, db_file)
        base_name = os.path.splitext(db_file)[0]
        target_DB = os.path.join(target_DB_dir, f"{base_name}_PMTfied.db")
        source_table = "SRTInIcePulses"
        target_table = "PMTsummarised"
        ref_str_dom_pos = "/groups/icecube/cyan/factory/DOMification/unique_string_dom_completed.csv"
        runPMTfication_DB_DB(
            source_file=source_DB, 
            target_file=target_DB, 
            source_table=source_table,
            target_table=target_table,
            dom_ref_pos_file=ref_str_dom_pos,
        )

In [91]:
batch_PMTfication_DB_DB(source_DB_dir, test_PMTfied_DB_dir)
# 2 min 3.9 sec

Processing...: 100%|██████████| 30/30 [02:03<00:00,  4.13s/it]


In [15]:
# def addStringAndDOMtoDF(df: pd.DataFrame, 
#                         reference_data: np.ndarray,
#                         Nlines_model: int,
#                         tolerance_xy: float = 10,
#                         tolerance_z: float = 2,) -> pd.DataFrame:
#     if 'string' not in df.columns:
#         df['string'] = np.nan
#     if 'dom_number' not in df.columns:
#         df['dom_number'] = np.nan
    
#     # Filter rows where 'string' or 'dom_number' is NaN (null in DataFrame terms)
#     rows_to_update = df[df['string'].isna() | df['dom_number'].isna()].iloc[:Nlines_model]
    
#     for idx, row in rows_to_update.iterrows():
#         dom_x, dom_y, dom_z = row['dom_x'], row['dom_y'], row['dom_z']
        
#         # Find matches in reference_data within the xy tolerance
#         matches_xy = reference_data[
#             (np.abs(reference_data[:, 2] - dom_x) <= tolerance_xy) &
#             (np.abs(reference_data[:, 3] - dom_y) <= tolerance_xy)
#         ]
        
#         if len(matches_xy) > 0:
#             match_z = matches_xy[np.abs(matches_xy[:, 4] - dom_z) <= tolerance_z]
            
#             if len(match_z) > 0:
#                 string_val = int(match_z[0, 0])
#                 dom_number_val = int(match_z[0, 1])
                
#                 df.at[idx, 'string'] = string_val
#                 df.at[idx, 'dom_number'] = dom_number_val
    
#     return df

In [25]:
def getTruthPA(con_source: sql.Connection, Nlines_model: int) -> pa.Table:
    truth_table_name = getTruthTableNameDB(con_source)
    cur_source = con_source.cursor()
    cur_source.execute(f"PRAGMA table_info({truth_table_name})")
    schema_info = cur_source.fetchall()
    query = f"SELECT * FROM {truth_table_name} LIMIT {Nlines_model}"
    cur_source.execute(query)
    rows = cur_source.fetchall()
    truth_data_dict = {col[1]: [row[i] for row in rows] for i, col in enumerate(schema_info)}
    truth_table_pa = pa.Table.from_pydict(truth_data_dict)
    return truth_table_pa

* this conversion from db to parquet inevitably requires use of pandas dataframe.
* it would be much desirable if the data is directly converted from I3 to PMTfied parquet.

In [26]:
def getPMTfiedPA(con_source: sql.Connection, 
                source_table: str,
                dom_ref_pos_file: str,
                Nlines_model: int) -> pa.Table:
    dom_ref_pos = load_reference_data(dom_ref_pos_file)
    addStringAndDOMtoDB(con_source, source_table, dom_ref_pos, Nlines_model)
    
    query = f'SELECT * FROM {source_table} LIMIT {Nlines_model}'
    cur_source = con_source.cursor()
    cur_source.execute(query)
    rows = cur_source.fetchall()
    columns = [description[0] for description in cur_source.description]
    
    event_no_idx = columns.index('event_no')
    dom_string_idx = columns.index('string')
    dom_number_idx = columns.index('dom_number')
    dom_x_idx = columns.index('dom_x')
    dom_y_idx = columns.index('dom_y')
    dom_z_idx = columns.index('dom_z')
    dom_time_idx = columns.index('dom_time')
    dom_hlc_idx = columns.index('hlc')
    dom_charge_idx = columns.index('charge')
    pmt_area_idx = columns.index('pmt_area')
    rde_idx = columns.index('rde')
    saturation_status_idx = columns.index('is_saturated_dom')
    
    def getMaxQtotal(all_pulses_event: List[List[List[float]]]) -> float:
        Qsums = [sum([pulse[dom_charge_idx] for pulse in pulses]) for pulses in all_pulses_event]
        return max(Qsums)
    
    def getQweightedAverageDOMposition(all_pulses_event: List[List[List[float]]], maxQtotal: float) -> List[float]:
        dom_x = [pulse[dom_x_idx] for pulses_dom in all_pulses_event for pulse in pulses_dom]
        dom_y = [pulse[dom_y_idx] for pulses_dom in all_pulses_event for pulse in pulses_dom]
        dom_z = [pulse[dom_z_idx] for pulses_dom in all_pulses_event for pulse in pulses_dom]
        charge_sums = [pulse[dom_charge_idx] for pulses_dom in all_pulses_event for pulse in pulses_dom]

        weighted_x = np.mean([x * charge / maxQtotal for x, charge in zip(dom_x, charge_sums)])
        weighted_y = np.mean([y * charge / maxQtotal for y, charge in zip(dom_y, charge_sums)])
        weighted_z = np.mean([z * charge / maxQtotal for z, charge in zip(dom_z, charge_sums)])

        return [weighted_x, weighted_y, weighted_z]
        
    def getRelativeDOMposition(dom_x: float, dom_y: float, dom_z: float, avg_dom_position: List[float]) -> List[float]:
        return [dom_x - avg_dom_position[0], dom_y - avg_dom_position[1], dom_z - avg_dom_position[2]]
    
    # NOTE pulses_dom: [pulse, ...]
    def getDOMposition(pulses_dom: List[List[float]]) -> List[float]:
        return [pulses_dom[0][dom_x_idx], pulses_dom[0][dom_y_idx], pulses_dom[0][dom_z_idx]]
    
    def getDOMstring(pulses_dom: List[List[float]]) -> int:
        return pulses_dom[0][dom_string_idx]
    
    def getDOMnumber(pulses_dom: List[List[float]]) -> int:
        return pulses_dom[0][dom_number_idx]
    
    def getPmtArea(pulses_dom: List[List[float]]) -> float:
        return pulses_dom[0][pmt_area_idx]
    
    def getRDE(pulses_dom: List[List[float]]) -> float:
        return pulses_dom[0][rde_idx]
    
    def getSaturationStatus(pulses_dom: List[List[float]]) -> int:
        return pulses_dom[0][saturation_status_idx]
    
    def getFirstHlc(pulses_dom: List[List[float]]) -> List[int]:
        n = 3
        _fillIncomplete = -1
        if len(pulses_dom) < n:
            hlc = [pulse[dom_hlc_idx] for pulse in pulses_dom]
            hlc.extend([_fillIncomplete] * (n - len(hlc)))
        else:
            hlc = [pulse[dom_hlc_idx] for pulse in pulses_dom[:n]]
        return hlc
    
    def getFirstPulseTime(pulses_dom: List[List[float]], saturationStatus: int) -> List[float]:
        n = 3
        # HACK consider changing the fill values
        _fillSaturated = -1
        _fillIncomplete = -1
        
        if saturationStatus == 1:
            pulse_times = [_fillSaturated] * n
        elif len(pulses_dom) < n:
            pulse_times = [pulse[dom_time_idx] for pulse in pulses_dom]
            pulse_times.extend([_fillIncomplete] * (n - len(pulse_times)))
        else:
            pulse_times = [pulse[dom_time_idx] for pulse in pulses_dom[:n]]
        return pulse_times
    
    # HACK necessary?
    def getFirstHlcPulseTime(pulses_dom: List[List[float]], saturationStatus: int) -> List[float]:
        n = 3
        _fillSaturated = -1
        _fillIncomplete = -1
        if saturationStatus == 1:
            pulse_times = [_fillSaturated] * n
        elif len(pulses_dom) < n:
            pulse_times = [pulse[dom_time_idx] for pulse in pulses_dom if pulse[dom_hlc_idx] == 1]
            pulse_times.extend([_fillIncomplete] * (n - len(pulse_times)))
        else:
            pulse_times = [pulse[dom_time_idx] for pulse in pulses_dom if pulse[dom_hlc_idx] == 1][:n]
        return pulse_times
        
    def getElapsedTimeUntilChargeFraction(pulses_dom: List[List[float]], saturationStatus: int, percentile1 = 10, percentile2 = 50) -> List[float]:
        # HACK consider changing the fill values
        _fillSaturated = -1
        _fillIncomplete = -1
        if saturationStatus == 1:
            times = [_fillSaturated] * 2
        elif len(pulses_dom) < 2:
            times = [_fillIncomplete] * 2
        else:
            Qtotal = sum([pulse[dom_charge_idx] for pulse in pulses_dom])
            t_0 = pulses_dom[0][dom_time_idx]
            Qcum = 0
            T_first, T_second = -1, -1 # if these are not -1, then they are assigned
            for pulse in pulses_dom:
                Qcum += pulse[dom_charge_idx]
                if Qcum > percentile1 / 100 * Qtotal and T_first == -1:
                    T_first = pulse[dom_time_idx] - t_0
                if Qcum > percentile2 / 100 * Qtotal:
                    T_second = pulse[dom_time_idx] - t_0
                    break
            times = [T_first, T_second]
        return times
    
    def getStandardDeviation(pulse_times: List[float], saturationStatus: int) -> float:
        # HACK consider changing the fill values
        _fillSaturated = 0
        _fillIncomplete = 0
        if saturationStatus == 1:
            sigmaT = _fillSaturated
        elif len(pulse_times) < 2:
            sigmaT = _fillIncomplete
        else:
            sigmaT = np.std(pulse_times)
        return sigmaT
    
    def getFirstChargeReadout(pulses: List[List[float]], saturationStatus: int) -> List[float]:
        # HACK consider changing the fill values
        _fillSaturated = -1
        _fillIncomplete = -1
        n = 3
        if saturationStatus == 1:
            charge_readouts = [_fillSaturated] * n
        elif len(pulses) < n:
            charge_readouts = [pulse[dom_charge_idx] for pulse in pulses]
            charge_readouts.extend([_fillIncomplete] * (n - len(charge_readouts)))
        else:
            charge_readouts = [pulse[dom_charge_idx] for pulse in pulses[:n]]
        return charge_readouts
    
    def getAccumulatedChargeAfterNanoSec(pulses: List[List[float]], saturationStatus: int, interval1 = 25, interval2 = 75) -> List[float]:
        # HACK consider changing the fill values
        _fillSaturated = -1
        _fillIncomplete = -1
        if saturationStatus == 1:
            Qs = [_fillSaturated] * 3
        elif len(pulses) < 1:
            Qs = [_fillIncomplete] * 3
        else:
            Qtotal = sum([pulse[dom_charge_idx] for pulse in pulses])
            t_0 = pulses[0][dom_time_idx]
            Qinterval1 = sum([pulse[dom_charge_idx] for pulse in pulses if pulse[dom_time_idx] - t_0 < interval1])
            Qinterval2 = sum([pulse[dom_charge_idx] for pulse in pulses if pulse[dom_time_idx] - t_0 < interval2])
            Qs = [Qinterval1, Qinterval2, Qtotal]
        return Qs
    
    def processDOM(pulses: List[List[float]], avg_dom_position: List[float]):
        dom_string = getDOMstring(pulses)
        dom_number = getDOMnumber(pulses)
        dom_x, dom_y, dom_z = getDOMposition(pulses)
        dom_x_rel, dom_y_rel, dom_z_rel = getRelativeDOMposition(dom_x, dom_y, dom_z, avg_dom_position)
        pmt_area = getPmtArea(pulses)
        rde = getRDE(pulses)
        saturation_status = getSaturationStatus(pulses)
        
        # Get remaining features
        first_three_charge_readout = getFirstChargeReadout(pulses, saturation_status)
        accumulated_charge_after_nano_sec = getAccumulatedChargeAfterNanoSec(pulses, saturation_status)
        first_three_pulse_time = getFirstPulseTime(pulses, saturation_status)
        # first_three_hlc_pulse_time = getFirstHlcPulseTime(pulses, saturation_status)
        first_three_hlc = getFirstHlc(pulses)
        elapsed_time_until_charge_fraction = getElapsedTimeUntilChargeFraction(pulses, saturation_status)
        standard_deviation = getStandardDeviation([pulse[dom_time_idx] for pulse in pulses], saturation_status)
        
        data_dom = ([dom_string, dom_number]            # dom_number
                    + [dom_x, dom_y, dom_z]             # dom_x, dom_y, dom_z
                    + [dom_x_rel, dom_y_rel, dom_z_rel] # dom_x_rel, dom_y_rel, dom_z_rel
                    + [pmt_area, rde, saturation_status]# pmt_area, rde, saturationStatus
                    + first_three_charge_readout        # q1, q2, q3
                    + accumulated_charge_after_nano_sec # Q25, Q75, Qtotal
                    + first_three_hlc                   # hlc1, hlc2, hlc3
                    + first_three_pulse_time            # t1, t2, t3
                    # + first_three_hlc_pulse_time        # t1_hlc, t2_hlc, t3_hlc
                    + elapsed_time_until_charge_fraction# T10, T50
                    + [standard_deviation]              # sigmaT
                    )
        return data_dom            
    # original data
    events_doms_pulses = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    # new data
    processed_data = []
    for row in rows:
        event_no = row[event_no_idx]
        string = row[dom_string_idx]
        dom_number = row[dom_number_idx]
        events_doms_pulses[event_no][string][dom_number].append(row)
    
    # print(f"Processing {len(events_doms_pulses)} unique events")
    # NOTE data structure
    # events_doms_pulses  : {event_no: {string: {dom_number: [pulse, ...], ...}, ...}, ...}
    # strings_doms_pulses :            {string: {dom_number: [pulse, ...], ...}, ...}
    # doms_pulses         :                     {dom_number: [pulse, ...], ...}
    # pulses              :                                  [pulse, ...]
    for event_no, strings_doms_pulses in events_doms_pulses.items():
        for doms_pulses in strings_doms_pulses.values():
            # Convert the values to a list of pulses (rows)
            all_pulses_event = list(doms_pulses.values())
            maxQtotal = getMaxQtotal(all_pulses_event)
            avg_dom_position = getQweightedAverageDOMposition(all_pulses_event, maxQtotal)
            for pulses in doms_pulses.values():
                dom_data = [event_no] + processDOM(pulses, avg_dom_position)
                processed_data.append(dom_data)

    # Convert the processed data into a DataFrame for easier handling
    df_processed = pd.DataFrame(processed_data, columns=[
        'event_no', 'dom_string', 'dom_number', 'dom_x', 'dom_y', 'dom_z',
        'dom_x_rel', 'dom_y_rel', 'dom_z_rel', 'pmt_area', 'rde',
        'saturation_status', 'q1', 'q2', 'q3', 'Q25', 'Q75', 'Qtotal',
        'hlc1', 'hlc2', 'hlc3', 't1', 't2', 't3', 'T10', 'T50', 'sigmaT'
    ])
    pa_processed = pa.Table.from_pandas(df_processed)
    return pa_processed  

* `runPMTfication_DB_Parquet` layer is intended to intervene the process of sqlite connection and close

In [27]:
def runPMTfication_DB_Parquet(source_file: str, 
                            source_table: str, 
                            dom_ref_pos_file: str,
                            Nlines_model: int = None) -> (pa.Table, pa.Table):
    con_source = None
    try:
        base_name = os.path.splitext(os.path.basename(source_file))[0]
        
        con_source = sql.connect(source_file)
        
        if Nlines_model is None:
            Nlines_model = get_table_row_count(con_source, source_table)
        
        # Get the PyArrow tables for PMTfied and truth data
        pa_pmtfied = getPMTfiedPA(con_source, source_table, dom_ref_pos_file, Nlines_model)
        pa_truth = getTruthPA(con_source, Nlines_model)
        
    except Exception as e:
        print(f"An error occurred during the PMTfication: {e}")
        raise e
    
    finally:
        if con_source:
            con_source.close()

    return pa_pmtfied, pa_truth

```python
test_dir = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/"
test_source_DB_dir = test_dir + "testSource/"
test_PMTfied_DB_dir = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/PMTfiedDB/"
test_PMTfied_parquet_dir = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/PMTfiedParquet/"

QA_DB = test_source_DB_dir + "Level2_NuE_NuGenCCNC.022015.000110.db"
QA_DB_PMTfied = test_PMTfied_DB_dir + "voila.db"

ref_str_dom_pos = "/groups/icecube/cyan/factory/DOMification/unique_string_dom_completed.csv"
```

In [34]:
runPMTfication_DB_Parquet(
    source_file = QA_DB,
    source_table = "SRTInIcePulses",
    dom_ref_pos_file = ref_str_dom_pos,
)

(pyarrow.Table
 event_no: int64
 dom_string: double
 dom_number: double
 dom_x: double
 dom_y: double
 dom_z: double
 dom_x_rel: double
 dom_y_rel: double
 dom_z_rel: double
 pmt_area: double
 rde: double
 saturation_status: double
 q1: double
 q2: double
 q3: double
 Q25: double
 Q75: double
 Qtotal: double
 hlc1: double
 hlc2: double
 hlc3: double
 t1: double
 t2: double
 t3: double
 T10: double
 T50: double
 sigmaT: double
 ----
 event_no: [[550,550,550,550,550,...,571,571,571,571,571]]
 dom_string: [[7,7,14,14,14,...,74,74,74,74,74]]
 dom_number: [[56,57,46,47,49,...,26,27,28,30,32]]
 dom_x: [[-334.8,-334.8,-413.46,-413.46,-413.46,...,338.44,338.44,338.44,338.44,338.44]]
 dom_y: [[-424.5,-424.5,-327.27,-327.27,-327.27,...,463.72,463.72,463.72,463.72,463.72]]
 dom_z: [[-435.55,-452.57,-266.34,-283.36,-317.4,...,80.21,63.19,46.16,12.12,-21.92]]
 dom_x_rel: [[-59.08235363168285,-59.08235363168285,-121.01267852870848,-121.01267852870848,-121.01267852870848,...,337.2229517790636,337.222

In [37]:
def batch_PMTfication_DB_parquet(source_DB_dir: str, target_dir: str) -> None:
    db_files = [f for f in os.listdir(source_DB_dir) if f.endswith('.db')]
    
    pmtfied_tables = []
    truth_tables = []
    
    for db_file in tqdm(db_files, desc="Processing..."):
        source_DB = os.path.join(source_DB_dir, db_file)
        
        pa_pmtfied, pa_truth = runPMTfication_DB_Parquet(
            source_file=source_DB, 
            source_table="SRTInIcePulses",
            dom_ref_pos_file=ref_str_dom_pos,
        )
        
        pmtfied_tables.append(pa_pmtfied)
        truth_tables.append(pa_truth)
    
    # Combine all PMTfied tables and write to a single Parquet file
    combined_pmtfied_table = pa.concat_tables(pmtfied_tables)
    combined_pmtfied_file = os.path.join(target_dir, "combined_PMTfied.parquet")
    pq.write_table(combined_pmtfied_table, combined_pmtfied_file)

    combined_truth_table = pa.concat_tables(truth_tables)
    combined_truth_file = os.path.join(target_dir, "combined_truth.parquet")
    pq.write_table(combined_truth_table, combined_truth_file)

In [39]:
batch_PMTfication_DB_parquet(test_source_DB_dir, test_PMTfied_parquet_dir)
# individual: 1 min 58 sec
# combining: 5 min 40 sec

Processing...:   0%|          | 0/30 [00:00<?, ?it/s]

Processing...: 100%|██████████| 30/30 [05:38<00:00, 11.28s/it]


In [59]:
seeThis = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/testSource/PMTfiedParquet/event_no=0/276c6d50570445a984e54548bb4dcbe5-0.parquet"
seeThat = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/testSource/PMTfiedParquet/event_no=0/c0e0546be44940d9bd6dc39391a627b8-0.parquet"
compareThis = "/groups/icecube/cyan/factory/DOMification/PMTfied/test/PMTfiedDB/Level2_NuE_NuGenCCNC.022015.000110.db"
# df_seeThis = pq.read_table(seeThis).to_pandas()
# df_seeThat = pq.read_table(seeThat).to_pandas()
df_compareThis = convertDBtoDF(compareThis, "PMTsummarised")

In [49]:
df_seeThis.columns

Index(['energy', 'position_x', 'position_y', 'position_z', 'azimuth', 'zenith',
       'pid', 'event_time', 'interaction_type', 'elasticity', 'RunID',
       'SubrunID', 'EventID', 'SubEventID', 'dbang_decay_length',
       'track_length', 'stopped_muon', 'energy_track', 'energy_cascade',
       'inelasticity', 'DeepCoreFilter_13', 'CascadeFilter_13',
       'MuonFilter_13', 'OnlineL2Filter_17', 'L3_oscNext_bool',
       'L4_oscNext_bool', 'L5_oscNext_bool', 'L6_oscNext_bool',
       'L7_oscNext_bool', 'Homogenized_QTot', 'MCLabelClassification',
       'MCLabelCoincidentMuons', 'MCLabelBgMuonMCPE',
       'MCLabelBgMuonMCPECharge', 'GNLabelTrackEnergyDeposited',
       'GNLabelTrackEnergyOnEntrance', 'GNLabelTrackEnergyOnEntrancePrimary',
       'GNLabelTrackEnergyDepositedPrimary', 'GNLabelEnergyPrimary',
       'GNLabelCascadeEnergyDepositedPrimary', 'GNLabelCascadeEnergyDeposited',
       'GNLabelEnergyDepositedTotal', 'GNLabelEnergyDepositedPrimary',
       'GNHighestEInIceParticl

In [52]:
df_seeThat.columns

Index(['dom_string', 'dom_number', 'dom_x', 'dom_y', 'dom_z', 'dom_x_rel',
       'dom_y_rel', 'dom_z_rel', 'pmt_area', 'rde', 'saturation_status', 'q1',
       'q2', 'q3', 'Q25', 'Q75', 'Qtotal', 'hlc1', 'hlc2', 'hlc3', 't1', 't2',
       't3', 'T10', 'T50', 'sigmaT'],
      dtype='object')

In [60]:
df_compareThis.columns

Index(['event_no', 'dom_string', 'dom_number', 'dom_x', 'dom_y', 'dom_z',
       'dom_x_rel', 'dom_y_rel', 'dom_z_rel', 'pmt_area', 'rde',
       'saturation_status', 'q1', 'q2', 'q3', 'Q25', 'Q75', 'Qtotal', 'hlc1',
       'hlc2', 'hlc3', 't1', 't2', 't3', 'T10', 'T50', 'sigmaT'],
      dtype='object')

In [23]:
seeAndreas = '/groups/icecube/petersen/GraphNetDatabaseRepository/high_energy_data_converted_by_Andreas_M/parquet_db/NuE/Level2_IC86.2020_NuE.021870.000000.parquet'

# read parquet file to df
df = pd.read_parquet(seeAndreas)

In [24]:
df['SplitInIcePulses'].iloc[0]


{'charge': array([0.77499998, 0.57499999, 1.17499995, 1.32500005, 1.125     ,
        1.125     , 1.07500005, 1.42499995, 1.22500002, 0.82499999,
        2.875     , 0.32499999, 1.125     , 0.92500001, 0.97500002,
        3.32500005, 0.72500002, 0.42500001, 0.72500002, 1.17499995,
        0.625     , 0.92500001, 0.32499999, 1.17499995, 1.42499995,
        1.02499998, 0.875     , 0.82499999, 0.82499999, 1.02499998,
        0.875     , 0.67500001, 0.67500001, 0.67500001, 0.875     ,
        1.07500005, 1.02499998, 0.52499998, 1.375     , 0.47499999,
        1.02499998, 0.32499999, 1.17499995, 1.22500002, 1.17499995,
        1.32500005, 0.22499999, 0.42500001, 0.57499999, 0.22499999,
        0.92500001, 0.625     , 1.27499998, 0.97500002, 1.17499995,
        1.02499998, 0.625     , 1.02499998, 1.07500005, 0.77499998,
        1.27499998, 0.82499999, 0.72500002, 0.875     , 1.07500005,
        0.92500001, 0.625     , 2.2249999 , 0.77499998, 0.875     ,
        0.875     , 0.57499999, 0.375 

In [25]:
df['SplitInIcePulses'].iloc[0].keys()

dict_keys(['charge', 'dom_time', 'width', 'dom_x', 'dom_y', 'dom_z', 'pmt_area', 'rde', 'is_bright_dom', 'is_bad_dom', 'is_saturated_dom', 'is_errata_dom', 'event_time', 'hlc', 'awtd', 'fadc', 'string', 'pmt_number', 'dom_number', 'dom_type', 'event_no'])

In [26]:
df['SplitInIcePulses'].iloc[0].items()

dict_items([('charge', array([0.77499998, 0.57499999, 1.17499995, 1.32500005, 1.125     ,
       1.125     , 1.07500005, 1.42499995, 1.22500002, 0.82499999,
       2.875     , 0.32499999, 1.125     , 0.92500001, 0.97500002,
       3.32500005, 0.72500002, 0.42500001, 0.72500002, 1.17499995,
       0.625     , 0.92500001, 0.32499999, 1.17499995, 1.42499995,
       1.02499998, 0.875     , 0.82499999, 0.82499999, 1.02499998,
       0.875     , 0.67500001, 0.67500001, 0.67500001, 0.875     ,
       1.07500005, 1.02499998, 0.52499998, 1.375     , 0.47499999,
       1.02499998, 0.32499999, 1.17499995, 1.22500002, 1.17499995,
       1.32500005, 0.22499999, 0.42500001, 0.57499999, 0.22499999,
       0.92500001, 0.625     , 1.27499998, 0.97500002, 1.17499995,
       1.02499998, 0.625     , 1.02499998, 1.07500005, 0.77499998,
       1.27499998, 0.82499999, 0.72500002, 0.875     , 1.07500005,
       0.92500001, 0.625     , 2.2249999 , 0.77499998, 0.875     ,
       0.875     , 0.57499999, 0.375   

In [40]:
print(f"event number of the first: {df['SplitInIcePulses'].iloc[0]['event_no']}")
print(f"event number of the second: {df['SplitInIcePulses'].iloc[1]['event_no']}")

event number of the first: 498120
event number of the second: 498122


In [41]:
print(f"the number of pulses of the first event: {len(df['SplitInIcePulses'].iloc[0]['dom_x'])}")
print(f"the number of pulses of the second event: {len(df['SplitInIcePulses'].iloc[1]['dom_x'])}")
print(f"the number of pulses of the third event: {len(df['SplitInIcePulses'].iloc[2]['dom_x'])}")


the number of pulses of the first event: 94
the number of pulses of the second event: 46
the number of pulses of the third event: 83
