In [1]:
import os, sys
from graphnet.data.sqlite.sqlite_utilities import create_table
import pandas as pd
import sqlite3
import pyarrow.parquet as pq
import sqlalchemy
from tqdm import tqdm
from typing import Any, Dict, List, Optional
import numpy as np
import pickle
import gc

[1;34mgraphnet[0m: [32mINFO    [0m 2023-02-25 06:59:22 - get_logger - Writing log to [1mlogs/graphnet_20230225-065922.log[0m


In [2]:
input_data_folder = './data/train'
meta_data_path = './data/train_meta.parquet'
geometry_table = pd.read_csv('./data/sensor_geometry.csv')

In [3]:
def load_input(meta_batch: pd.DataFrame, input_data_folder: str) -> pd.DataFrame:
        batch_id = meta_batch['batch_id'].unique()
        assert len(batch_id) == 1, "contains multiple batch_ids. Did you set the batch_size correctly?"
        
        detector_readings = pd.read_parquet(path = f'{input_data_folder}/batch_{batch_id[0]}.parquet')
        sensor_positions = geometry_table.loc[detector_readings['sensor_id'], ['x', 'y', 'z']]
        sensor_positions.index = detector_readings.index

        for column in sensor_positions.columns:
            if column not in detector_readings.columns:
                detector_readings[column] = sensor_positions[column]

        detector_readings['auxiliary'] = detector_readings['auxiliary'].replace({True: 1, False: 0})
        return detector_readings.reset_index()


In [4]:
def add_to_table(database_path: str,
                      df: pd.DataFrame,
                      table_name:  str,
                      is_primary_key: bool,
                      engine: sqlalchemy.engine.base.Engine) -> None:
                      
    try:
        create_table(   columns=  df.columns,
                        database_path = database_path, 
                        table_name = table_name,
                        integer_primary_key= is_primary_key,
                        index_column = 'event_id')
    except sqlite3.OperationalError as e:
        if 'already exists' in str(e):
            pass
        else:
            raise e
   
    df.to_sql(table_name, con=engine, index=False, if_exists="append", chunksize = 200000)
    engine.dispose()
    return

In [5]:
def convert_to_sqlite(meta_data_path: str,
                      database_path: str,
                      input_data_folder: str,
                      batch_size: int = 200000,
                      batch_ids: list = [],
                      engine: sqlalchemy.engine.base.Engine = None
                      ) -> None:
    
    meta_data_iter = pq.ParquetFile(meta_data_path).iter_batches(batch_size = batch_size)
    batch_id = 1
    converted_batches = []
    for meta_data_batch in meta_data_iter:
        if batch_id in batch_ids:
            meta_data_batch  = meta_data_batch.to_pandas()
            add_to_table(database_path = database_path,
                        df = meta_data_batch,
                        table_name='meta_table',
                        is_primary_key= True,
                        engine = engine)
            pulses = load_input(meta_batch=meta_data_batch, input_data_folder= input_data_folder)
            del meta_data_batch 
            add_to_table(database_path = database_path,
                        df = pulses,
                        table_name='pulse_table',
                        is_primary_key= False,
                        engine = engine)
            del pulses 
            converted_batches.append(batch_id)
        batch_id +=1
        if len(batch_ids) == len(converted_batches):
            break
        gc.collect()
    del meta_data_iter 
    print(f'Conversion Complete! Database available at\n {database_path}')

In [6]:
# list_dict = {}
# list_train_ids = range(1,661)
# for batch_number in range(0,10):
#     list_dict[batch_number] = np.random.choice(list_train_ids, 66, replace=False)
#     list_train_ids = [x for x in list_train_ids if x not in list_dict[batch_number]]
#     print(f'Batch {batch_number} contains {len(list_dict[batch_number])} events')


# with open('big_batch_indx.pkl', 'wb') as f:
#     pickle.dump(list_dict, f)


In [7]:
list_dict = pd.read_pickle('big_batch_indx.pkl')

In [8]:
list_dict

{0: array([496, 303, 133,  61,   6, 547,  47,  20, 245, 464,  15, 148, 414,
        230, 648,  84, 313, 258,  45, 165, 598, 277, 359,  33, 294,  60,
        579,  72, 495, 617,  76, 139, 224, 343, 432, 349, 296, 244, 621,
        260, 606,  74, 613, 342, 442, 158, 363, 307, 611, 480, 169, 346,
         58, 115, 316,  73, 512, 118, 459, 433,  68, 630, 529, 614, 178,
        491]),
 1: array([ 83, 634, 341, 107, 543, 386, 513, 304, 179, 559, 589, 632, 636,
        257, 151, 518,  90, 526, 227, 544, 379, 527, 242, 620,  17, 655,
          2, 571, 328, 300, 546, 114, 237, 213, 440, 149, 403, 504, 436,
        602,  77, 633, 574, 508,  29, 472, 154,  57, 159, 448, 385, 184,
        171,  54, 401, 576, 295, 475, 657, 649,  75,  70, 145, 284, 365,
        625]),
 2: array([ 59, 506, 201, 212, 319, 132, 484, 357, 468, 534, 573, 441, 569,
        396,  82, 545, 536, 537,  69, 298, 280, 530, 194, 120, 437, 270,
          3, 420, 350, 137,  53, 498, 105,  78, 645, 155, 556, 608, 364,
        218,

In [9]:
for batch_number in tqdm([5]):
    database_path = f'./data/big_batch_{batch_number}.db'
    engine = sqlalchemy.create_engine("sqlite:///" + database_path)
    convert_to_sqlite(meta_data_path,
                    database_path=database_path,
                    input_data_folder=input_data_folder,
                    batch_size=200000,
                    batch_ids=list_dict[batch_number],
                    engine=engine)

100%|██████████| 1/1 [2:14:32<00:00, 8072.95s/it]

Conversion Complete! Database available at
 ./data/big_batch_5.db



