# From images to coordinates

### Initial imports

In [1]:
import keras
from keras.callbacks import EarlyStopping

from src.experiment.data_loading import (
    load_and_process_coordinates,
    load_image_set,
    prepare_train_val_test_splits,
)
from src.experiment.models import build_img_to_coordinates_model
from src.experiment.regression_evaluation import (
    regression_evaluation,
)

### Loading images

In [2]:
# Dataset dependant parameters
DATA_FOLDER = "data/3ball"
NUM_BALLS = 3
PIXELS_PER_AXIS = 32
COLORED_BALLS = False

X1, X2, Y = load_image_set(data_folder=DATA_FOLDER, colored_balls=COLORED_BALLS)

100%|██████████| 10000/10000 [00:02<00:00, 3387.21it/s]
100%|██████████| 10000/10000 [00:02<00:00, 3636.34it/s]


### Coordinates generation and data split

In [3]:
coordinates = load_and_process_coordinates(
    data_folder=DATA_FOLDER, num_balls=NUM_BALLS, pixels_per_axis=PIXELS_PER_AXIS
)

# Prepare train, validation, and test splits
(
    x_train,
    x_val,
    x_test,
    y_train,
    y_val,
    y_test,
    coordinates_train,
    coordinates_val,
    coordinates_test,
) = prepare_train_val_test_splits(X1, X2, Y, coordinates)

### Model definition and training

In [4]:
model = build_img_to_coordinates_model(
    num_balls=NUM_BALLS, pixels_per_axis=PIXELS_PER_AXIS, colored_balls=COLORED_BALLS
)
model.summary()

# Free up RAM in case the model definition cells were run multiple times
keras.backend.clear_session()

Model: "ImgToCoordinates"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 2)]       0         
                                                                 
 conv2d (Conv2D)             (None, 30, 30, 1)         19        
                                                                 
 conv2d_1 (Conv2D)           (None, 28, 28, 1)         10        
                                                                 
 conv2d_2 (Conv2D)           (None, 1, 1, 6)           4710      
                                                                 
 flatten (Flatten)           (None, 6)                 0         
                                                                 
 dense (Dense)               (None, 36)                252       
                                                                 
 dense_1 (Dense)             (None, 12)           

In [5]:
model.compile(loss="mae", optimizer="adam")

model.fit(
    x_train,
    coordinates_train,
    epochs=100,
    batch_size=16,
    shuffle=True,
    validation_data=(x_val, coordinates_val),
    callbacks=[
        EarlyStopping(
            monitor="val_loss", patience=10, restore_best_weights=True, min_delta=0.0001
        )
    ],
)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100


<keras.callbacks.History at 0x1e99290f1c0>

### Model results

In [6]:
regression_evaluation(regressor_model=model, x_test=x_test, y_test=coordinates_test)

0
Prediction: [ 0.48747224  0.3961269   0.4812074   0.40900043  0.5058669   0.42896134
 -0.00442183  0.5235008   0.00289989  0.49434146 -0.0021149   0.49652472]
True values: [ 0.32258065  0.67741935  0.70967742  0.19354839  0.51612903  0.48387097
  0.          0.5        -0.5         0.5         0.5         0.25      ]
1
Prediction: [ 0.4211915   0.5016501   0.389774    0.52944976  0.4289481   0.55260056
  0.00145395 -0.16636434  0.00613198 -0.2035939   0.0050637  -0.2148546 ]
True values: [ 0.61290323  0.80645161  0.25806452  0.41935484  0.35483871  0.4516129
  0.         -0.25        0.25       -0.25       -0.25        0.25      ]
2
Prediction: [ 0.46910405  0.45857638  0.49549076  0.4349512   0.45547453  0.45169327
 -0.00232383  0.17532575  0.00224506  0.21334204 -0.00318912  0.15318155]
True values: [ 0.4516129   0.25806452  0.38709677  0.32258065  0.67741935  0.70967742
 -0.75       -0.25       -0.25       -0.5         0.5         0.5       ]
3
Prediction: [0.40374753 0.6045468  0