# This is the final endition of a SQLITE dataset with 'efficient' DataLoader

In [2]:
filepath = r"C:\Users\jv97\Desktop\github\Neutrino-Machine-Learning\raw_data"
filename = "rasmus_classification_muon_3neutrino_3mio.db"

In [None]:
import sqlite3
import os
import torch
import numpy as np
from pandas import read_sql
from torch_geometric.data import Data, Batch

class custom_db_dataset(torch.utils.data.Dataset):
    
    def __init__(self, filepath, filename, features, targets, TrTV, event_nos = None, x_transform = None, y_transform = None):
        self.filepath = filepath
        self.filename = filename
        self.features = features #Should be string of features, eg: "charge_log10, time, pulse_width, SRTInIcePulses, dom_x, dom_y, dom_z"
        self.targets = targets #Should be string of targets, eg: "azimuth, zenith, energy_log10"
        self.TrTV = TrTV #Should be cumulative sum of percentages for "Tr(ain)T(est)V(alidation)"" sets.
        
        self.con = sqlite3.connect('file:'+os.path.join(self.filepath,self.filename+'?mode=ro'),uri=True)
        
        if isinstance(event_nos,type(None)):
            self.event_nos = np.asarray(read_sql("SELECT event_no FROM truth",self.con)).reshape(-1)
        else:
            self.event_nos = event_nos
        
    def __len__(self):
        """length method, number of events"""
        return len(self.event_nos)
    
    def __getitem__(self, index):
        if isinstance(index, int):
            return self.get_single(index)
        if isinstance(index, list):
            return self.get_list(index)
    
    def get_single(self,index):
        query = f"SELECT {self.features} FROM features WHERE event_no = {self.event_nos[index]}"
        x = torch.tensor(read_sql(query,self.con).values)

        query = f"SELECT {self.targets} FROM truth WHERE event_no = {self.event_nos[index]}"
        y = torch.tensor(read_sql(query,self.con).values)
        return Data(x=x, y=y)
    
    def get_list(self,index):
        query = f"SELECT {self.features} FROM features WHERE event_no IN {tuple(self.event_nos[index])} {self.where_extra_x}"
        x = torch.tensor(read_sql(query,self.con).values)

        query = f"SELECT {self.targets} FROM truth WHERE event_no IN {tuple(self.event_nos[index])} {self.where_extra_y}"
        y = torch.tensor(read_sql(query,self.con).values)
        return Data(x=x, y=y)
    
    def return_self(self,event_nos):
        return custom_db_dataset(self.filepath,
                                 self.filename,
                                 self.features,
                                 self.targets,
                                 self.TrTV,
                                 event_nos,
                                 self.x_transform,
                                 self.y_transform)
    
    def train(self):
        return self.return_self(self.event_nos[:int(TrTV[0]*self.__len__())])

    def test(self):
        return self.return_self(self.event_nos[int(TrTV[0]*self.__len__()):int(TrTV[1]*self.__len__())])

    def val(self):
        return self.return_self(self.event_nos[int(TrTV[1]*self.__len__()):int(TrTV[2]*self.__len__())])
    
    def return_dataloaders(self, batch_size):
        from torch.utils.data import BatchSampler, DataLoader, SequentialSampler, RandomSampler
        def collate(batch):
            return Batch.from_data_list(batch)

        train_loader = DataLoader(dataset = self.train(),
                                  collate_fn = collate,
                                  sampler = BatchSampler(RandomSampler(self.train()),
                                                         batch_size=batch_size,
                                                         drop_last=False))
        
        test_loader = DataLoader(dataset = a.test(),
                                 collate_fn = collate,
                                 sampler = SequentialSampler(RandomSampler(self.test()),
                                                             batch_size=batch_size,
                                                             drop_last=False))
        
        val_loader = DataLoader(dataset = self.val(),
                                collate_fn = collate,
                                sampler = BatchSampler(RandomSampler(self.val()),
                                                       batch_size=batch_size,
                                                       drop_last=False))


In [145]:
df = read_sql(f"SELECT * FROM features WHERE event_no = {a.event_nos[5]}",a.con)
df.values

array([[1.97000000e+02, 1.20102893e+08, 6.00000000e+00, ...,
        8.00000000e+00, 1.00000000e+00, 0.00000000e+00],
       [1.98000000e+02, 1.20102893e+08, 5.00000000e+01, ...,
        8.00000000e+00, 1.00000000e+00, 0.00000000e+00],
       [1.99000000e+02, 1.20102893e+08, 5.80000000e+01, ...,
        8.00000000e+00, 1.00000000e+00, 0.00000000e+00],
       ...,
       [2.47000000e+02, 1.20102893e+08, 4.50000000e+01, ...,
        8.00000000e+00, 1.00000000e+00, 1.00000000e+00],
       [2.48000000e+02, 1.20102893e+08, 4.70000000e+01, ...,
        8.00000000e+00, 1.00000000e+00, 0.00000000e+00],
       [2.49000000e+02, 1.20102893e+08, 8.60000000e+01, ...,
        8.00000000e+00, 1.00000000e+00, 1.00000000e+00]])

# Below here is experimentation:

In [133]:
class custom_db_dataset(torch.utils.data.Dataset):
    
    def __init__(self, filepath, filename):
        self.con = sqlite3.connect('file:'+os.path.join(filepath,filename+'?mode=ro'),uri=True)
#         self.event_nos, self.event_lengths = np.unique(read_sql("SELECT event_no FROM features",self.con).event_no.values,return_counts=True)
#         self.event_nos = np.asarray(read_sql("SELECT event_no FROM truth",self.con)).reshape(-1)
        df = read_sql(f"SELECT COUNT(*), event_no FROM features GROUP BY event_no",self.con)
        print("Memory usage: ",df.memory_usage())
        self.event_lengths = np.asarray(df.iloc[:,0])
        self.cumsum = np.append(0,self.event_lengths.cumsum())
        self.event_nos = np.asarray(df.iloc[:,1])
        del df
        self.features = "charge_log10, time, pulse_width, SRTInIcePulses, dom_x, dom_y, dom_z"
        self.targets = "energy_log10"
#         self.where_extra_x = " and SRTInIcePulses = 1"
#         self.where_extra_y = ""
        
    def __len__(self):
        """length method, number of events"""
        return len(self.event_nos)
    
    def __getitem__(self, index):
        if isinstance(index, int):
            return self.get_single(index)
#         if isinstance(index, list):
#             return self.get_list(index)
    
#     def get_single(self,index):
#         query = f"SELECT {self.features} FROM features  {self.where_extra_x}"
#         x = torch.tensor(read_sql(query,self.con).to_numpy())

#         query = f"SELECT {self.targets} FROM truth WHERE event_no = {self.event_nos[index]} {self.where_extra_y}"
#         y = torch.tensor(read_sql(query,self.con).to_numpy())
#         return Data(x=x, y=y)
    
    def get_single(self,index):
        query = f"SELECT {self.features} FROM features LIMIT {self.cumsum[index]},{self.event_lengths[index]}"
        x = torch.tensor(read_sql(query, self.con).to_numpy())
        
        query = f"SELECT {self.targets} FROM truth WHERE event_no = {self.event_nos[index]}"
        y = torch.tensor(read_sql(query,self.con).to_numpy())
        return Data(x=x,y=y)
        
    
#     def get_list(self,index):
#         query = f"SELECT event_no, {self.features} FROM features WHERE event_no in {tuple(self.event_nos[index])} {self.where_extra_x}"
#         events = read_sql(query,self.con)
#         x = torch.tensor(events.iloc[:,1:].to_numpy())

#         query = f"SELECT {self.targets} FROM truth WHERE event_no in {tuple(self.event_nos[index])} {self.where_extra_y}"
#         y = torch.tensor(read_sql(query,self.con).to_numpy())
        
#         data_list = []
#         _, events = np.unique(events.event_no.values.flatten(), return_counts = True)
#         for tmp_x, tmp_y in zip(torch.split(x, events.tolist()), y):
#             data_list.append(Data(x=tmp_x,y=tmp_y))
#         return data_list
    
#     def query(self, query_string):
#         """run a query and return the result"""
#         self.cursor.execute(query_string)
#         return self.cursor.fetchall()
    
#     def process_query(self, items):
#         return read_sql(items)

class custom_db_dataset1(torch.utils.data.Dataset):
    
    def __init__(self, filepath, filename, event_nos = None):
        self.filepath = filepath
        self.filename = filename
        self.con = sqlite3.connect('file:'+os.path.join(self.filepath,self.filename+'?mode=ro'),uri=True)
        if isinstance(event_nos,type(None)):
            self.event_nos = np.asarray(read_sql("SELECT event_no FROM truth LIMIT 50",self.con)).reshape(-1)
        else:
            self.event_nos = event_nos
        self.features = "charge_log10, time, pulse_width, SRTInIcePulses, dom_x, dom_y, dom_z"
        self.targets = "event_no, energy_log10"
        self.where_extra_x = ""#" and SRTInIcePulses = 1"
        self.where_extra_y = ""
        
    def __len__(self):
        """length method, number of events"""
        return len(self.event_nos)
    
    def __getitem__(self, index):
        if isinstance(index, int):
            return self.get_single(index)
        if isinstance(index, list):
            return self.get_list(index)
    
    def get_single(self,index):
        query = f"SELECT {self.features} FROM features WHERE event_no = {self.event_nos[index]} {self.where_extra_x}"
        x = torch.tensor(read_sql(query,self.con).to_numpy())

        query = f"SELECT {self.targets} FROM truth WHERE event_no = {self.event_nos[index]} {self.where_extra_y}"
        y = torch.tensor(read_sql(query,self.con).to_numpy())
        return Data(x=x, y=y)
    
    def get_list(self,index):
        query = f"SELECT {self.features} FROM features WHERE event_no IN {tuple(self.event_nos[index])} {self.where_extra_x}"
        x = torch.tensor(read_sql(query,self.con).to_numpy())

        query = f"SELECT {self.targets} FROM truth WHERE event_no IN {tuple(self.event_nos[index])} {self.where_extra_y}"
        y = torch.tensor(read_sql(query,self.con).to_numpy())
        return Data(x=x, y=y)
    
    def train(self):
        return custom_db_dataset1(self.filepath,self.filename,event_nos = self.event_nos[:25])
    
    def val(self):
        return custom_db_dataset1(self.filepath,self.filename,event_nos = self.event_nos[25:])

class custom_db_dataset2(torch.utils.data.Dataset):
    
    def __init__(self, filepath, filename):
        self.con = sqlite3.connect('file:'+os.path.join(filepath,filename+'?mode=ro'),uri=True)
        self.event_nos = np.asarray(read_sql("SELECT event_no FROM truth",self.con)).reshape(-1)
        
        if not all(self.event_nos == np.sort(self.event_nos)):
            print('ERROR: indexing for this database is not yet supported!!')
        
        self.features = "charge_log10, time, pulse_width, SRTInIcePulses, dom_x, dom_y, dom_z"
        self.targets = "energy_log10"
        self.where_extra_x = " and SRTInIcePulses = 1"
        self.where_extra_y = ""
        
    def __len__(self):
        """length method, number of events"""
        return len(self.event_nos)
    
    def __getitem__(self, index):
        if isinstance(index, int):
            return self.get_single(index)
#         if isinstance(index, list):
#             return self.get_list(index)
    
    def get_single(self,index):
        query = f"SELECT {self.features} FROM features WHERE event_no BETWEEN {self.event_nos[index-1]} AND {self.event_nos[index+1]}"
        x = torch.tensor(read_sql(query,self.con).to_numpy())

        query = f"SELECT {self.targets} FROM truth WHERE event_no = {self.event_nos[index]}"
        y = torch.tensor(read_sql(query,self.con).to_numpy())
        return Data(x=x, y=y)

In [134]:
# a = custom_db_dataset(filepath,filename)
a = custom_db_dataset1(filepath,filename)

In [140]:
def collate(batch):
    return Batch.from_data_list(batch)

from torch.utils.data import BatchSampler, DataLoader, SequentialSampler, RandomSampler

b = DataLoader(dataset = a.train(),
               sampler = BatchSampler(RandomSampler(a.train()),
                                      batch_size=5,
                                      drop_last=False),
               collate_fn = collate)
b2 = DataLoader(dataset = a.val(),
               sampler = BatchSampler(RandomSampler(a.val()),
                                      batch_size=5,
                                      drop_last=False),
               collate_fn = collate)

In [141]:
import time
start_time = time.time()
i = 0
for batch in b:
    print(time.time() - start_time)
    print(batch.y[:,0]- a.event_nos.min())
    start_time = time.time()
    i += 1
    if i > 100:
        break

import time
start_time = time.time()
i = 0
for batch in b2:
    print(time.time() - start_time)
    print(batch.y[:,0]- a.event_nos.min())
    start_time = time.time()
    i += 1
    if i > 100:
        break

0.008000612258911133
tensor([15., 22., 24., 25., 29.], dtype=torch.float64)
0.0049991607666015625
tensor([ 4., 17., 19., 21., 23.], dtype=torch.float64)
0.003998756408691406
tensor([ 1.,  5.,  9., 13., 16.], dtype=torch.float64)
0.0050013065338134766
tensor([ 3.,  6.,  8., 18., 20.], dtype=torch.float64)
0.003998517990112305
tensor([ 0.,  2., 10., 14., 27.], dtype=torch.float64)
0.005000114440917969
tensor([32., 37., 44., 47., 55.], dtype=torch.float64)
0.003002166748046875
tensor([30., 31., 49., 50., 51.], dtype=torch.float64)
0.004001617431640625
tensor([35., 45., 46., 48., 57.], dtype=torch.float64)
0.0030014514923095703
tensor([38., 39., 40., 53., 58.], dtype=torch.float64)
0.003999948501586914
tensor([34., 42., 54., 56., 59.], dtype=torch.float64)


In [32]:
import time
start_time = time.time()
data_list = []
for i in range(512):
    data_list.append(a[i])
b = collate(data_list)
print(time.time() - start_time)

1.2828373908996582


In [8]:
b

Batch(batch=[21091], x=[21091, 7], y=[512, 1])

In [18]:
class MultiEpochsDataLoader(torch.utils.data.DataLoader):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._DataLoader__initialized = False
        self.batch_sampler = _RepeatSampler(self.batch_sampler)
        self._DataLoader__initialized = True
        self.iterator = super().__iter__()

    def __len__(self):
        return len(self.batch_sampler.sampler)

    def __iter__(self):
        for i in range(len(self)):
            yield next(self.iterator)

class _RepeatSampler(object):
    """ Sampler that repeats forever.
    Args:
        sampler (Sampler)
    """

    def __init__(self, sampler):
        self.sampler = sampler

    def __iter__(self):
        while True:
            yield from iter(self.sampler)

In [19]:
b = MultiEpochsDataLoader(a, batch_size=512,shuffle=False,collate_fn = collate)

In [13]:
i

0