In [1]:
import os
import numpy as np
import tensorflow as tf
from keras import models, layers

from loop import TrainingLoop
from batch_selection import windowed_batch_selector, sorting_batch_selector

In [2]:
# Loading the dataset from the files saved in the preprocessing notebook.
path = 'data/iris'
prefix = 'iris_'
X_train = np.load(os.path.join(path, prefix+'train_vectors.npy'))
y_train = np.load(os.path.join(path, prefix+'train_labels.npy'))
X_test  = np.load(os.path.join(path, prefix+'test_vectors.npy'))
y_test  = np.load(os.path.join(path, prefix+'test_labels.npy'))

In [3]:
def build_model():
    # Setting up the model.
    model = models.Sequential()
    model.add(layers.Dense(50, activation='relu', input_shape=(X_test.shape[1],)))
    model.add(layers.Dense(40, activation='relu'))
    model.add(layers.Dense(20, activation='relu'))
    model.add(layers.Dense(y_train.shape[1], activation='softmax'))
    return model

In [4]:
def train(model, X_train, y_train, batch_selection, epochs):
    log_dir = {windowed_batch_selector: 'windowed', sorting_batch_selector: 'sorting', None: 'original'}
    # Put the model in our custom training loop.
    TrainingLoop(
        model, 
        X_train, 
        y_train, 
        validation_split = 0.1,
        batch_size = 4,
        optimizer = tf.keras.optimizers.Adam(),
        loss_function = tf.keras.losses.CategoricalCrossentropy(),
        train_metrics = tf.keras.metrics.CategoricalAccuracy(),
        val_metrics = tf.keras.metrics.CategoricalAccuracy(),
        batch_selection = batch_selection,
        log_file = os.path.join('logs', log_dir[batch_selection], 'iris.csv')
    ).train(epochs)  # Training the model.

In [5]:
# Set random seed so the comparison of different solutions won't be affected by it.
tf.random.set_seed(42)
np.random.seed(42)

# Train model with default batch selection.
model = build_model()
train(model, X_train, y_train, batch_selection=None, epochs=20)

# We still have to compile the model for the test evaluation.
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# After compiling we can run the evaluation.
model.evaluate(X_test, y_test)

Epoch 1/20	Loss: 0.7039	Metrics: 0.4351: 	Validation metrics: 0.4166: 	100% | 27/27 [00:04<00:00,  6.47it/s]
Epoch 2/20	Loss: 0.3617	Metrics: 0.6666: 	Validation metrics: 0.5833: 	100% | 27/27 [00:00<00:00, 39.76it/s]
Epoch 3/20	Loss: 0.2019	Metrics: 0.7685: 	Validation metrics: 0.75: 	100% | 27/27 [00:01<00:00, 26.47it/s]
Epoch 4/20	Loss: 0.1377	Metrics: 0.8796: 	Validation metrics: 0.8333: 	100% | 27/27 [00:00<00:00, 50.28it/s]
Epoch 5/20	Loss: 0.0952	Metrics: 0.8611: 	Validation metrics: 0.9166: 	100% | 27/27 [00:00<00:00, 42.45it/s]
Epoch 6/20	Loss: 0.0581	Metrics: 0.8703: 	Validation metrics: 1.0: 	100% | 27/27 [00:00<00:00, 57.20it/s]
Epoch 7/20	Loss: 0.0372	Metrics: 0.8888: 	Validation metrics: 0.9166: 	100% | 27/27 [00:00<00:00, 52.84it/s]
Epoch 8/20	Loss: 0.0240	Metrics: 0.9259: 	Validation metrics: 0.9166: 	100% | 27/27 [00:00<00:00, 61.64it/s]
Epoch 9/20	Loss: 0.0159	Metrics: 0.9444: 	Validation metrics: 0.9166: 	100% | 27/27 [00:00<00:00, 54.77it/s]
Epoch 10/20	Loss: 0.0105

[0.11605016142129898, 0.9666666388511658]

In [6]:
# Set random seed so the comparison of different solutions won't be affected by it.
tf.random.set_seed(42)
np.random.seed(42)

# Train model with windowed batch selection algorithm.
model = build_model()
train(model, X_train, y_train, batch_selection=windowed_batch_selector, epochs=20)

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.evaluate(X_test, y_test)

Epoch 1/20	Loss: 0.8344	Metrics: 0.5370: 	Validation metrics: 0.75: 	100% | 27/27 [00:04<00:00,  5.69it/s]
Epoch 2/20	Loss: 0.5805	Metrics: 0.8888: 	Validation metrics: 0.9166: 	100% | 27/27 [00:01<00:00, 20.45it/s]
Epoch 3/20	Loss: 0.3109	Metrics: 0.8518: 	Validation metrics: 0.9166: 	100% | 27/27 [00:00<00:00, 28.51it/s]
Epoch 4/20	Loss: 0.1584	Metrics: 0.8611: 	Validation metrics: 0.9166: 	100% | 27/27 [00:00<00:00, 35.25it/s]
Epoch 5/20	Loss: 0.0972	Metrics: 0.8425: 	Validation metrics: 0.9166: 	100% | 27/27 [00:00<00:00, 29.00it/s]
Epoch 6/20	Loss: 0.0658	Metrics: 0.8240: 	Validation metrics: 0.9166: 	100% | 27/27 [00:00<00:00, 33.09it/s]
Epoch 7/20	Loss: 0.0433	Metrics: 0.8518: 	Validation metrics: 0.9166: 	100% | 27/27 [00:01<00:00, 17.90it/s]
Epoch 8/20	Loss: 0.0342	Metrics: 0.8333: 	Validation metrics: 0.8333: 	100% | 27/27 [00:01<00:00, 20.30it/s]
Epoch 9/20	Loss: 0.0228	Metrics: 0.8518: 	Validation metrics: 1.0: 	100% | 27/27 [00:00<00:00, 31.25it/s]
Epoch 10/20	Loss: 0.0177

[0.1510007530450821, 0.9666666388511658]

In [7]:
# Set random seed so the comparison of different solutions won't be affected by it.
tf.random.set_seed(42)
np.random.seed(42)

# Train model with sorting batch selection algorithm.
model = build_model()
train(model, X_train, y_train, batch_selection=sorting_batch_selector, epochs=20)

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.evaluate(X_test, y_test)

Epoch 1/20	Loss: 0.8131	Metrics: 0.4907: 	Validation metrics: 0.75: 	100% | 27/27 [00:04<00:00,  5.56it/s]
Epoch 2/20	Loss: 0.2476	Metrics: 0.8333: 	Validation metrics: 0.75: 	100% | 27/27 [00:00<00:00, 32.57it/s]
Epoch 3/20	Loss: 0.0591	Metrics: 0.8518: 	Validation metrics: 0.8333: 	100% | 27/27 [00:01<00:00, 25.99it/s]
Epoch 4/20	Loss: 0.0206	Metrics: 0.8703: 	Validation metrics: 0.9166: 	100% | 27/27 [00:02<00:00, 13.22it/s]
Epoch 5/20	Loss: 0.0098	Metrics: 0.8518: 	Validation metrics: 0.9166: 	100% | 27/27 [00:01<00:00, 25.62it/s]
Epoch 6/20	Loss: 0.0054	Metrics: 0.8611: 	Validation metrics: 1.0: 	100% | 27/27 [00:00<00:00, 28.42it/s]
Epoch 7/20	Loss: 0.0033	Metrics: 0.9166: 	Validation metrics: 1.0: 	100% | 27/27 [00:00<00:00, 38.41it/s]
Epoch 8/20	Loss: 0.0021	Metrics: 0.9166: 	Validation metrics: 0.9166: 	100% | 27/27 [00:00<00:00, 28.75it/s]
Epoch 9/20	Loss: 0.0015	Metrics: 0.9351: 	Validation metrics: 1.0: 	100% | 27/27 [00:01<00:00, 26.95it/s]
Epoch 10/20	Loss: 0.0011	Metrics

[0.11316245049238205, 0.9666666388511658]