# 2n stage training

* In the 2nd stage, the training data is divided into two parts, easy and hard, based on the prediction results of the 1st stage.
* finetuning from 1st stage training
* part A training easy part

In [None]:
import wandb
wandb_api = "xxxxx" #user_secrets.get_secret("wandb_api")
wandb.login(key=wandb_api)
    
from pytorch_lightning.loggers import WandbLogger

NAME='split02-1'

wandb_logger = WandbLogger(
    project="ICECUBE",
    entity="yamsam",
    name=NAME,
    group='graphnet-tune',
    log_model='all',
)

In [None]:
train_db = '../../../input/batch_301-350.db'
valid_db = '../../../input/dynedge-pretrained/batch_51.db'

error_file = '../split01/result_300-350.csv' # prediction from 1st model include direction kappa
retrain_dict = '../split01/last_state_dict_split01_first.pth'
first_half=True

# Append to PATH
import sys
sys.path.append('../../../input/graphnet/src')

In [None]:
import graphnet

In [None]:
import pyarrow.parquet as pq
import sqlite3
import pandas as pd
import sqlalchemy
from tqdm import tqdm
import os
from typing import Any, Dict, List, Optional
import numpy as np

from graphnet.data.sqlite.sqlite_utilities import create_table

def load_input(meta_batch: pd.DataFrame, input_data_folder: str) -> pd.DataFrame:
        """
        Will load the corresponding detector readings associated with the meta data batch.
        """
        batch_id = pd.unique(meta_batch['batch_id'])

        assert len(batch_id) == 1, "contains multiple batch_ids. Did you set the batch_size correctly?"
        
        detector_readings = pd.read_parquet(path = f'{input_data_folder}/batch_{batch_id[0]}.parquet')
        sensor_positions = geometry_table.loc[detector_readings['sensor_id'], ['x', 'y', 'z']]
        sensor_positions.index = detector_readings.index

        for column in sensor_positions.columns:
            if column not in detector_readings.columns:
                detector_readings[column] = sensor_positions[column]

        detector_readings['auxiliary'] = detector_readings['auxiliary'].replace({True: 1, False: 0})
        return detector_readings.reset_index()

def add_to_table(database_path: str,
                      df: pd.DataFrame,
                      table_name:  str,
                      is_primary_key: bool,
                      ) -> None:
    """Writes meta data to sqlite table. 

    Args:
        database_path (str): the path to the database file.
        df (pd.DataFrame): the dataframe that is being written to table.
        table_name (str, optional): The name of the meta table. Defaults to 'meta_table'.
        is_primary_key(bool): Must be True if each row of df corresponds to a unique event_id. Defaults to False.
    """
    try:
        create_table(   columns=  df.columns,
                        database_path = database_path, 
                        table_name = table_name,
                        integer_primary_key= is_primary_key,
                        index_column = 'event_id')
    except sqlite3.OperationalError as e:
        if 'already exists' in str(e):
            pass
        else:
            raise e
    engine = sqlalchemy.create_engine("sqlite:///" + database_path)
    df.to_sql(table_name, con=engine, index=False, if_exists="append", chunksize = 200000)
    engine.dispose()
    return

def convert_to_sqlite(meta_data_path: str,
                      database_path: str,
                      input_data_folder: str,
                      batch_size: int = 200000,
                      batch_ids: Optional[List[int]] = None,) -> None:
    """Converts a selection of the Competition's parquet files to a single sqlite database.

    Args:
        meta_data_path (str): Path to the meta data file.
        batch_size (int): the number of rows extracted from meta data file at a time. Keep low for memory efficiency.
        database_path (str): path to database. E.g. '/my_folder/data/my_new_database.db'
        input_data_folder (str): folder containing the parquet input files.
        batch_ids (List[int]): The batch_ids you want converted. Defaults to None (all batches will be converted)
    """
    if batch_ids is None:
        batch_ids = np.arange(1,661,1).to_list()
    else:
        assert isinstance(batch_ids,list), "Variable 'batch_ids' must be list."
    if not database_path.endswith('.db'):
        database_path = database_path+'.db'
    meta_data_iter = pq.ParquetFile(meta_data_path).iter_batches(batch_size = batch_size)
    batch_id = 1
    converted_batches = []
    progress_bar = tqdm(total = len(batch_ids))
    for meta_data_batch in meta_data_iter:
        if batch_id in batch_ids:
            meta_data_batch  = meta_data_batch.to_pandas()
            add_to_table(database_path = database_path,
                        df = meta_data_batch,
                        table_name='meta_table',
                        is_primary_key= True)
            pulses = load_input(meta_batch=meta_data_batch, input_data_folder= input_data_folder)
            del meta_data_batch # memory
            add_to_table(database_path = database_path,
                        df = pulses,
                        table_name='pulse_table',
                        is_primary_key= False)
            del pulses # memory
            progress_bar.update(1)
            converted_batches.append(batch_id)
        batch_id +=1
        if len(batch_ids) == len(converted_batches):
            break
    progress_bar.close()
    del meta_data_iter # memory
    print(f'Conversion Complete!. Database available at\n {database_path}')

## Defining A Selection

In [None]:
from sklearn.model_selection import train_test_split

def make_selection(df: pd.DataFrame, pulse_threshold: int = 200) -> None:
    """Creates a validation and training selection (20 - 80). All events in both selections satisfies n_pulses <= 200 by default. """
    n_events = np.arange(0, len(df),1)
    train_selection, validate_selection = train_test_split(n_events, 
                                                                    shuffle=True, 
                                                                    random_state = 42, 
                                                                    test_size=0.20) 
    df['train'] = 0
    df['validate'] = 0
    
    df['train'][train_selection] = 1
    df['validate'][validate_selection] = 1
    
    assert len(train_selection) == sum(df['train'])
    assert len(validate_selection) == sum(df['validate'])
    
    for selection in ['train', 'validate']:
        df.loc[df[selection] == 1, :].to_csv(f'{selection}_selection_max_{pulse_threshold}_pulses.csv')
    return

def get_number_of_pulses(db: str, event_id: int, pulsemap: str) -> int:
    with sqlite3.connect(db) as con:
        query = f'select event_id from {pulsemap} where event_id = {event_id} limit 20000'
        data = con.execute(query).fetchall()
    return len(data)

def count_pulses(database: str, pulsemap: str) -> pd.DataFrame:
    """ Will count the number of pulses in each event and return a single dataframe that contains counts for each event_id."""
    with sqlite3.connect(database) as con:
        query = 'select event_id from meta_table'
        events = pd.read_sql(query,con)
    counts = {'event_id': [],
              'n_pulses': []}
    for event_id in tqdm(events['event_id']):
        a = get_number_of_pulses(database, event_id, pulsemap)
        counts['event_id'].append(event_id)
        counts['n_pulses'].append(a)
    df = pd.DataFrame(counts)
    df.to_csv('counts.csv')
    return df

In the 2nd stage, the training data is divided into two parts, easy and hard, based on the prediction results of the 1st stage.

In [None]:
def make_selection_half(df: pd.DataFrame, postfix) -> None:
    """Creates a validation and training selection (20 - 80). All events in both selections satisfies n_pulses <= 200 by default. """
    n_events = np.arange(0, len(df),1)
    train_selection, validate_selection = train_test_split(n_events, 
                                                                    shuffle=True, 
                                                                    random_state = 42, 
                                                                    test_size=0.20) 
    df['train'] = 0
    df['validate'] = 0
    
    df['train'][train_selection] = 1
    df['validate'][validate_selection] = 1
    
    assert len(train_selection) == sum(df['train'])
    assert len(validate_selection) == sum(df['validate'])

    for selection in ['train', 'validate']:
        df.loc[df[selection] == 1, :].to_csv(f'{selection}_selection_{postfix}_pulses.csv')
    return

pulsemap = 'pulse_table'
database = train_db

df = count_pulses(database, pulsemap)

edf = pd.read_csv(error_file)
edf['event_id']=edf.event_id.astype(int)
edf['first_half'] = 1/np.sqrt(edf['direction_kappa']) <= 0.5
df = df.merge(edf[['event_id', 'first_half']], on='event_id')
df1 = df[df.first_half ==1].reset_index(drop=True)
df2 = df[df.first_half ==0].reset_index(drop=True)
print (df1.shape, df2.shape)

make_selection_half(df1, 'first')
make_selection_half(df2, 'second')

## Training DynEdge

In [None]:
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor

from torch.optim.adam import Adam
from graphnet.data.constants import FEATURES, TRUTH
from graphnet.models import StandardModel
from graphnet.models.detector.icecube import IceCubeKaggle
from graphnet.models.gnn import DynEdge
from graphnet.models.graph_builders import KNNGraphBuilder
from graphnet.models.task.reconstruction import DirectionReconstructionWithKappa, ZenithReconstructionWithKappa, AzimuthReconstructionWithKappa
from graphnet.training.callbacks import ProgressBar, PiecewiseLinearLR
from graphnet.training.loss_functions import VonMisesFisher3DLoss, VonMisesFisher2DLoss
from graphnet.training.labels import Direction
from graphnet.training.utils import make_dataloader
from graphnet.utilities.logging import Logger
from pytorch_lightning import Trainer
import pandas as pd

logger = Logger()

def build_model(config: Dict[str,Any], train_dataloader: Any) -> StandardModel:
    """Builds GNN from config"""
    # Building model
    detector = IceCubeKaggle(
        graph_builder=KNNGraphBuilder(nb_nearest_neighbours=8),
    )
    gnn = DynEdge(
        nb_inputs=detector.nb_outputs,
        global_pooling_schemes=["min", "max", "mean"],
    )

    if config["target"] == 'direction':
        task = DirectionReconstructionWithKappa(
            hidden_size=gnn.nb_outputs,
            target_labels=config["target"],
            loss_function=VonMisesFisher3DLoss(),
        )
        prediction_columns = [config["target"] + "_x", 
                              config["target"] + "_y", 
                              config["target"] + "_z", 
                              config["target"] + "_kappa" ]
        additional_attributes = ['zenith', 'azimuth', 'event_id']

    model = StandardModel(
        detector=detector,
        gnn=gnn,
        tasks=[task],
        optimizer_class=Adam,        
#       optimizer_kwargs={"lr": 1e-05, "eps": 1e-03},
        optimizer_kwargs={"lr": 1e-03, "eps": 1e-03},
#       optimizer_kwargs={"lr": 1e-06, "eps": 1e-03},
        scheduler_class=PiecewiseLinearLR,
        scheduler_kwargs={
           "milestones": [
               0,
               len(train_dataloader) / 2,
               len(train_dataloader) * config["fit"]["max_epochs"],
           ],
           "factors": [1e-01, 1, 1e-02],
       },
        scheduler_config={
            "interval": "step",
        },
    )
    model.prediction_columns = prediction_columns
    model.additional_attributes = additional_attributes
    
    return model

def load_pretrained_model(config: Dict[str,Any], state_dict_path: str = '../../../input/dynedge-pretrained/dynedge_pretrained_batch_1_to_50/state_dict.pth') -> StandardModel:
    train_dataloader, _ = make_dataloaders(config = config)
    model = build_model(config = config, 
                        train_dataloader = train_dataloader)
    #model._inference_trainer = Trainer(config['fit'])
    model.load_state_dict(state_dict_path)
    model.prediction_columns = [config["target"] + "_x", 
                              config["target"] + "_y", 
                              config["target"] + "_z", 
                              config["target"] + "_kappa" ]
    model.additional_attributes = ['zenith', 'azimuth', 'event_id']
    return model

def make_dataloaders(config: Dict[str, Any]) -> List[Any]:
    """Constructs training and validation dataloaders for training with early stopping."""
    train_dataloader = make_dataloader(db = config['path'],
                                            selection = pd.read_csv(config['train_selection'])[config['index_column']].ravel().tolist(),
                                            pulsemaps = config['pulsemap'],
                                            features = features,
                                            truth = truth,
                                            batch_size = config['batch_size'],
                                            num_workers = config['num_workers'],
                                            shuffle = False,
                                            labels = {'direction': Direction()},
                                            index_column = config['index_column'],
                                            truth_table = config['truth_table'],
                                            sample_limit = config['sample_limit']
                                            )
    
    validate_dataloader = make_dataloader(db = config['path'],
                                            selection = pd.read_csv(config['validate_selection'])[config['index_column']].ravel().tolist(),
                                            pulsemaps = config['pulsemap'],
                                            features = features,
                                            truth = truth,
                                            batch_size = config['batch_size'],
                                            num_workers = config['num_workers'],
                                            shuffle = False,
                                            labels = {'direction': Direction()},
                                            index_column = config['index_column'],
                                            truth_table = config['truth_table'],
                                            sample_limit = config['sample_limit']                                       
                                            )
    return train_dataloader, validate_dataloader

def train_dynedge_from_scratch(config: Dict[str, Any]) -> StandardModel:
    """Builds and trains GNN according to config."""
    logger.info(f"features: {config['features']}")
    logger.info(f"truth: {config['truth']}")
    
    archive = os.path.join(config['base_dir'], "train_model_without_configs")
    run_name = f"dynedge_{config['target']}_{config['run_name_tag']}"

    train_dataloader, validate_dataloader = make_dataloaders(config = config)

    model = build_model(config, train_dataloader)
   
    loss_checkpoint = ModelCheckpoint(
        dirpath='./',
        filename=f"best_loss_{NAME}",
        monitor="val_loss",
        save_last=True,
        save_top_k=1,
        save_weights_only=True,
        mode="min",
    )
    
    # Training model
    callbacks = [
        EarlyStopping(
            monitor="val_loss",
            patience=config["early_stopping_patience"],
        ),
        loss_checkpoint
        #TQDMProgressBar(refresh_rate=10)
#        ProgressBar(),
    ]
    
    model.fit(
        train_dataloader,
        validate_dataloader,
        callbacks=callbacks,
        **config["fit"],
        logger=wandb_logger
    )
    return model

def inference(model, config: Dict[str, Any]) -> pd.DataFrame:
    """Applies model to the database specified in config['inference_database_path'] and saves results to disk."""
    # Make Dataloader
    test_dataloader = make_dataloader(db = config['inference_database_path'],
                                            selection = None, # Entire database
                                            pulsemaps = config['pulsemap'],
                                            features = features,
                                            truth = truth,
                                            batch_size = config['batch_size'],
                                            num_workers = config['num_workers'],
                                            shuffle = False,
                                            labels = {'direction': Direction()},
                                            index_column = config['index_column'],
                                            truth_table = config['truth_table'],
                                            )
    
    # Get predictions
    results = model.predict_as_dataframe(
        gpus = [0],
        dataloader = test_dataloader,
        prediction_columns=model.prediction_columns,
        additional_attributes=model.additional_attributes,
    )
    # Save predictions and model to file
    archive = os.path.join(config['base_dir'], "train_model_without_configs")
    run_name = f"dynedge_{config['target']}_{config['run_name_tag']}"
    db_name = config['path'].split("/")[-1].split(".")[0]
    path = os.path.join(archive, db_name, run_name)
    logger.info(f"Writing results to {path}")
    os.makedirs(path, exist_ok=True)

    results.to_csv(f"{path}/results.csv")
    return results

class MyProgressBar(TQDMProgressBar):
    def init_validation_tqdm(self):
        bar = super().init_validation_tqdm()
        if not sys.stdout.isatty():
            bar.disable = True
        return bar

    def init_predict_tqdm(self):
        bar = super().init_predict_tqdm()
        if not sys.stdout.isatty():
            bar.disable = True
        return bar

    def init_test_tqdm(self):
        bar = super().init_test_tqdm()
        if not sys.stdout.isatty():
            bar.disable = True
        return bar

    
def train_dynedge_restart(model, config: Dict[str, Any]) -> StandardModel:
    """Builds and trains GNN according to config."""
    logger.info(f"features: {config['features']}")
    logger.info(f"truth: {config['truth']}")
    
    archive = os.path.join(config['base_dir'], "train_model_without_configs")
    run_name = f"dynedge_{config['target']}_{config['run_name_tag']}"

    train_dataloader, validate_dataloader = make_dataloaders(config = config)

    #model = build_model(config, train_dataloader)
    lr_monitor = LearningRateMonitor(logging_interval='step')
    # Training model
    
    loss_checkpoint = ModelCheckpoint(
        dirpath='./',
        filename=f"best_loss_{NAME}",
        monitor="val_loss",
        save_last=True,
        save_top_k=1,
        save_weights_only=True,
        mode="min",
    )
    
    callbacks = [
        EarlyStopping(
            monitor="val_loss",
            patience=config["early_stopping_patience"],
        ),
        MyProgressBar(),
        lr_monitor,
        loss_checkpoint
    ]

    model.fit(
        train_dataloader,
        validate_dataloader,
        callbacks=callbacks,
        **config["fit"],
        logger=wandb_logger
    )
    return model


In [None]:
# Constants
features = FEATURES.KAGGLE
truth = TRUTH.KAGGLE

# Configuration
config = {
        "path": train_db,
        "inference_database_path": valid_db,
        "pulsemap": 'pulse_table',
        "truth_table": 'meta_table',
        "features": features,
        "truth": truth,
        "index_column": 'event_id',
        "run_name_tag": 'my_example',
        "batch_size":512, # 512
        "sample_limit":1000,
        "num_workers": 8,
        "target": 'direction',
        "early_stopping_patience": 5,
        "fit": {
                "max_epochs": 20,
                "gpus": [0],
                "distribution_strategy": None,
                #"resume_from_checkpoint": "../exp10"
                #"enable_progress_bar":False
                },
#        'train_selection': 'train_selection_first_pulses.csv',
#        'validate_selection': 'validate_selection_second_pulses.csv',
        'test_selection': None,
        'base_dir': 'training'
}

if first_half:
    config['train_selection'] = 'train_selection_first_pulses.csv'
    config['validate_selection'] = 'validate_selection_first_pulses.csv'
else:
    config['train_selection'] = 'train_selection_second_pulses.csv'
    config['validate_selection'] = 'validate_selection_second_pulses.csv'
 

In [None]:
# Train from scratch (slow) - remember to save it!
#model = train_dynedge_from_scratch(config = config)

# Load state-dict from pre-trained model (faster)
model = load_pretrained_model(config, retrain_dict)
model = train_dynedge_restart(model, config = config)

In [None]:
import torch
if first_half:
    torch.save(model.state_dict(), f'last_state_dict_{NAME}_first.pth')
else:
    torch.save(model.state_dict(), f'last_state_dict_{NAME}_second.pth')


## Inference & Evaluation

With a trained model loaded into memory, we can now apply the model to batch_51. The following cells will start inference (or load in a csv with predictions, if you're in a hurry) and plot the results. 

In [None]:
# Inference
del model
import gc
gc.collect()
torch.cuda.empty_cache()
if first_half:
    path_name =  f'last_state_dict_{NAME}_first.pth'
else:
    path_name =  f'last_state_dict_{NAME}_second.pth'

model = load_pretrained_model(config = config, state_dict_path = path_name)
config["batch_size"] = 32
results = inference(model, config)

In [None]:
def convert_to_3d(df: pd.DataFrame) -> pd.DataFrame:
    """Converts zenith and azimuth to 3D direction vectors"""
    df['true_x'] = np.cos(df['azimuth']) * np.sin(df['zenith'])
    df['true_y'] = np.sin(df['azimuth'])*np.sin(df['zenith'])
    df['true_z'] = np.cos(df['zenith'])
    return df

def calculate_angular_error(df : pd.DataFrame) -> pd.DataFrame:
    """Calcualtes the opening angle (angular error) between true and reconstructed direction vectors"""
    df['angular_error'] = np.arccos(df['true_x']*df['direction_x'] + df['true_y']*df['direction_y'] + df['true_z']*df['direction_z'])
    return df

In [None]:
results = convert_to_3d(results)
results = calculate_angular_error(results)
results.to_csv(f'result_{first_half}.csv', index=False)

In [None]:
score = results["angular_error"].mean()
print ('score=', score)

In [None]:
from matplotlib import pyplot as plt
fig = plt.figure(figsize = (6,6))
plt.hist(results['angular_error'], 
         bins = np.arange(0,np.pi*2, 0.05), 
         histtype = 'step', 
         label = f'mean angular error: {np.round(results["angular_error"].mean(),2)}')
plt.xlabel('Angular Error [rad.]', size = 15)
plt.ylabel('Counts', size = 15)
plt.title('Angular Error Distribution (Batch 51)', size = 15)
plt.legend(frameon = False, fontsize = 15)

In [None]:
wandb.finish()