# Install

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!nvidia-smi

In [None]:
#@title Load data

!rm -r content
!rm -r cache
!rm -r /content/icecube-neutrinos-in-deep-ice


DATASET_PATH = '/content/drive/MyDrive/work/projects/icecube/datasets/icecube-neutrinos-in-deep-ice'

!mkdir icecube-neutrinos-in-deep-ice/
!mkdir icecube-neutrinos-in-deep-ice/train
!mkdir icecube-neutrinos-in-deep-ice/train_meta

!cp {DATASET_PATH}/sensor_geometry.csv icecube-neutrinos-in-deep-ice/

!unzip {DATASET_PATH}/train_meta_splitted.zip 


In [None]:
%%time

#@title Install modules

!rm  -r software
!mkdir software/
!mkdir software/graphnet

!unzip /content/drive/MyDrive/work/projects/icecube/libs/graphnet-main-20230216.zip -d software/graphnet

!pip install -q wandb
!pip install -q polars
!pip install -q torchinfo
!pip install -q colorlog>=6.6
!pip install -q ruamel.yaml
!pip install -q torch==1.11+cu115 --find-links https://download.pytorch.org/whl/torch_stable.html
!pip install -q torch-cluster==1.6.0 -f https://data.pyg.org/whl/torch-1.11.0+cu115.html
!pip install -q torch-scatter==2.0.9 -f https://data.pyg.org/whl/torch-1.11.0+cu115.html
!pip install -q torch-sparse==0.6.13 -f https://data.pyg.org/whl/torch-1.11.0+cu115.html
!pip install -q pytorch-lightning>=1.6
!pip install -q "awkward>=1.8,<2.0"
!pip install -q torch_geometric==2.0.4
!pip install -q dill
#!pip install -q cudf-cu11==22.12.0 --extra-index-url=https://pypi.nvidia.com # last vrsion depends on new protobuf with conflicts with colab


# Config

In [None]:
class CFG:
  MODE          = 'train' # test
  EXP_ID        = 73
  EXP_COMMENT   = "classific_24_bins" # l5_post_mlp336_256_256 l5_post_mlp336_256_tr_l4_h4_f512
  RESUME        = False
  RESUME_MODEL  = '/content/drive/MyDrive/work/projects/icecube/models/ice_gnn_v4_exp_73_classific_24_bins/model_exp_73_classific_24_bins_last.pt'

  def to_dict(): return { v:m for v, m in vars(CFG).items() if not (v.startswith('_')  or callable(m))} 

CFG.REMOTE_DATASET_PATH    = '/content/drive/MyDrive/work/projects/icecube/datasets/icecube-neutrinos-in-deep-ice'
CFG.DATASET_PATH           = 'icecube-neutrinos-in-deep-ice'
CFG.INPUT_DATA_PATH        = f'{CFG.REMOTE_DATASET_PATH}/{CFG.MODE}'
CFG.GEOMETRY_TABLE         = f'{CFG.DATASET_PATH}/sensor_geometry.csv'
CFG.SCATTER_ABSORT_TABLE   = f'{CFG.REMOTE_DATASET_PATH}/scattering_and_absorption.csv'
CFG.DOMS_EFF_TABLE         = f'{CFG.REMOTE_DATASET_PATH}/doms_eff.csv'
CFG.META_PATH              = f'/content/content/icecube-neutrinos-in-deep-ice/{CFG.MODE}_meta/'
CFG.BATCH_RANGE            = (1,658)
CFG.VAL_BATCH              = 659
CFG.VAL_EVENTS             = 100000
CFG.MAX_PULSES_PER_EVENT   = 1000
CFG.CACHE_DIR              = 'cache'
CFG.WANDBAPIKEY            = '****'
CFG.USE_WANDB              = False
CFG.LOADER                 = 'pl'  # pl,pd,cudf
CFG.SAMPLE_FILTER          = False
CFG.ANGLES                 = 'az,ze'  # az,ze/az/ze
CFG.ZENITH_RANGE           = None # (0.0, math.pi/2.0) # (math.pi/2.0, math.pi) # None
CFG.FROZEN                 = False
CFG.UNFOROZEN_LAYERS       = []

CFG.CHECKPOINTS_PATH = f"/content/drive/MyDrive/work/projects/icecube/models/ice_gnn_v4_exp_{CFG.EXP_ID}_{CFG.EXP_COMMENT}"

# Lib

In [None]:
#@title helpers

# Append to PATH
import sys
import gc
sys.path.append('software/graphnet/src')

import random
import pyarrow.parquet as pq
import sqlite3
import pandas as pd
import sqlalchemy
from tqdm import tqdm
import os
from typing import Any, Dict, List, Optional, Union
import numpy as np
import math
import torch
from torch.optim.adam import Adam
import pandas as pd
import matplotlib.pyplot as plt
from inspect import getfullargspec
import shutil
from os import path
from torch.utils.data import DataLoader, Dataset
from torch_geometric.data import Data
from torch_geometric.data import Batch, Data
from torch.utils.data import Subset
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from scipy.interpolate import interp1d                
import polars as pl

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

sensors_data = None

if CFG.LOADER == 'cudf':
    import cudf

if CFG.MODE=='train':
    from torchinfo import summary

if CFG.USE_WANDB:
    import wandb

    def wandb_init():
        wandb.login(key=CFG.WANDBAPIKEY)
        wandb.init(
            # set the wandb project where this run will be logged
            project="icecube",
            name=f"{CFG.EXP_ID}_{CFG.EXP_COMMENT}",
            config=CFG.TRAIN_CFG,
            #resume="must"
        )


def get_sensors(sensor_features):
    """ Get sensor positions """            
    df = pd.read_csv(CFG.GEOMETRY_TABLE)      
    df['line_id'] = df.sensor_id // 60 + 1                 # string id
    df['core']    = (df.line_id > 78).astype(np.float32)   # sensor from DeepCore
    df.x = df.x.astype(np.float32)              # distances in kilometers
    df.y = df.y.astype(np.float32)
    df.z = df.z.astype(np.float32)    
    
    phys = pd.read_csv(CFG.SCATTER_ABSORT_TABLE)
    phys.z = (phys.z).astype(np.float32)
    phys.a = (phys.a * 1e-3).astype(np.float32)
    phys.b = (phys.b * 1e-2).astype(np.float32)
    interp_scatter = interp1d(phys.z, phys.a)
    interp_absort  = interp1d(phys.z, phys.b)
    df['sc']  = interp_scatter(df.z)
    df['abs'] = interp_absort(df.z)
    df['r']   = np.sqrt(df.x**2 + df.y**2)*1e-3

    eff = pd.read_csv(CFG.DOMS_EFF_TABLE)
    df['eff'] = eff.astype(np.float32)

    sensors_tensor = torch.tensor(df[sensor_features].values, dtype=torch.float32, device=device)

    sensors_data = torch.nn.Embedding(5160, len(sensor_features), device=device, _weight=sensors_tensor).requires_grad_(False)

    return sensors_data

def split_meta_data():
    if not path.isdir(CFG.META_PATH):
        os.makedirs(CFG.META_PATH, exist_ok=True)
    meta_data_iter = pq.ParquetFile(CFG.META_TABLE).iter_batches(batch_size = 200_000)
    batch_ids = []
    for meta_data_batch in tqdm(meta_data_iter):
        meta_data_batch = meta_data_batch.to_pandas()
        batch_id = pd.unique(meta_data_batch['batch_id'])[0]
        if CFG.MODE == 'test':
            if batch_id < CFG.BATCH_RANGE[0]:
                continue            
            if batch_id == CFG.BATCH_RANGE[1]:
                break   
        meta_data_batch.to_parquet(f'{CFG.META_PATH}/batch_{batch_id}_meta.parquet')
        batch_ids.append(batch_id)        
    batch_range = (min(batch_ids),max(batch_ids)+1)
    return batch_range


def load_batch(batch_id, max_events, doms_agg=False, bin_num=None):
    reindex = False

    if CFG.MODE == 'submit':
        batch_meta_df = (
                      pl.read_parquet(f'{CFG.META_PATH}/batch_{batch_id}_meta.parquet')
                    ).select(['event_id','first_pulse_index','last_pulse_index']
                    ).with_columns([
                      pl.lit(0.0).alias('azimuth').cast(pl.Float32), 
                      pl.lit(0.0).alias('zenith').cast(pl.Float32)])
    else:
        batch_meta_df = (
                      pl.read_parquet(f'{CFG.META_PATH}/batch_{batch_id}_meta.parquet')
                    ).select(['event_id','azimuth','zenith','first_pulse_index','last_pulse_index'])

    if CFG.ZENITH_RANGE:
        batch_meta_df = batch_meta_df.filter((pl.col("zenith") > CFG.ZENITH_RANGE[0]) & (pl.col("zenith") <= CFG.ZENITH_RANGE[1]))

      
    batch_df = (
                      pl.read_parquet(f'{CFG.INPUT_DATA_PATH}/batch_{batch_id}.parquet')
                    ).select(['event_id','time','charge','auxiliary','sensor_id']
                    ).with_columns([
                      pl.col("time").cast(pl.Float32),
                      pl.col("charge").cast(pl.Float32),
                      #pl.lit(0.0).alias('auxiliary').cast(pl.Float32), 
                      pl.col("auxiliary").cast(pl.Float32),                      
                      pl.col("sensor_id").cast(pl.Float32)])
                    
    if max_events:
        batch_meta_df = batch_meta_df[:max_events]
        batch_df = batch_df[:batch_meta_df[-1]['last_pulse_index'][0]+1]

    # CFG.FILTER_BY_ENERGY = None
    # if CFG.FILTER_BY_ENERGY:
    #     batch_df = batch_df.filter(pl.col("charge")>CFG.FILTER_BY_ENERGY)
    #     reindex = True

    reindex = True # for GRU must be sorted
    if doms_agg:
        batch_df = batch_df.groupby(['event_id', 'sensor_id']).agg([
                      pl.col("auxiliary").mean(),
                      pl.col("charge").sum(),
                      pl.col("time").min()
                  ])
    else:
        batch_indexes = torch.tensor(batch_meta_df.select(['event_id','first_pulse_index','last_pulse_index']).to_numpy(), device=device, dtype=torch.long)

    if reindex:
        batch_df = batch_df.sort(['event_id', 'time']).with_row_count()
        batch_indexes_df = batch_df.groupby(['event_id']).agg([
                      pl.col("row_nr").min().alias('first_pulse_index'),
                      pl.col("row_nr").max().alias('last_pulse_index'),
                  ]).sort(['event_id'])
        batch_indexes = torch.tensor(batch_indexes_df.to_numpy(), device=device, dtype=torch.long)

    if CFG.MAX_PULSES_PER_EVENT:
        #clip indexes
        batch_indexes[:,2] = batch_indexes[:,2] - batch_indexes[:,1] + 1                         # pulses len
        batch_indexes[:,2] = torch.clip(batch_indexes[:,2], min=0, max=CFG.MAX_PULSES_PER_EVENT) # clip to max
        batch_indexes[:,2] = batch_indexes[:,1] + batch_indexes[:,2]                             # new indexes

    batch_angles    = torch.tensor(batch_meta_df.select(['azimuth','zenith']).to_numpy(), device=device, dtype=torch.float32)

    if CFG.ANGLES == 'az,ze':
      batch_direcions = angles_to_vectors(batch_angles)
    if CFG.ANGLES == 'az':
      batch_direcions = angles_to_vectors_2D(batch_angles[:,0])
    if CFG.ANGLES == 'ze':
      batch_direcions = angles_to_vectors_2D(batch_angles[:,1])

    batch_classes = None

    if bin_num:
        azimuth_edges, zenith_edges = build_az_ze_edges(bin_num)
        batch_classes = angles_to_code(batch_angles, bin_num, azimuth_edges, zenith_edges)

    batch_features  = torch.tensor(batch_df.select(['time','charge','auxiliary','sensor_id']).to_numpy(), device=device, dtype=torch.float32)
    
    return batch_indexes, batch_direcions, batch_classes, batch_features

def build_dataloader(batch_id, shuffle, config, max_events=None, indexes=None):
  print(f'build loader for batch {batch_id} with limit: {max_events}')
  dataset = BatchDataset(batch_id, max_events, config['doms_agg'], config['bin_num'])
  if not indexes is None:
    dataset = Subset(dataset, indexes)
  dataloader = DataLoader(dataset, batch_size=config['batch_size'], num_workers=0, shuffle=shuffle, collate_fn=collate_fn)
  return dataloader

def collate_fn(graphs: List[Data]) -> Batch:
    batch = Batch.from_data_list(graphs)
    # map x,y,z
    batch.sensor_id = batch.sensor_id.long()
    batch.x = torch.cat([sensors_data(batch.sensor_id), batch.x], axis=1)
    return batch

class BatchDataset(Dataset):
 
  def __init__(self, batch_id, max_events=None, doms_agg=False, bin_num=None):
    self.batch_indexes, self.batch_directions, self.batch_classes, self.batch_features = load_batch(batch_id, max_events, doms_agg, bin_num)
    self.cache = {}
    self.bin_num = bin_num
 
  def __len__(self):
    return len(self.batch_indexes)
   
  def __getitem__(self,idx):
    cached = self.cache.get(idx)
    if cached:
      return cached
    event_id, event_first_pulse, event_last_pulse = self.batch_indexes[idx] 
    event_direction = self.batch_directions[idx] 
    event_direction = event_direction.unsqueeze(0)
    
    x = self.batch_features[event_first_pulse:event_last_pulse,:-1]
    sensor_id = self.batch_features[event_first_pulse:event_last_pulse,-1]

    graph = Data(x=x, edge_index=None)
    graph.n_pulses  = event_last_pulse - event_first_pulse + 1
    graph.event_ids = event_id
    graph.sensor_id = sensor_id
    graph.direction = event_direction
    if self.bin_num:
        graph.class_id = self.batch_classes[idx].unsqueeze(0)
    self.cache[idx] = graph
    return graph

def angles_to_vectors(angles):
    vectors = torch.empty((angles.shape[0], 3), device=angles.device, dtype=angles.dtype)
    zen = angles[:,1]
    az  = angles[:,0]
    sz = torch.sin(zen)
    vectors[:,0] = torch.cos(az)*sz # x
    vectors[:,1] = torch.sin(az)*sz # y
    vectors[:,2] = torch.cos(zen)   # z
    return vectors

def vectors_to_angles(vectors):
    vectors = vectors.clone()
    v_squared = vectors.pow(2.0)
        
    ## Shortcut optimization for azimuth: calculate 2d unit vectors for x and y independent of z
    xy_sq = torch.sum(v_squared[:, 0:2], axis=1)
    xy_d = torch.sqrt(xy_sq)[:, None]
        
    vectors[:, 0:2] = torch.where(xy_d == 0, xy_d, vectors[:, 0:2]/xy_d)

    ## For z, use full 3d unit vector
    d = torch.sqrt(xy_sq + v_squared[:, 2])
    vectors[:, 2] = torch.where(d == 0, d, vectors[:, 2]/d)

    ## As mentioned by others, clip solely to avoid floating point errors, the unit vectors should already be within this range.
    vectors =  torch.clip(vectors, -1, 1)

    azimuth = torch.arccos(vectors[:, 0])
    ## if y < 0, convert from quadrants 1 and 2 to quadrants 3 and 4
    azimuth = torch.where(vectors[:, 1] >= 0, azimuth, 2*torch.pi - azimuth)
    azimuth = torch.where(torch.isfinite(azimuth), azimuth, torch.tensor(0.0, dtype=azimuth.dtype, device=azimuth.device))

    zenith = torch.arccos(vectors[:, 2])
        
    ## IMPORTANT: zenith angles are not evenly distributed, so set the error case to pi/2!
    ## (even though x, y, z might be. It would be a fun exercise to check if random values
    ##  for x, y, z converted to zenith angles would match the observed distribution of zenith angles in the train labels)
    zenith = torch.where(torch.isfinite(zenith), zenith, torch.tensor(math.pi/2, dtype=zenith.dtype, device=azimuth.device))

    angles = torch.stack([azimuth, zenith], axis=1)
    return angles


def angular_dist_score(all_true, all_pred):
    az_true  = all_true[:,0]
    zen_true = all_true[:,1]
    az_pred  = all_pred[:,0]
    zen_pred = all_pred[:,1]
    sa1 = torch.sin(az_true)
    ca1 = torch.cos(az_true)
    sz1 = torch.sin(zen_true)
    cz1 = torch.cos(zen_true)
    sa2 = torch.sin(az_pred)
    ca2 = torch.cos(az_pred)
    sz2 = torch.sin(zen_pred)
    cz2 = torch.cos(zen_pred)
    scalar_prod = sz1*sz2*(ca1*ca2 + sa1*sa2) + (cz1*cz2)
    scalar_prod = torch.clip(scalar_prod, -1, 1) 
    distanses = torch.abs(torch.arccos(scalar_prod))   
    return torch.mean(distanses), distanses

def angle_errors(n1, n2, eps=1e-8):
    """ Calculate angles between two vectors:: n1,n2: (B,3) return: (B,) """
    n1 = n1 / (torch.linalg.vector_norm(n1, dim=1, keepdims=True) + eps)
    n2 = n2 / (torch.linalg.vector_norm(n2, dim=1, keepdims=True) + eps)
    
    cos = (n1*n2).sum(axis=1)                     # angles between vectors
    angle_err = torch.arccos( cos.clip(-1,1) )
        
    r1   =  n1[:,0]*n1[:,0] + n1[:,1]*n1[:,1]    # angles between vectors in (x,y)    
    r2   =  n2[:,0]*n2[:,0] + n2[:,1]*n2[:,1]
    cosX = (n1[:,0]*n2[:,0] + n1[:,1]*n2[:,1]) / (torch.sqrt(r1*r2) + eps)    
    azimuth_err = torch.arccos( cosX.clip(-1,1) )
                                
    zerros = r1 < eps                            # azimuth angle not defined

    azimuth_err[zerros] = torch.rand((len(n1[zerros]),), dtype=n1.dtype, device=n1.device)*np.pi
    
    zenith1  = torch.arccos( n1[:,2].clip(-1,1) )
    zenith2  = torch.arccos( n2[:,2].clip(-1,1) )
    zenith_err = torch.abs(zenith2 - zenith1)
        
    return angle_err.mean(), azimuth_err.mean(), zenith_err.mean()

def angles_to_vectors_2D(angles):
    vectors = torch.empty((angles.shape[0], 2), device=angles.device, dtype=angles.dtype)
    vectors[:,0] = torch.cos(angles) # x
    vectors[:,1] = torch.sin(angles) # y
    return vectors

def vectors_to_angles_2D(vectors, azimuth=False):
    vectors = vectors.clone()
    v_squared = vectors.pow(2.0)
        
    xy_sq = torch.sum(v_squared, axis=1)
    xy_d = torch.sqrt(xy_sq)[:, None]
        
    vectors = torch.where(xy_d == 0, xy_d, vectors/xy_d)

    vectors =  torch.clip(vectors, -1, 1)

    angles = torch.arccos(vectors[:, 0])
    angles = torch.where(torch.isfinite(angles), angles, torch.tensor(0.0, dtype=angles.dtype, device=angles.device))
    if azimuth:
      angles = torch.where(vectors[:, 1] >= 0, angles, 2*torch.pi - angles) # [0,2pi] range
    return angles

def angular_dist_score_2D(ang_true, ang_pred):
    sa1 = torch.sin(ang_true)
    ca1 = torch.cos(ang_true)
    sa2 = torch.sin(ang_pred)
    ca2 = torch.cos(ang_pred)
    scalar_prod = ca1*ca2 + sa1*sa2
    scalar_prod = torch.clip(scalar_prod, -1, 1) 
    distanses = torch.abs(torch.arccos(scalar_prod))   
    return torch.mean(distanses), distanses

def get_rot_matrix(phi):
    s = torch.sin(phi)
    c = torch.cos(phi)
    rot = torch.stack([torch.stack([c, -s]),
                       torch.stack([s, c])])
    rot = rot.squeeze(-1)
    return rot

def build_az_ze_edges(bin_num):
    # Create Azimuth Edges
    azimuth_edges = torch.tensor(np.linspace(0, 2 * np.pi, bin_num + 1), dtype=torch.float32).to(device)
    # Create Zenith Edges
    zenith_edges = []
    zenith_edges.append(0)
    for bin_idx in range(1, bin_num):
        zenith_edges.append(np.arccos(np.cos(zenith_edges[-1]) - 2 / (bin_num)))
    zenith_edges.append(np.pi)
    zenith_edges = torch.tensor(np.array(zenith_edges), dtype=torch.float32).to(device)
    return azimuth_edges, zenith_edges

def build_angle_bin_vector(azimuth_edges, zenith_edges, bin_num):
    angle_bin_zenith0 = np.tile(zenith_edges[:-1], bin_num)
    angle_bin_zenith1 = np.tile(zenith_edges[1:], bin_num)
    angle_bin_azimuth0 = np.repeat(azimuth_edges[:-1], bin_num)
    angle_bin_azimuth1 = np.repeat(azimuth_edges[1:], bin_num)

    angle_bin_area = (angle_bin_azimuth1 - angle_bin_azimuth0) * (np.cos(angle_bin_zenith0) - np.cos(angle_bin_zenith1))
    angle_bin_vector_sum_x = (np.sin(angle_bin_azimuth1) - np.sin(angle_bin_azimuth0)) * ((angle_bin_zenith1 - angle_bin_zenith0) / 2 - (np.sin(2 * angle_bin_zenith1) - np.sin(2 * angle_bin_zenith0)) / 4)
    angle_bin_vector_sum_y = (np.cos(angle_bin_azimuth0) - np.cos(angle_bin_azimuth1)) * ((angle_bin_zenith1 - angle_bin_zenith0) / 2 - (np.sin(2 * angle_bin_zenith1) - np.sin(2 * angle_bin_zenith0)) / 4)
    angle_bin_vector_sum_z = (angle_bin_azimuth1 - angle_bin_azimuth0) * ((np.cos(2 * angle_bin_zenith0) - np.cos(2 * angle_bin_zenith1)) / 4)

    angle_bin_vector_mean_x = angle_bin_vector_sum_x / angle_bin_area
    angle_bin_vector_mean_y = angle_bin_vector_sum_y / angle_bin_area
    angle_bin_vector_mean_z = angle_bin_vector_sum_z / angle_bin_area

    angle_bin_vector = np.zeros((1, bin_num * bin_num, 3))
    angle_bin_vector[:, :, 0] = angle_bin_vector_mean_x
    angle_bin_vector[:, :, 1] = angle_bin_vector_mean_y
    angle_bin_vector[:, :, 2] = angle_bin_vector_mean_z

    angle_bin_vector = torch.tensor(angle_bin_vector, dtype=torch.float32).to(device)
    return angle_bin_vector

def angles_to_code(angles,bin_num,azimuth_edges,zenith_edges):
    azimuth_code = (angles[:, 0] > azimuth_edges[1:].reshape((-1, 1))).sum(axis=0)
    zenith_code = (angles[:, 1] > zenith_edges[1:].reshape((-1, 1))).sum(axis=0)
    angle_code = bin_num * azimuth_code + zenith_code
    return angle_code

def code_to_vector(pred, bin_num, angle_bin_vector, max=False, epsilon=1e-8):
    # convert prediction to vector
    if max:
        pred_vector = angle_bin_vector[0,pred.argmax(axis=1)]
    else:
        pred_vector = (pred.reshape((-1, bin_num * bin_num, 1)) * angle_bin_vector).sum(axis=1)
            
    # normalize
    pred_vector_norm = torch.sqrt((pred_vector**2).sum(axis=1))
    mask = pred_vector_norm < epsilon
    pred_vector_norm[mask] = 1
    
    # assign <1, 0, 0> to very small vectors (badly predicted)
    pred_vector /= pred_vector_norm.reshape((-1, 1))

    pred_vector[mask] = torch.tensor([1., 0., 0.], dtype=pred.dtype, device=pred.device)

    return pred_vector

def code_to_angle(pred, bin_num, angle_bin_vector, max=False, epsilon=1e-8):
    # convert prediction to vector
    pred_vector = code_to_vector(pred, bin_num, angle_bin_vector, max, epsilon)

    # convert to angle
    azimuth = torch.arctan2(pred_vector[:, 1], pred_vector[:, 0])
    azimuth[azimuth < 0] += 2 * torch.pi
    zenith = torch.arccos(pred_vector[:, 2])

    angles = torch.cat([azimuth.view(-1,1), zenith.view(-1,1)], axis=1)
    return angles

class Lion(Optimizer):
  r"""Implements Lion algorithm."""

  def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
    """Initialize the hyperparameters.
    Args:
      params (iterable): iterable of parameters to optimize or dicts defining
        parameter groups
      lr (float, optional): learning rate (default: 1e-4)
      betas (Tuple[float, float], optional): coefficients used for computing
        running averages of gradient and its square (default: (0.9, 0.99))
      weight_decay (float, optional): weight decay coefficient (default: 0)
    """

    if not 0.0 <= lr:
      raise ValueError('Invalid learning rate: {}'.format(lr))
    if not 0.0 <= betas[0] < 1.0:
      raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
    if not 0.0 <= betas[1] < 1.0:
      raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
    defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
    super().__init__(params, defaults)

  @torch.no_grad()
  def step(self, closure=None):
    """Performs a single optimization step.
    Args:
      closure (callable, optional): A closure that reevaluates the model
        and returns the loss.
    Returns:
      the loss.
    """
    loss = None
    if closure is not None:
      with torch.enable_grad():
        loss = closure()

    for group in self.param_groups:
      for p in group['params']:
        if p.grad is None:
          continue

        # Perform stepweight decay
        p.data.mul_(1 - group['lr'] * group['weight_decay'])

        grad = p.grad
        state = self.state[p]
        # State initialization
        if len(state) == 0:
          # Exponential moving average of gradient values
          state['exp_avg'] = torch.zeros_like(p)

        exp_avg = state['exp_avg']
        beta1, beta2 = group['betas']

        # Weight update
        update = exp_avg * beta1 + grad * (1 - beta1)
        p.add_(torch.sign(update), alpha=-group['lr'])
        # Decay the momentum running average coefficient
        exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)

    return loss

class PiecewiseLinearLR(_LRScheduler):
    """Interpolate learning rate linearly between milestones."""

    def __init__(
        self,
        optimizer: Optimizer,
        milestones: List[int],
        factors: List[float],
        last_epoch: int = -1,
        verbose: bool = False,
    ):
        """Construct `PiecewiseLinearLR`.

        For each milestone, denoting a specified number of steps, a factor
        multiplying the base learning rate is specified. For steps between two
        milestones, the learning rate is interpolated linearly between the two
        closest milestones. For steps before the first milestone, the factor
        for the first milestone is used; vice versa for steps after the last
        milestone.

        Args:
            optimizer: Wrapped optimizer.
            milestones: List of step indices. Must be increasing.
            factors: List of multiplicative factors. Must be same length as
                `milestones`.
            last_epoch: The index of the last epoch.
            verbose: If ``True``, prints a message to stdout for each update.
        """
        # Check(s)
        if milestones != sorted(milestones):
            raise ValueError("Milestones must be increasing")
        if len(milestones) != len(factors):
            raise ValueError(
                "Only multiplicative factor must be specified for each milestone."
            )

        self.milestones = milestones
        self.factors = factors
        super().__init__(optimizer, last_epoch, verbose)

    def _get_factor(self) -> np.ndarray:
        # Linearly interpolate multiplicative factor between milestones.
        return np.interp(self.last_epoch, self.milestones, self.factors)

    def get_lr(self) -> List[float]:
        """Get effective learning rate(s) for each optimizer."""
        if not self._get_lr_called_within_step:
            warnings.warn(
                "To get the last learning rate computed by the scheduler, "
                "please use `get_last_lr()`.",
                UserWarning,
            )
        return [base_lr * self._get_factor() for base_lr in self.base_lrs]


In [None]:
#@title Conv

"""Class(es) implementing layers to be used in `graphnet` models."""

from typing import Any, Callable, Optional, Sequence, Union

from torch.functional import Tensor
from torch_geometric.nn import EdgeConv, TransformerConv
from torch_geometric.nn.pool import knn_graph
from torch_geometric.typing import Adj
from pytorch_lightning import LightningModule

class DynEdgeConv(EdgeConv, LightningModule):
    """Dynamical edge convolution layer."""

    def __init__(
        self,
        nn: Callable,
        aggr: str = "max",
        nb_neighbors: int = 8,
        features_subset: Optional[Union[Sequence[int], slice]] = None,
        last_layer = False,
        **kwargs: Any,
    ):
        """Construct `DynEdgeConv`.

        Args:
            nn: The MLP/torch.Module to be used within the `EdgeConv`.
            aggr: Aggregation method to be used with `EdgeConv`.
            nb_neighbors: Number of neighbours to be clustered after the
                `EdgeConv` operation.
            features_subset: Subset of features in `Data.x` that should be used
                when dynamically performing the new graph clustering after the
                `EdgeConv` operation. Defaults to all features.
            **kwargs: Additional features to be passed to `EdgeConv`.
        """
        # Check(s)
        if features_subset is None:
            features_subset = slice(None)  # Use all features
        assert isinstance(features_subset, (list, slice))

        # Base class constructor
        super().__init__(nn=nn, aggr=aggr, **kwargs)

        # Additional member variables
        self.nb_neighbors = nb_neighbors
        self.features_subset = features_subset
        self.last_layer = last_layer

    def forward(
        self, x: Tensor, edge_index: Adj, batch: Optional[Tensor] = None
    ) -> Tensor:
        """Forward pass."""
        # Standard EdgeConv forward pass
        x = super().forward(x, edge_index)

        if not self.last_layer:   # unnesessary last layer
            # Recompute adjacency
            edge_index = knn_graph(
                x=x[:,self.features_subset],
                k=self.nb_neighbors,
                batch=batch,
            ).to(self.device)

        return x, edge_index


class DynTransformerConv(TransformerConv, LightningModule):
    """Dynamical edge convolution layer."""

    def __init__(
        self, 
        in_channels, 
        out_channels: int, 
        nb_neighbors: int = 8,
        features_subset: Optional[Union[Sequence[int], slice]] = None,
        last_layer = False,
        heads: int = 1, 
        concat: bool = True, 
        beta: bool = False, 
        dropout: float = 0.0, 
        edge_dim: Optional[int] = None, 
        bias: bool = True, 
        root_weight: bool = True, **kwargs,
    ):

        # Check(s)
        if features_subset is None:
            features_subset = slice(None)  # Use all features
        assert isinstance(features_subset, (list, slice))

        # Base class constructor
        super().__init__(in_channels=in_channels, out_channels=out_channels, 
                         heads=heads, concat=concat, beta=beta, dropout=dropout, edge_dim=edge_dim, bias=bias, 
                         root_weight=root_weight, **kwargs)

        # Additional member variables
        self.nb_neighbors = nb_neighbors
        self.features_subset = features_subset
        self.last_layer = last_layer

    def forward(
        self, x: Tensor, edge_index: Adj, batch: Optional[Tensor] = None
    ) -> Tensor:
        """Forward pass."""
        # Standard EdgeConv forward pass
        x = super().forward(x, edge_index)

        if not self.last_layer:   # unnesessary last layer
            # Recompute adjacency
            edge_index = knn_graph(
                x=x[:,self.features_subset],
                k=self.nb_neighbors,
                batch=batch,
            ).to(self.device)

        return x, edge_index


In [None]:
#@title DynEdge

"""Implementation of the DynEdge GNN model architecture."""
from typing import List, Optional, Sequence, Tuple, Union

import torch
from torch import Tensor, LongTensor
from torch_geometric.data import Data
from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum

# from graphnet.models.components.layers import DynEdgeConv  # QuData: owerwrite
from graphnet.utilities.config import save_model_config
from graphnet.models.gnn.gnn import GNN
from torch_geometric.utils.homophily import homophily
from torch_geometric.nn.pool import TopKPooling
from torch_geometric.utils import to_dense_batch

GLOBAL_POOLINGS = {
    "min": scatter_min,
    "max": scatter_max,
    "sum": scatter_sum,
    "mean": scatter_mean,
}

def calculate_xyzt_homophily(
    x: Tensor, edge_index: LongTensor, batch: Batch
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
    """Calculate xyzt-homophily from a batch of graphs.

    Homophily is a graph scalar quantity that measures the likeness of
    variables in nodes. Notice that this calculator assumes a special order of
    input features in x.

    Returns:
        Tuple, each element with shape [batch_size,1].
    """
    hx = homophily(edge_index, x[:, 0], batch).reshape(-1, 1)
    hy = homophily(edge_index, x[:, 1], batch).reshape(-1, 1)
    hz = homophily(edge_index, x[:, 2], batch).reshape(-1, 1)
    ht = homophily(edge_index, x[:, -3], batch).reshape(-1, 1) # for dynamic reshape
    return hx, hy, hz, ht


class DynEdge(GNN):
    """DynEdge (dynamical edge convolutional) model."""

    @save_model_config
    def __init__(
        self,
        nb_inputs: int,
        *,
        nb_neighbours: int = 8,
        features_subset: Optional[Union[List[int], slice]] = None,
        dynedge_layers = None,
        post_processing_layer_sizes: Optional[List[int]] = None,
        post_processing_transformer: Optional[Dict] = None,
        readout_layer_sizes: Optional[List[int]] = None,
        global_pooling: Optional[Dict] = None,
        add_global_variables_after_pooling: bool = False,
        sensor_embedding = False,
        local_pooling = None 
    ):
        """Construct `DynEdge`.

        Args:
            nb_inputs: Number of input features on each node.
            nb_neighbours: Number of neighbours to used in the k-nearest
                neighbour clustering which is performed after each (dynamical)
                edge convolution.
            features_subset: The subset of latent features on each node that
                are used as metric dimensions when performing the k-nearest
                neighbours clustering. Defaults to [0,1,2].
            dynedge_layer_sizes: The layer sizes, or latent feature dimenions,
                used in the `DynEdgeConv` layer. Each entry in
                `dynedge_layer_sizes` corresponds to a single `DynEdgeConv`
                layer; the integers in the corresponding tuple corresponds to
                the layer sizes in the multi-layer perceptron (MLP) that is
                applied within each `DynEdgeConv` layer. That is, a list of
                size-two tuples means that all `DynEdgeConv` layers contain a
                two-layer MLP.
                Defaults to [(128, 256), (336, 256), (336, 256), (336, 256)].
            post_processing_layer_sizes: Hidden layer sizes in the MLP
                following the skip-concatenation of the outputs of each
                `DynEdgeConv` layer. Defaults to [336, 256].
            readout_layer_sizes: Hidden layer sizes in the MLP following the
                post-processing _and_ optional global pooling. As this is the
                last layer(s) in the model, the last layer in the read-out
                yields the output of the `DynEdge` model. Defaults to [128,].
            global_pooling_schemes: The list global pooling schemes to use.
                Options are: "min", "max", "mean", and "sum".
            add_global_variables_after_pooling: Whether to add global variables
                after global pooling. The alternative is to  added (distribute)
                them to the individual nodes before any convolutional
                operations.
        """
        # Latent feature subset for computing nearest neighbours in DynEdge.
        if features_subset is None:
            features_subset = slice(0, 3)


        self._dynedge_layers = dynedge_layers
        self._post_processing_layer_sizes = post_processing_layer_sizes
        self._post_processing_transformer = post_processing_transformer
        self._local_pooling_conf = local_pooling

        # Read-out layer sizes
        if readout_layer_sizes is None:
            readout_layer_sizes = [
                128,
            ]

        assert isinstance(readout_layer_sizes, list)
        assert len(readout_layer_sizes)
        assert all(size > 0 for size in readout_layer_sizes)

        self._readout_layer_sizes = readout_layer_sizes

        # Global pooling scheme(s)
        if global_pooling is None:
            global_pooling = {"type": "simple", "schemes": ["min","max","mean"], "nb_out": 768}

        self._global_pooling_conf = global_pooling
        self._global_pooling_model = None

        self._add_global_variables_after_pooling = (
            add_global_variables_after_pooling
        )

        # Base class constructor
        super().__init__(nb_inputs, self._readout_layer_sizes[-1])

        # Remaining member variables()
        self._activation = torch.nn.LeakyReLU()
        self._nb_inputs = nb_inputs
        self._nb_global_variables = 5 + nb_inputs
        self._nb_neighbours = nb_neighbours
        self._features_subset = features_subset

        self._sensor_embed = None
        if sensor_embedding:            
            sensor_embed_init_weights = torch.zeros((5160, 8), dtype=torch.float32, device=device)
            sensor_embed_init_weights[:,:4] = 1.0
            self._sensor_embed = torch.nn.Embedding(5160, 8, _weight=sensor_embed_init_weights)

        self._construct_layers()
        
        self._local_pooling = None        
        if self._local_pooling_conf:
            if self._local_pooling_conf["type"] == "TopKPooling":
                self._local_pooling = TopKPooling(self._post_processing_layer_sizes[-1],self._local_pooling_conf["k"])

        if self._global_pooling_conf["type"] == "GRU":
            self._global_pooling_model = torch.nn.GRU(input_size=self._global_pooling_conf["nb_in"],
                                                      hidden_size=int(self._global_pooling_conf["nb_out"]/2),
                                                      batch_first=True,
                                                      bidirectional=self._global_pooling_conf["bidirectional"])

    def build_dyn_edge_conv_layer(self, conf, nb_latent_features, last_layer):
        sizes = conf['sizes']
        layers = []
        layer_sizes = [nb_latent_features] + list(sizes)
        for ix, (nb_in, nb_out) in enumerate(
            zip(layer_sizes[:-1], layer_sizes[1:])
        ):
            if ix == 0:
                nb_in *= 2
            layers.append(torch.nn.Linear(nb_in, nb_out))
            layers.append(self._activation)
        
        conv_layer = DynEdgeConv(
                torch.nn.Sequential(*layers),
                aggr="add",
                nb_neighbors=self._nb_neighbours,
                features_subset=self._features_subset,
                last_layer=last_layer
            )
        return conv_layer, nb_out

    def build_transformer_conv_layer(self, conf, nb_latent_features, last_layer):
        conv_layer = DynTransformerConv(
                in_channels=nb_latent_features,
                out_channels=conf['nb_out'],
                heads=conf['heads'],
                nb_neighbors=self._nb_neighbours,
                features_subset=self._features_subset,
                last_layer=last_layer
            )
        return conv_layer, conf['nb_out']

    def build_layer(self, layer_conf, nb_latent_features, last_layer):
        if layer_conf['type'] == 'DynEdgeConv':
            return self.build_dyn_edge_conv_layer(layer_conf, nb_latent_features, last_layer)
        if layer_conf['type'] == 'TransformerConv':
            return self.build_transformer_conv_layer(layer_conf, nb_latent_features, last_layer)

    def _construct_layers(self) -> None:
        """Construct layers (torch.nn.Modules)."""
        # Convolutional operations
        nb_input_features = self._nb_inputs
        if not self._add_global_variables_after_pooling:
            nb_input_features += self._nb_global_variables

        self._conv_layers = torch.nn.ModuleList()
        nb_latent_features = nb_input_features
        for layer_conf in self._dynedge_layers:
            last_layer = len(self._conv_layers) == (len(self._dynedge_layers) - 1) # qudata: unnessesary last layer
            conv_layer, nb_out = self.build_layer(layer_conf, nb_latent_features, last_layer)
            self._conv_layers.append(conv_layer)
            nb_latent_features = nb_out

        # Post-processing operations
        nb_latent_features = (
            sum(layer_conf['nb_out'] for layer_conf in self._dynedge_layers)
            + nb_input_features
        )

        post_processing_layers = []
        if self._post_processing_layer_sizes:
            
            layer_sizes = [nb_latent_features] + list(
                self._post_processing_layer_sizes
            )
            for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]):
                post_processing_layers.append(torch.nn.Linear(nb_in, nb_out))
                post_processing_layers.append(self._activation)            

        if self._post_processing_transformer:
            encoder_layer = torch.nn.TransformerEncoderLayer(
                                                             d_model=self._post_processing_transformer["d_model"], 
                                                             nhead=self._post_processing_transformer["nhead"],
                                                             dim_feedforward=self._post_processing_transformer["dim_feedforward"]
                                                            )
            post_processing_layers.append(torch.nn.TransformerEncoder(encoder_layer, self._post_processing_transformer["num_layers"]))

        self._post_processing = torch.nn.Sequential(*post_processing_layers)

        # Read-out operations
        nb_latent_features = self._global_pooling_conf["nb_out"]

        if self._add_global_variables_after_pooling:
            nb_latent_features += self._nb_global_variables

        readout_layers = []
        layer_sizes = [nb_latent_features] + list(self._readout_layer_sizes)
        for nb_in, nb_out in zip(layer_sizes[:-1], layer_sizes[1:]):
            readout_layers.append(torch.nn.Linear(nb_in, nb_out))
            readout_layers.append(self._activation)

        self._readout = torch.nn.Sequential(*readout_layers)

    def _global_pooling_simple(self, x: Tensor, batch: LongTensor) -> Tensor:
        """Perform global pooling."""
        pooled = []
        for pooling_scheme in self._global_pooling_conf["schemes"]:
            pooling_fn = GLOBAL_POOLINGS[pooling_scheme]
            pooled_x = pooling_fn(x, index=batch, dim=0)
            if isinstance(pooled_x, tuple) and len(pooled_x) == 2:
                # `scatter_{min,max}`, which return also an argument, vs.
                # `scatter_{mean,sum}`
                pooled_x, _ = pooled_x
            pooled.append(pooled_x)
        pooled = torch.cat(pooled, dim=1)
        return pooled

    def _global_pooling_gru(self, x: Tensor, batch: LongTensor) -> Tensor:
        x, mask = to_dense_batch(x, batch)
        pooled = self._global_pooling_model(x)[0][:, -1]
        return pooled

    def _global_pooling(self, x: Tensor, batch: LongTensor) -> Tensor:
        if self._global_pooling_conf["type"] == "simple":
            return self._global_pooling_simple(x, batch)
        if self._global_pooling_conf["type"] == "GRU":
            return self._global_pooling_gru(x, batch)

    def _calculate_global_variables(
        self,
        x: Tensor,
        edge_index: LongTensor,
        batch: LongTensor,
        *additional_attributes: Tensor,
    ) -> Tensor:
        """Calculate global variables."""
        # Calculate homophily (scalar variables)
        h_x, h_y, h_z, h_t = calculate_xyzt_homophily(x, edge_index, batch)

        # Calculate mean features
        global_means = scatter_mean(x, batch, dim=0)

        # Add global variables
        global_variables = torch.cat(
            [
                global_means,
                h_x,
                h_y,
                h_z,
                h_t,
            ]
            + [attr.unsqueeze(dim=1) for attr in additional_attributes],
            dim=1,
        )

        return global_variables

    def forward(self, data: Data) -> Tensor:
        """Apply learnable forward pass."""
        if CFG.FROZEN: torch.set_grad_enabled(False) 

        # Convenience variables
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # Sensor embeddings
        if self._sensor_embed:    
            #with torch.no_grad():         
            s_emb = self._sensor_embed(data.sensor_id)
            x[:,:4] = x[:,:4] * s_emb[:,:4] + s_emb[:,4:] # x,y,z,t * wx,wy,wz,wt + ax,ay,az,at

        global_variables = self._calculate_global_variables(
            x,
            edge_index,
            batch,
            torch.log10(data.n_pulses),
        )

        # Distribute global variables out to each node
        if not self._add_global_variables_after_pooling:
            distribute = (
                batch.unsqueeze(dim=1) == torch.unique(batch).unsqueeze(dim=0)
            ).type(torch.float)

            global_variables_distributed = torch.sum(
                distribute.unsqueeze(dim=2)
                * global_variables.unsqueeze(dim=0),
                dim=1,
            )

            x = torch.cat((x, global_variables_distributed), dim=1)

        # DynEdge-convolutions
        skip_connections = [x]
        for conv_layer_idx, conv_layer in enumerate(self._conv_layers): 
            if CFG.FROZEN and self.training and conv_layer_idx in CFG.UNFOROZEN_LAYERS: torch.set_grad_enabled(True)
            x, edge_index = conv_layer(x, edge_index, batch)
            skip_connections.append(x)

        # Skip-cat
        x = torch.cat(skip_connections, dim=1)

        # Post-processing        
        x = self._post_processing(x)

        if self._local_pooling:
            x, edge_index, _, batch, _, _ = self._local_pooling(x, edge_index, None, batch)

        # (Optional) Global pooling
        if self._global_pooling_conf:
            x = self._global_pooling(x, batch=batch)
            if self._add_global_variables_after_pooling:
                x = torch.cat(
                    [
                        x,
                        global_variables,
                    ],
                    dim=1,
                )

        if CFG.FROZEN and self.training: torch.set_grad_enabled(True) 

        # Read-out
        x = self._readout(x)

        return x


In [None]:
#@title GraphNet

from graphnet.data.constants import FEATURES, TRUTH
from graphnet.models import StandardModel
from graphnet.models.detector.icecube import IceCubeKaggle
#from graphnet.models.gnn import DynEdge # QuData: owerwrite
from graphnet.models.graph_builders import KNNGraphBuilder
from graphnet.models.task.reconstruction import DirectionReconstructionWithKappa, ZenithReconstructionWithKappa, AzimuthReconstructionWithKappa
from graphnet.training.callbacks import ProgressBar
from graphnet.training.loss_functions import VonMisesFisher3DLoss, VonMisesFisher2DLoss
from graphnet.training.labels import Direction
from graphnet.utilities.logging import get_logger
from graphnet.models.graph_builders import GraphBuilder
from graphnet.models.detector.detector import Detector

from typing import Any, Dict, List, Optional, Union

import torch
from torch import Tensor
from torch.nn import ModuleList
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch_geometric.data import Data

from graphnet.models.coarsening import Coarsening
from graphnet.utilities.config import save_model_config
from graphnet.models.detector.detector import Detector
from graphnet.models.gnn.gnn import GNN
from graphnet.models.model import Model
from graphnet.models.task import Task
from torch.nn.functional import cosine_similarity, normalize
from graphnet.training.loss_functions import LossFunction
from graphnet.utilities.maths import eps_like


class CosineLoss(LossFunction):
    def _forward(self, preds, target):
        target = target.reshape(-1, 3)
        return 1 - cosine_similarity(preds[:,:3], target, dim=1, eps=1e-8)

class CosineLoss2D(LossFunction):
    def _forward(self, preds, target):
        target = target.reshape(-1, 2)
        return 1 - cosine_similarity(preds[:,:2], target, dim=1, eps=1e-8)

class VMFCustomLoss(LossFunction):
    def _forward(self, preds, target, eps = 1e-8, kappa0=10):
        target = target.reshape(-1, 3)          
        kappa = preds[:,3] 
        preds = preds[:,:3]*kappa.unsqueeze(1)
        logC  = -kappa + torch.log(kappa+eps) 
        mask  = kappa < kappa0
        ka    = kappa[mask]
        logC[mask] = torch.log( ka / ( torch.exp(ka) - torch.exp(-ka) + eps) )        
        return -(preds* target).sum(axis=1) - logC

class VMFCustomLoss2(LossFunction):
    def _forward(self, preds, target, eps = 1e-8, kappa0=10):    
        target = target.reshape(-1, 3)      
        kappa = preds[:,3] 
        preds = preds[:,:3]*kappa.unsqueeze(1)
        logC  = torch.log(kappa/(1-torch.exp(-2*kappa))+eps) - kappa
        return -( (n_true*n_pred).sum(axis=1) + logC)

class IceCubeCustom(Detector):
    """`Detector` class for Kaggle Competition."""

    def __init__(
        self, graph_builder: GraphBuilder, scalers: List[dict] = None, features = None
    ):
        self._features = features

        super().__init__(graph_builder, scalers)

    @property
    def features(self) -> List[str]:
        return self._features

    def _forward(self, data: Data) -> Data:
        """Ingest data, build graph, and preprocess features.

        Args:
            data: Input graph data.

        Returns:
            Connected and preprocessed graph data.
        """
        # Check(s)
        #self._validate_features(data)

        # Preprocessing
        data.x[:, 0] /= 500.0  # x
        data.x[:, 1] /= 500.0  # y
        data.x[:, 2] /= 500.0  # z
        data.x[:, -3] = (data.x[:, -3] - 1.0e04) / 3.0e4  # time
        data.x[:, -2] = torch.log10(data.x[:, -2]) / 3.0  # charge

        return data

class DirectionReconstructionWithKappa2D(Task):
    """Reconstructs direction with kappa from the 3D-vMF distribution."""

    # Requires three features: untransformed points in (x,y,z)-space.
    nb_inputs = 2

    def _forward(self, x: Tensor) -> Tensor:
        # Transform outputs to angle and prepare prediction
        kappa = torch.linalg.vector_norm(x, dim=1) + eps_like(x)
        vec_x = x[:, 0] / kappa
        vec_y = x[:, 1] / kappa
        return torch.stack((vec_x, vec_y, kappa), dim=1)

class DirectionReconstructionWithBins(Task):
    """Reconstructs direction with kappa from the 3D-vMF distribution."""

    # Requires three features: untransformed points in (x,y,z)-space.
    nb_inputs = 576

    def _forward(self, x: Tensor) -> Tensor:
        return x

"""Standard model class(es)."""
class StandardCustomModel(Model):
    """Main class for standard models in graphnet.

    This class chains together the different elements of a complete GNN-based
    model (detector read-in, GNN architecture, and task-specific read-outs).
    """

    @save_model_config
    def __init__(
        self,
        *,
        detector: Detector,
        gnn: GNN,
        tasks: Union[Task, List[Task]],
        coarsening: Optional[Coarsening] = None,
        optimizer_class: type = Adam,
        optimizer_kwargs: Optional[Dict] = None,
        scheduler_class: Optional[type] = None,
        scheduler_kwargs: Optional[Dict] = None,
        scheduler_config: Optional[Dict] = None     
    ) -> None:
        """Construct `StandardModel`."""
        # Base class constructor
        super().__init__()

        # Check(s)
        if isinstance(tasks, Task):
            tasks = [tasks]
        assert isinstance(tasks, (list, tuple))
        assert all(isinstance(task, Task) for task in tasks)
        assert isinstance(detector, Detector)
        assert isinstance(gnn, GNN)
        assert coarsening is None or isinstance(coarsening, Coarsening)

        # Member variable(s)
        self._detector = detector
        self._gnn = gnn
        self._tasks = ModuleList(tasks)
        self._coarsening = coarsening

    def forward(self, data: Data) -> List[Union[Tensor, Data]]:
        """Forward pass, chaining model components."""
        if CFG.FROZEN: torch.set_grad_enabled(False) 
        if self._coarsening:
            data = self._coarsening(data)
        data = self._detector(data)
        x = self._gnn(data)
        preds = [task(x) for task in self._tasks]
        return preds

    def compute_loss(
        self, preds: Tensor, data: Data, verbose: bool = False
    ) -> Tensor:
        """Compute and sum losses across tasks."""
        losses = [
            task.compute_loss(pred, data)
            for task, pred in zip(self._tasks, preds)
        ]
        if verbose:
            self.info(f"{losses}")
        assert all(
            loss.dim() == 0 for loss in losses
        ), "Please reduce loss for each task separately"
        return torch.sum(torch.stack(losses))

    def _get_batch_size(self, data: Data) -> int:
        return torch.numel(torch.unique(data.batch))

def build_model(config: Dict[str,Any]) -> StandardModel:
    """Builds GNN from config"""
    # Building model

    if not "mode" in config:
        config["mode"] = "regression"

    len_train_dataloader = 1000 # len(train_dataloader)

    detector = IceCubeCustom(
        graph_builder=KNNGraphBuilder(
              nb_nearest_neighbours=config["neighbours"][0],
              columns=config["features_subset"]
            ),
        features=config["features"]
    )    
    gnn = DynEdge(
        #nb_inputs=detector.nb_outputs,
        nb_inputs=len(config["features"]),
        nb_neighbours=config["neighbours"][1],
        dynedge_layers=config["dynedge_layers"],
        post_processing_layer_sizes=config["post_processing_layer_sizes"],
        post_processing_transformer=config["post_processing_transformer"],
        readout_layer_sizes=config["readout_layer_sizes"],
        global_pooling=config.get("global_pooling"),
        features_subset=config["features_subset"],
        sensor_embedding=config.get("sensor_embedding", False),
        local_pooling=config.get("local_pooling"),
    )

    loss_function = None
    task = None
    
    if config["loss_function"] == "VonMisesFisher3DLoss":
        loss_function = VonMisesFisher3DLoss()

    if config["loss_function"] == "CosineLoss":
        loss_function = CosineLoss()

    if config["loss_function"] == "VMFCustomLoss":
        loss_function = VMFCustomLoss()

    if config["loss_function"] == "CosineLoss2D":
        loss_function = CosineLoss2D()

    if config["loss_function"] == "CrossEntropyLoss":
        loss_function = torch.nn.CrossEntropyLoss()        

    if config["target"] == 'direction':
        if CFG.ANGLES == 'az,ze':
            task = DirectionReconstructionWithKappa(
                    hidden_size=gnn.nb_outputs,
                    target_labels=config["target"],
                    loss_function=loss_function,       
                )
            prediction_columns = [config["target"] + "_x", 
                                  config["target"] + "_y", 
                                  config["target"] + "_z", 
                                  config["target"] + "_kappa" ]
            additional_attributes = ['zenith', 'azimuth', 'event_id', 'sensor_id']
        else:
            task = DirectionReconstructionWithKappa2D(
                hidden_size=gnn.nb_outputs,
                target_labels=config["target"],
                loss_function=loss_function,       
            )
            prediction_columns = [config["target"] + "_x", 
                                  config["target"] + "_y", 
                                  config["target"] + "_kappa" ]
            additional_attributes = ['zenith', 'azimuth', 'event_id', 'sensor_id']

    if config["target"] == 'class_id':
        DirectionReconstructionWithBins.nb_inputs = config["bin_num"]**2
        task = DirectionReconstructionWithBins(
                                hidden_size=gnn.nb_outputs,
                                target_labels=config["target"],
                                loss_function=loss_function)
        
        prediction_columns = [] 
        additional_attributes = ['zenith', 'azimuth', 'event_id', 'sensor_id']


    model = StandardCustomModel(
        detector=detector,
        gnn=gnn,
        tasks=[task],
    )
    model.prediction_columns = prediction_columns
    model.additional_attributes = additional_attributes
    
    return model

# Model

In [None]:
#@title cfg

# Configuration
CFG.NUM_EPOCHS = 10000
CFG.NUM_STEPS = -1
CFG.TRAIN_CFG = {
        "mode": "classification", # regression/classification
        "features": ['x', 'y', 'z', 'time', 'charge', 'auxiliary'], #['x', 'y', 'z', 'core', 'sc', 'abs', 'r', 'eff', 'time', 'charge', 'auxiliary'],  #['x', 'y', 'z', 'time', 'charge', 'auxiliary']
        "doms_agg": False,
        "features_subset": slice(0, 3),   # (0, 3)
        "neighbours": (8,16),             # (8,8)
        "sensor_embedding": False,         # False/True
        "dynedge_layers": [
                            {
                              'type': 'DynEdgeConv',
                              'sizes': (128,256),
                              'nb_out': 256,
                            },
                            {
                              'type': 'DynEdgeConv',
                              'sizes': (336,256),
                              'nb_out': 256,
                            },
                            {
                              'type': 'DynEdgeConv',
                              'sizes': (336,256),
                              'nb_out': 256,
                            },
                            {
                              'type': 'DynEdgeConv',
                              'sizes': (336,256),
                              'nb_out': 256,
                            },
                            {
                              'type': 'DynEdgeConv',
                              'sizes': (336,256),
                              'nb_out': 256,
                            },
                          ],
        "post_processing_layer_sizes": [2048,256],    # [336,256]
        "post_processing_transformer": None, # None {"d_model": 256, "nhead": 4, "num_layers": 4, "dim_feedforward": 512}
        "local_pooling": None,               # {"type": 'TopKPooling', "k": 4},
        "global_pooling": {"type": "simple", "schemes": ["min","max","mean"], "nb_out": 768}, # {"type": "simple", "schemes": ["min","max","mean"], "nb_out": 768} {"type": "GRU", "nb_in": 256, "nb_out": 512}
        "readout_layer_sizes": [512],     # 128
        "batch_size": 400,                # 500
        "batch_train_size": 400,          # for grad acc        
        "optimizer": {
                "type": 'adamW',           # lion/adam/adamW
                "lr":  1e-05,             # 1e-03 1e-04 1e-05 1e-06
                "eps": 1e-08              # 1e-03
        },
        "augmentation": {
            "az_rot": False
        }   
}

if CFG.TRAIN_CFG["mode"] == "regression":
    CFG.TRAIN_CFG["bin_num"]       = None
    CFG.TRAIN_CFG["target"]        = "direction"
    CFG.TRAIN_CFG["loss_function"] = "VonMisesFisher3DLoss"

if CFG.TRAIN_CFG["mode"] == "classification":
    CFG.TRAIN_CFG["bin_num"]       = 24
    CFG.TRAIN_CFG["target"]        = "class_id"
    CFG.TRAIN_CFG["loss_function"] = "CrossEntropyLoss"

CFG.TRAIN_CFG["scheduler"] = {
                "type": 'PiecewiseLinearLR',
                "milestones": [0, (200_000 / CFG.TRAIN_CFG['batch_train_size']) / 2, ((200_000 / CFG.TRAIN_CFG['batch_train_size']) * 2000)],
                "factors":    [1e-02, 1, 1e-02] #[1e-02, 1, 1e-02]
        }

sensors_data = get_sensors(CFG.TRAIN_CFG['features'][:-3])

if CFG.USE_WANDB:
    wandb_init()


In [None]:
#@title loaders

#train_dataloader = build_dataloader(50, True)
val_dataloader = build_dataloader(CFG.VAL_BATCH, False, CFG.TRAIN_CFG, CFG.VAL_EVENTS)

print('valid batches: ', len(val_dataloader))


In [None]:
#batch = next(iter(val_dataloader))
#batch

In [None]:
#for batch in tqdm(val_dataloader):
#  batch = batch

In [None]:
#@title init

model = build_model(config= CFG.TRAIN_CFG)

if CFG.TRAIN_CFG['optimizer']['type'] == 'adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=CFG.TRAIN_CFG["optimizer"]["lr"], eps=CFG.TRAIN_CFG["optimizer"]["eps"]) 

if CFG.TRAIN_CFG['optimizer']['type'] == 'adamW':
    optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.TRAIN_CFG["optimizer"]["lr"], eps=CFG.TRAIN_CFG["optimizer"]["eps"]) 

if CFG.TRAIN_CFG['optimizer']['type'] == 'lion':
    optimizer = Lion(model.parameters(), lr=CFG.TRAIN_CFG["optimizer"]["lr"]) 

if CFG.TRAIN_CFG['scheduler']['type'] == 'PiecewiseLinearLR':
    scheduler=PiecewiseLinearLR(optimizer, milestones=CFG.TRAIN_CFG['scheduler']['milestones'], factors=CFG.TRAIN_CFG['scheduler']['factors'])
    
model = model.to(device)

summary(model)

#summary(model, input_size=(CFG.TRAIN_CFG['batch_size'],CFG.MAX_PULSES_PER_EVENT, len(CFG.FEATURES)), col_names=["input_size","output_size","num_params","mult_adds"])

In [None]:
#@title test

#batch = next(iter(val_dataloader))

#batch.to(device)
#out = model(batch)

# Train

In [None]:
#@title lib

if CFG.TRAIN_CFG['mode'] == 'classification':
    bin_num = CFG.TRAIN_CFG['bin_num']
    azimuth_edges, zenith_edges = build_az_ze_edges(bin_num)
    angle_bin_vector = build_angle_bin_vector(azimuth_edges.cpu().numpy(), zenith_edges.cpu().numpy(), bin_num)

def augmentation(batch):
    augm_cfg = CFG.TRAIN_CFG['augmentation']

    # azimuth rotate
    if augm_cfg['az_rot']:
        rot_phi = torch.rand((1), device=device, dtype=torch.float32) * torch.pi * 2.0
        rot_matrix = get_rot_matrix(rot_phi)
        batch.x[:,:2] = batch.x[:,:2] @ rot_matrix
        batch.direction[:,:2] = batch.direction[:,:2] @ rot_matrix

    return batch

def preds_to_vectors(preds):
    preds = preds[0].detach()

    if CFG.TRAIN_CFG['mode'] == 'regression':
        vectors_pred = preds[:,:3]
        kappa_preds =  preds[:,3]
        return vectors_pred, kappa_preds, vectors_pred, kappa_preds

    if CFG.TRAIN_CFG['mode'] == 'classification':
        preds = torch.nn.functional.softmax(preds, dim=1)
        vectors_pred_avg = code_to_vector(preds, bin_num, angle_bin_vector, max=False) 
        vectors_pred_max = code_to_vector(preds, bin_num, angle_bin_vector, max=True) 
        return vectors_pred_avg, None, vectors_pred_max, None

def fit(model, loader, epoch, train=True, events_batch=None, batch_repeat=-1): 
    all_vectors_pred_max   = torch.tensor([], dtype=float).to(device)
    all_vectors_pred_avg   = torch.tensor([], dtype=float).to(device)
    all_kappa_pred_max     = torch.tensor([], dtype=float).to(device)
    all_kappa_pred_avg     = torch.tensor([], dtype=float).to(device)

    all_vectors_target = torch.tensor([], dtype=float).to(device)

    tot_loss = 0
    count = 0
    total = CFG.NUM_STEPS if CFG.NUM_STEPS>0 and train else len(loader)
    pbar = tqdm(enumerate(loader), total=total)
    model.train(train)
    if train:
        # Clear gradients
        optimizer.zero_grad()     
    batch_train_ratio = CFG.TRAIN_CFG['batch_train_size'] / CFG.TRAIN_CFG['batch_size'] 
    for steps, batch in pbar: 
        pbar.set_description(f"epoch {epoch} (batch:{events_batch}:{batch_repeat}): {(tot_loss.cpu().item()/(steps) if steps > 0 else 0.0):.6f}")
        if steps == CFG.NUM_STEPS: #if train and steps == CFG.NUM_STEPS:
            break

        batch = batch.to(device)

        # Forward propagation
        if train:
            batch = augmentation(batch)
            preds = model(batch)
        else:
            with torch.no_grad():
                preds = model(batch)

        if CFG.TRAIN_CFG['mode'] == 'regression':
            L = model.compute_loss(preds, batch)
        else:
            L = model._tasks[0]._loss_function(preds[0], batch.class_id)

        if train:
            # Calculating gradients
            L.backward()
            if (steps+1) % batch_train_ratio == 0:
                # Update parameters
                optimizer.step()  
                # Update scheduler lr
                scheduler.step()
                # Clear gradients
                optimizer.zero_grad()

        loss_val = L.detach()

        vectors_pred_avg, kappa_pred_avg, vectors_pred_max, kappa_pred_max = preds_to_vectors(preds)

        vectors_target = batch.direction

        all_vectors_pred_avg   = torch.cat([all_vectors_pred_avg, vectors_pred_avg])
        all_vectors_pred_max   = torch.cat([all_vectors_pred_max, vectors_pred_max])

        if CFG.TRAIN_CFG['mode'] == 'regression':
            all_kappa_pred_avg   = torch.cat([all_kappa_pred_avg, kappa_pred_avg])
            all_kappa_pred_max   = torch.cat([all_kappa_pred_max, kappa_pred_max])

        all_vectors_target = torch.cat([all_vectors_target, vectors_target])

        tot_loss += loss_val
        count+=1

    error_avg, az_error_avg, ze_error_avg = angle_errors(all_vectors_pred_avg, all_vectors_target)
    error_max, az_error_max, ze_error_max = angle_errors(all_vectors_pred_max, all_vectors_target)

    kappa_mean = all_kappa_pred_avg.mean()

    return tot_loss.item()/count, error_avg.item(), az_error_avg.item(), ze_error_avg.item(), error_max.item(), az_error_max.item(), ze_error_max.item(), kappa_mean.item()

def save_checkpoint(model, optimizer, epoch, val_loss, val_err_avg, val_err_max, val_kappa, best_epoch, best_val_loss, best_val_ads_avg, best_val_ads_max, batch_offset, batch_repeat, history, last):

    exp_path  = CFG.CHECKPOINTS_PATH
    if epoch == 1:
        if path.isdir(exp_path):
            shutil.rmtree(exp_path, ignore_errors=True)
            
    if not path.isdir(exp_path):
        os.makedirs(exp_path, exist_ok=True)
    
    if last:
        fname = f"{exp_path}/model_exp_{CFG.EXP_ID}_{CFG.EXP_COMMENT}_last.pt"
    else:
        fname = f"{exp_path}/model_exp_{CFG.EXP_ID}_{CFG.EXP_COMMENT}_val_ads_avg_{val_err_avg:.4f}_val_ads_max_{val_err_max:.4f}_val_loss_{val_loss:.4f}_epoch_{epoch:04d}.pt"
        
    print(f'save checkpoint: epoch: {epoch}, val_ads_avg: {val_err_avg:.4f}, val_ads_max: {val_err_max:.4f}, val_loss: {val_loss:.4f} to {fname}')
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_epoch': best_epoch,
            'best_val_loss': best_val_loss,
            'best_val_ads_avg': best_val_ads_avg,
            'best_val_ads_max': best_val_ads_max,
            'val_loss': val_loss,
            'val_err_avg': val_err_avg,
            'val_err_max': val_err_max,
            'val_kappa': val_kappa,
            'batch_offset': batch_offset,
            'batch_repeat': batch_repeat,
            'history': history,
            'cfg': CFG.to_dict()
            }, fname)

def plot_train(history):

    h = np.array(history)                               # learning output
    plt.figure(figsize=(21,7), facecolor ='w')   

    ax1 = plt.subplot(1,3,1);  
    ax1.set( ylim=(4.7, 6.5), xlim=(0, h[-1,0]) )
    ax1.grid(color='gray', linestyle='--', alpha=0.6)
    ax1.set_title(f'trn_loss: {h[-1,1]:.4f}, val_loss: {h[-1,9]:.4f}, lr: {h[-1,17]:.4f}')

    ax1.plot(h[:,0], h[:,1], "-b")   
    ax1.plot(h[:,0], h[:,9], "-g")  
    ax11 = ax1.twinx()
    ax11.plot(h[:,0], h[:,17], ":k")  
    ax11.set_yscale('log');    

    ax2 = plt.subplot(1,3,2)
    ax2.set( ylim=(0.54, 1.10), xlim=(0, h[-1,0]) )
    ax2.grid(color='gray', linestyle='--', alpha=0.6)
    ax2.set_title(f'trn_ads_avg: {h[-1,2]:.4f}, val_ads_avg: {h[-1,10]:.4f}, val_kappa: {h[-1,16]:.4f}')
    ax2.plot(h[:,0], h[:,2],  "-b")   
    ax2.plot(h[:,0], h[:,10], "-g") 
    ax2.plot(h[:,0], h[:,3],  ":b")   
    ax2.plot(h[:,0], h[:,11], ":g") 
    ax2.plot(h[:,0], h[:,4],  "--b")   
    ax2.plot(h[:,0], h[:,12], "--g") 
    ax21 = ax2.twinx()
    ax21.plot(h[:,0], h[:,16], ":r")  

    ax3 = plt.subplot(1,3,3)
    ax3.set( ylim=(0.54, 1.10), xlim=(0, h[-1,0]) )
    ax3.grid(color='gray', linestyle='--', alpha=0.6)
    ax3.set_title(f'trn_ads_max: {h[-1,5]:.4f}, val_ads_max: {h[-1,13]:.4f}')
    ax3.plot(h[:,0], h[:,5],  "-b")   
    ax3.plot(h[:,0], h[:,13], "-g") 
    ax3.plot(h[:,0], h[:,6],  ":b")   
    ax3.plot(h[:,0], h[:,14], ":g") 
    ax3.plot(h[:,0], h[:,7],  "--b")   
    ax3.plot(h[:,0], h[:,15], "--g") 


    #pars = [ f"{k:10s}: {v}\n" for k,v in CFG.items() ]
    #ax1.text(h[0,0]+(h[-1,0]-h[0,0])*0.5, 0., "".join(pars), {'fontsize':12, 'fontname':'monospace'})
    plt.show()



In [None]:
#@title fit

BATCH_REPEAT_CNT = 1

history          = []
epoch            = 1
best_epoch       = 0
best_val_loss    = 1000
best_val_ads_avg = 1000
best_val_ads_max = 1000
batch_offset     = 0
batch_repeat     = 0

if CFG.RESUME:
    #models = glob(resume_path + '*.pt')
    #models.sort()
    #resume_model = resume_path + 'model_exp_5_epoch_0107_val_ads_1.0854_trn_ads_1.0900_val_loss_1.7167_trn_loss_1.7201.pt'
    #resume_model = models[-1]
    state = torch.load(CFG.RESUME_MODEL)
    epoch = state['epoch']
    model.load_state_dict(state['model_state_dict'])
    optimizer.load_state_dict(state['optimizer_state_dict'])
    scheduler.load_state_dict(state['scheduler_state_dict'])
    best_epoch = state['best_epoch']
    best_val_loss = state['best_val_loss']
    best_val_ads_avg = state['best_val_ads_avg']
    best_val_ads_max = state['best_val_ads_max']
    val_loss = state['val_loss']
    val_err_avg = state['val_err_avg']
    val_err_max = state['val_err_max']
    val_kappa = state['val_kappa']
    batch_offset = state['batch_offset']
    batch_repeat = state['batch_repeat']
    history = state['history']
    cfg = state['cfg']
    print(f'resume model: {CFG.RESUME_MODEL} with config: {cfg}')
    print(f'from lr: {scheduler.get_last_lr()[0]}')
    print(f'from epoch: {epoch}')

    
    print(f'\tloss:     val: {val_loss:.4f} best: {best_val_loss:.4f}')
    print(f'\tads_avg:  val: {val_err_avg:.4f} best: {best_val_ads_avg:.4f}')
    print(f'\tads_max:  val: {val_err_max:.4f} best: {best_val_ads_max:.4f}')
    print(f'\tkappa:    val: {val_kappa:.4f}')
    print(f'\tbest epoch: {best_epoch}')

    plot_train(history)
    epoch += 1
else:
    print(f'train from scratch')

In [None]:
lr = 1e-06
eps = 1e-03
l2 = 0
#optimizer = torch.optim.Adam(model.parameters(), lr=lr, eps=eps, weight_decay=l2) 

In [None]:
CFG.NUM_STEPS = -1

train_dataloader = None
train_batch = None

for epoch in range(epoch,CFG.NUM_EPOCHS+1):
    batch_repeat += 1

    if batch_repeat > BATCH_REPEAT_CNT:
      batch_repeat = 1
      batch_offset += 1
    
    cur_train_batch = CFG.BATCH_RANGE[0] + batch_offset
    if cur_train_batch > CFG.BATCH_RANGE[1] - 1:
      batch_offset = 0
      cur_train_batch = CFG.BATCH_RANGE[0]
      batch_repeat

    #if train_batch != cur_train_batch:
    #  train_batch = cur_train_batch
    #  if train_dataloader:
    #    del train_dataloader
    #    gc.collect()
    #  train_dataloader = build_dataloader(train_batch, True)

    train_batch = cur_train_batch
    if train_dataloader:
        del train_dataloader
        gc.collect()
    
    if CFG.SAMPLE_FILTER:
        train_dataloader = build_dataloader(train_batch, False, CFG.TRAIN_CFG)
        # check current metric
        _, M_trn_before_filter, _, kappa_all, metrics_all = fit(model, train_dataloader, epoch, False, train_batch,batch_repeat)
        #filter samples by kappa
        samples_before = len(train_dataloader)
        #indexes = (metrics_all<1.57).nonzero().squeeze().tolist()        
        indexes = ((kappa_all>0.7) & (kappa_all<2.0)).nonzero().squeeze().tolist()        
        train_dataloader = build_dataloader(train_batch, True, CFG.TRAIN_CFG, None, indexes)
        samples_after = len(train_dataloader)
        L_trn, M_trn, kappa_trn, _, _ = fit(model, train_dataloader, epoch, True, train_batch,batch_repeat)
        print(f'sample filtering: before: {M_trn_before_filter:.5f}, after: {M_trn:.5f} samples: {(samples_after/samples_before):.5f}')
        L_val, M_val, kappa_val, _, _ = fit(model, val_dataloader, epoch, False, CFG.VAL_BATCH,-1)
    else:
        train_dataloader = build_dataloader(train_batch, True, CFG.TRAIN_CFG)
        trn_loss, trn_err_avg, trn_az_err_avg, trn_ze_err_avg, trn_err_max, trn_az_err_max, trn_ze_err_max, trn_kappa = fit(model, train_dataloader, epoch, True, train_batch,batch_repeat)
        val_loss, val_err_avg, val_az_err_avg, val_ze_err_avg, val_err_max, val_az_err_max, val_ze_err_max, val_kappa = fit(model, val_dataloader, epoch, False, CFG.VAL_BATCH,-1)    


    lr = scheduler.get_last_lr()[0]
    history.append([epoch, 
                    trn_loss, trn_err_avg, trn_az_err_avg, trn_ze_err_avg, trn_err_max, trn_az_err_max, trn_ze_err_max, trn_kappa,
                    val_loss, val_err_avg, val_az_err_avg, val_ze_err_avg, val_err_max, val_az_err_max, val_ze_err_max, val_kappa,
                    lr])

    if CFG.USE_WANDB:
        wandb.log({
                   "trn_loss": trn_loss, 
                   "trn_err_avg": trn_err_avg, "trn_az_err_avg": trn_az_err_avg, "trn_ze_err_avg": trn_ze_err_avg,
                   "trn_err_max": trn_err_max, "trn_az_err_max": trn_az_err_max, "trn_ze_err_max": trn_ze_err_max,
                   "trn_kappa": trn_kappa,
                   "val_loss": val_loss, 
                   "val_err_avg": val_err_avg, "val_az_err_avg": val_az_err_avg, "val_ze_err_avg": val_ze_err_avg,
                   "val_err_max": val_err_max, "val_az_err_max": val_az_err_max, "val_ze_err_max": val_ze_err_max,
                   "val_kappa": val_kappa,
                   "lr": lr
                   })

    print(f'epoch {epoch} (batch: {train_batch}, lr: {lr}):')    

    if best_val_loss > val_loss:
        print(f'\t !!! val_loss improved from {best_val_loss:.4f} to {val_loss:.4f}')
        best_val_loss = val_loss
        best_epoch = epoch
    if best_val_ads_avg > val_err_avg:            
        print(f'\t !!! val_ads_avg improved from {best_val_ads_avg:.4f} to {val_err_avg:.4f}')
        best_val_ads_avg = val_err_avg
        best_epoch = epoch
    if best_val_ads_max > val_err_max:            
        print(f'\t !!! val_ads_max improved from {best_val_ads_max:.4f} to {val_err_max:.4f}')
        best_val_ads_max = val_err_max
        best_epoch = epoch

    print(f'\tloss:     trn: {trn_loss:.4f} val: {val_loss:.4f} best: {best_val_loss:.4f}')
    print(f'\tads_avg:  trn: {trn_err_avg:.4f} val: {val_err_avg:.4f} best: {best_val_ads_avg:.4f}')
    print(f'\taz_avg:   trn: {trn_az_err_avg:.4f} val: {val_az_err_avg:.4f}')
    print(f'\tze_avg:   trn: {trn_ze_err_avg:.4f} val: {val_ze_err_avg:.4f}')
    print(f'\tads_max:  trn: {trn_err_max:.4f} val: {val_err_max:.4f} best: {best_val_ads_max:.4f}')
    print(f'\taz_max:   trn: {trn_az_err_max:.4f} val: {val_az_err_max:.4f}')
    print(f'\tze_max:   trn: {trn_ze_err_max:.4f} val: {val_ze_err_max:.4f}')
    print(f'\tkappa:    trn: {trn_kappa:.4f} val: {val_kappa:.4f}')
    print(f'\tbest epoch: {best_epoch}')

    if best_epoch == epoch:
      save_checkpoint(model, optimizer, epoch, val_loss, val_err_avg, val_err_max, val_kappa, best_epoch, best_val_loss, best_val_ads_avg, best_val_ads_max, batch_offset, batch_repeat, history, False)
    else:
      save_checkpoint(model, optimizer, epoch, val_loss, val_err_avg, val_err_max, val_kappa, best_epoch, best_val_loss, best_val_ads_avg, best_val_ads_max, batch_offset, batch_repeat, history, True)

    if epoch % 10 == 0:
        plot_train(history)


# Validate

In [None]:
resume_model = CFG.RESUME_MODEL
#resume_model = '/content/drive/MyDrive/work/projects/icecube/models/ice_gnn_v2_exp_1_aAZ_l4_e64_fs03_b200_lrs/model_exp_1_aAZ_l4_e64_fs03_b200_lrs_val_ads_0.7691_trn_ads_0.7982_val_loss_0.7409_trn_loss_0.7066_epoch_0246.pt'

resume_state = torch.load(resume_model)
model.load_state_dict(resume_state['model_state_dict'])
#model.load_state_dict(resume_state)

In [None]:
CFG.NUM_STEPS = 10

val_loss, val_err_avg, val_az_err_avg, val_ze_err_avg, val_err_max, val_az_err_max, val_ze_err_max, val_kappa = fit(model, val_dataloader, 100, False, CFG.VAL_BATCH)

print(f'\tloss:    val: {val_loss:.4f}')
print(f'\tads_avg: val: {val_err_avg:.4f}')
print(f'\taz_avg:  val: {val_az_err_avg:.4f}')
print(f'\tze_avg:  val: {val_ze_err_avg:.4f}')
print(f'\tads_max: val: {val_err_max:.4f}')
print(f'\taz_max:  val: {val_az_err_max:.4f}')
print(f'\tze_max:  val: {val_ze_err_max:.4f}')
print(f'\tkappa:   val: {val_kappa:.4f}')
