In [1]:
from graphnet.data.sqlite.sqlite_utilities import create_table
import pandas as pd
import sqlite3
from tqdm import tqdm
import numpy as np
from sklearn.model_selection import train_test_split

[1;34mgraphnet[0m: [32mINFO    [0m 2023-02-25 11:53:43 - get_logger - Writing log to [1mlogs/graphnet_20230225-115343.log[0m


In [2]:
def make_selection(save_path, df: pd.DataFrame, pulse_threshold: int = 200) -> None:
    n_events = np.arange(0, len(df),1)
    train_selection, validate_selection = train_test_split(n_events, 
                                                                    shuffle=True, 
                                                                    random_state = 42, 
                                                                    test_size=0.01) 
    df['train'] = 0
    df['validate'] = 0
    
    df['train'][train_selection] = 1
    df['validate'][validate_selection] = 1
    
    assert len(train_selection) == sum(df['train'])
    assert len(validate_selection) == sum(df['validate'])

    df['train'][df['n_pulses']> pulse_threshold] = 0
    df['validate'][df['n_pulses']> pulse_threshold] = 0
    
    for selection in ['train', 'validate']:
        df.loc[df[selection] == 1, :].to_pickle(f'{save_path}/{selection}_selection_max_{pulse_threshold}_pulses.pkl')
    return

def get_number_of_pulses(db: str, event_id: int, pulsemap: str) -> int:
    with sqlite3.connect(db) as con:
        query = f'select event_id from {pulsemap} where event_id = {event_id} limit 20000'
        data = con.execute(query).fetchall()
    return len(data)

def count_pulses(database: str, save_path, pulsemap: str) -> pd.DataFrame:
    """ Will count the number of pulses in each event and return a single dataframe that contains counts for each event_id."""
    with sqlite3.connect(database) as con:
        query = 'select event_id from meta_table'
        events = pd.read_sql(query,con)
    counts = {'event_id': [],
              'n_pulses': []}
    for event_id in tqdm(events['event_id']):
        a = get_number_of_pulses(database, event_id, pulsemap)
        counts['event_id'].append(event_id)
        counts['n_pulses'].append(a)
    df = pd.DataFrame(counts)
    df.to_pickle(f'{save_path}/counts.pkl')
    return df


In [3]:
pulsemap = 'pulse_table'
database = './data/big_batch_5.db'
save_path = './data/'

df = count_pulses(database, save_path, pulsemap)
make_selection(save_path, df = df, pulse_threshold =  200)

100%|██████████| 13200000/13200000 [21:54<00:00, 10041.66it/s]
