In [15]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras

In [20]:
class PopulationBasedModel:
    
    def __init__(self, model, population_size, learning_rates, loss):
        if len(learning_rates) != population_size:
            raise ValueError("Population size and number of learning rates must be equal")
        self.models = [tf.keras.models.clone_model(model, input_tensors=None)
                        for _ in range(population_size)]
        self.learning_rates = learning_rates
        self.loss = loss
        self.numberOfModels = len(self.models)
        for i in range(self.numberOfModels):
            self.models[i].compile(optimizer=tf.keras.optimizers.Adam(learning_rate=float(self.learning_rates[i])),
                                   loss=(self.loss),
                                   metrics=['accuracy'])

    def train(self, x_train, y_train, epochs, batch_size, NumBatchesBeforeUpdate, x_test, y_test):

        batches_xtrain = []
        batches_ytrain = []
        for i in range(0, len(x_train), batch_size):
            batches_xtrain.append(x_train[i:i+batch_size])
            batches_ytrain.append(y_train[i:i+batch_size])

        for i in range(epochs):
            batchoffset = 0
            while(batchoffset < len(batches_xtrain)):
                for model in self.models:
                    model.train_on_batch(batches_xtrain[batchoffset], batches_ytrain[batchoffset])
                batchoffset += 1
                if batchoffset % NumBatchesBeforeUpdate == 0:
                    self.update(x_test, y_test)
            print("Epoch", i, "done")

    def update(self, x_test, y_test):

        key=lambda x: x.evaluate(x_test, y_test)[0]
        indexes = np.argsort([key(self.models[i]) for i in range(self.numberOfModels)])
        self.models = [self.models[i] for i in indexes]
        self.learning_rates = [self.learning_rates[i] for i in indexes]

        # Replace weights and biases of the worst 20 percent of the models with the best model 20 percent of the models
        # and add some noise to the learning rate (very little close to mean = 0 and std = 0.001)
        # also update the corresponding hyperparameters with some  random noise as described above
        for i in range(int(self.numberOfModels * 0.2)):
            self.models[-i].set_weights(self.models[i].get_weights())
            self.models[-i].compile(optimizer=tf.keras.optimizers.Adam(learning_rate=self.learning_rates[-i] + np.random.normal(0, 0.001)),
                                   loss=self.loss,
                                   metrics=['accuracy'])
            self.learning_rates[-i] += np.random.normal(0, 0.001)
        




In [21]:
# Trying out PBT on MNIST dataset
mnist = tf.keras.datasets.mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train, X_test = X_train / 255.0, X_test / 255.0
X_train, X_test = X_train.reshape(-1, 28, 28, 1), X_test.reshape(-1, 28, 28, 1)

model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

hyperparameters = [0.001, 0.01, 0.1, 0.5, 1, 2, 3, 4, 5, 6]
hyperparameters = np.array(hyperparameters, dtype=np.float32)
num_epochs = 1
num_population = 10
num_generations = 3
loss = 'sparse_categorical_crossentropy'

pbt = PopulationBasedModel(model, num_population, hyperparameters, loss)
pbt.train(X_train, y_train, num_epochs, 32, 10, X_test, y_test)
print(pbt.get_best_agent().evaluate(X_test, y_test))

model.compile(optimizer='adam', loss=loss, metrics=['accuracy'])
model.fit(X_train, y_train, epochs=num_epochs, validation_data=(X_test, y_test))
print(model.evaluate(X_test, y_test))

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.4747 - loss: 2.0382
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.3626 - loss: 2.1421
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.1027 - loss: 2.3405
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - accuracy: 0.0921 - loss: 213.7387
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.0941 - loss: 82.5303
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - accuracy: 0.1052 - loss: 70642.6016
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - accuracy: 0.1009 - loss: 12757.7627
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - accuracy: 0.0924 - loss: 382123.7812
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step - accuracy: 0.1046 - loss: 19038.0645

KeyboardInterrupt: 