# IceCube: Creating a dataset

This notebook creates a dataset on google drive to train the model.
Initial training batches are also uploaded to google drive.

Every 5 training batches were collected in a pack of 1m examples.<br>
These examples were sorted by sequence length and combined into groups of the same length. 

## Load data from disk

In [None]:
%%time
CREATE_DATASET      = True          # start dataset generation
LOAD_FROM_DISK      = True          # download kaggle source batches from google drive

DROP_AUX            = False         # throw out aux==True
DOMS_AGG            = False         # aggregate by sensors (time of the first)
T_MAX               = 512           # maximum number of pulses in an event
IS_AGG              = True          # generate aggregated event features

LINES_FILTER        = False         # set a filter by the number of strings in the event
LINES_MIN           = 0             # minimum number of strings (if LINES_FILTER==True)
LINES_MAX           = 100           # maximum number of strings (if LINES_FILTER==True)
#-------------------------------------------------------------------------------

FIRST_BATCH_ID   = 1                # number of the first batch for training and validation
NUM_BATCHES      = 10               # total number of batches for training and validation
BATCHES_IN_PACK  = 5                # number of batches per group for training and validation

# folder for the resulting dataset:
DATASET_FOLDER = "/content/drive/MyDrive/IceCube/IceCube-Dataset/ALL_512"

#===============================================================================

import os, gc, sys, time, datetime, math, random,  psutil
import numpy as np
import matplotlib.pyplot as plt
from   pathlib   import Path        
from   tqdm.auto import tqdm
import pandas as pd
import pyarrow, pyarrow.parquet as pq     # read by chanks
import torch

from psutil import virtual_memory
print(f'Your runtime has {(virtual_memory().total / 1024**3):.1f} gigabytes of available RAM\n')
#===============================================================================
# Copying competition batches from google drive
if LOAD_FROM_DISK:
    !cp /content/drive/MyDrive/IceCube/IceCube-Dataset/train_meta_splitted.zip /content/
    !unzip -q /content/train_meta_splitted.zip
    !rm       /content/train_meta_splitted.zip

    !cp /content/drive/MyDrive/IceCube/IceCube-Dataset/sensor_geometry.csv /content/
    !cp /content/drive/MyDrive/IceCube/IceCube-Dataset/scattering_and_absorption.csv /content/
            
    for batch_id in tqdm(range(FIRST_BATCH_ID, FIRST_BATCH_ID + NUM_BATCHES)):
        !cp    /content/drive/MyDrive/IceCube/IceCube-Dataset/train/batch_{batch_id}.parquet /content/

## Prepare Data Functions

In [None]:
#===============================================================================
PATH      = Path("/content")                  # path to dataset
PATH_PHYS = Path("/content")                  # path to dataset
PATH_META = Path("/content/content/icecube-neutrinos-in-deep-ice/train_meta")
files_trn = [item for item in (PATH  / "train").glob('*')]  # all train files
print(f"{len(files_trn):3d} train files")
#===============================================================================

def info(text, pref="", end="\n"):
    """ 
    Information about the progress of calculations (time and memory) 
    """
    gc.collect()
    ram, t = psutil.virtual_memory().used / 1024**3,  time.time()    
    print(f"{pref}{(t-info.beg)/60:5.1f}m[{t-info.last:+5.1f}s] {ram:6.3f}Gb > {text}",end=end)
    info.last = time.time(); 
info.beg = info.last = time.time()

#-------------------------------------------------------------------------------

def get_sensors():
    """ 
    Get sensor positions 
    """            
    df = pd.read_csv(PATH / "sensor_geometry.csv")      
    df['line_id'] = df.sensor_id // 60 + 1                 # string id
    df['core']    = (df.line_id > 78).astype(np.float32)   # sensor from DeepCore
    df.x = ( df.x * 1e-3 ).astype(np.float32)              # distances in kilometers
    df.y = ( df.y * 1e-3 ).astype(np.float32)
    df.z = ( df.z * 1e-3 ).astype(np.float32)    
    
    from scipy.interpolate import interp1d                 # add absorption
    phys = pd.read_csv(PATH_PHYS / "scattering_and_absorption.csv")
    phys.z = (phys.z * 1e-3).astype(np.float32)
    phys.a = (phys.a * 1e-2).astype(np.float32)
    interp = interp1d(phys.z, phys.a)
    df['a'] = interp(df.z)

    df['r'] = np.sqrt(df.x**2 + df.y**2)

    return df[['sensor_id', 'line_id', 'core', 'x', 'y', 'z', 'a', 'r']]

#-------------------------------------------------------------------------------

def get_target_angles(batch_id=1):
    """ 
    Get target angles for batch with batch_id 
    """    
    assert batch_id > 0 and  batch_id < 661, "Wrong batch_id"        
    df = pd.read_parquet(PATH_META / f"batch_{batch_id}_meta.parquet")
    df.event_id= df.event_id.astype(np.int64)      
    df.azimuth = df.azimuth.astype(np.float32)
    df.zenith  = df.zenith.astype(np.float32)                        
    return df[ ['event_id','azimuth','zenith'] ]

#-------------------------------------------------------------------------------

def prepare_batch(df, verbose=True, drop_aux = DROP_AUX, doms_agg = DOMS_AGG):
    """ 
    Preparing a loaded batch, shifting and normalizing times 
    """    
    df['event_id'] = df.index.astype(np.int64)
    df = df.reset_index(drop=True)  # sensor_id, t, charge, aux, event_id    
    df.rename(columns={"time": "t", "auxiliary": "aux", 'charge': 'q'}, inplace=True)
    df.q = df.q.astype(np.float32)

    if drop_aux:
        df = df[ ~df.aux ]
    
    if doms_agg:
        df = df.groupby(['event_id', 'sensor_id']).agg(
            aux = ( 'aux', "mean"),
            q   = ( 'q',   "sum"),
            t   = ( 't',   "min"),            
        )
        df = df.reset_index()
    
    if verbose: info(f"load_batch: loaded  {df.shape}")
        
    times = df.groupby('event_id').agg( t_min = ('t', 'min') )
    df = df.merge(times, left_on='event_id', right_index=True, how='left')
    df.t = (( df.t - df.t_min ) * 0.299792458e-3 ).astype(np.float32)             
    
    if verbose: info("load_batch: shift_times")    

    return df[['event_id', 'sensor_id', 'aux', 'q', 't' ]]
        
#-------------------------------------------------------------------------------

def cut_pulses(df, max_pulses = 128, verbose=True):
    """ 
    Throw out the last and unreliable pulses in the event if there are more than max_pulses 
    """
    tot = len(df)
    df = df.sort_values(['event_id','aux','t'])          # do you need aux???
    df.reset_index(drop=True, inplace=True)    

    df = df.groupby('event_id').head(max_pulses)         # cut pulses by event
    df.reset_index(inplace=True)                         # sorted by time later!

    if not DROP_AUX:
        df = df.sort_values(['event_id','t'])        
        df.reset_index(drop=True, inplace=True)

    if verbose: info(f"cut_pulses (max={max_pulses}): removed {100*(tot-len(df))/tot:.2f}%")
    return df

#-------------------------------------------------------------------------------

def lines_filter(df, lines_min=1, lines_max=1):
    """
    Filter by number of strings
    """
    agg = df[df.aux == 0].copy().groupby('event_id').agg( lines0 = ( 'line_id',   'nunique') )
    agg.reset_index(inplace=True)    
    agg = agg[(agg.lines0 >= lines_min) & (agg.lines0 <= lines_max)]
    df = df[df.event_id.isin(agg.event_id)]
    df.reset_index(drop=True, inplace=True)  
    return df, agg
#-------------------------------------------------------------------------------

def angles2vector(df):
    """ 
    Add unit vector components from (azimuth,zenith) to the DataFrame df 
    """
    df['nx'] = np.sin(df.zenith) * np.cos(df.azimuth)
    df['ny'] = np.sin(df.zenith) * np.sin(df.azimuth)
    df['nz'] = np.cos(df.zenith) 
    return df

#-------------------------------------------------------------------------------

def delta_angle(n1, n2, eps=1e-8):
    """ 
    Calculate angles between two vectors: n1,n2: (B,3) return: (B,) 
    """
    n1 = n1 / (np.linalg.norm(n1, axis=1, keepdims=True) + eps)
    n2 = n2 / (np.linalg.norm(n2, axis=1, keepdims=True) + eps)
    cos = (n1*n2).sum(axis=1).clip(-1,1)
    return np.arccos( cos )

#-------------------------------------------------------------------------------

def get_event_features(df, target_df, suf = "", aux=True):    
    """ 
    Aggregated features characterizing the entire event 
    """

    df['xt'] = df.x*df.t;  df['yt'] = df.y*df.t; df['zt'] = df.z*df.t;  df['tt'] = df.t**2;       
    if aux:
        for col in df.columns:
            if col not in ['event_id', 'sensor_id', 'line_id']:
                df[col] = df[col] * (1-df.aux)

    df = df.groupby('event_id').agg(          # this is for all pulses with any aux
        tot      = ('t',        'count'),            
        t_med    = ('t',        'median'),  
        t        = ('t',        'mean'),  
        x        = ('x',        'mean'),  
        y        = ('y',        'mean'),  
        z        = ('z',        'mean'),  
        stdT     = ('t',        'std'),
        stdX     = ('x',        'std'),
        stdY     = ('y',        'std'),
        stdZ     = ('z',        'std'),
        xt       = ('xt',       'mean'),
        yt       = ('yt',       'mean'),
        zt       = ('zt',       'mean'),
        tt       = ('tt',       'mean'),
        q        = ('q',        'mean' ),
        q_min    = ('q',        'min' ),
        q_max    = ('q',        'max' ),
        q_med    = ('q',        'median' ),
        aux      = ('aux',      'mean' ),       
        core     = ('core',     'mean' ),        
        lines    = ('line_id',  'nunique' ),
        doms     = ('sensor_id','nunique' ),                
    )
    df.reset_index(inplace=True)    
    
    df.aux   = df.aux  .astype(np.float32)
    df.lines = df.lines.astype(np.float32)
    df.doms  = df.doms .astype(np.float32)    
    df.stdT  = df.stdT .astype(np.float32)    
    df.stdX  = df.stdX .astype(np.float32)    
    df.stdY  = df.stdY .astype(np.float32)    
    df.stdZ  = df.stdZ .astype(np.float32)    

    df['p_lines'] = np.log10(df.tot / df.lines).astype(np.float32)
    df['p_doms']  = np.log10(df.tot / df.doms ).astype(np.float32)    

    df.q         = np.log(1+df.q)
    df.q_med     = np.log(1+df.q_med)
    df.q_min     = np.log(1+df.q_min)
    df.q_max     = np.log(1+df.q_max)
    df.lines     = np.log10(df.lines)  / 10
    df.doms      = np.log10(df.doms)   / 10
    df['pulses'] =(np.log10(df.tot)    / 10).astype(np.float32)
    
    df = df.fillna(0.0)   # if exclude aux is possible problems for std?

    if len(suf):          # add a suffix to the column name
        cols = [ col + suf for col in df.columns]
        df.columns = cols

    return df

#-------------------------------------------------------------------------------

def get_pulse_features(df):
    """ """
    df.drop(columns=['line_id'], inplace=True)   # !!!! (embedding ?)

    df.q    = np.log(1+df.q)    
    for col in df.columns:
        if col not in ['sensor_id', 'event_id', 'line_id', 'tot']:
            df[col] = df[col].astype(np.float32)

    return df    

#===============================================================================
#                     Create dataset for train and validation
#===============================================================================

def get_files(batch_ids):
    files = [PATH / f"batch_{batch_id}.parquet"  for batch_id in batch_ids]
    return files, batch_ids

#-------------------------------------------------------------------------------

def append_dict(data, T, df, agg_df):
    """ 
    Add dataframes df and agg_df with the given number of pulses T to the data dictionary.
    The keys in this dictionary are the number of pulses T.
    df:     event_id	sensor_id	aux	q	t  tot
    agg_df: event_id, nx, ny, nz, tot, t_aver, ...., ux, uy, uz, qx, qy, qz
    """
    assert len(df) % T == 0,  f"wait len(df) = T*B, got len={len(df)}, T={T}"    
    B, F = len(df) // T, df.shape[-1] - 3 # drop: event_id, sensor_id, tot
    ID   = agg_df[['event_id']].to_numpy()
    Y    = agg_df[['nx','ny','nz']].to_numpy()
    AGG  = agg_df.iloc[:, 5:].to_numpy()
    SENS = df.sensor_id.to_numpy().reshape(B,T)
    # (B*T, F) -> (B, T, F) -> (B, F, T) -> (B, F*T)
    FEAT= df.iloc[:, 2: -1].to_numpy().reshape(B,T,F)      # drop tot !

    assert len(ID)==len(Y) and len(ID)==len(AGG) and len(ID)==len(SENS) and len(ID)==len(FEAT), \
           f"{ID.shape}, {Y.shape}, {AGG.shape}, {SENS.shape} {FEAT.shape} from df={df.shape} agg_df={agg_df.shape} (T={T},F={F})"

    if T in data:    # ID, Y, AGG, SENS, FEAT 
        v = data[T]
        v[0] = torch.vstack((v[0], torch.tensor(ID,   dtype=torch.long)    ))
        v[1] = torch.vstack((v[1], torch.tensor(SENS, dtype=torch.long)    ))
        v[2] = torch.cat   ((v[2], torch.tensor(FEAT, dtype=torch.float32) ), dim=0 )
        v[3] = torch.vstack((v[3], torch.tensor(AGG,  dtype=torch.float32) ))
        v[4] = torch.vstack((v[4], torch.tensor(Y,    dtype=torch.float32) ))
        
    else:       
        data[T] = [torch.tensor(ID,   dtype=torch.long   ),
                   torch.tensor(SENS, dtype=torch.long   ),
                   torch.tensor(FEAT, dtype=torch.float32),
                   torch.tensor(AGG,  dtype=torch.float32),
                   torch.tensor(Y,    dtype=torch.float32) ]                  
                       
#-------------------------------------------------------------------------------

def create_dataset(batch_ids, sensors_df, verbose):
    """ 
    Starting dataset generation
    """
    files, batch_ids = get_files(batch_ids)
    data, events_df  = {}, pd.DataFrame({'event_id': []})
    for i, (batch_id, fname) in tqdm(enumerate(zip(batch_ids, files))):         
        info(f"******  batch_id: {batch_id:3d}")
        df = pd.read_parquet(fname)            

        df = prepare_batch(df)
        df = cut_pulses(df, max_pulses=T_MAX)        

        df = df.merge(sensors_df, left_on="sensor_id", right_on="sensor_id", how="left")
        df = df[['event_id', 'line_id', 'sensor_id', 'core', 'aux', 'q', 't', 'x', 'y', 'z']]
        info(f"merged batch with sensors {df.shape}")    

        target_df = get_target_angles(batch_id=batch_id)
        target_df = angles2vector(target_df).drop(columns=['azimuth','zenith'])
        info("loaded target angles")

        if LINES_FILTER:
            df, agg = lines_filter(df, lines_min = LINES_MIN, lines_max = LINES_MAX)  
            target_df = target_df[target_df.event_id.isin(agg.event_id)]    
            del agg        
            info(f"lines filter done: {df.shape}")    
            #if verbose and i == 0: display(df)

        if IS_AGG:
            agg_df = get_event_features(df, target_df, suf="", aux=False)
            if not DROP_AUX:            
                if DOMS_AGG:  # при агригации некоторые сенсоры имеют нецелый aux (умножаем на 1-него)!
                    agg2_df = get_event_features(df, target_df, suf="_aux", aux=True)            
                else:         
                    agg2_df = get_event_features(df[ ~df.aux ].copy(), target_df, suf="_aux", aux=False)            
                agg_df = agg_df.merge(agg2_df, left_on='event_id',  right_on='event_id_aux', how='left')
                agg_df = agg_df.drop(columns = ['event_id_aux','tot_aux'] )

            agg_df = target_df.merge(agg_df, left_on="event_id", right_on="event_id", how="left")        
            info('get_event_features done')
        else:
            agg_df = target_df

        df = get_pulse_features(df)                
        df = df[['event_id', 'sensor_id', 'aux', 'q', 't']]  #   'core', 'x', 'y', 'z'

        info('get_pulse_features done')

        if IS_AGG:
            df = df.merge(agg_df[['event_id', 'tot']], left_on='event_id', right_on='event_id', how='left')
            if verbose and i == 0: show_stats(df, agg_df)

        tots = df.tot.unique()
        info(f"count pulses:  {tots.mean():.0f} [{tots.min()} ... {tots.max()}]")                                    
        for n in tqdm(tots): 
            # first pulse will be last (for RNN)
            d1 = df    [df.    tot == n].sort_values(['event_id','t'], ascending=[True,False])           
            d2 = agg_df[agg_df.tot == n].sort_values(['event_id']) if IS_AGG else None
            append_dict(data, n, d1, d2)
        cols_df, cols_agg_df = df.columns, agg_df.columns
        del df, agg_df
    info("collected data for dataset")                
        
    return data, events_df.reset_index(drop=True), cols_df, cols_agg_df

#===============================================================================
#                                Diagnostic
#===============================================================================

def show_stats(df, agg_df):
    """ 
    Displaying information about dataframes
    """
    pd.set_option('display.float_format', lambda x: '%.2f' % x)
    display(df.head(5))
    display(df.describe(percentiles=[]).transpose())    
    display(df.info())
    display(agg_df.head(2))
    display(agg_df.describe(percentiles=[]).transpose())                                        
    display(agg_df.info())

#-------------------------------------------------------------------------------

def plot_metric(err, prefix="", bins = 200):    
    """ 
    Build a histogram of errors; calculate the statistics and the share w of 'bad examples' 
    """
    plt.figure(figsize=(6,4), facecolor ='w') 
    plt.axes().set_facecolor("ivory"); plt.autoscale(tight=True)
    p,_,_ = plt.hist(err, bins=bins, range=(0,np.pi), fc="lightblue", density=True, alpha=0.5)
    w = 2*p[len(p)//2: ].sum()*np.pi/bins    
    x = np.linspace(0,np.pi,bins)
    plt.plot(x, w * 0.5*np.sin(x),   c="darkred")
    plt.plot(x, p-w * 0.5*np.sin(x), c="darkblue")
    plt.title(f"{prefix}mean={np.mean(err):.3f}, median={np.median(err):.3f}, w={w:.3f}")    
    plt.ylabel("Density"); plt.xlabel(r"$\Delta \Psi$ (rad)"); plt.grid()
    plt.show()


## Prepare Data and save to disk

In [None]:
%%time
info.beg = info.last = time.time()
info("begin")

if CREATE_DATASET:
    sensors_df = get_sensors()    
    info(f"loaded sensors pos: tot={len(sensors_df)}")
    display(sensors_df.head(3))
    doms = torch.tensor(sensors_df[['x','y','z','core','a','r']].astype(np.float32).to_numpy())
    torch.save( { 'cols': ['x','y','z','core','a','r'], 'data': doms }, f"{DATASET_FOLDER}/doms.pt")
    
    for i,batch_id in tqdm(enumerate(range(FIRST_BATCH_ID, FIRST_BATCH_ID+NUM_BATCHES,  BATCHES_IN_PACK)), total=NUM_BATCHES//BATCHES_IN_PACK):
        pack_id = batch_id // BATCHES_IN_PACK + 1
        data, _, cols_df, cols_agg_df = create_dataset(range(batch_id, batch_id + BATCHES_IN_PACK), sensors_df, i==0)
        torch.save({'cols_df':      cols_df, 
                    'cols_agg_df':  cols_agg_df, 
                    'data': data },   f"{DATASET_FOLDER}/pack_{pack_id:02d}.pt")        
        del data; gc.collect()
        info(f"created pack {pack_id:2d}")   