# Keras Train BNN Classifier for FANET Reliability
Date: 17/03/2023
Desc: To train a BNN classifier to predict FANET reliability 

## Import libs

In [4]:
import pandas as pd
import numpy as np 
import matplotlib.pyplot as plt
import sklearn

# Import necessary modules
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from math import sqrt

# Keras specific
import tensorflow as tf
import tensorflow_probability as tfp
import keras
from keras.models import Model
from keras.layers import Dense, Input
from keras.utils import to_categorical 
import pickle

## Load and pre-process data

In [5]:
# TODO: Compile all data
delay_threshold = 0.04

dl_df = pd.read_hdf("/media/research-student/One Touch/FANET Datasets/Dataset_NP10000_BPSK_6-5Mbps/Dataset_NP10000_BPSK_6-5Mbps_8UAVs_processed_downlink.h5", '8_UAVs')

data_df = dl_df[["U2G_H_Dist", "Height", "Num_Members", "Bytes", "Sending_Interval", "Incorrectly_Received", "Queue_Overflow"]].copy()
data_df["Reliable"] = np.where(dl_df['Packet_State'] == "Reliable" , 1, 0)
data_df["Delay_Exceeded"] = np.where(dl_df['Delay'] > delay_threshold, 1, 0)

# Normalize data
max_h_dist = 550
max_height = 120
max_num_members = 39
max_bytes = 1145 # Should be 1144, but put 1145 just in case
data_df["U2G_H_Dist"] = data_df["U2G_H_Dist"].div(max_h_dist)
data_df["Height"] = data_df["Height"].div(max_height)
data_df["Num_Members"] = data_df["Num_Members"].div(max_num_members)
data_df["Bytes"] = data_df["Bytes"].div(max_bytes)

# Split to train and test
data_df_train, data_df_test = train_test_split(data_df, test_size=0.10, random_state=40, shuffle=False)
X_train = data_df_train[["U2G_H_Dist", "Height", "Num_Members", "Bytes", "Sending_Interval"]].values
X_test = data_df_test[["U2G_H_Dist", "Height", "Num_Members", "Bytes", "Sending_Interval"]].values
reliability_train = data_df_train["Reliable"].values
reliability_test = data_df_test["Reliable"].values
# incr_rcvd_train = data_df_train["Incorrectly_Received"].values
# incr_rcvd_test = data_df_test["Incorrectly_Received"].values
# delay_excd_train = data_df_train["Delay_Exceeded"].values
# delay_excd_test = data_df_test["Delay_Exceeded"].values
# queue_overflow_train = data_df_train["Queue_Overflow"].values
# queue_overflow_test = data_df_test["Queue_Overflow"].values

reliability_train = to_categorical(reliability_train) 
reliability_test = to_categorical(reliability_test)
# incr_rcvd_train = to_categorical(incr_rcvd_train) 
# incr_rcvd_test = to_categorical(incr_rcvd_test)
# delay_excd_train = to_categorical(delay_excd_train) 
# delay_excd_test = to_categorical(delay_excd_test)
# queue_overflow_train = to_categorical(queue_overflow_train) 
# queue_overflow_test = to_categorical(queue_overflow_test)

## Train the model

In [13]:
inputs = Input(shape=(5,))
base = tfp.layers.DenseFlipout(50, activation='relu')(inputs)
base = tfp.layers.DenseFlipout(25, activation='relu')(base)
base = tfp.layers.DenseFlipout(10, activation='relu')(base)
reliability_out = Dense(2, activation='softmax', name='reliability')(base)
# incr_rcvd_out = Dense(8, activation='softmax', name='incorrectly_received')(base)
# delay_excd_out = Dense(2, activation='softmax', name='delay_exceeded')(base)
# queue_overflow_out = Dense(2, activation='softmax', name='queue_overflow')(base)
# model = Model(inputs=inputs, outputs = [reliability_out, incr_rcvd_out, delay_excd_out, queue_overflow_out])
model = Model(inputs=inputs, outputs=reliability_out)
              
# Compile the model
model.compile(optimizer='adam', 
              loss={'reliability': 'binary_crossentropy'},
              metrics={'reliability': 'accuracy'})
# model.compile(optimizer='adam', 
#               loss={'reliability': 'binary_crossentropy',
#                     'incorrectly_received': 'categorical_crossentropy',
#                     'delay_exceeded': 'binary_crossentropy',
#                     'queue_overflow': 'binary_crossentropy'},
#               metrics={'reliability': 'accuracy',
#                     'incorrectly_received': 'accuracy',
#                     'delay_exceeded': 'accuracy',
#                     'queue_overflow': 'accuracy'},)

EPOCHS = 1
checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    mode='auto',
    save_best_only=True)

  loc = add_variable_fn(
  untransformed_scale = add_variable_fn(


In [14]:
# Y_train = [reliability_train, incr_rcvd_train, delay_excd_train, queue_overflow_train]
# Y_test = [reliability_test, incr_rcvd_test, delay_excd_test, queue_overflow_test]
Y_train = reliability_train
Y_test = reliability_test
history = model.fit(X_train, Y_train, epochs=EPOCHS, callbacks=[model_checkpoint_callback], validation_data=(X_test, Y_test))
# with open('/trainHistoryDict', 'wb') as file_pi:
#     pickle.dump(history.history, file_pi)

 22146/175873 [==>...........................] - ETA: 13:09 - loss: 299.5222 - accuracy: 0.5658

KeyboardInterrupt: 