# From images to coordinates

### Initial imports

In [1]:
import keras
from keras.callbacks import EarlyStopping

from src.experiment.data_loading import (
    load_and_process_coordinates,
    load_image_set,
    prepare_train_val_test_splits,
)
from src.experiment.models import build_img_to_coordinates_model
from src.experiment.regression_evaluation import (
    regression_evaluation,
)

### Loading images

In [2]:
# Dataset dependant parameters
DATA_FOLDER = "data/3ball_color"
NUM_BALLS = 3
PIXELS_PER_AXIS = 32
COLORED_BALLS = True

X1, X2, Y = load_image_set(data_folder=DATA_FOLDER, colored_balls=COLORED_BALLS)

100%|██████████| 10000/10000 [00:04<00:00, 2362.65it/s]
100%|██████████| 10000/10000 [00:03<00:00, 2619.58it/s]


### Coordinates generation and data split

In [3]:
coordinates = load_and_process_coordinates(
    data_folder=DATA_FOLDER, num_balls=NUM_BALLS, pixels_per_axis=PIXELS_PER_AXIS
)

# Prepare train, validation, and test splits
(
    x_train,
    x_val,
    x_test,
    y_train,
    y_val,
    y_test,
    coordinates_train,
    coordinates_val,
    coordinates_test,
) = prepare_train_val_test_splits(X1, X2, Y, coordinates)

### Model definition and training

In [4]:
model = build_img_to_coordinates_model(
    num_balls=NUM_BALLS, pixels_per_axis=PIXELS_PER_AXIS, colored_balls=COLORED_BALLS
)
model.summary()

# Free up RAM in case the model definition cells were run multiple times
keras.backend.clear_session()

Model: "ImgToCoordinates"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 6)]       0         
                                                                 
 conv2d (Conv2D)             (None, 30, 30, 1)         55        
                                                                 
 conv2d_1 (Conv2D)           (None, 28, 28, 1)         10        
                                                                 
 conv2d_2 (Conv2D)           (None, 1, 1, 6)           4710      
                                                                 
 flatten (Flatten)           (None, 6)                 0         
                                                                 
 dense (Dense)               (None, 36)                252       
                                                                 
 dense_1 (Dense)             (None, 12)           

In [5]:
model.compile(loss="mae", optimizer="adam")

model.fit(
    x_train,
    coordinates_train,
    epochs=100,
    batch_size=16,
    shuffle=True,
    validation_data=(x_val, coordinates_val),
    callbacks=[
        EarlyStopping(
            monitor="val_loss", patience=10, restore_best_weights=True, min_delta=0.0001
        )
    ],
)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100


<keras.callbacks.History at 0x25b000951c0>

### Model results

In [6]:
regression_evaluation(regressor_model=model, x_test=x_test, y_test=coordinates_test)

0
Prediction: [ 4.7458196e-01  3.4244683e-01  4.5479989e-01  3.3794156e-01
  4.5714450e-01  3.2935190e-01 -3.9898101e-03 -1.6008335e-01
  4.5656506e-04 -1.5499625e-01 -1.6034730e-03 -1.4070415e-01]
True values: [ 0.4516129   0.61290323  0.22580645  0.22580645  0.80645161  0.29032258
  0.         -0.5        -0.5        -0.25        0.25        0.5       ]
1
Prediction: [4.2474294e-01 5.8809561e-01 4.1419682e-01 5.8947599e-01 3.9806744e-01
 5.8337849e-01 3.4215380e-03 1.1874834e-01 3.8462207e-03 1.3342339e-01
 5.1869079e-04 1.4195505e-01]
True values: [ 0.4516129   0.74193548  0.32258065  0.58064516  0.64516129  0.22580645
  0.          0.25        0.          0.          0.         -0.25      ]
2
Prediction: [ 4.8165262e-01  4.8015001e-01  4.6362936e-01  4.7157559e-01
  4.5890650e-01  4.6778053e-01 -2.3308038e-03 -2.1976262e-01
  1.0068873e-03 -2.0865187e-01 -3.5719760e-04 -2.1667551e-01]
True values: [ 0.74193548  0.80645161  0.4516129   0.41935484  0.70967742  0.5483871
  0.         