In [1]:
import sys

sys.path.append("..")
from icecube.dataset import IceCubeCasheDatasetV0
from icecube.dataset import collate_fn
from icecube.utils import fit
from pathlib import Path
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import torch.nn.functional as F
from transformers.optimization import (
    get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)
from torch import nn
from x_transformers import ContinuousTransformerWrapper, Encoder, Decoder

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def angular_dist_score(
    az_true: torch.Tensor,
    zen_true: torch.Tensor,
    az_pred: torch.Tensor,
    zen_pred: torch.Tensor,
) -> torch.Tensor:
    sa1 = torch.sin(az_true)
    ca1 = torch.cos(az_true)
    sz1 = torch.sin(zen_true)
    cz1 = torch.cos(zen_true)

    sa2 = torch.sin(az_pred)
    ca2 = torch.cos(az_pred)
    sz2 = torch.sin(zen_pred)
    cz2 = torch.cos(zen_pred)

    scalar_prod = sz1 * sz2 * (ca1 * ca2 + sa1 * sa2) + cz1 * cz2
    scalar_prod = torch.clamp(scalar_prod, -1, 1)
    return torch.mean(torch.abs(torch.acos(scalar_prod)))





In [3]:
class CFG:
    DATA_CACHE_DIR = Path("../data/cache")
    BATCH_SIZE = 1024
    NUM_WORKERS = 16
    PRESISTENT_WORKERS = True
    LR = 1e-3
    WD = 1e-5
    WARM_UP_PCT = 0.1
    EPOCHS = 10
    FOLDER = 'EXP_03'
    EXP_NAME = 'FIRST_EXP'


def get_batch_paths(start: int, end: int, extension: str = "*.pth"):
    trn_path = []
    for i in range(start, end+1):
        path = (CFG.DATA_CACHE_DIR / f"batch_{i}").glob(extension)
        trn_path.extend(list(path))
    return trn_path

In [4]:
class LogCoshLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y_t, y_prime_t):
        ey_t = y_t - y_prime_t
        return torch.mean(torch.log(torch.cosh(ey_t + 1e-12)))



class MeanPoolingWithMask(nn.Module):
    def __init__(self):
        super(MeanPoolingWithMask, self).__init__()

    def forward(self, x, mask):
        # Multiply the mask with the input tensor to zero out the padded values
        x = x * mask.unsqueeze(-1)

        # Sum the values along the sequence dimension
        x = torch.sum(x, dim=1)

        # Divide the sum by the number of non-padded values (i.e. the sum of the mask)
        x = x / torch.sum(mask, dim=1, keepdim=True)

        return x

class FeedForward(nn.Module):
    def __init__(self, dim, dim_out = None, mult = 4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Linear(dim * mult, dim_out)
        )

    def forward(self, x):
        return self.net(x)


class IceCubeModelEncoderV0(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=6,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128,
                        depth=3, 
                        heads=8),
        )

        #self.pool = MeanPoolingWithMask()
        self.head = FeedForward(128, 2)

    def forward(self, x, mask):
        x = self.encoder(x, mask = mask)
        x = x.mean(dim=1)
        x = self.head(x)
        return x

#calculte metric based on angular distance
def get_score(y_hat, y):
    return angular_dist_score(y[:, 0], y[:, 1], y_hat[:, 0], y_hat[:, 1]).detach().cpu().numpy()

In [None]:
vld_path = get_batch_paths(3, 6)
trn_path = get_batch_paths(7, 100)

print(len(trn_path), len(vld_path))

trn_ds = IceCubeCasheDatasetV0(trn_path)
vld_ds = IceCubeCasheDatasetV0(vld_path)

trn_dl = DataLoader(
    trn_ds,
    batch_size=CFG.BATCH_SIZE,
    shuffle=True,
    num_workers=CFG.NUM_WORKERS,
    persistent_workers=CFG.PRESISTENT_WORKERS,
    drop_last=True,
    collate_fn=collate_fn,
)
vld_dl = DataLoader(
    vld_ds,
    batch_size=CFG.BATCH_SIZE,
    shuffle=False,
    num_workers=CFG.NUM_WORKERS,
    persistent_workers=CFG.PRESISTENT_WORKERS,
    drop_last=False,
    collate_fn=collate_fn,
)


custom_model = IceCubeModelEncoderV0()
opt = torch.optim.AdamW(custom_model.parameters(), lr=CFG.LR, weight_decay=CFG.WD)
loss_func = LogCoshLoss()
warmup_steps = int(len(trn_dl) * int(CFG.WARM_UP_PCT * CFG.EPOCHS))
total_steps = int(len(trn_dl) * CFG.EPOCHS)
sched = get_linear_schedule_with_warmup(
    opt, num_warmup_steps=warmup_steps, num_training_steps=total_steps
)

fit(
    epochs=CFG.EPOCHS,
    model=custom_model,
    train_dl=trn_dl,
    valid_dl=vld_dl,
    loss_fn=loss_func,
    opt=opt,
    metric=get_score,
    folder=CFG.FOLDER,
    exp_name=f"{CFG.EXP_NAME}",
    device="cuda:0",
    sched=sched,
)

18800000 800000


epoch,train_loss,valid_loss,val_metric


Better model found at epoch 0 with value: 1.2266074419021606.
   epoch  train_loss  valid_loss     metric
0      0     0.52132    0.484775  1.2266074
Better model found at epoch 1 with value: 1.197077751159668.
   epoch  train_loss  valid_loss     metric
0      1    0.479438    0.475394  1.1970778
Better model found at epoch 2 with value: 1.181969165802002.
   epoch  train_loss  valid_loss     metric
0      2    0.472418    0.470404  1.1819692
Better model found at epoch 3 with value: 1.1610087156295776.
   epoch  train_loss  valid_loss     metric
0      3    0.468921    0.467391  1.1610087


In [None]:
35
