# Train Neural Network for Classification of Basic Blocks

In [None]:
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.optimizers import Adam

from sklearn.model_selection import train_test_split
from includes.helpers import get_bb_train_test_set

In [None]:
tf.config.list_physical_devices('GPU')

In [None]:
X_train, X_test, y_train, y_test = get_bb_train_test_set()

In [None]:
# Build the binary classification model
model = Sequential([
    Input(shape=(X_train.shape[1],)),
    Dense(128, activation='relu'),  # Input layer (with 64 neurons)
    Dense(64, activation='relu'),  # Input layer (with 64 neurons)
    Dense(32, activation='relu'),  # Hidden layer
    Dense(1, activation='sigmoid')  # Output layer (single neuron for binary classification)
])


# Compile the model
model.compile(optimizer=Adam(), loss='binary_crossentropy', metrics=['accuracy'])

In [None]:
# Train the model
MODEL_EPOCHS = 21
MODEL_BATCH = 32
history = model.fit(X_train, y_train, epochs=MODEL_EPOCHS, batch_size=MODEL_BATCH, validation_data=(X_test, y_test))
# Save the entire model (architecture, weights, optimizer, and training configuration)

In [None]:
# Evaluate the model on the test data
loss, accuracy = model.evaluate(X_test, y_test)
print(f'Test Accuracy: {accuracy * 100:.2f}%')

In [None]:
model.save(f'testBBpredict_TF_{str(MODEL_EPOCHS)}_{str(MODEL_BATCH)}_{round(accuracy*100)}.keras')  # Saved in HDF5 format, or use .tf for SavedModel format
# model.save(f'testBBpredict_TF_{str(MODEL_EPOCHS)}_{str(MODEL_BATCH)}_{round(accuracy*100)}MAC.keras')  # Saved in HDF5 format, or use .tf for SavedModel format

In [None]:
# Predict on the test set
y_pred = model.predict(X_test)
# Convert probabilities to binary class predictions (0 or 1)
y_pred_binary = (y_pred > 0.5).astype(int)

### Model History (Accuracy per Epoch)

In [None]:
import matplotlib.pyplot as plt

plt.plot(history.history['accuracy'], label='train accuracy')
plt.plot(history.history['val_accuracy'], label='val accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
