In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import sqlite3
import sqlalchemy
import gc
import warnings
warnings.filterwarnings('ignore')

def prepare_dataframe(df, angle_post_fix = '_pred', vec_post_fix = '') -> pd.DataFrame:
    r = np.sqrt(df['direction_x'+ vec_post_fix]**2 + df['direction_y'+ vec_post_fix]**2 + df['direction_z' + vec_post_fix]**2)
    df['zenith' + angle_post_fix] = np.arccos(df['direction_z'+ vec_post_fix]/r)
    df['azimuth'+ angle_post_fix] = np.arctan2(df['direction_y'+ vec_post_fix],df['direction_x' + vec_post_fix])
    df['azimuth'+ angle_post_fix][df['azimuth'  + angle_post_fix]<0] = df['azimuth'  + angle_post_fix][df['azimuth'  +  angle_post_fix]<0] + 2*np.pi 

    return df[['kappa', 'azimuth_pred', 'zenith_pred']]

import numpy as np


def angular_dist_score(az_true, zen_true, az_pred, zen_pred):
    
    sa1 = np.sin(az_true)
    ca1 = np.cos(az_true)
    sz1 = np.sin(zen_true)
    cz1 = np.cos(zen_true)
    
    sa2 = np.sin(az_pred)
    ca2 = np.cos(az_pred)
    sz2 = np.sin(zen_pred)
    cz2 = np.cos(zen_pred)
    
    scalar_prod = sz1*sz2*(ca1*ca2 + sa1*sa2) + (cz1*cz2)
    scalar_prod =  np.clip(scalar_prod, -1, 1)

    return np.abs(np.arccos(scalar_prod))

In [None]:
b_id = 5

In [None]:
database_path = f'data/B5/extra_big_batch_5.db'
engine = sqlalchemy.create_engine("sqlite:///" + database_path)

with sqlite3.connect(database_path) as con:
    query = 'SELECT event_id, azimuth, zenith from meta_table'
    df_database = pd.read_sql(query,con).set_index('event_id')

df_database.shape

In [None]:
df_all = pd.DataFrame()
m_list = [0,1,2,3]
# m_list.remove(b_id)
for m_id in m_list:
    df_result = pd.read_pickle(f'inference/pred_M{m_id}_B{b_id}.pkl')
    df_result.rename(columns={'direction_kappa': 'kappa'}, inplace=True)
    df_predict = prepare_dataframe(df_result)
    df_final = df_predict.join(df_database, how = 'inner')
    df_final['error'] = angular_dist_score(df_final['azimuth'], df_final['zenith'], df_final['azimuth_pred'], df_final['zenith_pred'])
    df_final.rename({'kappa': f'kappa_m{m_id}', 'error': f'error_m{m_id}'}, axis=1, inplace=True)
    df_all = df_all.join(df_final[[f'kappa_m{m_id}', f'error_m{m_id}']], how='outer')
del df_final, df_predict, df_result
gc.collect()

In [None]:
df_all.to_pickle(f'process_events/pred_B{b_id}.pkl')

In [None]:
df_all