In [None]:
"""
Be sure to run the demo/grb_data_exploration.ipynb notebook first to get the preprocessed Fermi data.
"""
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import CosineSimilarity
from tensorflow.linalg import norm
from tensorflow.math import reduce_mean

from gw_grb_correlation.Fermi.data_preprocessing import create_dataframe_and_name_column_from_data_files
from gw_grb_correlation.Fermi.util import process_data, cosine_similarity_loss

from gw_grb_correlation.Fermi.visualization import evaluate_model_and_plot_accurracy
from gw_grb_correlation.Fermi.util import convert_cartesian_to_spherical

In [None]:
"""
Load data
"""
fermi_data = create_dataframe_and_name_column_from_data_files(data_type='fermi')

"""
List of detectors
"""
detectors = [f"n{i}" for i in range(10)] + ["na", "nb", "b0", "b1"]

"""
 Define input columns
"""
PH_CNT_columns = [f"{detector}_PH_CNT" for detector in detectors]
TRIG_columns = [f"{detector}_TRIG" for detector in detectors]
Orientation_columns = ['QSJ_1', 'QSJ_2', 'QSJ_3', 'QSJ_4']
fermi_data[PH_CNT_columns] = np.array(fermi_data[PH_CNT_columns].values.astype(np.float64)) * np.array(fermi_data[TRIG_columns].values.astype(np.float64))

all_data_columns = Orientation_columns + TRIG_columns + PH_CNT_columns

In [None]:
"""
Define function to train the GRB localization model
"""
# Custom cosine similarity loss
def cosine_similarity_loss(y_true, y_pred):
    import tensorflow as tf
    dot_product = reduce_sum(y_true * y_pred, axis=-1)
    norm_true = norm(y_true, axis=-1)
    norm_pred = norm(y_pred, axis=-1)
    cosine_sim = dot_product / (norm_true * norm_pred)
    return -cosine_sim

def train_GRB_localization_model(fermi_data, input_columns):
    """
    Train a GRB localization model using the provided Fermi data.

    Parameters:
    fermi_data (DataFrame): The Fermi data containing the necessary columns.
    input_columns (list): List of input columns to be used for training.

    Returns:
    model: The trained Keras model.
    """

    """
    Split the data into training and testing sets
    """
    X_scaled, X_train_scaled, X_test_scaled, y, y_train, y_test = process_data(fermi_data, input_columns)

    """
    Define model with Dropout
    """
    Dropout_rate = 0.05
    model = Sequential([
        Dense(128, input_dim=X_train_scaled.shape[1], activation='relu'),
        Dropout(Dropout_rate),
        Dense(512, activation='relu'),
        Dropout(Dropout_rate),
        Dense(512, activation='relu'),
        Dropout(Dropout_rate),
        Dense(3, activation=None),
    ])

    """
    Compile model
    """
    model.compile(
        optimizer=Adam(learning_rate=0.00002),
        loss=cosine_similarity_loss,
        metrics=[CosineSimilarity(name='cosine_similarity')]
    )
    
    """
    Print model summary
    """
    model.summary()

    """
    Train model
    """
    history = model.fit(
        X_train_scaled, y_train,
        epochs=400,
        batch_size=16,
        validation_data=(X_test_scaled, y_test)
    )

    model_path = "model.h5"
    """
    Save the trained model
    """
    model.save(model_path)
    print(f"Model trained and saved as {model_path}")

    """
    Evaluate model and plot accuracy
    """
    train_acc, val_acc = evaluate_model_and_plot_accurracy(model, history, X_test_scaled, y_test)
    return model, X_scaled, val_acc

In [None]:
"""
 Train the model with all data
"""
model, X_scaled, val_acc_all_data = train_GRB_localization_model(fermi_data, all_data_columns)

"""
 Make predictions based on the trained model
"""
predictions = model.predict(X_scaled)

"""
 Convert predictions from Cartesian to spherical coordinates (RA, DEC)
"""
RA_DEC_predictions = np.array([convert_cartesian_to_spherical(pred) for pred in predictions])
fermi_predict_data = fermi_data.copy()

fermi_predict_data['RA'] = RA_DEC_predictions[:, 0]
fermi_predict_data['DEC'] = RA_DEC_predictions[:, 1]

"""
 Save the predictions to a CSV file
"""
fermi_predict_data.to_csv("fermi_predict_data.csv", index=False)

In [None]:
"""
We will now do time correlation analysis again with the new localization data predicted by the model.
We will filter out short GRB data and find matched GRB-GW event pairs using the new location data of short GRBs.
"""
from gw_grb_correlation.Fermi.util import filtering
from gw_grb_correlation.Fermi.util import read_GW_data, remove_duplicate_times_in_gw_data, compare_time_within_range

"""
Filter out short GRB data
"""
short_GRB_data = filtering(fermi_predict_data, criteria={'T90': lambda x: x <= 2.1})

"""
Load GW data
"""
gw_data = read_GW_data(f"./gw_data/totalgwdata.csv")
gw_times = remove_duplicate_times_in_gw_data(gw_data)

"""
Find matched GRB-GW event pairs using new location data of short GRBs from the model
"""
match = compare_time_within_range(short_GRB_data, gw_times, time_range_seconds=86400*3)
filtered_gw_events = gw_data[gw_data['times'].isin(match['gw_time'])]

"""
Save the matched GRB-GW event pairs and filtered GW events to CSV files
"""
match.to_csv("GRB_GW_event_pairs_predict.csv", index=False)
filtered_gw_events.to_csv("Filtered_GW_events_predict.csv", index=False)


In [None]:
"""
This time, we test model trained without spacecraft orientation data
"""
no_spacecraft_columns = TRIG_columns + PH_CNT_columns
model, X_scaled, val_acc_no_ori = train_GRB_localization_model(fermi_data, no_spacecraft_columns)

In [None]:
"""
This time, we test model trained without trigger data
"""
no_trigger_columns = PH_CNT_columns + Orientation_columns
model, X_scaled, val_acc_no_trig = train_GRB_localization_model(fermi_data, no_trigger_columns)

In [None]:
"""
This time, we test model trained without photon count data
"""
no_pht_cnt_columns = TRIG_columns + Orientation_columns
model, X_scaled, val_acc_no_pht_cnt = train_GRB_localization_model(fermi_data, no_pht_cnt_columns)

In [None]:
"""
Compare the validation accuracies of the models trained with different data combinations
"""
plt.figure(figsize=(10, 6))
plt.plot(val_acc_all_data, label='Trained with all data')
plt.plot(val_acc_no_ori, label='Trained without spacecraft orientation')
plt.plot(val_acc_no_trig, label='Trained without trigger data')
plt.plot(val_acc_no_pht_cnt, label='Trained without photon count data')
plt.xlabel('Epochs')
plt.ylabel('Cosine Similarity')
plt.title('Validation Cosine Similarity')
plt.legend()
plt.grid(True)
plt.show()