In [53]:
import pickle

import numpy as np
import plotly.express as px
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers

# Loading the data

In [54]:
with open("data/randomized/train_data", "rb") as f:
    train_data = pickle.load(f)

train_features = train_data[0]
train_targets  = train_data[1]

In [55]:
with open("data/randomized/val_data", "rb") as f:
    val_data = pickle.load(f)

val_features = val_data[0]
val_targets  = val_data[1]

In [56]:
with open("data/randomized/test_data", "rb") as f:
    test_data = pickle.load(f)

test_features = test_data[0]
test_targets  = test_data[1]

# Training a baseline model

In [57]:
def reshape_to_train(unshaped: list) -> np.array:
    shaped = np.array([np.array(sample).reshape(-1, 1) for sample in unshaped])
    shaped = shaped.reshape(shaped.shape[0], shaped.shape[1])

    return shaped

In [58]:
mlp_model = Sequential()

mlp_model.add(layers.Input(shape = (21, 2, ), dtype = "int32"))
mlp_model.add(layers.Flatten())
mlp_model.add(layers.Dense(64, activation="relu"))
mlp_model.add(layers.Dense(256, activation="relu"))
mlp_model.add(layers.Dense(128, activation="relu"))
mlp_model.add(layers.Dense(64, activation="relu"))
mlp_model.add(layers.Dense(10, activation="softmax"))

mlp_model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

baseline_history = mlp_model.fit(x=train_features, y=train_targets, epochs=50, verbose=0)

# Testing the model

In [59]:
predictions = mlp_model.predict(test_features)
predictions = [pred.argmax() for pred in predictions]

# Metrics

In [62]:
acc = accuracy_score(y_true=test_targets, y_pred=predictions)
accuracy = round(acc, 4)

print(f"The accuracy is {round(accuracy*100, 4)}%")

print(classification_report(y_true=test_targets, y_pred=predictions, digits=4))

conf_matrix = confusion_matrix(y_true=test_targets, y_pred=predictions)
px.imshow(conf_matrix, color_continuous_scale="blues")

The accuracy is 75.93%
              precision    recall  f1-score   support

           0     0.9845    0.9922    0.9883       128
           1     0.8873    0.4632    0.6087       136
           2     0.9856    1.0000    0.9928       137
           3     0.3644    0.8036    0.5014       112
           4     0.7009    0.5467    0.6142       150
           5     0.4896    0.4196    0.4519       112
           6     1.0000    1.0000    1.0000       166
           7     0.6984    0.3308    0.4490       133
           8     0.9862    1.0000    0.9931       143
           9     0.7119    0.9474    0.8129       133

    accuracy                         0.7593      1350
   macro avg     0.7809    0.7503    0.7412      1350
weighted avg     0.7978    0.7593    0.7556      1350

