In [None]:
import numpy as np
import pandas as pd
import polars as pl
import os
import time
from tqdm.auto import tqdm
import numba as nb

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, TensorDataset, DataLoader
import math

import json

from sklearn.ensemble import HistGradientBoostingClassifier
import joblib

In [None]:
def angular_dist_score(az_true, zen_true, az_pred, zen_pred, mean=True):
    """ https://www.kaggle.com/code/sohier/mean-angular-error """
    if not (np.all(np.isfinite(az_true)) and
            np.all(np.isfinite(zen_true)) and
            np.all(np.isfinite(az_pred)) and
            np.all(np.isfinite(zen_pred))):
        raise ValueError("All arguments must be finite")
    
    # pre-compute all sine and cosine values
    sa1 = np.sin(az_true)
    ca1 = np.cos(az_true)
    sz1 = np.sin(zen_true)
    cz1 = np.cos(zen_true)
    
    sa2 = np.sin(az_pred)
    ca2 = np.cos(az_pred)
    sz2 = np.sin(zen_pred)
    cz2 = np.cos(zen_pred)
    
   
   
    scalar_prod = sz1*sz2*(ca1*ca2 + sa1*sa2) + (cz1*cz2)
    

    scalar_prod =  np.clip(scalar_prod, -1, 1)
    

    return np.average(np.abs(np.arccos(scalar_prod))) if mean else np.abs(np.arccos(scalar_prod))

In [None]:
DATA_DIR = "/kaggle/input/icecube-neutrinos-in-deep-ice/"
PREP_DIR = "/kaggle/input/icecube-preprocessed-data/"
TRAIN_META_FORMAT = "/kaggle/input/train-meta-parquet/train_meta_{:d}.parquet"
MODEL_DIR = "/kaggle/input/icecube-models-final/"

In [None]:
VALIDATE = False 

if not VALIDATE:
    PARQUETS_DIR = os.path.join(DATA_DIR + 'test')
    BATCH_LIST = list(sorted(os.listdir(PARQUETS_DIR)))
    metadata = pl.read_parquet(f'{DATA_DIR}/test_meta.parquet')
    CHECK_PREDICTION = False
else:
    PARQUETS_DIR = os.path.join(DATA_DIR + 'train')
    vbatches = [655]
    BATCH_LIST = [f'batch_{vb}.parquet' for vb in vbatches]
    META_FILES = [f'train_meta_{vb}.parquet' for vb in vbatches]
    def read_metadata():
        meta = []
        for mf, vb in zip(META_FILES, vbatches):
            bmeta = pl.read_parquet(TRAIN_META_FORMAT.format(vb))
      
            bmeta = bmeta.with_columns(pl.lit(vb).alias('batch_id'))
            meta.append(bmeta)
        return pl.concat(meta)
    metadata = read_metadata()
    CHECK_PREDICTION = True
    
GEOMETRY = os.path.join(PREP_DIR, "sensor_geometry_with_transparency.csv")
geometry = pl.scan_csv(GEOMETRY).with_columns(
                [pl.col('sensor_id').cast(pl.Int16)]
            )
    
NUM_BINS = 128
FEATURE_NAMES = ['time', 'charge', 'auxiliary', 'x', 'y', 'z', 'qe', 'scatter', 'absorp']
CHARGE_IDX = FEATURE_NAMES.index('charge')
TIME_IDX = FEATURE_NAMES.index('time')
AUX_IDX = FEATURE_NAMES.index('auxiliary')
N_FEATURES = len(FEATURE_NAMES)
MAX_SEQUENCE_LENGTH = 256
LONG_SEQ_MAX_LENGTH = 3072
BATCH_SIZE = 500
LONG_SEQ_BATCH_SIZE = 10
ULTTRA_LONG_RESAMPLE = 0

MAX_EVENTS = 200_000 if VALIDATE else 1000000000000

device = 'cuda' if torch.cuda.is_available() else 'cpu'

with open(os.path.join(PREP_DIR, f'angle_bins_{NUM_BINS}.json')) as fp:
    bin_data = json.load(fp)
    
azimuth_bin_centers = torch.tensor(bin_data['azimuth_bin_centers']).type(torch.float32).to(device)

zenith_centers = np.array(bin_data['zenith_bin_centers'])
kernel_length = 15
num_zenith_padding = (kernel_length - 1)//2
padded_zenith_bins = np.concatenate([
        -zenith_centers[num_zenith_padding-1 : : -1],
        zenith_centers,
        2 * np.pi - zenith_centers[-1 : -num_zenith_padding-1 : -1],
])
zenith_bin_centers = torch.tensor(padded_zenith_bins).type(torch.float32).to(device)
ZENITH_NUM_BINS = len(zenith_bin_centers)

# Model configs

# 15 layer
class m3_15l_vmf:
    name = 'v43vmf_ep349.ckpt'
    n_embd = 512 + 128 + 128 + 14
    middle_features = 8192
    neck_features = 2048
    bias = True
    dropout = 0.0
    unwanted_prefix = 'model'
    batchnorm = True

class m3_15l_seq3072:
    name = 'v99_v13_l3072_ep2.ckpt'
    n_embd = [512]*15
    n_heads = [8]*15
    bias = False
    dropout = 0.0
    neck_dropout = 0.0
    neck_features = 3072
    unwanted_prefix = 'model'

class m3_15l:
    name = 'v13r2_m3_s256_ep48.ckpt'
    n_embd = [512]*15
    n_heads = [8]*15
    bias = False
    dropout = 0.0
    neck_dropout = 0.0
    neck_features = 3072
    unwanted_prefix = 'model'
    long_seq_model_config = m3_15l_seq3072()
    stack_model_config = m3_15l_vmf()
    mixer_model = 'gboost_mixer_m315l_v13v43_sk1.0.2.pkl'
    selector_model = 'gboost_selector_m315l_v13v43_sk1.0.2.pkl'


class m3_18l_vmf:
    name = 'v49vmf_m3_18l_ep349.ckpt'
    n_embd = 512 + 128 + 128 + 14
    middle_features = 8192
    neck_features = 2048
    bias = True
    dropout = 0.0
    unwanted_prefix = 'model'
    batchnorm = False

class m3_18l_seq3072:
    name = 'v98_v97ep187_l3072_ep6.ckpt'
    n_embd = [512]*18
    n_heads = [8]*18
    bias = False
    dropout = 0.0
    neck_dropout = 0.0
    neck_features = 3072
    unwanted_prefix = 'model'

class m3_18l:
    name = 'v97r5_m3_s256_ep187.ckpt'
    n_embd = [512]*18
    n_heads = [8]*18
    bias = False
    dropout = 0.0
    neck_dropout = 0.0
    neck_features = 3072
    unwanted_prefix = 'model'
    long_seq_model_config = m3_18l_seq3072()
    stack_model_config = m3_18l_vmf()
    mixer_model = 'gboost_mixer_m318l_n2_v49_sk1.0.2.pkl'
    selector_model = 'gboost_selector_m318l_n2_v49_sk1.0.2.pkl'


class m3_r6_vmf:
    name = 'v3_vmfzn_m3r6_ep699.ckpt'
    n_embd = 512 + 128 + 128 + 14
    middle_features = 8192
    neck_features = 2048
    bias = True
    dropout = 0.0
    unwanted_prefix = 'model'
    batchnorm = False

class m3_r6_seq3072:
    name = 'v103_v99r6_l3072_ep7.ckpt'
    n_embd = [512]*18
    n_heads = [8]*18
    bias = False
    dropout = 0.0
    neck_dropout = 0.0
    neck_features = 3072
    unwanted_prefix = 'model'

class m3_r6:
    name = 'v99r6_m3_s256_ep26.ckpt'
    n_embd = [512]*18
    n_heads = [8]*18
    bias = False
    dropout = 0.0
    neck_dropout = 0.0
    neck_features = 3072
    unwanted_prefix = 'model'
    long_seq_model_config = m3_r6_seq3072()
    stack_model_config = m3_r6_vmf()
    mixer_model = 'gboost_mixer_v99r6_sk1.0.2.pkl'
    selector_model = 'gboost_selector_v99r6_sk1.0.2.pkl'

MODEL_CONFIGS = [m3_18l(), m3_15l(), m3_r6()]
MERGER_MODELS = ['gboost_merger_18ln2_15l.pkl', 'gboost_merger_15l_r6.pkl']

In [None]:
def set_seed(value):
    np.random.seed(value)

def sample_and_pad(data, pulse_indexes, padded_sequence_length):
    data[:, CHARGE_IDX] = np.log10(data[:, CHARGE_IDX]) / 3.0
    data[:, AUX_IDX] = data[:, AUX_IDX] - 0.5
    data_x = np.zeros((len(pulse_indexes), padded_sequence_length, data.shape[-1]), dtype=np.float32)
    sequence_lengths = np.zeros(len(pulse_indexes), dtype=np.int32)
    for ii in range(len(pulse_indexes)):
        event_data = data[pulse_indexes[ii, 0] : pulse_indexes[ii, 1] + 1]
        if len(event_data) > padded_sequence_length:
            naux_idx = np.where(event_data[:, AUX_IDX] == -0.5)[0]
            aux_idx = np.where(event_data[:, AUX_IDX] == 0.5)[0]
            if len(naux_idx) < padded_sequence_length:
                max_length_possible = min(padded_sequence_length, len(event_data))
                num_to_sample = max_length_possible - len(naux_idx)
                aux_idx_sample = np.random.choice(aux_idx, size=num_to_sample, replace=False)
                selected_idx = np.concatenate((naux_idx, aux_idx_sample))
            else:
                selected_idx = np.random.choice(naux_idx, size=padded_sequence_length, replace=False)
            selected_idx = np.sort(selected_idx)
            event_data = event_data[selected_idx]
        event_data[:, TIME_IDX] = ( event_data[:, TIME_IDX] - event_data[:, TIME_IDX].min() ) / 3e4
        assert np.all(np.isfinite(event_data))
        data_x[ii, :len(event_data), :] = event_data
        sequence_lengths[ii] = len(event_data)                       
    return data_x, sequence_lengths

In [None]:
def preprocess_data(bfile, seed, 
                    min_sequence_length, max_sequence_length, 
                    padded_sequence_length, print_time=False):
    maybe_print = print if print_time else (lambda *args: '')
    set_seed(seed)
    maybe_print("Reading", bfile)
    start_time = time.perf_counter()
    batch_id = int(bfile.split('.')[0].split('_')[-1])
    batch = pl.scan_parquet(f'{PARQUETS_DIR}/{bfile}')
    batch = batch.join(geometry, on='sensor_id', how='left')
    batch_meta = metadata.filter(pl.col('batch_id') == batch_id)

    data = batch.select(FEATURE_NAMES).collect().to_numpy()
    pulse_indexes = batch_meta.select(['first_pulse_index', 'last_pulse_index']).to_numpy()[:MAX_EVENTS]
    event_lengths = pulse_indexes[:, 1] - pulse_indexes[:, 0] + 1
    filter_events = (event_lengths > min_sequence_length) & (event_lengths <= max_sequence_length)
    event_to_keep = np.where(filter_events)[0]
    pulse_indexes = pulse_indexes[event_to_keep]
    maybe_print("Read and merge", bfile, "in", time.perf_counter() - start_time, "s")
    
    start_time = time.perf_counter()
    data_x, seq_lens = sample_and_pad(data, pulse_indexes, padded_sequence_length)
    maybe_print("Processed", bfile, "in", time.perf_counter() - start_time, "s")

    return data_x, seq_lens, event_to_keep

In [None]:
if VALIDATE:
    def check_data_loading():
        dx, dl, _ = preprocess_data(BATCH_LIST[0], seed=42, 
                                    min_sequence_length=0,
                                    max_sequence_length=MAX_SEQUENCE_LENGTH,
                                    padded_sequence_length=MAX_SEQUENCE_LENGTH,
                                    print_time=True)
        print(dx.shape)
        dx, dl, _ = preprocess_data(BATCH_LIST[0], seed=42, 
                                    min_sequence_length=MAX_SEQUENCE_LENGTH,
                                    max_sequence_length=np.inf,
                                    padded_sequence_length=LONG_SEQ_MAX_LENGTH,
                                    print_time=True)
        print(dx.shape)

    check_data_loading()

In [None]:
class IceCubeDataset(Dataset):
    def __init__(self, bfile, seed, min_sequence_lenth, max_sequence_length, padded_sequence_length):
        super().__init__()
        dx, sl, events = preprocess_data(bfile, seed, 
                                         min_sequence_lenth,
                                         max_sequence_length, 
                                         padded_sequence_length)
        self.x = torch.Tensor(dx)
        self.l = torch.Tensor(sl)
        
        self.sort_idx = np.argsort(sl)
        self.reverse_sort_idx = np.argsort(self.sort_idx)

        self.events = events

    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, index):
        si = self.sort_idx[index] 
        return self.x[si], self.l[si]

In [None]:
class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

    def __init__(self, ndim, bias=False):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

def mlp(n_embd, bias=False, dropout=0.0, out_embd=None):
    out_embd = n_embd if out_embd is None else out_embd
    return nn.Sequential(
        nn.Linear(n_embd, 4 * n_embd, bias=bias),
        nn.GELU(approximate='tanh'),
        nn.Linear(4 * n_embd, out_embd, bias=bias),
        nn.Dropout(dropout)
    )


class SelfAttention(nn.Module):
    def __init__(self, prev_emdb, n_embd, n_heads, bias=False, dropout=0.0):
        super().__init__()
        self.prev_embd = prev_emdb
        self.n_embd = n_embd
        self.n_heads = n_heads
        
        self.c_attn = nn.Linear(prev_emdb, 3 * n_embd, bias=bias)
        self.c_proj = nn.Linear(n_embd, n_embd, bias=bias)
        self.dropout = dropout
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask):
        B, T, _ = x.shape
        C = self.n_embd
        q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2) # (B, nh, T, hs)
        if hasattr(F, 'scaled_dot_product_attention'):
            y = F.scaled_dot_product_attention(q, k, v, 
                    attn_mask=attn_mask, dropout_p=self.dropout, is_causal=False)
        else:
            y = F._scaled_dot_product_attention(q, k, v, 
                    attn_mask=attn_mask, dropout_p=self.dropout, is_causal=False)[0]
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        y = self.resid_dropout(y)
        return y
    
class AttentionBlock(nn.Module):
    def __init__(self, prev_embd, n_embd, n_heads, bias=False, dropout=0.0):
        super().__init__()
        self.ln_1 = LayerNorm(prev_embd, bias)
        assert n_embd % prev_embd == 0, f"{prev_embd} {n_embd} should be divisble"
        self.attn = SelfAttention(prev_embd, n_embd, n_heads, bias, dropout)
        self.ln_2 = LayerNorm(n_embd, bias)
        self.mlp = mlp(n_embd, bias, dropout)

    def forward(self, x, attn_mask):
        x = x + self.attn(self.ln_1(x), attn_mask)
        x = x + self.mlp(self.ln_2(x))
        return x

class AttentionEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        dropout = config.dropout
        bias = config.bias
        
        attn_layers = []
        prev_embd = config.n_embd[0]
        for n_embd, n_heads in zip(config.n_embd, config.n_heads):
            attn_layers.append( AttentionBlock(prev_embd, n_embd, n_heads, bias, dropout) )
            prev_embd = n_embd
        self.attn = nn.ModuleList(attn_layers)
    
    def forward(self, x, attn_mask):
        out = x
        for attn_layer in self.attn:
            out = attn_layer(out, attn_mask)
        
        return out

    
class SequencePool(nn.Module):
    def __init__(self):
        super().__init__()
        self.num_pools = 1
    
    def forward(self, x, sequence_lengths, padding_mask):
        sumf = torch.sum(x * padding_mask.unsqueeze(2), dim=1) # Mask padded tokens
        meanf = sumf / sequence_lengths.view(-1, 1) # Normalize avg pool values by seq length
        out = meanf
        return out
    
class Neck(nn.Module):
    def __init__(self, in_features, out_features, bias, dropout):
        super().__init__()
        self.mlp = nn.Sequential(
            LayerNorm(in_features, bias=bias),
            nn.Linear(in_features, 4 * in_features, bias=bias),
            nn.GELU(approximate='tanh'),
            nn.Linear(4 * in_features, out_features, bias=bias),
            nn.Dropout(dropout)
        )
        self.n_repeats = out_features // in_features

    def forward(self, x):
        return x.repeat(1, self.n_repeats) + self.mlp(x)

    
class MultiLabelClassifier(nn.Module):    
    def __init__(self, n_features, max_block_size, num_classes, zenith_num_classes, config):
        super().__init__()

        self.inp = nn.Linear(n_features, config.n_embd[0])
        self.drop_inputs = nn.Dropout(config.dropout)

        self.encoder = AttentionEncoder(config)
        
        self.pool = SequencePool()

        num_out_features = config.n_embd[-1] * self.pool.num_pools

        self.neck_az = Neck(num_out_features, config.neck_features, config.bias, config.neck_dropout)
        self.neck_zn = Neck(num_out_features, config.neck_features, config.bias, config.neck_dropout)
        
        self.azimuth = nn.Linear(config.neck_features, num_classes)
        self.zenith = nn.Linear(config.neck_features, zenith_num_classes)

    def get_masks(self, x, l):
        key_padding_mask = torch.arange(x.shape[1]).view(1, -1).to(l.device) < l.view(-1, 1)
        attn_mask = (key_padding_mask.unsqueeze(1) == key_padding_mask.unsqueeze(2)).unsqueeze(1)  # (B, 1, T, T)
        return key_padding_mask, attn_mask

    def forward(self, x):
        inputs, seq_lengths = x
        out = self.inp(inputs)
        out = self.drop_inputs(out)
        key_padding_mask, attn_mask = self.get_masks(inputs, seq_lengths)
        out = self.encoder(out, attn_mask)
        pool = self.pool(out, seq_lengths, key_padding_mask)
    
        az_out = self.azimuth(self.neck_az(pool))
        zn_out = self.zenith(self.neck_zn(pool))
        return az_out, zn_out, pool

In [None]:
class StackNeck(nn.Module):
    def __init__(self, in_features, middle_features, out_features, bias, dropout, batchnorm):
        super().__init__()
        if batchnorm:
            self.mlp = nn.Sequential(
                nn.Linear(in_features, middle_features, bias=bias),
                nn.BatchNorm1d(middle_features),
                nn.GELU(approximate='tanh'),
                nn.Dropout(dropout),
                nn.Linear(middle_features, out_features, bias=bias),
                nn.BatchNorm1d(out_features),
                nn.GELU(approximate='tanh'),
                nn.Dropout(dropout),
            )
        else:
            self.mlp = nn.Sequential(
                nn.Linear(in_features, middle_features, bias=bias),
                nn.LayerNorm(middle_features),
                nn.GELU(approximate='tanh'),
                nn.Dropout(dropout),
                nn.Linear(middle_features, out_features, bias=bias),
                nn.LayerNorm(out_features),
                nn.GELU(approximate='tanh'),
                nn.Dropout(dropout),
            )

    def forward(self, x):
        return self.mlp(x)
    
class StackModel(nn.Module):
    def __init__(self, num_classes, zenith_num_classes, config):
        super().__init__()

        self.neck = StackNeck(config.n_embd, config.middle_features, config.neck_features, 
                              config.bias, config.dropout, config.batchnorm)
        self.xyz = nn.Linear(config.neck_features, 3)

    def forward(self, x):
        return self.xyz(self.neck(x))

In [None]:
def prepare_submission(event_ids, azimuth, zenith, validate_ids=False):
    if validate_ids:
        sample_submission = pd.read_parquet(os.path.join(DATA_DIR, 'sample_submission.parquet'))
        assert np.array_equal(event_ids, sample_submission.event_id.values)
    zenith_clipped = torch.clip(zenith, 0.0, np.pi)
    submission_df = pd.DataFrame(
        {
            'event_id': event_ids,
            'azimuth': azimuth.cpu().numpy(),
            'zenith': zenith_clipped.cpu().numpy(),
        }
    ).set_index('event_id')
    return submission_df

In [None]:
def load_model(model_config):
    model = MultiLabelClassifier(n_features=N_FEATURES, 
                                max_block_size=MAX_SEQUENCE_LENGTH,
                                num_classes=NUM_BINS,
                                zenith_num_classes=ZENITH_NUM_BINS,
                                config=model_config)
    checkpoint_path = os.path.join(MODEL_DIR, model_config.name)
    state_dict = torch.load(checkpoint_path)['state_dict']
    old_keys = list(state_dict.keys())
    
    for key in old_keys:
        if model_config.unwanted_prefix in key:
            new_key = key.split(model_config.unwanted_prefix)[1][1:]
            state_dict[new_key] = state_dict.pop(key)
    model.load_state_dict(state_dict)
    model.eval()
    model.to(device)
    return model

def load_stack_model(model_config):
    stack_model = StackModel(num_classes=NUM_BINS, zenith_num_classes=ZENITH_NUM_BINS, config=model_config)
    checkpoint_path = os.path.join(MODEL_DIR, model_config.name)
    state_dict = torch.load(checkpoint_path)['state_dict']
    old_keys = list(state_dict.keys())
    for key in old_keys:
        if model_config.unwanted_prefix in key:
            new_key = key.split(model_config.unwanted_prefix)[1][1:]
            state_dict[new_key] = state_dict.pop(key)
    stack_model.load_state_dict(state_dict)
    stack_model.eval()
    stack_model.to(device)
    return stack_model

In [None]:



def simple_average(pred_mbs):
    return torch.mean( torch.stack(pred_mbs, dim=2), dim=2)

def argmax_average(pred_mbs, centers):
    pred_classes = []
    for pred in pred_mbs:
        pred_classes.append(centers[pred.argmax(axis=1)].unsqueeze(0))
    angles = torch.mean(torch.cat(pred_classes), dim=0)
    return angles

def simpleavg_argmax(pred_mbs, centers):
    pred = simple_average(pred_mbs)
    angle = centers[pred.argmax(dim=1)]
    return angle

class AzimuthXY:
    def __init__(self):
        self.azx = torch.cos(azimuth_bin_centers) 
        self.azy = torch.sin(azimuth_bin_centers)

    def __call__(self, az_softmax):
        self.azx = self.azx.to(az_softmax.device)
        self.azy = self.azy.to(az_softmax.device)
        return az_softmax * self.azx, az_softmax * self.azy

az_xy = AzimuthXY()


def azimuth_vectorsum(azx, azy):
    azmx, azmy = torch.sum(azx, dim=1), torch.sum(azy, dim=1)
    azn = torch.sqrt(azmx**2 + azmy**2)
    az_pred = ( torch.arccos(azmx / azn) * torch.sign(azmy) ) % (np.pi * 2)
    return az_pred

def az_simpleavg_vectorsum(pred_mbs):
    azsf = simple_average(pred_mbs)
    azx, azy = az_xy(azsf)
    az = azimuth_vectorsum(azx, azy)
    return az


def zn_argmax_average(pred_mbs):
    return argmax_average(pred_mbs, zenith_bin_centers)


def ensemble_predictions(az_pred_mbs, zn_pred_mbs):
    with torch.no_grad():
        az_pred = az_simpleavg_vectorsum(az_pred_mbs)
        zn_pred = zn_argmax_average(zn_pred_mbs)
    return az_pred, zn_pred

def topksum(pred, centers, k=3):
    zn_topk, zn_topk_idk = torch.topk(pred, k=k, dim=1)
    zn_topk_smax = torch.softmax(zn_topk, axis=1)
    angle = torch.sum(zn_topk_smax * centers[zn_topk_idk], dim=1)
    return angle


def discrete_to_angle(az_softmax,
                      zn_pred,
                      azimuth_bin_centers,
                      zenith_bin_centers):
    
    azx, azy = torch.cos(azimuth_bin_centers), torch.sin(azimuth_bin_centers)
    azmx, azmy = az_softmax @ azx, az_softmax @ azy
    azn = torch.sqrt(azmx**2 + azmy**2)
    az_pred_center = ( torch.arccos(azmx / azn) * torch.sign(azmy) ) % (np.pi * 2)
    
    zn_pred_center = zenith_bin_centers[zn_pred.argmax(1)]
    
    return az_pred_center, zn_pred_center


def xyz_to_angle(xyz):
    z_normed = xyz[:, 2] / torch.norm(xyz ,dim=1)
    x_normed = xyz[:, 0] / torch.norm(xyz[:, :2] ,dim=1)
    azimuth = ( torch.arccos(x_normed) * torch.sign(xyz[:, 1]) ) % (np.pi * 2)
    zenith = torch.arccos(z_normed)
    return azimuth, zenith

def avg_az_np(az1, az2):
    return np.arctan2(np.sin(az1) + np.sin(az2), np.cos(az1) + np.cos(az2)) % (np.pi * 2)

def avg_az(az1, az2):
    return torch.atan2(torch.sin(az1) + torch.sin(az2), torch.cos(az1) + torch.cos(az2)) % (np.pi * 2)

def az_diff(az1, az2):
    res = ( az1 - az2 ) % (np.pi * 2)
    res[res > np.pi] = np.pi * 2 - res[res > np.pi]
    return res

In [None]:
def check_score(batch_id, az_pred, zn_pred, text, events_to_keep=None, convert_to_angle=True):
    if not VALIDATE:
        return
    if events_to_keep is None:
        events_to_keep = np.arange(len(az_pred), dtype=np.int32)
    if convert_to_angle:
        az, zn = ensemble_predictions([az_pred], [zn_pred])
    else:
        az, zn = az_pred, zn_pred
    az_gt = metadata.filter(pl.col('batch_id') == batch_id).select('azimuth').to_numpy().squeeze()
    zen_gt = metadata.filter(pl.col('batch_id') == batch_id).select('zenith').to_numpy().squeeze()
    angular_dist = angular_dist_score(az_gt[events_to_keep], zen_gt[events_to_keep], 
                                      az.cpu().numpy(), zn.cpu().numpy())
    print(f"ang_dist {text}", angular_dist, "\n")
    return angular_dist

In [None]:
def model_load_check():
    for model_config in MODEL_CONFIGS:
        model = load_model(model_config)
        stack_model = load_stack_model(model_config.stack_model_config)
        del model, stack_model
model_load_check()

In [None]:
FNAMES = ['time', 'charge', 'auxiliary', 'x', 'y', 'z', 'string']

CHARGE_IDX = FNAMES.index('charge')
TIME_IDX = FNAMES.index('time')
AUX_IDX = FNAMES.index('auxiliary')
X_IDX = FNAMES.index('x')
Y_IDX = FNAMES.index('y')
Z_IDX = FNAMES.index('z')
STRING_IDX = FNAMES.index('string')

TDIFF = 1000
ZDIFF = 60/500

def generate_features(data, pulse_indexes):
    feats = np.zeros((len(pulse_indexes), 17))
    for ii in tqdm(range(len(pulse_indexes))):
        event_data = data[pulse_indexes[ii, 0] : pulse_indexes[ii, 1] + 1]
        event_data[TIME_IDX] -= event_data[:, TIME_IDX].min()
        naux_event_data = event_data[event_data[:, AUX_IDX] == 0]

        event_length = len(event_data)
        num_naux = len(naux_event_data)
        if num_naux == 0:
            continue
        pct_naux = num_naux/event_length

        total_charge = naux_event_data[:, CHARGE_IDX].sum()
        charge_std = naux_event_data[:, CHARGE_IDX].std()

        time_min = naux_event_data[:, TIME_IDX].min()
        time_mean = naux_event_data[:, TIME_IDX].mean()
        time_max = naux_event_data[:, TIME_IDX].max()
        time_std = naux_event_data[:, TIME_IDX].std()
        time_ratio = time_min / time_max

        z_mean = naux_event_data[:, Z_IDX].mean()
        z_min = naux_event_data[:, Z_IDX].min()
        z_max = naux_event_data[:, Z_IDX].max()
        z_std = naux_event_data[:, Z_IDX].std()

        last_tval = naux_event_data[0, TIME_IDX].copy()
        last_xyz = naux_event_data[0, [X_IDX, Y_IDX, Z_IDX]].copy()
        last_string = naux_event_data[0, STRING_IDX].copy()
        for pulse in naux_event_data:
            tdiff = (pulse[TIME_IDX] - last_tval)
            zdiff = (pulse[Z_IDX] - last_xyz[2])
            dist = ((pulse[[X_IDX, Y_IDX, Z_IDX]] - last_xyz)**2).sum()
            speed = dist / (tdiff + 1e-6)
            same_pillar = ( ( tdiff < TDIFF ) &
                            ( abs(zdiff) < ZDIFF) &
                            ( pulse[STRING_IDX] == last_string ) 
            )
            if same_pillar:
                last_tval = pulse[TIME_IDX].copy()
                last_xyz = pulse[[X_IDX, Y_IDX, Z_IDX]].copy()
                last_string = pulse[STRING_IDX].copy()
            else:
                break
        else:
            tdiff = 0
            zdiff = 0

        feats[ii] = np.array([num_naux, pct_naux,
            total_charge, charge_std,
            time_min, time_mean, time_max, time_std,
            z_mean, z_min, z_max, z_std, time_ratio,
            tdiff, zdiff, dist, speed
        ])
    return feats

def generate_batch_features(bfile, events_to_keep):
    batch_id = int(bfile.split('.')[0].split('_')[-1])
    batch = pl.scan_parquet(f'{PARQUETS_DIR}/{bfile}')
    batch = batch.join(geometry, on='sensor_id', how='left')
    batch_meta = metadata.filter(pl.col('batch_id') == batch_id)

    data = batch.select(FNAMES).collect().to_numpy()
    pulse_indexes = batch_meta.select(['first_pulse_index', 'last_pulse_index']).to_numpy()[events_to_keep]
    event_lengths = pulse_indexes[:, 1] - pulse_indexes[:, 0] + 1

    feats = generate_features(data, pulse_indexes)
    return event_lengths, feats

def generate_boosting_features(event_lengths, batch_features, 
                          az_base, zn_base, az_stack, zn_stack,
                          prediction_data):
    
    az_base_prob = prediction_data[0].cpu().numpy()
    zn_base_prob = prediction_data[1].cpu().numpy()
    xyz = prediction_data[2].cpu().numpy()

    stack_az = az_stack.cpu().numpy()
    stack_zn = zn_stack.cpu().numpy()
    az_cls = np.digitize(stack_az, bins=azimuth_bin_centers.cpu().numpy(), right=False) - 1
    zn_cls = np.digitize(stack_zn, bins=zenith_bin_centers.cpu().numpy(), right=False) - 1 + num_zenith_padding
    a0 = np.arange(len(az_cls))
    pred_features = np.stack([
        az_base.cpu().numpy(),
        zn_base.cpu().numpy(),
        az_base_prob.max(1),
        zn_base_prob.max(1),
        az_base_prob[a0, az_cls],
        zn_base_prob[a0, zn_cls],
        az_base_prob[a0, az_cls]/az_base_prob.max(1),
        zn_base_prob[a0, zn_cls]/zn_base_prob.max(1),
        stack_az,
        stack_zn,
        1/np.sqrt(np.linalg.norm(xyz, axis=1)),
        1/np.sqrt(np.linalg.norm(xyz[:, :2], axis=1))
    ]).T
    adists = angular_dist_score(az_base.cpu().numpy(), 
                                zn_base.cpu().numpy(), 
                                az_stack.cpu().numpy(), 
                                zn_stack.cpu().numpy(), 
                                mean=False)[:, None]
    all_feats = np.concatenate(
            [event_lengths[:, None], pred_features, adists, batch_features], axis=1
        )
    return all_feats, pred_features

In [None]:
@torch.no_grad()
def predict_model(dataset, batch_size, model, stack_model):

    az_pred_batch, zn_pred_batch = [], []
    xyz_pred_batch = []
    dataloader = DataLoader(dataset, batch_size=batch_size)
    for x, l in tqdm(dataloader):
        b_max_len = int(l.max()) # Lazy packing, works well for batch size ~1000
        azp, znp, pool = model((x[:, :b_max_len].to(device), l.to(device)))
        xyz = stack_model(torch.cat([pool, azp, znp], dim=1))
        az_pred_batch.append(azp)
        zn_pred_batch.append(znp)
        xyz_pred_batch.append(xyz)
    
    az_pred_batch = torch.cat(az_pred_batch, dim=0)[dataset.reverse_sort_idx]
    zn_pred_batch = torch.cat(zn_pred_batch, dim=0)[dataset.reverse_sort_idx]
    xyz_pred_batch = torch.cat(xyz_pred_batch, dim=0)[dataset.reverse_sort_idx]

    az_pred_batch = torch.softmax(az_pred_batch, axis=1)
    zn_pred_batch = torch.softmax(zn_pred_batch, axis=1)

    return az_pred_batch, zn_pred_batch, xyz_pred_batch

@torch.no_grad()
def predict_on_batch(model, stack_model, dataset, batch_size, mutlisample=0, dataset_fn=None):
    if len(dataset) == 0:
        return None, None, None, None, (None, None, None)
    az_pred_batch, zn_pred_batch, xyz_pred_batch = predict_model(dataset, batch_size, model, stack_model)
    
    
    
    if mutlisample > 0:
        az_pred_batch, zn_pred_batch, xyz_pred_batch = [az_pred_batch], [zn_pred_batch], [xyz_pred_batch]
        for idx in range(mutlisample):
            dataset = dataset_fn(seed=42 + idx + 1)
            azp, znp, xyz = predict_model(dataset, batch_size, model, stack_model)
            az_pred_batch.append(azp)
            zn_pred_batch.append(znp)
            xyz_pred_batch.append(xyz)
        az_base, zn_base = ensemble_predictions(az_pred_batch, zn_pred_batch)
        az_pred_batch = torch.stack(az_pred_batch).mean(0)
        zn_pred_batch = torch.stack(zn_pred_batch).mean(0)
        xyz_pred_batch = torch.stack(xyz_pred_batch).mean(0)
    else:
        az_base, zn_base = discrete_to_angle(az_pred_batch, zn_pred_batch, azimuth_bin_centers, zenith_bin_centers)
    
    az_stack, zn_stack = xyz_to_angle(xyz_pred_batch)

    return az_base, zn_base, az_stack, zn_stack, (az_pred_batch, zn_pred_batch, xyz_pred_batch)

In [None]:
def join_predictions(preds, long_preds, ultralong_preds, ds, long_ds, ultralong_ds):
    n_events = len( set(ds.events).union(long_ds.events).union(ultralong_ds.events) )
    if len(preds.shape) == 1:
        joint_preds = torch.zeros(n_events, dtype=preds.dtype, device=preds.device)
    else:
        joint_preds = torch.zeros(n_events, preds.shape[1], dtype=preds.dtype, device=preds.device)
    joint_preds[torch.tensor(ds.events).type(torch.long)] = preds.clone()
    joint_preds[torch.tensor(long_ds.events).type(torch.long)] = long_preds.clone()
    if ultralong_preds is not None:
        joint_preds[torch.tensor(ultralong_ds.events).type(torch.long)] = ultralong_preds.clone()
    return joint_preds

In [None]:
az_pred, zn_pred = [], []
for bfile in BATCH_LIST:
    batch_id = int(bfile.split('.')[0].split('_')[-1])
    dataset = IceCubeDataset(bfile, seed=42, 
                             min_sequence_lenth=0, 
                             max_sequence_length=MAX_SEQUENCE_LENGTH,
                             padded_sequence_length=MAX_SEQUENCE_LENGTH,
    )
    long_seq_dataset = IceCubeDataset(bfile, seed=42,
                             min_sequence_lenth=MAX_SEQUENCE_LENGTH,
                             max_sequence_length=LONG_SEQ_MAX_LENGTH,
                             padded_sequence_length=LONG_SEQ_MAX_LENGTH,
    )
    ul_ds_create_fn = lambda seed: IceCubeDataset(bfile, seed=seed, 
                             min_sequence_lenth=LONG_SEQ_MAX_LENGTH,
                             max_sequence_length=np.inf,
                             padded_sequence_length=LONG_SEQ_MAX_LENGTH,
    )
    ultralong_ds = ul_ds_create_fn(seed=42)
    events_to_keep = list(set(dataset.events).union(long_seq_dataset.events).union(ultralong_ds.events))
    event_lengths, batch_features = generate_batch_features(bfile, events_to_keep)
    az_pred_mbs, zn_pred_mbs = [], []
    model_pred_features = []
    for model_config in MODEL_CONFIGS:
        model = load_model(model_config)
        stack_model = load_stack_model(model_config.stack_model_config)
        long_seq_model = load_model(model_config.long_seq_model_config)

        str_id = f'batch_{batch_id} model_{model_config.name}'

        az_b, zn_b, az_s, zn_s, pdt = predict_on_batch(model, stack_model, dataset, BATCH_SIZE)
        az_bls, zn_bls, az_sls, zn_sls, pdt_ls = predict_on_batch(long_seq_model, stack_model, 
                                                          long_seq_dataset, LONG_SEQ_BATCH_SIZE)
        az_buls, zn_buls, az_suls, zn_suls, pdt_uls = predict_on_batch(long_seq_model, stack_model, 
                                                          ultralong_ds, LONG_SEQ_BATCH_SIZE,
                                                          mutlisample=ULTTRA_LONG_RESAMPLE, 
                                                          dataset_fn=ul_ds_create_fn)
        az_base = join_predictions(az_b, az_bls, az_buls, dataset, long_seq_dataset, ultralong_ds)
        zn_base = join_predictions(zn_b, zn_bls, zn_buls, dataset, long_seq_dataset, ultralong_ds)
        az_stack = join_predictions(az_s, az_sls, az_suls, dataset, long_seq_dataset, ultralong_ds)
        zn_stack = join_predictions(zn_s, zn_sls, zn_suls, dataset, long_seq_dataset, ultralong_ds)
        prediction_data = []
        for p, p_ls, p_uls in zip(pdt, pdt_ls, pdt_uls):
            prediction_data.append(
                join_predictions(p, p_ls, p_uls, dataset, 
                                 long_seq_dataset, ultralong_ds)
            )

        ## combine classifier and vmf predictions
        boosting_feats, pred_features = generate_boosting_features(event_lengths, batch_features, 
                                                       az_base, zn_base,
                                                       az_stack, zn_stack,
                                                       prediction_data
                                                    )
        model_pred_features.append(pred_features)
        mixer_clf = joblib.load(f'{MODEL_DIR}/{model_config.mixer_model}')
        mixer_preds = mixer_clf.predict(boosting_feats)
        composite_zenith = zn_base.clone()
        v = mixer_preds == 1
        composite_zenith[v] = zn_stack[v].clone()
        
        selector_clf = joblib.load(f'{MODEL_DIR}/{model_config.selector_model}')
        selector_preds = selector_clf.predict(boosting_feats)
        v = selector_preds == 1
        rep_azimuth = az_base.clone()
        rep_zenith = composite_zenith.clone()
        rep_azimuth[v] = ( az_base[v] + np.pi ) % (np.pi * 2) 
        rep_zenith[v] = ( - composite_zenith[v] ) % np.pi 

        az_pred_mb, zn_pred_mb = rep_azimuth, rep_zenith
        
        az_pred_mbs.append(az_pred_mb)
        zn_pred_mbs.append(zn_pred_mb)
    

    if len(MODEL_CONFIGS) == 2:
        merger_model = joblib.load(f'{MODEL_DIR}/{MERGER_MODELS[0]}')
        merger_feats = np.concatenate(
            [event_lengths[:, None], model_pred_features[0], model_pred_features[1], batch_features],
            axis=1
        )
        merger_preds = merger_model.predict(merger_feats)
        az_pred_batch = az_pred_mbs[0].clone()
        zn_pred_batch = zn_pred_mbs[0].clone()
        v = merger_preds == 1
        az_pred_batch[v] = az_pred_mbs[1][v].clone()
        zn_pred_batch[v] = zn_pred_mbs[1][v].clone()
    elif len(MODEL_CONFIGS) == 1:
        az_pred_batch, zn_pred_batch = az_pred_mbs[0], zn_pred_mbs[0]
    elif len(MODEL_CONFIGS) == 3:
        # merge 18l and 15l
        merger01_model = joblib.load(f'{MODEL_DIR}/{MERGER_MODELS[0]}')
        merger01_feats = np.concatenate(
            [event_lengths[:, None], model_pred_features[0], model_pred_features[1], batch_features],
            axis=1
        )
        merger01_preds = merger01_model.predict(merger01_feats)
        merge01_az = az_pred_mbs[0].clone()
        merge01_zn = zn_pred_mbs[0].clone()
        v = merger01_preds == 1
        merge01_az[v] = az_pred_mbs[1][v].clone()
        merge01_zn[v] = zn_pred_mbs[1][v].clone()

        # merge 15l and r6
        merger12_model = joblib.load(f'{MODEL_DIR}/{MERGER_MODELS[1]}')
        merger12_feats = np.concatenate(
            [event_lengths[:, None], model_pred_features[1], model_pred_features[2], batch_features],
            axis=1
        )
        merger12_preds = merger12_model.predict(merger12_feats)
        merge12_az = az_pred_mbs[1].clone()
        merge12_zn = zn_pred_mbs[1].clone()
        v = merger12_preds == 1
        merge12_az[v] = az_pred_mbs[2][v].clone()
        merge12_zn[v] = zn_pred_mbs[2][v].clone()

        # average and vote on results of merge
        vote_az = avg_az(merge01_az, merge12_az)
        az_hdiff = torch.abs(az_diff(merge01_az, merge12_az)) > 0.4
        vote_az[az_hdiff] = merge01_az[az_hdiff].clone()

        vote_zn = ( merge01_zn + merge12_zn ) / 2
        zn_hdiff = torch.abs(merge01_zn - merge12_zn) > 0.4
        vote_zn[zn_hdiff] = merge01_zn[zn_hdiff].clone()

        az_pred_batch, zn_pred_batch  = vote_az, vote_zn
    else:
        raise NotImplementedError

    az_pred.append(az_pred_batch)
    zn_pred.append(zn_pred_batch)

    check_score(batch_id, az_pred_batch, zn_pred_batch,
                events_to_keep=events_to_keep,
                text=f'batch_{batch_id}', convert_to_angle=False)

In [None]:
azimuth = torch.cat(az_pred, axis=0)
zenith = torch.cat(zn_pred, axis=0)

event_ids = metadata.select('event_id').to_numpy().squeeze()
submission_df = prepare_submission(event_ids, azimuth, zenith, validate_ids=(not CHECK_PREDICTION))
submission_df.to_csv('submission.csv')
print('Saved submission')

In [None]:
if CHECK_PREDICTION:
    submission = pd.read_csv('submission.csv')
    az_pred, zen_pred = submission.azimuth.values, submission.zenith.values
    az_gt = metadata['azimuth'].to_numpy()
    zen_gt = metadata['zenith'].to_numpy()
    angular_dist = angular_dist_score(az_gt, zen_gt, az_pred, zen_pred)
    print("Angular Distance Score", angular_dist)