In [1]:
import os, sys
from graphnet.data.sqlite.sqlite_utilities import create_table
import pandas as pd
import sqlite3
import pyarrow.parquet as pq
import sqlalchemy
from tqdm import tqdm
from typing import Any, Dict, List, Optional
import numpy as np
import matplotlib.pyplot as plt
import pickle
import gc

[1;34mgraphnet[0m: [32mINFO    [0m 2023-03-15 10:55:02 - get_logger - Writing log to [1mlogs/graphnet_20230315-105502.log[0m


In [2]:
input_data_folder = './data/train'
meta_data_path = './data/train_meta.parquet'
geometry_table = pd.read_csv('./data/sensor_geometry.csv')

In [3]:
def load_input(meta_batch: pd.DataFrame, input_data_folder: str) -> pd.DataFrame:
        batch_id = meta_batch['batch_id'].unique()
        assert len(batch_id) == 1, "contains multiple batch_ids. Did you set the batch_size correctly?"
        
        detector_readings = pd.read_parquet(path = f'{input_data_folder}/batch_{batch_id[0]}.parquet')
        sensor_positions = geometry_table.loc[detector_readings['sensor_id'], ['x', 'y', 'z']]
        sensor_positions.index = detector_readings.index

        for column in sensor_positions.columns:
            if column not in detector_readings.columns:
                detector_readings[column] = sensor_positions[column]

        detector_readings['auxiliary'] = detector_readings['auxiliary'].replace({True: 1, False: 0})
        return detector_readings.reset_index()


In [4]:
def add_to_table(database_path: str,
                      df: pd.DataFrame,
                      table_name:  str,
                      is_primary_key: bool,
                      engine: sqlalchemy.engine.base.Engine) -> None:
                      
    try:
        create_table(   columns=  df.columns,
                        database_path = database_path, 
                        table_name = table_name,
                        integer_primary_key= is_primary_key,
                        index_column = 'event_id')
    except sqlite3.OperationalError as e:
        if 'already exists' in str(e):
            pass
        else:
            raise e
   
    df.to_sql(table_name, con=engine, index=False, if_exists="append", chunksize = 200000)
    engine.dispose()
    return

In [5]:
def convert_to_sqlite(meta_data_path: str,
                      database_path: str,
                      input_data_folder: str,
                      batch_size: int = 200000,
                      batch_ids: list = [],
                      event_ids: list = [],
                      engine: sqlalchemy.engine.base.Engine = None
                      ) -> None:
    
    meta_data_iter = pq.ParquetFile(meta_data_path).iter_batches(batch_size = batch_size)
    batch_id = 1
    converted_batches = []
    for meta_data_batch in meta_data_iter:
        print(batch_id, end=',')
        if batch_ids % 110 == 0:
            print('\n')
        if batch_id in batch_ids:
            meta_data_batch  = meta_data_batch.to_pandas()
            meta_data_batch = meta_data_batch.loc[meta_data_batch['event_id'].isin(event_ids)]
            if meta_data_batch.shape[0] > 0:
                add_to_table(database_path = database_path,
                            df = meta_data_batch,
                            table_name='meta_table',
                            is_primary_key= True,
                            engine = engine)
                pulses = load_input(meta_batch=meta_data_batch, input_data_folder= input_data_folder)
                del meta_data_batch 
                pulses = pulses.loc[pulses['event_id'].isin(event_ids)]
                pulses = pulses.groupby('event_id', group_keys=True).apply(lambda x: x.iloc[:500])
                if pulses.shape[0] > 0:
                    add_to_table(database_path = database_path,
                                df = pulses,
                                table_name='pulse_table',
                                is_primary_key= False,
                                engine = engine)
                del pulses
            converted_batches.append(batch_id)
        batch_id +=1
        if len(batch_ids) == len(converted_batches):
            break
        gc.collect()
    del meta_data_iter 

In [6]:
with open('focus_dict.pkl', 'rb') as f:
    focus_dict = pickle.load(f)

In [7]:
event_id_list = focus_dict['f0']
database_path = f'./data/F0/focus_batch_0.db'
engine = sqlalchemy.create_engine("sqlite:///" + database_path)
convert_to_sqlite(meta_data_path,
                database_path=database_path,
                input_data_folder=input_data_folder,
                batch_size=200000,
                batch_ids=list(range(1,661,1)),
                event_ids=event_id_list,
                engine=engine)

1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,