In [1]:
#|default_exp modelsgraph

In [2]:
#| export
import sys
sys.path.append('/opt/slh/graphnet-main/src')
import torch
from x_transformers import ContinuousTransformerWrapper, Encoder, Decoder
from torch import nn
from graphnet.training.loss_functions import VonMisesFisher2DLoss
from graphnet.models.task.reconstruction import (
    AzimuthReconstructionWithKappa,
    ZenithReconstruction,
)


[1;34mgraphnet[0m: [32mINFO    [0m 2023-02-05 15:47:09 - get_logger - Writing log to [1mlogs/graphnet_20230205-154709.log[0m


In [3]:
!ls ../../

EDA.ipynb  graphnet-main  leetcode.ipynb  packages
geom	   icecube	  logs		  test.py


In [4]:
#| export
class MeanPoolingWithMask(nn.Module):
    def __init__(self):
        super(MeanPoolingWithMask, self).__init__()

    def forward(self, x, mask):
        # Multiply the mask with the input tensor to zero out the padded values
        x = x * mask.unsqueeze(-1)

        # Sum the values along the sequence dimension
        x = torch.sum(x, dim=1)

        # Divide the sum by the number of non-padded values (i.e. the sum of the mask)
        x = x / torch.sum(mask, dim=1, keepdim=True)

        return x
    
loss_fn_azi = VonMisesFisher2DLoss()
loss_fn_zen = nn.L1Loss()

class CombineLossV0(nn.Module):
    def __init__(self, loss_fn_azi=loss_fn_azi, loss_fn_zen=loss_fn_zen):
        super().__init__()
        self.loss_fn_azi = loss_fn_azi
        self.loss_fn_zen = loss_fn_zen
        
    def forward(self, batch, output):
        target = batch['label']
        azi_pred, zen_pred = output.split(2, 1)
        azi_loss = self.loss_fn_azi(azi_pred, target)
        zen_loss = self.loss_fn_zen(zen_pred, target[:, -1].unsqueeze(-1))
        return azi_loss + zen_loss

    
class EncoderWithReconstructionLossV0(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ContinuousTransformerWrapper(
            dim_in=8,
            dim_out=128,
            max_seq_len=150,
            attn_layers=Encoder(dim=128, depth=6, heads=8),
        )

        self.pool = MeanPoolingWithMask()
        self.az = AzimuthReconstructionWithKappa(
            hidden_size=128,
            loss_function=loss_fn_azi,
            target_labels=["azimuth", "kappa"],
        )
        self.zn =  ZenithReconstruction(
            hidden_size=128,
            loss_function=loss_fn_zen,
            target_labels=["zenith"],
        )

    def forward(self, batch):
        x, mask = batch['event'], batch['mask']
        x = self.encoder(x, mask=mask)
        x = self.pool(x, mask)
        az = self.az(x)
        zn = self.zn(x)
        return torch.concat([az, zn], 1)

In [5]:
model = EncoderWithReconstructionLossV0().eval()
event = torch.rand(10, 100, 8)
mask = torch.ones(10, 100, dtype=torch.bool)
sensor_id = torch.randint(0, 5161, (10, 100))
label = torch.rand(10, 2)
input = dict(event=event, mask=mask, sensor_id=sensor_id, label=label)
with torch.no_grad():
    y = model(input)

In [9]:
CombineLossV0()(input, y)

tensor(3.5121)

In [7]:
#|hide
#|eval: false
from nbdev.doclinks import nbdev_export
nbdev_export()