# Setup

In [None]:
import wandb
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Dropout, GlobalMaxPooling2D, MaxPooling2D, BatchNormalization
from tensorflow.keras.models import Model
import numpy as np
import matplotlib.pyplot as plt
import os

# Set this flag for different datasets
#dataset = 'mnist'
dataset = 'cifar'

# Data

In [None]:
if dataset == 'mnist':
    fashion_mnist = tf.keras.datasets.fashion_mnist
    (X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()
    X_train, X_test = X_train/255.0, X_test/255.0
    X_train, X_test = np.expand_dims(X_train, -1), np.expand_dims(X_test, -1)
elif dataset == 'cifar':
    cifar10 = tf.keras.datasets.cifar10
    (X_train, y_train), (X_test, y_test) = cifar10.load_data()
    X_train, X_test = X_train/255.0, X_test/255.0
    y_train, y_test = y_train.flatten(), y_test.flatten()
else:
    raise ValueError('Dataset could not be determined')

# Number of classes
K = len(set(y_train))

# Model

In [None]:
# Configure the wandb sweep
sweep_config = {
    'method': 'bayes',
    'metric': {
      'name': 'val_loss',
      'goal': 'minimize'   
    },
    'parameters': {
        'learning_rate': {
            'values': [1, 0.1, 0.01, 0.001, 0.0001]
        },
        'epochs': {
            'values': [5, 10, 20]
        },
        'batch_size': {
            'values': [16, 32, 64, 128, 256]
        },
        'optimizer': {
            'values': ['adam', 'nadam', 'sgd', 'rmsprop']
        },
        'conv_layers': {
            'values': [4, 6, 8]
        },
        'kernel_size': {
            'values': [3, 5, 7]
        },
        'dense_layers': {
            'values': [1, 2, 3]
        },
        'dropout': {
            'values': [0.3, 0.2, 0.1]
        },
    }
}

In [None]:
# Initialise sweep
sweep_id = wandb.sweep(sweep_config, project='tensorflow-test', entity='kavp')

In [None]:
# Mega function to define and train model and log results (used by the sweep)
def sweep_func():
    # Default hyperparameter values
    config_defaults = {
        'learning_rate': 0.001,
        'epochs': 50,
        'batch_size': 128,
        'optimizer': 'adam',
        'conv_layers': 6,
        'kernel_size': 3,
        'dense_layers': 2,
        'dropout': 0.2,
        'eager_mode': False,
    }

    # Initialise run
    wandb.init(config=config_defaults)

    # Variable to hold the sweep values
    config = wandb.config
    
    if config['eager_mode'] == True:
        tf.compat.v1.enable_eager_execution()
    elif config['eager_mode'] == False:
        tf.compat.v1.disable_eager_execution()
    else:
        raise ValueError('eager_mode property of wandb config could not be determined.') 

    num_layers = int(config['conv_layers']/2)
    # Build model with keras functional API
    inp = Input(shape=X_train[0].shape)
    x = Conv2D(32, config['kernel_size'], activation='relu', padding='same')(inp)
    x = BatchNormalization()(x)
    x = Conv2D(32, config['kernel_size'], activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2,2))(x)
    for i in range(1,num_layers):
        x = Conv2D(32*pow(2, i), config['kernel_size'], activation='relu', padding='same')(x)
        x = BatchNormalization()(x)
        x = Conv2D(32*pow(2, i), config['kernel_size'], activation='relu', padding='same')(x)
        x = BatchNormalization()(x)
        x = MaxPooling2D((2,2))(x)
    x = Flatten()(x)
    for i in range(config['dense_layers']-1):
        x = Dropout(config['dropout'])(x)
        x = Dense(1024/(i+1), activation='relu')(x)
    x = Dropout(config['dropout'])(x)
    x = Dense(K, activation='softmax')(x)

    model = Model(inp, x)

    model.compile(
        optimizer=config['optimizer'],
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'],
        run_eagerly=config['eager_mode'],
    )

    # Data augmentation
    data_generator = tf.keras.preprocessing.image.ImageDataGenerator(width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True)
    train_generator = data_generator.flow(X_train, y_train, config['batch_size'])
    steps_per_epoch = X_train.shape[0] // config['batch_size']

    with tf.compat.v1.Session() as sess:
        r = model.fit_generator(train_generator, validation_data=(X_test,y_test), epochs=config['epochs'], steps_per_epoch=steps_per_epoch)
        wandb.tensorflow.log(tf.compat.v1.summary.merge_all())
        wandb.log({'loss': r.history['loss'][-1], 'val_loss': r.history['val_loss'][-1], 'accuracy': r.history['accuracy'][-1], 'val_accuracy': r.history['val_accuracy'][-1]})

        wandb_data = [[x,y] for (x,y) in zip(np.arange(0, config['epochs'], 1), r.history['loss'])]
        table = wandb.Table(data=wandb_data, columns = ["epoch", "loss"])
        wandb.log({"loss_against_epochs" : wandb.plot.line(table, "epoch", "loss", title="Training loss")})

        wandb_data = [[x,y] for (x,y) in zip(np.arange(0, config['epochs'], 1), r.history['val_loss'])]
        table = wandb.Table(data=wandb_data, columns = ["epoch", "val_loss"])
        wandb.log({"val_loss_against_epochs" : wandb.plot.line(table, "epoch", "val_loss", title="Validation loss")})

        wandb_data = [[x,y] for (x,y) in zip(np.arange(0, config['epochs'], 1), r.history['accuracy'])]
        table = wandb.Table(data=wandb_data, columns = ["epoch", "accuracy"])
        wandb.log({"accuracy_against_epochs" : wandb.plot.line(table, "epoch", "accuracy", title="Training accuracy")})

        wandb_data = [[x,y] for (x,y) in zip(np.arange(0, config['epochs'], 1), r.history['val_accuracy'])]
        table = wandb.Table(data=wandb_data, columns = ["epoch", "val_accuracy"])
        wandb.log({"val_accuracy_against_epochs" : wandb.plot.line(table, "epoch", "val_accuracy", title="Validation accuracy")})

        # Save model
        model.save(os.path.join(wandb.run.dir, 'model.h5'))

In [None]:
wandb.agent(sweep_id, sweep_func)

In [None]:
wandb.finish()

# Predictions

In [None]:
# Load a good run
good_run_path = 'kavp/tensorflow-test/gacty1xz/'
run = wandb.Api().run(good_run_path)
# Load model
best_model = wandb.restore('model.h5', run_path=good_run_path)
# Load config
config = run.config

# New model
num_layers = int(config['conv_layers']/2)
# Build model with keras functional API
inp = Input(shape=X_train[0].shape)
x = Conv2D(32, config['kernel_size'], activation='relu', padding='same')(inp)
x = BatchNormalization()(x)
x = Conv2D(32, config['kernel_size'], activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling2D((2,2))(x)
for i in range(1,num_layers):
    x = Conv2D(32*pow(2, i), config['kernel_size'], activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = Conv2D(32*pow(2, i), config['kernel_size'], activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2,2))(x)
x = Flatten()(x)
for i in range(config['dense_layers']-1):
    x = Dropout(config['dropout'])(x)
    x = Dense(1024/(i+1), activation='relu')(x)
x = Dropout(config['dropout'])(x)
x = Dense(K, activation='softmax')(x)

model = Model(inp, x)

model.compile(
    optimizer=config['optimizer'],
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'],
    run_eagerly=config['eager_mode'],
)

# Load its weights into the new model
model.load_weights(best_model.name)

In [None]:
# Obtain some predictions and plot the confusion matrix
# argmax to get predictions from one-hot encoded vectors
predicted_class = tf.argmax(model.predict(X_test),1)
tf.math.confusion_matrix(y_test, predicted_class)

In [None]:
# Labels
if dataset == 'mnist':
    labels = ['t-shirt', 'trousers', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
elif dataset == 'cifar':
    labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
else:
    raise ValueError('Dataset could not be determined')

In [None]:
# Show some misclassified examples
misclassified = np.where(predicted_class != y_test)[0]
index = np.random.choice(misclassified)
plt.imshow(X_test[index].squeeze(), cmap='gray')
plt.title('Ground truth: %s    Predicted: %s' % (labels[y_test[index]], labels[tf.get_static_value(predicted_class[index])]));