In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os

In [None]:
import gc

gc.collect()

In [3]:
main_folder = '../'
dataset_folder = main_folder + 'dataset/'
logs_folder = main_folder + 'logs/'
checkpoints_folder = main_folder + 'checkpoints/'

In [4]:
X_train_canvas = np.load(dataset_folder + 'X_train_canvas.npy')
coords = np.load(dataset_folder + 'coords.npy')
y_train = np.load(dataset_folder + 'y_train.npy')
X_test_canvas = np.load(dataset_folder + 'X_test_canvas.npy')
coords_test = np.load(dataset_folder + 'coords_test.npy')
y_test = np.load(dataset_folder + 'y_test.npy')

In [5]:
X_train_2d = X_train_canvas.reshape(X_train_canvas.shape[0], 128, 128)
X_test_2d = X_test_canvas.reshape(X_test_canvas.shape[0], 128, 128)

In [6]:
import mymodels

model = mymodels.sect1()
model.compile()

In [None]:
params = {'epochs': 30, 
          'batch_size': 1024, 
          'tensorboard': True, 
          'cp_callback': True}

run = model.train(X_train_2d, coords, X_test_2d, coords_test, params, logs_folder, checkpoints_folder)
history_model = run.history
print("The history has the following data: ", history_model.keys())

# Plotting the training and validation accuracy during the training
sns.lineplot(
    x=run.epoch, y=history_model["mean_absolute_error"], color="blue", label="Training set"
)
sns.lineplot(
    x=run.epoch,
    y=history_model["val_mean_absolute_error"],
    color="red",
    label="Valdation set",
)
plt.xlabel("epochs")
plt.ylabel("accuracy")

In [None]:
cp_dir = checkpoints_folder + 'model_epoch_30.weights.h5'

model.evaluate(X_test_2d, coords_test, cp_dir)

In [None]:
random_sample = np.random.randint(0, X_test_2d.shape[0])

image = X_test_2d[random_sample]
actual = coords_test[random_sample]
prediction = model.predict(X_test_2d[random_sample].reshape(1, 128, 128, 1))

plt.imshow(image, cmap='gray')
#add dots to predicted values
plt.scatter(prediction[0][0], prediction[0][1], color='red')
plt.scatter(actual[0], actual[1], color='green')
#add legend
plt.legend(['Predicted', 'Actual'])
plt.title(f"Predicted: {prediction.round()}, Actual: {actual}")
plt.show()