In [31]:
from typing import Any, Dict, Optional, Tuple, List
import numpy as np

import pandas as pd
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import ConcatDataset, DataLoader, Dataset, SequentialSampler #random_split
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from torch import Tensor
import sqlite3
import math
from torch import default_generator, randperm
from torch._utils import _accumulate
from torch.utils.data.dataset import Subset


def pad_collate(batch):
  (xx, y) = zip(*batch)
  x_lens = [len(x) for x in xx]
  print(x_lens)
  xx_pad = pad_sequence(xx, batch_first=True, padding_value=0)

  pad_mask = torch.zeros_like(xx_pad[:, :, 0]).type(torch.bool)
  for i, length in enumerate(x_lens):
    pad_mask[i, length:] = True

  return xx_pad, torch.tensor(y), pad_mask

class SimpleDataset(Dataset):
  def __init__(self, 
               db_path: str, 
               event_no_list: List[int], #event_no_list_path: str,
               pulsemap: str,
               input_cols: List[str],
               target_cols: List[str],
               truth_table: str = "truth"
               ):
    self.db_path = db_path
    self.event_no_list = event_no_list #self.event_no_list_path = event_no_list_path
    self.pulsemap = pulsemap
    self.input_cols = input_cols
    self.target_cols = target_cols
    self.truth_table = truth_table


    if isinstance(list(input_cols), list):
      self.input_cols_str = ", ".join(input_cols)
    else:

      self.input_cols_str = input_cols

    if isinstance(target_cols, list):
      self.target_cols_str = ", ".join(target_cols)
    else:
      self.target_cols_str = target_cols
    
    # self.event_no_list = np.genfromtxt(self.event_no_list_path,dtype=int)

    self.data_len = len(self.event_no_list)
    

  def __getitem__(self, index):
    event_no = self.event_no_list[index]
    with sqlite3.connect(self.db_path) as conn:
      features = Tensor(conn.execute(f"SELECT {self.input_cols_str} FROM {self.pulsemap} WHERE event_no == {event_no}").fetchall())
      truth = Tensor(conn.execute(f"SELECT {self.target_cols_str} FROM {self.truth_table} WHERE event_no == {event_no}").fetchall())
    return features, truth
  
  def __len__(self):
    return self.data_len

class SimpleIceCubeSQLDatamodule(LightningDataModule):
    """Example of LightningDataModule for MNIST dataset.

    A DataModule implements 5 key methods:

        def prepare_data(self):
            # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
            # download data, pre-process, split, save to disk, etc...
        def setup(self, stage):
            # things to do on every process in DDP
            # load data, set variables, etc...
        def train_dataloader(self):
            # return train dataloader
        def val_dataloader(self):
            # return validation dataloader
        def test_dataloader(self):
            # return test dataloader
        def teardown(self):
            # called on every process in DDP
            # clean up after fit or test

    This allows you to share a full dataset without explaining how to download,
    split, transform and process the data.

    Read the docs:
        https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html
    """

    def __init__(
        self,
        db_path: str = "/groups/icecube/petersen/GraphNetDatabaseRepository/Upgrade_Data/sqlite3/dev_step4_upgrade_028_with_noise_dynedge_pulsemap_v3_merger_aftercrash.db",
        event_no_list_path: str = "/groups/icecube/moust/storage/event_selections/event_no_numu_track_energy_15_200_nhits_4_400_sorted.csv",
        pulsemap: str = "SplitInIcePulses_dynedge_v2_Pulses",
        input_cols: List[str] = ["charge", "dom_time", "dom_x", "dom_y", "dom_z", "pmt_dir_x", "pmt_dir_y", "pmt_dir_z" ],
        target_cols: List[str] = "energy",
        truth_table: str = "truth",
        data_dir: str = "data/",
        # train_val_test_split: Tuple[float, float, float] = (0.8, 0.1, 0.1),# train_val_test_split_rate: Tuple[float, float, float] = (0.8, 0.1, 0.1),
        batch_size: int = 2,
        num_workers: int = 0,
        pin_memory: bool = False,
    ):
        super().__init__()

        self.db_path = db_path
        self.event_no_list_path = event_no_list_path
        self.pulsemap = pulsemap
        self.input_cols = input_cols
        self.target_cols = target_cols
        self.truth_table = truth_table
        self.data_dir = data_dir
        # train_val_test_split: Tuple[float, float, float] = (0.8, 0.1, 0.1),# train_val_test_split_rate: Tuple[float, float, float] = (0.8, 0.1, 0.1),
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        # self.save_hyperparameters(logger=False)

        # data transformations here if any
        self.event_no_list = pd.read_csv(event_no_list_path,header=0,names=["event_no"],index_col=None)["event_no"].to_numpy()

        # self.event_no_list = np.genfromtxt(self.hparams.event_no_list_path,dtype=int)
    
        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None
        self.data_test: Optional[Dataset] = None

    # @property
    # def num_classes(self):
    #     return 10

    def prepare_data(self):
        """Download data if needed.

        Do not use it to assign state (self.x = y).
        """
        pass

    def setup(self, stage: Optional[str] = None):
        """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.

        This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
        careful not to execute things like random split twice!
        """
        # sampler = SequentialSampler()
        if not self.data_train and not self.data_val and not self.data_test:
            self.data_train= SimpleDataset(
                db_path = self.db_path, 
                event_no_list = self.event_no_list[self.event_no_list % 10 > 1], #event_no_list_path = self.hparams.event_no_list_path,
                pulsemap = self.pulsemap,
                input_cols = self.input_cols,
                target_cols = self.target_cols,
                truth_table = self.truth_table,
            )
            self.data_val= SimpleDataset(
                db_path = self.db_path,
                event_no_list = self.event_no_list[self.event_no_list % 10 == 1], #event_no_list_path = self.hparams.event_no_list_path,
                pulsemap = self.pulsemap,
                input_cols = self.input_cols,
                target_cols = self.target_cols,
                truth_table = self.truth_table,
            )
            self.data_test= SimpleDataset(
                db_path = self.db_path,
                event_no_list = self.event_no_list[self.event_no_list% 10 == 0], #event_no_list_path = self.hparams.event_no_list_path,
                pulsemap = self.pulsemap,
                input_cols = self.input_cols,
                target_cols = self.target_cols,
                truth_table = self.truth_table,
            )


    def train_dataloader(self):
        return DataLoader(
            dataset=self.data_train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn= pad_collate,
            sampler=SequentialSampler(self.data_train)
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.data_val,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn= pad_collate,
            sampler=SequentialSampler(self.data_val)
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.data_test,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn= pad_collate,
            sampler=SequentialSampler(self.data_test)
        )

    def teardown(self, stage: Optional[str] = None):
        """Clean up after fit or test."""
        pass

    def state_dict(self):
        """Extra things to save to checkpoint."""
        return {}

    def load_state_dict(self, state_dict: Dict[str, Any]):
        """Things to do when loading checkpoint."""
        pass


dm=SimpleIceCubeSQLDatamodule()
dm.setup()


In [20]:
len(dm.data_test)

53916

In [32]:
train_dl = dm.train_dataloader()
train_dl

<torch.utils.data.dataloader.DataLoader at 0x7fbfc15f2940>

In [33]:
features, truth, pad_mask = next(iter(train_dl))

[200, 200]


In [24]:
features.shape

torch.Size([2, 200, 8])

In [25]:
truth

tensor([ 66.5481, 112.6614])

In [29]:
pad_mask.sum()

tensor(0)

In [30]:
pad_mask

tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, F