In [1]:
import os
import numpy as np

# Function to load dataset from file. This is needed so we can easily load the two datasets without copy pasteing.
def load_data( name ):
    X_train = np.load(os.path.join('data', name, name + '_train_vectors.npy'))
    X_test = np.load(os.path.join('data', name, name + '_test_vectors.npy'))
    Y_train = np.load(os.path.join('data', name, name + '_train_labels.npy'))
    Y_test = np.load(os.path.join('data', name, name + '_test_labels.npy'))

    # The images need to have shape (28, 28, 1), we didn't take care of this in preprocessing.
    X_train = np.expand_dims(X_train, -1)
    X_test = np.expand_dims(X_test, -1)

    return X_train, Y_train, X_test, Y_test

# The same model is used for both datasets so it is more convenient to make them in a funtion.
def make_model(X_train, Y_train, batch_selection, log_file):

    # This is a simple convolutional neural network. It isn't the best possible network for MNIST
    # but the point here is to test how much batch selection methods will speed up a CNN, not the CNN itself.
    model = Sequential()
    model.add(layers.Input(shape = (28, 28, 1,)))
    model.add(layers.Conv2D(64, kernel_size = (3, 3), activation = "relu"))
    model.add(layers.MaxPooling2D( pool_size = (2, 2)))
    model.add(layers.Conv2D(64, kernel_size = (3, 3), activation = "relu"))
    model.add(layers.MaxPooling2D(pool_size = (2, 2)))
    model.add(layers.Flatten())
    model.add(layers.Dense(10, activation = "softmax"))
    
    # Put the model in our custom training loop.
    training = TrainingLoop(
        model = model,
        X = X_train,
        y = Y_train,
        optimizer = keras.optimizers.Adam(),
        loss_function = keras.losses.CategoricalCrossentropy(from_logits=True),
        batch_size = 64,
        train_metrics = tf.keras.metrics.CategoricalAccuracy(),
        val_metrics = tf.keras.metrics.CategoricalAccuracy(),
        validation_split = 0.2,
        batch_selection = batch_selection,
        log_file = "logs/" + log_file + "/mnist.csv",
    )

    # We still have to compile the model for the test evaluation.
    model.compile(loss = "categorical_crossentropy", optimizer = "adam", metrics=["accuracy"])

    return model, training


In [2]:
from loop import TrainingLoop
import tensorflow as tf

# These lines will make the gpu not give errors.
gpus= tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential


In [3]:
from batch_selection import windowed_batch_selector, sorting_batch_selector

selector_list = [[None, 'original'], [windowed_batch_selector, 'windowed'], [sorting_batch_selector, 'sorting']]

X_train, Y_train, X_test, Y_test = load_data( "mnist" )
for selector in selector_list:
    print( "\n\n" + selector[1] + "\n")
    # Set random seed so the comparison of different solutions won't be affected by it.
    tf.random.set_seed(42)
    np.random.seed(42)
    
    model, training = make_model( X_train, Y_train, selector[0], selector[1] )

    training.train(epochs = 10)




original

Epoch 1/10	Loss: 1.5112	Metrics: 0.9152: 	Validation metrics: 0.9690: 	100% | 750/750 [00:03<00:00, 193.58it/s]
Epoch 2/10	Loss: 1.5065	Metrics: 0.9720: 	Validation metrics: 0.9724: 	100% | 750/750 [00:02<00:00, 281.61it/s]
Epoch 3/10	Loss: 1.4915	Metrics: 0.9783: 	Validation metrics: 0.9734: 	100% | 750/750 [00:02<00:00, 284.19it/s]
Epoch 4/10	Loss: 1.4825	Metrics: 0.9827: 	Validation metrics: 0.9823: 	100% | 750/750 [00:02<00:00, 284.60it/s]
Epoch 5/10	Loss: 1.4789	Metrics: 0.9853: 	Validation metrics: 0.9815: 	100% | 750/750 [00:02<00:00, 286.67it/s]
Epoch 6/10	Loss: 1.4890	Metrics: 0.9870: 	Validation metrics: 0.9844: 	100% | 750/750 [00:02<00:00, 285.87it/s]
Epoch 7/10	Loss: 1.4763	Metrics: 0.9883: 	Validation metrics: 0.9844: 	100% | 750/750 [00:02<00:00, 286.99it/s]
Epoch 8/10	Loss: 1.4783	Metrics: 0.9890: 	Validation metrics: 0.9829: 	100% | 750/750 [00:02<00:00, 286.33it/s]
Epoch 9/10	Loss: 1.4696	Metrics: 0.9895: 	Validation metrics: 0.9862: 	100% | 750/750 [00:02