In [9]:
import sqlite3
import pandas as pd
import numpy as np
import torch
from torch import Tensor
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from torch.utils.data import DataLoader, Dataset, random_split
import pytorch_lightning as pl
from typing import List

In [3]:
mini_db_path = "/groups/icecube/moust/storage/140021_db/db_out_mini/merged_140021_mini.db"
query_all ="SELECT * FROM truth"
with sqlite3.connect(mini_db_path) as conn:
    db_tables = pd.read_sql_query("SELECT name FROM sqlite_master WHERE type = 'table'", conn)
    print('Tables within the SQL database is:')
    print(db_tables)
    print()
    mini_db = {name:  pd.read_sql_query("SELECT * FROM "+name, conn) for name in db_tables.name}

Tables within the SQL database is:
                  name
0                retro
1     SplitInIcePulses
2  I3MCTree__primaries
3  I3MCTree__particles
4   I3TriggerHierarchy
5                truth



In [13]:
def pad_collate(batch):
  (xx, y) = zip(*batch)
  x_lens = [len(x) for x in xx]
  xx_pad = pad_sequence(xx, batch_first=True, padding_value=0)

  pad_mask = torch.zeros_like(xx_pad[:, :, 0]).type(torch.bool)
  for i, length in enumerate(x_lens):
    pad_mask[i, length:] = True

  return xx_pad, torch.tensor(y), pad_mask

class SimpleDataset(Dataset):
  def __init__(self, 
               db_path: str, 
               event_no_list: List[int], #event_no_list_path: str,
               pulsemap: str,
               input_cols: List[str],
               target_cols: List[str],
               truth_table: str = "truth"
               ):
    self.db_path = db_path
    self.event_no_list = event_no_list #self.event_no_list_path = event_no_list_path
    self.pulsemap = pulsemap
    self.input_cols = input_cols
    self.target_cols = target_cols
    self.truth_table = truth_table


    if isinstance(list(input_cols), list):
      self.input_cols_str = ", ".join(input_cols)
    else:

      self.input_cols_str = input_cols

    if isinstance(target_cols, list):
      self.target_cols_str = ", ".join(target_cols)
    else:
      self.target_cols_str = target_cols
    
    # self.event_no_list = np.genfromtxt(self.event_no_list_path,dtype=int)

    self.data_len = len(self.event_no_list)
    

  def __getitem__(self, index):
    event_no = self.event_no_list[index]
    with sqlite3.connect(self.db_path) as conn:
      features = Tensor(conn.execute(f"SELECT {self.input_cols_str} FROM {self.pulsemap} WHERE event_no == {event_no}").fetchall())
      truth = Tensor(conn.execute(f"SELECT {self.target_cols_str} FROM {self.truth_table} WHERE event_no == {event_no}").fetchall())
    return features, truth
  
  def __len__(self):
    return self.data_len

In [14]:
simpledataset = SimpleDataset( 
               db_path = mini_db_path, 
               event_no_list = np.arange(10),
               pulsemap = "SplitInIcePulses",
               input_cols = ["charge","dom_time","dom_x","dom_y","dom_z"],
               target_cols = ["energy","inelasticity"],
               )
dataloader = DataLoader(dataset=simpledataset, batch_size = 4, collate_fn = pad_collate)

In [12]:
dataiter = iter(dataloader)
data = next(dataiter)
x1, truth , lengths= data
print(x1.shape, truth, lengths)

ValueError: only one element tensors can be converted to Python scalars

In [15]:
for i, (features, truth,pad) in enumerate(dataloader):
    print(truth.shape)
    print(features.shape)
    print(pad.shape)
    # pred = model(features)
    # print(pred.shape)
    if i == 1:
        break

ValueError: only one element tensors can be converted to Python scalars