In [1]:
import getpass
from pathlib import Path
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F

from scipy.interpolate import interp1d
from sklearn.preprocessing import RobustScaler
from torch import LongTensor, Tensor
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup

COMP_NAME = "icecube-neutrinos-in-deep-ice"
# Return the “login name” of the user
KERNEL = False if getpass.getuser() == "anjum" else True
if not KERNEL:  # in personal computer
    INPUT_PATH = Path(f"/mnt/storage_dimm2/kaggle_data/{COMP_NAME}")
    OUTPUT_PATH = Path(f"/mnt/storage_dimm2/kaggle_output/{COMP_NAME}")
    MODEL_CACHE = Path("/mnt/storage/model_cache/torch")
    TRANSPARENCY_PATH = INPUT_PATH / "ice_transparency.txt"
else:           # in kaggle
    INPUT_PATH = Path(f"/kaggle/input/{COMP_NAME}")
    MODEL_CACHE = None
    TRANSPARENCY_PATH = "/kaggle/input/icecubetransparency/ice_transparency.txt"

    # Install packages
    import subprocess

    if torch.cuda.is_available():
        whls = [
            "/kaggle/input/pytorchgeometric/torch_cluster-1.6.0-cp37-cp37m-linux_x86_64.whl",
            "/kaggle/input/pytorchgeometric/torch_scatter-2.1.0-cp37-cp37m-linux_x86_64.whl",
            "/kaggle/input/pytorchgeometric/torch_sparse-0.6.16-cp37-cp37m-linux_x86_64.whl",
            "/kaggle/input/pytorchgeometric/torch_spline_conv-1.2.1-cp37-cp37m-linux_x86_64.whl",
            "/kaggle/input/pytorchgeometric/torch_geometric-2.2.0-py3-none-any.whl",
            "/kaggle/input/pytorchgeometric/ruamel.yaml-0.17.21-py3-none-any.whl",
        ]
    else:
        whls = [
            "/kaggle/input/pytorch-geometric/PyTorch-Geometric/torch_cluster-1.6.0-cp37-cp37m-linux_x86_64.whl",
            "/kaggle/input/pytorch-geometric/PyTorch-Geometric/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl",
            "/kaggle/input/pytorch-geometric/PyTorch-Geometric/torch_sparse-0.6.15-cp37-cp37m-linux_x86_64.whl",
            "/kaggle/input/pytorch-geometric/PyTorch-Geometric/torch_spline_conv-1.2.1-cp37-cp37m-linux_x86_64.whl",
            "/kaggle/input/pytorch-geometric/PyTorch-Geometric/torch_geometric-2.1.0.post1-py3-none-any.whl",
            "/kaggle/input/pytorchgeometric/ruamel.yaml-0.17.21-py3-none-any.whl",
        ]

    for w in whls:
        print("Installing", w)
        subprocess.call(["pip", "install", w, "--no-deps", "--upgrade"])

    import sys
#     sys.path.append("/kaggle/input/graphnet/graphnet-main/src")

# from graphnet.models.graph_builders import KNNGraphBuilder
# from graphnet.models.task.reconstruction import (
#     AzimuthReconstructionWithKappa,
#     ZenithReconstruction,
# )
# from graphnet.training.loss_functions import VonMisesFisher2DLoss, CosineLoss
# from graphnet.models.gnn.gnn import GNN
# from graphnet.models.utils import calculate_xyzt_homophily
# from graphnet.utilities.config import save_model_config

import torch_geometric
import torch_geometric.nn as pyg_nn
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
# from torch_geometric.nn import EdgeConv
from torch_geometric.nn import EdgeConv, SAGEConv, ChebConv, GCNConv
# from torch_geometric.nn.pool import knn_graph
from torch_geometric.nn import knn_graph
from torch_geometric.typing import Adj

import torch_scatter
from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum

GLOBAL_POOLINGS = {
    "min": scatter_min,
    "max": scatter_max,
    "sum": scatter_sum,
    "mean": scatter_mean,
}

_dtype = {
    "batch_id": "int16",
    "event_id": "int64",
}

Installing /kaggle/input/pytorchgeometric/torch_cluster-1.6.0-cp37-cp37m-linux_x86_64.whl
Processing /kaggle/input/pytorchgeometric/torch_cluster-1.6.0-cp37-cp37m-linux_x86_64.whl
Installing collected packages: torch-cluster
Successfully installed torch-cluster-1.6.0




Installing /kaggle/input/pytorchgeometric/torch_scatter-2.1.0-cp37-cp37m-linux_x86_64.whl
Processing /kaggle/input/pytorchgeometric/torch_scatter-2.1.0-cp37-cp37m-linux_x86_64.whl
Installing collected packages: torch-scatter
Successfully installed torch-scatter-2.1.0




Installing /kaggle/input/pytorchgeometric/torch_sparse-0.6.16-cp37-cp37m-linux_x86_64.whl
Processing /kaggle/input/pytorchgeometric/torch_sparse-0.6.16-cp37-cp37m-linux_x86_64.whl
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.16




Installing /kaggle/input/pytorchgeometric/torch_spline_conv-1.2.1-cp37-cp37m-linux_x86_64.whl
Processing /kaggle/input/pytorchgeometric/torch_spline_conv-1.2.1-cp37-cp37m-linux_x86_64.whl
Installing collected packages: torch-spline-conv
Successfully installed torch-spline-conv-1.2.1




Installing /kaggle/input/pytorchgeometric/torch_geometric-2.2.0-py3-none-any.whl
Processing /kaggle/input/pytorchgeometric/torch_geometric-2.2.0-py3-none-any.whl
Installing collected packages: torch-geometric
Successfully installed torch-geometric-2.2.0




Installing /kaggle/input/pytorchgeometric/ruamel.yaml-0.17.21-py3-none-any.whl
Processing /kaggle/input/pytorchgeometric/ruamel.yaml-0.17.21-py3-none-any.whl
Installing collected packages: ruamel.yaml
Successfully installed ruamel.yaml-0.17.21




https://github.com/graphnet-team/graphnet

## Dataset

In [2]:
!cat /kaggle/input/icecubetransparency/ice_transparency.txt

depth scattering_len absorption_len
1398.4 13.2 45.1
1408.4 14.0 48.6
1418.4 14.7 53.2
1428.4 17.0 57.6
1438.4 16.0 57.6
1448.4 14.4 52.2
1458.4 16.0 60.1
1468.4 20.8 74.6
1478.4 26.7 96.6
1488.4 34.7 110.5
1498.4 39.7 135.6
1508.5 38.7 134.7
1518.6 27.8 98.2
1528.7 16.6 64.7
1538.8 13.7 48.5
1548.7 13.5 44.3
1558.7 15.7 54.4
1568.5 15.7 56.7
1578.5 14.7 52.1
1588.5 17.6 60.7
1598.5 21.6 72.7
1608.5 24.0 78.9
1618.5 20.0 68.7
1628.5 17.8 66.6
1638.5 28.9 100.0
1648.4 36.9 128.6
1658.4 42.1 148.2
1668.4 46.5 165.7
1678.5 45.4 156.0
1688.5 39.1 138.5
1698.5 30.6 113.9
1708.5 26.5 90.2
1718.5 19.3 73.5
1728.5 20.8 75.9
1738.5 20.1 67.8
1748.5 20.3 68.6
1758.5 24.5 83.8
1768.5 33.5 119.5
1778.5 36.2 121.6
1788.5 35.4 108.3
1798.5 32.3 113.4
1808.5 40.2 139.1
1818.4 44.7 148.1
1828.4 34.5 122.8
1838.4 30.6 113.8
1848.4 27.5 89.9
1858.4 19.7 71.7
1868.5 21.4 70.6
1878.5 28.8 95.9
1888.5 38.3 116.5
1898.5 38.4 143.6
1908.5 44.2 169.4
1918.5 50.5 178.0
1928.5 46.6 156.5
1938.5 36.8 135.3
1948.

In [3]:
# datasets.py
def ice_transparency(data_path, datum=1950):
    # Data from page 31 of https://arxiv.org/pdf/1301.5361.pdf
    # Datum is from footnote 8 of page 29
    df = pd.read_csv(data_path, delim_whitespace=True)
    df["z"] = df["depth"] - datum
    df["z_norm"] = df["z"] / 500
    df[["scattering_len_norm", "absorption_len_norm"]] = RobustScaler().fit_transform(
        df[["scattering_len", "absorption_len"]]
    )

    # These are both roughly equivalent after scaling
    f_scattering = interp1d(df["z_norm"], df["scattering_len_norm"])
    f_absorption = interp1d(df["z_norm"], df["absorption_len_norm"])
    return f_scattering, f_absorption

In [4]:
class IceCubeDataset(Dataset):
    def __init__(
        self,
        batch_id,
        event_ids,
        sensor_df,
        mode="test",
        y=None,
        pulse_limit=300,
        transform=None,
        pre_transform=None,
        pre_filter=None,
    ):
        super().__init__(transform, pre_transform, pre_filter)
        self.y = y
        self.event_ids = event_ids
        self.batch_df = pd.read_parquet(INPUT_PATH / mode / f"batch_{batch_id}.parquet")
        self.sensor_df = sensor_df
        self.pulse_limit = pulse_limit
        self.f_scattering, self.f_absorption = ice_transparency(TRANSPARENCY_PATH)

        self.batch_df["time"] = (self.batch_df["time"] - 1.0e04) / 3.0e4
        self.batch_df["charge"] = np.log10(self.batch_df["charge"]) / 3.0
        self.batch_df["auxiliary"] = self.batch_df["auxiliary"].astype(int) - 0.5

    def len(self):
        return len(self.event_ids)

    def get(self, idx):
        event_id = self.event_ids[idx]
        event = self.batch_df.loc[event_id]
    
        # represent each event by a single graph
        event = pd.merge(event, self.sensor_df, on="sensor_id")
        col = ["x", "y", "z", "time", "charge", "qe", "auxiliary"]

        x = event[col].values
        x = torch.tensor(x, dtype=torch.float32)
        data = Data(x=x, n_pulses=torch.tensor(x.shape[0], dtype=torch.int32))

        # Add ice transparency data
        z = data.x[:, 2].numpy()
        scattering = torch.tensor(self.f_scattering(z), dtype=torch.float32).view(-1, 1)
        # absorption = torch.tensor(self.f_absorption(z), dtype=torch.float32).view(-1, 1)

        data.x = torch.cat([data.x, scattering], dim=1)

        # Downsample the large events
        if data.n_pulses > self.pulse_limit:
            data.x = data.x[np.random.choice(data.n_pulses, self.pulse_limit)]
            data.n_pulses = torch.tensor(self.pulse_limit, dtype=torch.int32)
    
        # Builds graph from the k-nearest neighbours.
        data.edge_index = knn_graph(
            data.x[:, [0, 1, 2]],  # x, y, z
            k=8,
            batch=None,
            loop=False
        )

        if self.y is not None:
            y = self.y.loc[idx, :].values
            y = torch.tensor(y, dtype=torch.float32)
            data.y = y

        return data

In [5]:
# preprocessing.py
def prepare_sensors():
    sensors = pd.read_csv(INPUT_PATH / "sensor_geometry.csv").astype(
        {
            "sensor_id": np.int16,
            "x": np.float32,
            "y": np.float32,
            "z": np.float32,
        }
    )
    sensors["string"] = 0
    sensors["qe"] = 1

    for i in range(len(sensors) // 60):
        start, end = i * 60, (i * 60) + 60
        sensors.loc[start:end, "string"] = i

        # High Quantum Efficiency in the lower 50 DOMs - https://arxiv.org/pdf/2209.03042.pdf (Figure 1)
        if i in range(78, 86):
            start_veto, end_veto = i * 60, (i * 60) + 10
            start_core, end_core = end_veto + 1, (i * 60) + 60
            sensors.loc[start_core:end_core, "qe"] = 1.35

    # https://github.com/graphnet-team/graphnet/blob/b2bad25528652587ab0cdb7cf2335ee254cfa2db/src/graphnet/models/detector/icecube.py#L33-L41
    # Assume that "rde" (relative dom efficiency) is equivalent to QE
    sensors["x"] /= 500
    sensors["y"] /= 500
    sensors["z"] /= 500
    sensors["qe"] -= 1.25
    sensors["qe"] /= 0.25

    return sensors

In [6]:
sensors = prepare_sensors()
sensors

Unnamed: 0,sensor_id,x,y,z,string,qe
0,0,-0.51228,-1.04216,0.99206,0,-1.0
1,1,-0.51228,-1.04216,0.95802,0,-1.0
2,2,-0.51228,-1.04216,0.92398,0,-1.0
3,3,-0.51228,-1.04216,0.88994,0,-1.0
4,4,-0.51228,-1.04216,0.85590,0,-1.0
...,...,...,...,...,...,...
5155,5155,-0.02194,0.01344,-0.94478,85,0.4
5156,5156,-0.02194,0.01344,-0.95878,85,0.4
5157,5157,-0.02194,0.01344,-0.97280,85,0.4
5158,5158,-0.02194,0.01344,-0.98682,85,0.4


In [7]:
meta = pd.read_parquet(
    INPUT_PATH / f"train_meta.parquet", columns=["batch_id", "event_id", "azimuth", "zenith"]
).astype(_dtype)
meta

Unnamed: 0,batch_id,event_id,azimuth,zenith
0,1,24,5.029555,2.087498
1,1,41,0.417742,1.549686
2,1,59,1.160466,2.401942
3,1,67,5.845952,0.759054
4,1,72,0.653719,0.939117
...,...,...,...,...
131953919,660,2147483597,5.895612,0.333071
131953920,660,2147483603,3.273695,1.503301
131953921,660,2147483617,2.945376,1.723253
131953922,660,2147483626,1.616582,1.937025


In [8]:
batch_ids = meta["batch_id"].unique()
batch_ids

array([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
        14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,
        27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
        40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,
        53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,
        66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,
        79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,
        92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103, 104,
       105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,
       118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130,
       131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
       144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156,
       157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
       170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 18

In [9]:
# for i, b in enumerate(batch_ids):
#     event_ids = meta[meta["batch_id"] == b]["event_id"].tolist()
#     y = meta[meta["batch_id"] == b][['zenith', 'azimuth']].reset_index(drop=True)
#     dataset = IceCubeDataset(
#         b, event_ids, sensors, mode='train', y=y,
#     )
#     print(f'batch {i}')
#     print("num of graph:", len(dataset), '\t', dataset[0], '\t', dataset[1])
#     if i >= 3:
#         break

In [10]:
# dataset[0].edge_index

## Dynamic graph

In [11]:
def calculate_distance_matrix(xyz_coords: Tensor) -> Tensor:
    """Calculate the matrix of pairwise distances between pulses.
    Args:
        xyz_coords: (x,y,z)-coordinates of pulses, of shape [nb_doms, 3].
    Returns:
        Matrix of pairwise distances, of shape [nb_doms, nb_doms]
    """
    diff = xyz_coords.unsqueeze(dim=2) - xyz_coords.T.unsqueeze(dim=0)
    return torch.sqrt(torch.sum(diff**2, dim=1))


class EuclideanGraphBuilder(nn.Module):
    """Builds graph according to Euclidean distance between nodes.
    See https://arxiv.org/pdf/1809.06166.pdf.
    """
    def __init__(
        self,
        sigma: float,
        threshold: float = 0.0,
        columns: List[int] = None,
    ):
        """Construct `EuclideanGraphBuilder`."""
        # Base class constructor
        super().__init__()

        # Check(s)
        if columns is None:
            columns = [0, 1, 2]

        # Member variable(s)
        self._sigma = sigma
        self._threshold = threshold
        self._columns = columns

    def forward(self, data: Data) -> Data:
        """Forward pass."""
        # Constructs the adjacency matrix from the raw, DOM-level data and
        # returns this matrix
        xyz_coords = data.x[:, self._columns]

        # Construct block-diagonal matrix indicating whether pulses belong to
        # the same event in the batch
        batch_mask = data.batch.unsqueeze(dim=0) == data.batch.unsqueeze(dim=1)

        distance_matrix = calculate_distance_matrix(xyz_coords)
        affinity_matrix = torch.exp(
            -0.5 * distance_matrix**2 / self._sigma**2
        )

        # Use softmax to normalise all adjacencies to one for each node
        exp_row_sums = torch.exp(affinity_matrix).sum(axis=1)
        weighted_adj_matrix = torch.exp(
            affinity_matrix
        ) / exp_row_sums.unsqueeze(dim=1)

        # Only include edges with weights that exceed the chosen threshold (and
        # are part of the same event)
        sources, targets = torch.where(
            (weighted_adj_matrix > self._threshold) & (batch_mask)
        )
        edge_weights = weighted_adj_matrix[sources, targets]

        data.edge_index = torch.stack((sources, targets))
        data.edge_weight = edge_weights

        return data

## Model

In [12]:
class DenseDynBlock(nn.Module):
    """
    Dense Dynamic graph convolution block
    """
    def __init__(self, in_channels, out_channels=64, sigma=0.5):
        super(DenseDynBlock, self).__init__()
        self.GraphBuilder = EuclideanGraphBuilder(sigma=sigma)
        self.gnn = SAGEConv(in_channels, out_channels)

    def forward(self, data):
        data1 = self.GraphBuilder(data)
        x, edge_index, batch = data1.x, data1.edge_index, data1.batch
        x = self.gnn(x, edge_index)
        data1.x = torch.cat((x, data.x), 1)
        return data1

In [13]:
class MyGNN(nn.Module):
    """
    Dynamic graph convolution layer
    """
    def __init__(self, in_channels, hidden_channels, out_channels, n_blocks):
        super().__init__()
        self.n_blocks = n_blocks
        self.head = SAGEConv(in_channels, hidden_channels)
        c_growth  = hidden_channels
        self.gnn = nn.Sequential(*[DenseDynBlock(hidden_channels+i*c_growth, c_growth)
                                    for i in range(n_blocks-1)])
        fusion_dims = int(hidden_channels * self.n_blocks + c_growth * ((1 + self.n_blocks - 1) * (self.n_blocks - 1) / 2))
        self.linear = nn.Linear(fusion_dims, out_channels)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        data.x = self.head(x, edge_index)
        feats = [data.x]
        for i in range(self.n_blocks-1):
            data = self.gnn[i](data)
            feats.append(data.x)
        feats = torch.cat(feats, 1)
        x = pyg_nn.global_mean_pool(feats, data.batch)
        out = F.relu(self.linear(x))
        return out

In [14]:
model = MyGNN(8, 16, 2, 3)
model

MyGNN(
  (head): SAGEConv(8, 16, aggr=mean)
  (gnn): Sequential(
    (0): DenseDynBlock(
      (GraphBuilder): EuclideanGraphBuilder()
      (gnn): SAGEConv(16, 16, aggr=mean)
    )
    (1): DenseDynBlock(
      (GraphBuilder): EuclideanGraphBuilder()
      (gnn): SAGEConv(32, 16, aggr=mean)
    )
  )
  (linear): Linear(in_features=96, out_features=2, bias=True)
)

In [15]:
# train_loader = DataLoader(dataset[0:500], batch_size=32, num_workers=1)
# for d in train_loader:
#     print(d)
#     break
# for sample_batched in train_loader:
#     outputs = model(sample_batched)
#     print(outputs.shape, sample_batched.x.shape, sample_batched.y.shape)
#     break

In [16]:
epochs = 10
batchsize = 32
criterion = nn.L1Loss()
opt = torch.optim.AdamW(model.parameters(), lr=0.3)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('using ', device)
model = model.to(device)

using  cuda


In [17]:
for i, b in enumerate(batch_ids):
    event_ids = meta[meta["batch_id"] == b]["event_id"].tolist()
    y = meta[meta["batch_id"] == b][['zenith', 'azimuth']].reset_index(drop=True)
    dataset = IceCubeDataset(
        b, event_ids, sensors, mode='train', y=y,
    )
    train_len = int(0.7*len(dataset[0:3000]))
    train_loader = DataLoader(dataset[0:train_len], batch_size=batchsize)
    val_loader = DataLoader(dataset[train_len:3000], batch_size=batchsize)
    
    print(f'batch {i}')
    for epoch_num in range(epochs):
        total_loss_train = 0
        model.train()
        for sample_batched in tqdm(train_loader, desc='train'):
            opt.zero_grad()
            sample_batched = sample_batched.to(device)
            outputs = model(sample_batched)
            label = sample_batched.y.reshape(-1, 2).to(device)
            loss = criterion(outputs, label)
            total_loss_train += loss.cpu().item()
            loss.backward()
            opt.step()
#             break
        
        total_loss_val = 0
        model.eval()
        with torch.no_grad():
            for sample_batched in tqdm(val_loader, desc='val'):
                sample_batched = sample_batched.to(device)
                outputs = model(sample_batched)
                label = sample_batched.y.reshape(-1, 2).to(device)
                loss = criterion(outputs, label)
                total_loss_val += loss.cpu().item()
#                 break

        print(f'epoch[{epoch_num}]', total_loss_train / train_len, total_loss_train / (len(dataset[0:3000]) - train_len))

    # just three batch dataset
    if i >= 2:
        break

batch 0


train: 100%|██████████| 66/66 [00:11<00:00,  5.64it/s]
val: 100%|██████████| 29/29 [00:04<00:00,  7.16it/s]


epoch[0] 0.15883770136606126 0.37062130318747627


train: 100%|██████████| 66/66 [00:09<00:00,  6.93it/s]
val: 100%|██████████| 29/29 [00:03<00:00,  7.34it/s]


epoch[1] 0.07365136612029302 0.17185318761401705


train: 100%|██████████| 66/66 [00:09<00:00,  6.97it/s]
val: 100%|██████████| 29/29 [00:04<00:00,  6.56it/s]


epoch[2] 0.07365136612029302 0.17185318761401705


train: 100%|██████████| 66/66 [00:09<00:00,  7.04it/s]
val: 100%|██████████| 29/29 [00:03<00:00,  7.36it/s]


epoch[3] 0.07365136612029302 0.17185318761401705


train: 100%|██████████| 66/66 [00:09<00:00,  7.01it/s]
val: 100%|██████████| 29/29 [00:03<00:00,  7.27it/s]


epoch[4] 0.07365136612029302 0.17185318761401705


train: 100%|██████████| 66/66 [00:09<00:00,  6.76it/s]
val: 100%|██████████| 29/29 [00:03<00:00,  7.30it/s]


epoch[5] 0.07365136612029302 0.17185318761401705


train: 100%|██████████| 66/66 [00:09<00:00,  7.01it/s]
val: 100%|██████████| 29/29 [00:03<00:00,  7.39it/s]


epoch[6] 0.07365136612029302 0.17185318761401705


train: 100%|██████████| 66/66 [00:09<00:00,  6.73it/s]
val: 100%|██████████| 29/29 [00:03<00:00,  7.46it/s]


epoch[7] 0.07365136612029302 0.17185318761401705


train: 100%|██████████| 66/66 [00:09<00:00,  6.98it/s]
val: 100%|██████████| 29/29 [00:04<00:00,  7.18it/s]


epoch[8] 0.07365136612029302 0.17185318761401705


train: 100%|██████████| 66/66 [00:09<00:00,  7.09it/s]
val: 100%|██████████| 29/29 [00:04<00:00,  6.65it/s]


epoch[9] 0.07365136612029302 0.17185318761401705
batch 1


train: 100%|██████████| 66/66 [00:10<00:00,  6.24it/s]
val: 100%|██████████| 29/29 [00:04<00:00,  7.09it/s]


epoch[0] 0.07298439525422595 0.17029692225986057


train: 100%|██████████| 66/66 [00:09<00:00,  7.05it/s]
val: 100%|██████████| 29/29 [00:04<00:00,  6.43it/s]


epoch[1] 0.07298439525422595 0.17029692225986057


train: 100%|██████████| 66/66 [00:09<00:00,  6.96it/s]
val: 100%|██████████| 29/29 [00:03<00:00,  7.28it/s]


epoch[2] 0.07298439525422595 0.17029692225986057


train: 100%|██████████| 66/66 [00:09<00:00,  7.04it/s]
val: 100%|██████████| 29/29 [00:03<00:00,  7.27it/s]


epoch[3] 0.07298439525422595 0.17029692225986057


train: 100%|██████████| 66/66 [00:09<00:00,  6.69it/s]
val: 100%|██████████| 29/29 [00:04<00:00,  7.22it/s]


epoch[4] 0.07298439525422595 0.17029692225986057


train: 100%|██████████| 66/66 [00:09<00:00,  7.09it/s]
val: 100%|██████████| 29/29 [00:04<00:00,  7.05it/s]


epoch[5] 0.07298439525422595 0.17029692225986057


train: 100%|██████████| 66/66 [00:09<00:00,  6.75it/s]
val: 100%|██████████| 29/29 [00:04<00:00,  7.15it/s]


epoch[6] 0.07298439525422595 0.17029692225986057


train: 100%|██████████| 66/66 [00:09<00:00,  7.00it/s]
val: 100%|██████████| 29/29 [00:03<00:00,  7.39it/s]


epoch[7] 0.07298439525422595 0.17029692225986057


train: 100%|██████████| 66/66 [00:09<00:00,  6.80it/s]
val: 100%|██████████| 29/29 [00:04<00:00,  6.80it/s]


epoch[8] 0.07298439525422595 0.17029692225986057


train: 100%|██████████| 66/66 [00:09<00:00,  7.11it/s]
val: 100%|██████████| 29/29 [00:04<00:00,  7.15it/s]


epoch[9] 0.07298439525422595 0.17029692225986057
batch 2


train: 100%|██████████| 66/66 [00:10<00:00,  6.03it/s]
val: 100%|██████████| 29/29 [00:04<00:00,  7.15it/s]


epoch[0] 0.0736443441254752 0.17183680295944215


train: 100%|██████████| 66/66 [00:09<00:00,  7.11it/s]
val: 100%|██████████| 29/29 [00:04<00:00,  7.18it/s]


epoch[1] 0.0736443441254752 0.17183680295944215


train: 100%|██████████| 66/66 [00:09<00:00,  7.05it/s]
val: 100%|██████████| 29/29 [00:04<00:00,  6.70it/s]


epoch[2] 0.0736443441254752 0.17183680295944215


train: 100%|██████████| 66/66 [00:09<00:00,  7.04it/s]
val: 100%|██████████| 29/29 [00:03<00:00,  7.43it/s]


epoch[3] 0.0736443441254752 0.17183680295944215


train: 100%|██████████| 66/66 [00:09<00:00,  7.01it/s]
val: 100%|██████████| 29/29 [00:04<00:00,  7.20it/s]


epoch[4] 0.0736443441254752 0.17183680295944215


train: 100%|██████████| 66/66 [00:09<00:00,  6.80it/s]
val: 100%|██████████| 29/29 [00:03<00:00,  7.25it/s]


epoch[5] 0.0736443441254752 0.17183680295944215


train: 100%|██████████| 66/66 [00:09<00:00,  7.05it/s]
val: 100%|██████████| 29/29 [00:03<00:00,  7.39it/s]


epoch[6] 0.0736443441254752 0.17183680295944215


train: 100%|██████████| 66/66 [00:09<00:00,  6.73it/s]
val: 100%|██████████| 29/29 [00:03<00:00,  7.36it/s]


epoch[7] 0.0736443441254752 0.17183680295944215


train: 100%|██████████| 66/66 [00:09<00:00,  7.01it/s]
val: 100%|██████████| 29/29 [00:03<00:00,  7.44it/s]


epoch[8] 0.0736443441254752 0.17183680295944215


train: 100%|██████████| 66/66 [00:09<00:00,  7.06it/s]
val: 100%|██████████| 29/29 [00:04<00:00,  6.48it/s]

epoch[9] 0.0736443441254752 0.17183680295944215





## Infer

In [18]:
def infer(model, loader, device="cpu"):
    model.to(device)
    model.eval()

    predictions = []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            pred_angles = model(batch)
            predictions.append(pred_angles.cpu())

    return torch.cat(predictions, 0)

In [19]:
def make_predictions(model, device="cpu", mode="test", batch_size=32):
    sensors = prepare_sensors()

    meta = pd.read_parquet(
        INPUT_PATH / f"{mode}_meta.parquet", columns=["batch_id", "event_id"]
    ).astype(_dtype)
    batch_ids = meta["batch_id"].unique()

    if mode == "train":
        batch_ids = batch_ids[:6]

    batch_preds = []
    for b in batch_ids:
        event_ids = meta[meta["batch_id"] == b]["event_id"].tolist()
        dataset = IceCubeDataset(
            b, event_ids, sensors, mode=mode,
        )
        loader = DataLoader(dataset, batch_size=batch_size, num_workers=1)
        batch_preds.append(infer(model, loader, device=device))
        print("Finished batch", b)

        if mode == "train" and b == 6:
            break

    output = torch.cat(batch_preds, 0)

    event_id_labels = []
    for b in batch_ids:
        event_id_labels.extend(meta[meta["batch_id"] == b]["event_id"].tolist())

    sub = {
        "event_id": event_id_labels,
        "azimuth": output[:, 0],
        "zenith": output[:, 1],
    }

    sub = pd.DataFrame(sub)
    sub.to_csv("submission.csv", index=False)

In [20]:
make_predictions(model, device="cuda", mode="test", batch_size=32)

Finished batch 661


In [21]:
pd.read_csv("submission.csv")

Unnamed: 0,event_id,azimuth,zenith
0,2092,0.0,0.0
1,7344,0.0,0.0
2,9482,0.0,0.0
