In [1]:
import sys

sys.path.append("..")
from icecube.dataset import IceCubeCasheDatasetV0
from icecube.dataset import collate_fn
from icecube.utils import fit
from pathlib import Path
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import torch.nn.functional as F
from transformers.optimization import (
    get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)
from torch import nn
from x_transformers import ContinuousTransformerWrapper, Encoder, Decoder
from datasets import load_dataset, load_from_disk, concatenate_datasets
import pandas as pd
import numpy as np
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def angular_dist_score(
    az_true: torch.Tensor,
    zen_true: torch.Tensor,
    az_pred: torch.Tensor,
    zen_pred: torch.Tensor,
) -> torch.Tensor:
    sa1 = torch.sin(az_true)
    ca1 = torch.cos(az_true)
    sz1 = torch.sin(zen_true)
    cz1 = torch.cos(zen_true)

    sa2 = torch.sin(az_pred)
    ca2 = torch.cos(az_pred)
    sz2 = torch.sin(zen_pred)
    cz2 = torch.cos(zen_pred)

    scalar_prod = sz1 * sz2 * (ca1 * ca2 + sa1 * sa2) + cz1 * cz2
    scalar_prod = torch.clamp(scalar_prod, -1, 1)
    return torch.mean(torch.abs(torch.acos(scalar_prod)))





In [3]:
class CFG:
    DATA_CACHE_DIR = Path("../data/cache")
    BATCH_SIZE = 1024 * 2
    NUM_WORKERS = 16
    PRESISTENT_WORKERS = True
    LR = 1e-3
    WD = 1e-5
    WARM_UP_PCT = 0.1
    EPOCHS = 3
    FOLDER = 'EXP_03_HF'
    EXP_NAME = 'FIRST_EXP'


def get_batch_paths(start: int, end: int, extension: str = "*.pth"):
    trn_path = []
    for i in range(start, end+1):
        path = (CFG.DATA_CACHE_DIR / f"batch_{i}").glob(extension)
        trn_path.extend(list(path))
    return trn_path

In [4]:
class LogCoshLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y_t, y_prime_t):
        ey_t = y_t - y_prime_t
        return torch.mean(torch.log(torch.cosh(ey_t + 1e-12)))



class MeanPoolingWithMask(nn.Module):
    def __init__(self):
        super(MeanPoolingWithMask, self).__init__()

    def forward(self, x, mask):
        # Multiply the mask with the input tensor to zero out the padded values
        x = x * mask.unsqueeze(-1)

        # Sum the values along the sequence dimension
        x = torch.sum(x, dim=1)

        # Divide the sum by the number of non-padded values (i.e. the sum of the mask)
        x = x / torch.sum(mask, dim=1, keepdim=True)

        return x

class FeedForward(nn.Module):
    def __init__(self, dim, dim_out = None, mult = 4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Linear(dim * mult, dim_out)
        )

    def forward(self, x):
        return self.net(x)


class IceCubeModelEncoderV0(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=6,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128,
                        depth=6, 
                        heads=8),
        )

        #self.pool = MeanPoolingWithMask()
        self.head = FeedForward(128, 2)

    def forward(self, x, mask):
        x = self.encoder(x, mask = mask)
        x = x.mean(dim=1)
        x = self.head(x)
        return x

#calculte metric based on angular distance
def get_score(y_hat, y):
    return angular_dist_score(y[:, 0], y[:, 1], y_hat[:, 0], y_hat[:, 1]).detach().cpu().numpy()

In [5]:
from copy import deepcopy
class HuggingFaceDatasetV0(torch.utils.data.Dataset):
    def __init__(self, ds, max_events=100):
        self.ds = ds
        self.max_events = max_events

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        item = self.ds[idx]

        event = pd.DataFrame(item)[
            [
                "time",
                "charge",
                "auxiliary",
                "x",
                "y",
                "z",
            ]
        ].astype(np.float32)
        if self.max_events:
            event = event[: self.max_events]
        event["time"] /= event["time"].max()
        event[["x", "y", "z"]] /= 500
        event["charge"] = np.log10(event["charge"])

        event = event.values
        mask = np.ones(len(event), dtype=bool)
        label = np.array([item["azimuth"], item["zenith"]], dtype=np.float32)

        batch = deepcopy(
            {
                "event": torch.tensor(event),
                "mask": torch.tensor(mask),
                "label": torch.tensor(label),
            }
        )
        return batch


def collate_fn(batch):
    event = [x["event"] for x in batch]
    mask = [x["mask"] for x in batch]
    label = [x["label"] for x in batch]

    event = torch.nn.utils.rnn.pad_sequence(event, batch_first=True)
    mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)
    batch = {"event": event, "mask": mask, "label": torch.stack(label)}
    return batch

In [6]:
PATH_DATAST = Path("../data")
nums = [i for i in range(1, 400)]
random.shuffle(nums)
trn_pth = [load_from_disk(PATH_DATAST / "hf_cashe" / f"batch_{i}.parquet") for i in nums]
trn_pth = concatenate_datasets(trn_pth)

vld_pth = [load_from_disk(PATH_DATAST / "hf_cashe" / f"batch_{i}.parquet") for i in range(600, 603)]
vld_pth = concatenate_datasets(vld_pth)


print(len(trn_pth), len(vld_pth))

trn_ds = HuggingFaceDataset(trn_pth)
vld_ds = HuggingFaceDataset(vld_pth)

trn_dl = DataLoader(
    trn_ds,
    batch_size=CFG.BATCH_SIZE,
    shuffle=False,
    num_workers=CFG.NUM_WORKERS,
    persistent_workers=CFG.PRESISTENT_WORKERS,
    drop_last=True,
    collate_fn=collate_fn,
)
vld_dl = DataLoader(
    vld_ds,
    batch_size=CFG.BATCH_SIZE,
    shuffle=False,
    num_workers=CFG.NUM_WORKERS,
    persistent_workers=CFG.PRESISTENT_WORKERS,
    drop_last=False,
    collate_fn=collate_fn,
)


custom_model = IceCubeModelEncoderV0()
opt = torch.optim.AdamW(custom_model.parameters(), lr=CFG.LR, weight_decay=CFG.WD)
loss_func = LogCoshLoss()
warmup_steps = int(len(trn_dl) * int(CFG.WARM_UP_PCT * CFG.EPOCHS))
total_steps = int(len(trn_dl) * CFG.EPOCHS)
sched = get_linear_schedule_with_warmup(
    opt, num_warmup_steps=warmup_steps, num_training_steps=total_steps
)

fit(
    epochs=CFG.EPOCHS,
    model=custom_model,
    train_dl=trn_dl,
    valid_dl=vld_dl,
    loss_fn=loss_func,
    opt=opt,
    metric=get_score,
    folder=CFG.FOLDER,
    exp_name=f"{CFG.EXP_NAME}",
    device="cuda:0",
    sched=sched,
)



79800000 600000


epoch,train_loss,valid_loss,val_metric


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceed

Better model found at epoch 1 with value: 1.2050968408584595.
   epoch  train_loss  valid_loss     metric
0      1    0.487884    0.483495  1.2050968


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceed

In [15]:
trn_pth = trn_pth.shuffle()

In [16]:
trn_pth[0]

{'event_id': 816953944,
 'sensor_id': [3819,
  4460,
  537,
  4057,
  4795,
  2920,
  4738,
  1918,
  4746,
  4531,
  2920,
  4077,
  4800,
  3963,
  3963,
  3964,
  3963,
  3964,
  3964,
  3965,
  3962,
  3963,
  3966,
  3429,
  3428,
  3427,
  3106,
  3428,
  3431,
  2833,
  2833,
  2832,
  2834,
  2240,
  4906,
  2293,
  4764,
  2239,
  2237,
  2240,
  1757,
  1703,
  2298,
  1564,
  1647,
  1645,
  1168,
  2099,
  1639,
  1108,
  1105,
  281,
  4811,
  1644,
  562,
  5029,
  2830,
  1704,
  3740,
  219,
  2144,
  3707,
  2738,
  4775,
  4664,
  3493,
  4564,
  2216,
  2946,
  537,
  1707,
  4005,
  1077],
 'time': [5974,
  6010,
  6234,
  6355,
  6422,
  6509,
  7064,
  7474,
  7529,
  7609,
  9060,
  9482,
  9733,
  9874,
  9911,
  9912,
  9927,
  9929,
  9938,
  10073,
  10133,
  10181,
  10368,
  10456,
  10483,
  10505,
  10541,
  10654,
  10784,
  11020,
  11077,
  11079,
  11299,
  11395,
  11612,
  11670,
  11682,
  11785,
  11831,
  11907,
  12040,
  12073,
  12181,
  12206

In [14]:
trn_pth[0]

{'event_id': 1028662962,
 'sensor_id': [1228,
  4451,
  1409,
  2536,
  3425,
  1905,
  2416,
  577,
  1698,
  1312,
  3888,
  1228,
  68,
  67,
  69,
  69,
  68,
  67,
  68,
  70,
  68,
  429,
  432,
  70,
  67,
  428,
  429,
  64,
  429,
  428,
  67,
  91,
  4426,
  434,
  432,
  65,
  62,
  3948,
  2833,
  880,
  2375,
  496,
  3458,
  2330,
  432,
  3425,
  3413,
  4195,
  4560,
  5096,
  1348,
  4643,
  857,
  1228,
  1178,
  1233,
  3425,
  3716,
  4422,
  5103],
 'time': [6950,
  7159,
  7527,
  7986,
  8512,
  8575,
  8864,
  8979,
  9261,
  9410,
  9665,
  9719,
  9888,
  9913,
  9915,
  9995,
  10028,
  10052,
  10102,
  10128,
  10135,
  10179,
  10264,
  10309,
  10340,
  10368,
  10368,
  10392,
  10421,
  10445,
  10483,
  10503,
  10649,
  10680,
  10779,
  10933,
  11035,
  11193,
  11264,
  11352,
  11616,
  11633,
  11700,
  12096,
  12122,
  12156,
  12836,
  12968,
  13061,
  13687,
  13721,
  14282,
  14502,
  14575,
  15909,
  16009,
  16180,
  16389,
  16425,
  1