In [69]:
import numpy as np
import os, sqlite3, pickle, sys, gzip, shutil
if hasattr(__builtins__,'__IPYTHON__'):
    print('Notebook')
    from tqdm.notebook import tqdm
else:
    print('Not notebook')
    from tqdm import tqdm
import os.path as osp

from pandas import read_sql, read_pickle, concat, read_csv, DataFrame
from sklearn.preprocessing import normalize, RobustScaler
from sklearn.neighbors import kneighbors_graph as knn
import matplotlib.pyplot as plt

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
from spektral.data import Dataset, Graph
from scipy.sparse import csr_matrix

n_steps=50
features=["dom_x", "dom_y", "dom_z", "dom_time", "charge_log10", "width", "rqe"]
targets= ["energy_log10", "zenith","azimuth", "event_no"]
transform_path='../db_files/dev_lvl7/transformers.pkl'
db_path= '../db_files/dev_lvl7/dev_lvl7_mu_nu_e_classification_v003.db'
set_path='../db_files/dev_lvl7/sets.pkl'
n_neighbors = 30
traintest='train'
i_train=0
i_test=0

def get_event_no():
    print('Reading sets')
    sets = read_pickle(set_path)
    train_events = sets['train']
    test_events = sets['test']
    return train_events['event_no'].to_numpy(), test_events['event_no'].to_numpy()

Notebook


In [70]:
db_file   = db_path

tls=[]
tals=[]
mix_list=[]
print("Connecting to db-file")
with sqlite3.connect(db_file) as conn:
    # Find indices to cut after

    # SQL queries format
    feature_call = ", ".join(features)
    target_call  = ", ".join(targets)

    # Load data from db-file
    print("Reading files")
    df_truth=read_sql(f"select event_no from truth", conn)
    splits=np.array_split(df_truth['event_no'].to_numpy(),n_steps)
    start_ids, stop_ids=[],[]
    for i in range(n_steps):
        start_ids.append(splits[i][0])
        stop_ids.append(splits[i][-1])

    train_events, test_events=get_event_no()
    df_test=df_truth[df_truth['event_no'].isin(test_events)]
    df_train=df_truth[df_truth['event_no'].isin(train_events)]
#     testid, trainid=df_test.index.to_numpy(), df_train.index.to_numpy()
    mask_test, mask_train= [], []
    for i in range(n_steps):
        mask_test.append(np.in1d(splits[i], test_events))
        mask_train.append(np.in1d(splits[i], train_events))

    print('Saving test/train IDs')
#     pickle.dump(df_test, open(osp.join(self.path, "testid.pkl"), 'wb'))
#     pickle.dump(df_train, open(osp.join(self.path, "trainid.pkl"), 'wb'))
    print('Starting loop')
    print(start_ids, stop_ids)
    for i, (start_id, stop_id) in enumerate(zip(start_ids, stop_ids)):
        if i<10:
            df_event = read_sql(f"select event_no from features where event_no >= {start_id} and event_no <= {stop_id}", conn)
            print('Events read')
            df_feat  = read_sql(f"select {feature_call} from features where event_no >= {start_id} and event_no <= {stop_id}", conn)
            print('Features read')
            df_targ  = read_sql(f"select {target_call} from truth    where event_no >= {start_id} and event_no <= {stop_id}", conn)
            print('Targets read, transforming')
            transformers = pickle.load(open(transform_path, 'rb'))
            trans_x      = transformers['features']
            trans_y      = transformers['truth']


            for col in ["dom_x", "dom_y", "dom_z"]:
                df_feat[col] = trans_x[col].inverse_transform(np.array(df_feat[col]).reshape(1, -1)).T/1000

            for col in ["energy_log10", "zenith","azimuth"]:
                # print(col)
                df_targ[col] = trans_y[col].inverse_transform(np.array(df_targ[col]).reshape(1, -1)).T



            # Cut indices
            print("Splitting data to events")
            idx_list    = np.array(df_event)
            x_not_split = np.array(df_feat)

            _, idx, counts = np.unique(idx_list.flatten(), return_index = True, return_counts = True) 
            xs          = np.split(x_not_split, np.cumsum(counts)[:-1])

            ys          = np.array(df_targ)
            print(df_feat.head())
            print(df_targ.head())

            graph_list=[]
            # Generate adjacency matrices
            for x, y in tqdm(zip(xs, ys), total = len(xs)):
                try:
                    a = knn(x[:, :3], self.n_neighbors)
                except:
                    a = csr_matrix(np.ones(shape = (x.shape[0], x.shape[0])) - np.eye(x.shape[0]))
                graph_list.append(Graph(x = x, a = a, y = y))
            print('List->array')
            graph_list = np.array(graph_list, dtype = object)
            test_list = graph_list[mask_test[i]]
            tls.append(test_list)
            train_list = graph_list[mask_train[i]]
            tals.append(train_list)
            mix_list.append(test_list[::10])

Connecting to db-file
Reading files
Reading sets
Saving test/train IDs
Starting loop
[0, 165837, 331674, 497511, 663348, 829184, 995020, 1160856, 1326692, 1492528, 3049864, 3215700, 3381536, 3547372, 3713208, 3879044, 4044880, 4210716, 4376552, 4542388, 6101874, 6267710, 6433546, 6599382, 6765218, 6931054, 7096890, 7262726, 7428562, 7594398, 9113884, 9279720, 9445556, 9611392, 9777228, 9943064, 10108900, 10274736, 10440572, 10606408, 10772244, 10938080, 12022539, 12188375, 12354211, 12520047, 12685883, 12851719, 13017555, 13183391] [165836, 331673, 497510, 663347, 829183, 995019, 1160855, 1326691, 1492527, 3049863, 3215699, 3381535, 3547371, 3713207, 3879043, 4044879, 4210715, 4376551, 4542387, 6101873, 6267709, 6433545, 6599381, 6765217, 6931053, 7096889, 7262725, 7428561, 7594397, 9113883, 9279719, 9445555, 9611391, 9777227, 9943063, 10108899, 10274735, 10440571, 10606407, 10772243, 10938079, 12022538, 12188374, 12354210, 12520046, 12685882, 12851718, 13017554, 13183390, 117104763]
E

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=165837.0), HTML(value='')))


List->array
Events read
Features read
Targets read, transforming
Splitting data to events
     dom_x    dom_y    dom_z  dom_time  charge_log10  width   rqe
0  0.00171 -0.15063 -0.43889 -0.529963      0.166667    0.0 -0.35
1  0.00171 -0.15063 -0.45591 -0.430712      0.583333    0.0 -0.35
2  0.03125 -0.07293 -0.34836  2.157303     -0.666667    1.0  0.00
3  0.03125 -0.07293 -0.39742  1.387640      0.500000    1.0 -0.35
4  0.03125 -0.07293 -0.49554  2.084270     -0.416667    1.0 -0.35
   energy_log10    zenith   azimuth  event_no
0      1.632209  1.097939  2.654670    165837
1      1.048916  1.080239  3.165756    165838
2      0.742656  2.225398  4.262146    165839
3      1.256873  0.799320  1.277917    165840
4      1.514665  0.836001  1.552987    165841


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=165837.0), HTML(value='')))


List->array
Events read
Features read
Targets read, transforming
Splitting data to events
     dom_x    dom_y    dom_z  dom_time  charge_log10  width   rqe
0  0.04629 -0.03488 -0.36708  1.044944     -0.416667    1.0  0.00
1  0.03125 -0.07293 -0.25725 -0.153558      0.583333    1.0 -0.35
2  0.04160  0.03549 -0.22799  0.046816      0.333333    1.0  0.00
3  0.11319 -0.06047 -0.26623  0.887640     -0.416667    1.0  0.00
4 -0.00968 -0.07950 -0.20547  1.780899      0.416667    1.0  0.00
   energy_log10    zenith   azimuth  event_no
0      1.683983  1.620525  5.277000    331674
1      1.755312  1.683496  0.531419    331675
2      0.821207  1.693390  4.711780    331676
3      1.568033  1.809342  0.906927    331677
4      1.210022  1.539435  3.951813    331678


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=165837.0), HTML(value='')))


List->array
Events read
Features read
Targets read, transforming
Splitting data to events
     dom_x    dom_y    dom_z  dom_time  charge_log10  width  rqe
0  0.07237 -0.06660 -0.30701  0.653558      0.750000    1.0  0.0
1  0.07237 -0.06660 -0.32804  0.189139     -0.916667    0.0  0.0
2  0.07237 -0.06660 -0.34205  0.814607     -0.416667    0.0  0.0
3  0.07237 -0.06660 -0.38410  0.391386      0.500000    1.0  0.0
4  0.04160  0.03549 -0.38217 -0.404494     -0.333333    1.0  0.0
   energy_log10    zenith   azimuth  event_no
0      0.858973  1.155964  2.131158    497511
1      1.902080  0.637631  1.083060    497512
2      1.127586  1.028787  0.357357    497513
3      1.030833  1.164146  4.825155    497514
4      0.947284  2.262021  0.639454    497515


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=165837.0), HTML(value='')))


List->array
Events read
Features read
Targets read, transforming
Splitting data to events
     dom_x    dom_y    dom_z  dom_time  charge_log10  width   rqe
0 -0.07780 -0.05433 -0.50152  1.915730      0.666667    1.0 -0.35
1  0.04629 -0.03488 -0.41815  0.531835     -0.333333    1.0  0.00
2  0.03125 -0.07293 -0.40443  0.059925      0.750000    1.0 -0.35
3  0.03125 -0.07293 -0.42545 -0.108614     -0.500000    1.0 -0.35
4  0.03125 -0.07293 -0.49554 -0.282772      0.000000    1.0 -0.35
   energy_log10    zenith   azimuth  event_no
0      1.331997  1.778370  4.517724    663348
1      1.248558  0.274977  6.280253    663349
2      1.125666  2.651380  2.772573    663350
3      1.093367  2.750240  5.057547    663351
4      1.778567  2.393655  2.645987    663352


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=165836.0), HTML(value='')))


List->array
Events read
Features read
Targets read, transforming
Splitting data to events
     dom_x    dom_y    dom_z  dom_time  charge_log10  width   rqe
0  0.07941 -0.24824 -0.36915 -0.106742     -0.333333    1.0 -0.35
1  0.00171 -0.15063 -0.28570  1.029963     -0.583333    1.0 -0.35
2  0.00171 -0.15063 -0.33676  0.653558     -0.916667    1.0 -0.35
3  0.12497 -0.13125 -0.27228  1.041199      0.583333    1.0 -0.35
4  0.12497 -0.13125 -0.40845  1.445693      0.000000    1.0 -0.35
   energy_log10    zenith   azimuth  event_no
0      1.812697  2.022997  2.989641    829184
1      1.886296  2.300012  2.758065    829185
2      1.846559  3.039526  0.582286    829186
3      1.686604  1.774123  0.493501    829187
4      0.942275  2.005075  2.488980    829188


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=165836.0), HTML(value='')))


List->array
Events read
Features read
Targets read, transforming
Splitting data to events
     dom_x    dom_y    dom_z  dom_time  charge_log10  width   rqe
0  0.04629 -0.03488 -0.33304  0.219101      0.166667    1.0  0.00
1  0.19434 -0.03092 -0.14269  3.153558     -0.500000    1.0 -0.35
2  0.03125 -0.07293 -0.29930  0.114232      0.750000    0.0  0.00
3  0.03125 -0.07293 -0.31332  0.001873      0.000000    0.0  0.00
4  0.03125 -0.07293 -0.32033 -0.091760     -1.083333    0.0  0.00
   energy_log10    zenith   azimuth  event_no
0      1.417009  2.443839  2.992347    995020
1      1.218718  1.868808  3.474688    995021
2      1.846489  2.620414  2.635571    995022
3      1.726707  2.221184  4.388748    995023
4      1.482853  1.756255  4.155227    995024


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=165836.0), HTML(value='')))


List->array
Events read
Features read
Targets read, transforming
Splitting data to events
     dom_x    dom_y    dom_z  dom_time  charge_log10  width   rqe
0  0.12497 -0.13125 -0.37440 -0.147940      0.000000    0.0 -0.35
1  0.12497 -0.13125 -0.39142 -0.477528     -0.833333    0.0 -0.35
2  0.12497 -0.13125 -0.39142 -0.455056      0.583333    0.0 -0.35
3  0.12497 -0.13125 -0.39142 -0.436330     -0.916667    0.0 -0.35
4  0.24815 -0.11187 -0.31896  1.505618      0.250000    1.0 -0.35
   energy_log10    zenith   azimuth  event_no
0      1.572222  1.327346  3.893821   1160856
1      1.297468  1.795847  3.094765   1160857
2      0.944986  2.796375  1.960719   1160858
3      0.837844  1.325572  0.401802   1160859
4      1.487119  1.427525  4.169186   1160860


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=165836.0), HTML(value='')))


List->array
Events read
Features read
Targets read, transforming
Splitting data to events
     dom_x    dom_y    dom_z  dom_time  charge_log10  width   rqe
0  0.04629 -0.03488 -0.36708  2.500000     -0.083333    1.0  0.00
1 -0.03296  0.06244 -0.45735  0.365169     -0.833333    0.0 -0.35
2 -0.03296  0.06244 -0.47437  0.838951      0.500000    0.0 -0.35
3 -0.03296  0.06244 -0.50841  1.483146      0.083333    0.0 -0.35
4  0.04160  0.03549 -0.45226  1.307116      0.416667    1.0  0.00
   energy_log10    zenith   azimuth  event_no
0      1.683881  0.285779  6.186256   1326692
1      1.412262  2.345044  1.782284   1326693
2      1.565177  2.706756  4.326798   1326694
3      0.822042  2.408181  2.249201   1326695
4      1.023354  2.239803  0.805475   1326696


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=165836.0), HTML(value='')))


List->array
Events read
Features read
Targets read, transforming
Splitting data to events
     dom_x    dom_y    dom_z  dom_time  charge_log10  width  rqe
0  0.04629 -0.03488 -0.38410  0.653558     -0.083333    1.0  0.0
1  0.04629 -0.03488 -0.45219 -0.234082      0.750000    0.0  0.0
2  0.04629 -0.03488 -0.45219 -0.211610     -1.250000    0.0  0.0
3  0.04629 -0.03488 -0.45219 -0.007491     -0.250000    0.0  0.0
4  0.04629 -0.03488 -0.46921 -0.387640     -1.166667    0.0  0.0
   energy_log10    zenith   azimuth  event_no
0      1.260988  2.658666  2.334869   1492528
1      1.533949  1.456936  1.889017   1492529
2      1.461374  1.697669  4.717784   1492530
3      1.229914  1.059398  1.160550   1492531
4      1.553517  2.169752  2.341257   1492532


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=165836.0), HTML(value='')))


List->array


In [63]:
tl=[]
for i in range(len(test_list)):
    tl.append(test_list[i].y[3])

In [64]:
df1=df_test[(df_test['event_no']<=stop_ids[2]) & (df_test['event_no']>=start_ids[2])]['event_no']

In [67]:
tal=[]
for i in range(len(train_list)):
    tal.append(train_list[i].y[3])

In [66]:
df2=df_train[(df_train['event_no']<=stop_ids[2]) & (df_train['event_no']>=start_ids[2])]['event_no']

In [71]:
tal=[]
for i in range(len(train_list)):
    tal.append(train_list[i].y[3])

In [81]:
ml=[]
for j in range(len(mix_list)):
    for i in range(len(mix_list[j])):
        ml.append(mix_list[j][i].y[3])

In [86]:
m_list = [graph for gl in mix_list for graph in gl]

In [None]:
flat_list = []
for sublist in t:
    for item in sublist:
        flat_list.append(item)

In [18]:
with sqlite3.connect(db_file) as conn:
    # Find indices to cut after

    # SQL queries format
    feature_call = ", ".join(features)
    target_call  = ", ".join(targets)

    # Load data from db-file
    print("Reading files")
    df_truth=read_sql(f"select event_no from truth", conn)

Reading files


In [23]:
df_train, df_test

(          event_no
 0                0
 1                1
 2                2
 3                3
 4                4
 ...            ...
 8291799  117104759
 8291800  117104760
 8291801  117104761
 8291802  117104762
 8291803  117104763
 
 [6633442 rows x 1 columns],
           event_no
 8                8
 10              10
 15              15
 29              29
 32              32
 ...            ...
 8291786  117104746
 8291787  117104747
 8291790  117104750
 8291792  117104752
 8291798  117104758
 
 [1658362 rows x 1 columns])

In [29]:
mask_test[9][-6]

True