In [1]:
import os, sys
from graphnet.data.sqlite.sqlite_utilities import create_table
import pandas as pd
from sklearn.model_selection import train_test_split
import sqlite3
import pyarrow.parquet as pq
import sqlalchemy
from tqdm import tqdm
from typing import Any, Dict, List, Optional
import numpy as np
import matplotlib.pyplot as plt
import pickle
import gc

[1;34mgraphnet[0m: [32mINFO    [0m 2023-03-18 09:41:33 - get_logger - Writing log to [1mlogs/graphnet_20230318-094133.log[0m


In [2]:
input_data_folder = './data/train'
meta_data_path = './data/train_meta.parquet'
geometry_table = pd.read_csv('./data/sensor_geometry.csv')

In [3]:
def load_input(batch_id: int, input_data_folder: str, event_ids: []) -> pd.DataFrame:
        
        detector_readings = pd.read_parquet(path = f'{input_data_folder}/batch_{batch_id}.parquet')
        detector_readings = detector_readings.loc[detector_readings.index.isin(event_ids)]
        sensor_positions = geometry_table.loc[detector_readings['sensor_id'], ['x', 'y', 'z']]
        sensor_positions.index = detector_readings.index
        detector_readings_copy = detector_readings.copy()
        detector_readings_copy.loc[:, 'x'] = sensor_positions['x']
        detector_readings_copy.loc[:, 'y'] = sensor_positions['y']
        detector_readings_copy.loc[:, 'z'] = sensor_positions['z']
        detector_readings = detector_readings_copy
        del detector_readings_copy
        detector_readings['auxiliary'].replace({True: 1, False: 0}, inplace=True)
        
        return detector_readings.reset_index()

In [4]:
def add_to_table(database_path: str,
                      df: pd.DataFrame,
                      table_name:  str,
                      is_primary_key: bool,
                      engine: sqlalchemy.engine.base.Engine) -> None:
                      
    try:
        create_table(   columns=  df.columns,
                        database_path = database_path, 
                        table_name = table_name,
                        integer_primary_key= is_primary_key,
                        index_column = 'event_id')
    except sqlite3.OperationalError as e:
        if 'already exists' in str(e):
            pass
        else:
            raise e
   
    df.to_sql(table_name, con=engine, index=False, if_exists="append", chunksize = 200000)
    engine.dispose()
    return

In [7]:
with open('focus_dict.pkl', 'rb') as f:
    focus_dict = pickle.load(f)

assert len(focus_dict['f0']) + len(focus_dict['f1']) + len(focus_dict['f2']) + len(focus_dict['f3']) == 131953924, "focus_dict does not contain all events"

In [10]:
idx = 0
event_ids = focus_dict[f'f{idx}']
database_path = f'./data/F{idx}/focus_batch_{idx}.db'
engine = sqlalchemy.create_engine("sqlite:///" + database_path)

In [11]:
meta_data_iter = pq.ParquetFile(meta_data_path).iter_batches(batch_size = 200000)
batch_id = 1
meta_data_batch = next(meta_data_iter)
        
print(batch_id, end=',')
if batch_id % 50 == 0:
    print(batch_id, end=',')
if True:
    meta_data_batch  = meta_data_batch.to_pandas()
    meta_data_batch.drop(columns=['first_pulse_index', 'last_pulse_index'], inplace=True)
    meta_data_batch = meta_data_batch.loc[meta_data_batch['event_id'].isin(event_ids)].reset_index(drop=True)
                    
if meta_data_batch.shape[0] > 0:

    pulses = load_input(batch_id = batch_id, input_data_folder= input_data_folder, event_ids = event_ids)       
    pulses = pulses.groupby('event_id').head(500).reset_index(drop=True)

    add_to_table(database_path = database_path,
                                    df = meta_data_batch,
                                    table_name='meta_table',
                                    is_primary_key= True,
                                    engine = engine)
                                    
    add_to_table(database_path = database_path,
                                    df = pulses,
                                    table_name='pulse_table',
                                    is_primary_key= False,
                                    engine = engine)


1,

In [15]:
meta_data_batch

Unnamed: 0,batch_id,event_id,azimuth,zenith
0,1,79,3.533397,2.479947
1,1,140,4.486290,1.655948
2,1,406,6.261226,0.910476
3,1,448,4.161056,1.427407
4,1,663,0.240543,2.739548
...,...,...,...,...
46474,1,3266035,5.798148,2.174658
46475,1,3266043,4.897595,1.746368
46476,1,3266078,5.004502,1.893823
46477,1,3266175,5.091808,2.732550


In [25]:
first_event_id_meta = meta_data_batch.iloc[0].event_id
last_event_id_meta = meta_data_batch.iloc[-1].event_id

In [26]:
with sqlite3.connect(database_path) as con:
        query = f'select event_id from meta_table where event_id in ({first_event_id_meta}, {last_event_id_meta})'
        events_df = pd.read_sql(query,con) 

In [27]:
events_df.shape[0]

2

In [29]:
first_event_id_pulses = pulses.iloc[0].event_id
last_event_id_pulses = pulses.iloc[-1].event_id

with sqlite3.connect(database_path) as con:
        query = f'select event_id from pulse_table where event_id in ({first_event_id_pulses}, {last_event_id_pulses})'
        events_df = pd.read_sql(query,con)

events_df.shape[0]

Unnamed: 0,event_id
0,79
1,79
2,79
3,79
4,79
...,...
447,3266196
448,3266196
449,3266196
450,3266196


In [12]:
def convert_to_sqlite(meta_data_path: str,
                      database_path: str,
                      input_data_folder: str,
                      batch_size: int = 200000,
                      batch_ids: list = [],
                      event_ids: list = [],
                      engine: sqlalchemy.engine.base.Engine = None
                      ) -> None:
    
    meta_data_iter = pq.ParquetFile(meta_data_path).iter_batches(batch_size = batch_size)
    batch_id = 1
    for meta_data_batch in meta_data_iter:
        while True:
            try:
                print(batch_id, end=',')
                if batch_id % 50 == 0:
                    print(batch_id, end=',')
                if batch_id in batch_ids:
                    meta_data_batch  = meta_data_batch.to_pandas()
                    meta_data_batch.drop(columns=['first_pulse_index', 'last_pulse_index'], inplace=True)
                    meta_data_batch = meta_data_batch.loc[meta_data_batch['event_id'].isin(event_ids)].reset_index(drop=True)
                    
                    if meta_data_batch.shape[0] > 0:

                        pulses = load_input(batch_id = batch_id, input_data_folder= input_data_folder, event_ids = event_ids)       
                        pulses = pulses.groupby('event_id').head(500).reset_index(drop=True)

                        add_to_table(database_path = database_path,
                                    df = meta_data_batch,
                                    table_name='meta_table',
                                    is_primary_key= True,
                                    engine = engine)
                                    
                        add_to_table(database_path = database_path,
                                    df = pulses,
                                    table_name='pulse_table',
                                    is_primary_key= False,
                                    engine = engine)

                        del meta_data_batch 
                        del pulses
                batch_id +=1
                gc.collect()
                break

            except Exception as e:
                print(e)
            
    del meta_data_iter 

In [10]:
with open('focus_dict.pkl', 'rb') as f:
    focus_dict = pickle.load(f)

assert len(focus_dict['f0']) + len(focus_dict['f1']) + len(focus_dict['f2']) + len(focus_dict['f3']) == 131953924, "focus_dict does not contain all events"

In [13]:
idx = 0
event_id_list = focus_dict[f'f{idx}']
database_path = f'./data/F{idx}/focus_batch_{idx}.db'
engine = sqlalchemy.create_engine("sqlite:///" + database_path)
convert_to_sqlite(meta_data_path,
                database_path=database_path,
                input_data_folder=input_data_folder,
                batch_size=200000,
                batch_ids=list(range(1,651,1)),
                event_ids=event_id_list,
                engine=engine)


1,2,3,4,5,

KeyboardInterrupt: 

In [None]:
with sqlite3.connect(database_path) as con:
        query = 'select event_id from meta_table'
        events_df = pd.read_sql(query,con) 

train_selection, validate_selection = train_test_split(np.arange(0, events_df.shape[0], 1), 
                                                        shuffle=True, 
                                                        random_state = 42, 
                                                        test_size=0.01)

train_selection_events = events_df[events_df.index.isin(train_selection)]['event_id'].to_list()
validate_selection_events = events_df[events_df.index.isin(validate_selection)]['event_id'].to_list()
event_dict = {'train': train_selection_events, 'validate': validate_selection_events}
with open(f'data/F{idx}/event_dict.pkl', 'wb') as f:
    pickle.dump(event_dict, f)