In [1]:
#|default_exp dataset

In [13]:
#| export
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import torch
import pandas as pd
import numpy as np
from copy import deepcopy
from datasets import  load_from_disk

In [14]:
from pdb import set_trace
class CFG:
    CACHE_PATH = Path('../data/cache')

In [15]:
fns = list((CFG.CACHE_PATH/'batch_3').glob('*.pth'))

In [25]:
# | export
# function that loads the data from the pth file and return the data and the label as pd.DataFrame
def load_data(
    fn: Path,
    columns_event: str = ["time", "charge", "auxiliary", "x", "y", "z"],
    columns_label: str = ["azimuth", "zenith"],
    keep_auxiliary_event: bool = False,
):
    data = torch.load(fn)
    event = pd.DataFrame.from_records(data["event"])[columns_event]
    if keep_auxiliary_event:
        event = event.query("auxiliary == True")
    label = pd.DataFrame.from_records(data["target"])[columns_label]
    return event.astype(np.float32), label


class IceCubeCasheDatasetV0(Dataset):
    def __init__(self, fns, max_events=100):
        self.fns = fns
        self.max_events = max_events

    def __len__(self):
        return len(self.fns)

    def __getitem__(self, idx):
        fn = self.fns[idx]
        event, label = load_data(fn)

        if self.max_events:
            event = event[: self.max_events]
        event["time"] /= event["time"].max()
        event[["x", "y", "z"]] /= 500
        event["charge"] = np.log10(event["charge"])

        event = torch.tensor(event.values)
        mask = torch.ones(len(event), dtype=torch.bool)
        label = torch.tensor(label.values, dtype=torch.float32)

        return {"event": event, "mask": mask, "label": label}


class IceCubeCasheDatasetV1(Dataset):
    """_summary_

    Args:
        Dataset (_type_): Same as IceCubeCasheDatasetV0 but with the option to keep the auxiliary events
    """

    def __init__(self, fns, max_events=100, keep_auxiliary_event: bool = True):
        self.fns = fns
        self.max_events = max_events
        self.keep_auxiliary_event = keep_auxiliary_event

    def __len__(self):
        return len(self.fns)

    def __getitem__(self, idx):
        fn = self.fns[idx]
        event, label = load_data(fn, keep_auxiliary_event=self.keep_auxiliary_event)

        if self.max_events:
            event = event[: self.max_events]
        event["time"] /= event["time"].max()
        event[["x", "y", "z"]] /= 500
        event["charge"] = np.log10(event["charge"])

        event = torch.tensor(event.values)
        mask = torch.ones(len(event), dtype=torch.bool)
        label = torch.tensor(label.values, dtype=torch.float32)

        return {"event": event, "mask": mask, "label": label}


# collate_fn that pads the event and mask to the max length in the batch using pythorch pad_sequence
class HuggingFaceDatasetV0(Dataset):
    def __init__(self, ds, max_events=100):
        self.ds = ds
        self.max_events = max_events

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]

        event = pd.DataFrame(item)[
            [
                "time",
                "charge",
                "auxiliary",
                "x",
                "y",
                "z",
            ]
        ].astype(np.float32)
        if self.max_events:
            event = event[: self.max_events]
        event["time"] /= event["time"].max()
        event[["x", "y", "z"]] /= 500
        event["charge"] = np.log10(event["charge"])

        event = event.values
        mask = np.ones(len(event), dtype=bool)
        label = np.array([item["azimuth"], item["zenith"]], dtype=np.float32)

        batch = deepcopy(
            {
                "event": torch.tensor(event),
                "mask": torch.tensor(mask),
                "label": torch.tensor(label),
            }
        )
        return batch


class HuggingFaceDatasetV1(Dataset):
    """
    Same as HuggingFaceDatasetV0 but returns sensor_id as well
    in addition it adds + 1 to make the sensor_id start from 1 instead of 0,
    0 is ignore index


    """

    def __init__(self, ds, max_events=100):
        self.ds = ds
        self.max_events = max_events

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]

        event = pd.DataFrame(item)[
            [
                "sensor_id", 
                "time",
                "charge",
                "auxiliary",
                "x",
                "y",
                "z",
            ]
        ].astype(np.float32)
        if self.max_events:
            event = event[: self.max_events]
        event["time"] /= event["time"].max()
        event[["x", "y", "z"]] /= 500
        event["charge"] = np.log10(event["charge"])

        sensor_id = event["sensor_id"].values + 1

        event = event[
            [
                "time",
                "charge",
                "auxiliary",
                "x",
                "y",
                "z",
            ]
        ].values
        mask = np.ones(len(event), dtype=bool)
        label = np.array([item["azimuth"], item["zenith"]], dtype=np.float32)

        batch = deepcopy(
            {
                "sensor_id": torch.tensor(sensor_id, dtype=torch.long),
                "event": torch.tensor(event),
                "mask": torch.tensor(mask),
                "label": torch.tensor(label),
            }
        )
        return batch


In [26]:
ds = IceCubeCasheDatasetV1(fns)
#dl = DataLoader(ds, batch_size=64, shuffle=True, num_workers=4, collate_fn=collate_fn)
# for x in dl:
#     break

In [27]:
ds = HuggingFaceDatasetV1(load_from_disk('/opt/slh/icecube/data/hf_cashe/batch_1.parquet'))



In [28]:
ds[0]

{'sensor_id': tensor([3919, 4158, 3521, 5042, 2949,  861, 2441, 1744, 3610, 5058, 5058, 2978,
         5060, 3497, 3162, 2960, 1398, 1971, 3388, 1584, 1941, 1242,  559,  558,
         1406,  558,  559,  558,  558, 3051,  554,  973,  974, 2262,  976,  561,
          555, 3277, 4832, 4572, 3521, 3700,  301,  614, 3439, 2422, 3610, 3116,
         5058, 4529, 3497, 2449, 3290, 3051, 4905, 1971, 3453,   49, 3268, 3268,
          105]),
 'event': tensor([[ 0.3115,  0.1222,  1.0000,  0.6068,  0.6713,  0.4132],
         [ 0.3213,  0.0700,  1.0000, -0.2909,  0.7485,  0.4255],
         [ 0.3411, -0.0339,  1.0000,  1.0105,  0.5158, -0.3492],
         [ 0.3502, -0.6478,  1.0000, -0.0194, -0.1590,  0.3620],
         [ 0.4232,  0.1973,  1.0000,  1.1527,  0.3418,  0.7158],
         [ 0.4269, -0.1707,  1.0000, -0.5813, -0.6148,  0.3272],
         [ 0.4353,  0.2109,  1.0000, -1.0533, -0.0312, -0.3563],
         [ 0.4455, -0.1107,  1.0000,  1.0009, -0.1169,  0.9016],
         [ 0.4504,  0.0107,  1.0000,

In [12]:
pd.read_csv('../data/sensor_geometry.csv')

Unnamed: 0,sensor_id,x,y,z
0,0,-256.14,-521.08,496.03
1,1,-256.14,-521.08,479.01
2,2,-256.14,-521.08,461.99
3,3,-256.14,-521.08,444.97
4,4,-256.14,-521.08,427.95
...,...,...,...,...
5155,5155,-10.97,6.72,-472.39
5156,5156,-10.97,6.72,-479.39
5157,5157,-10.97,6.72,-486.40
5158,5158,-10.97,6.72,-493.41


In [7]:
#| export
def good_luck():
    return True

In [8]:
#|hide
#|eval: false
from nbdev.doclinks import nbdev_export
nbdev_export()

In [None]:
Dataset.S