<a href="https://colab.research.google.com/github/ayulockin/LossLandscape/blob/master/ResNet20v1_CIFAR10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Set up and imports

In [None]:
# TensorFlow Imports
import tensorflow as tf
print(tf.__version__)

In [None]:
# Which GPU?
!nvidia-smi

In [None]:
%%capture
!pip install wandb

In [None]:
import wandb
from wandb.keras import WandbCallback

wandb.login()

In [None]:
# Other imports
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import time

# Random seed fixation
tf.random.set_seed(666)
np.random.seed(666)

## Get the model from [keras-idiomatic-programmer](https://github.com/GoogleCloudPlatform/keras-idiomatic-programmer)

In [None]:
!wget https://raw.githubusercontent.com/GoogleCloudPlatform/keras-idiomatic-programmer/master/zoo/resnet/resnet_cifar10.py

In [None]:
import resnet_cifar10

## Utils

In [None]:
def get_training_model():
    # ResNet20
    n = 2
    depth =  n * 9 + 2
    n_blocks = ((depth - 2) // 9) - 1

    # The input tensor
    inputs = Input(shape=(32, 32, 3))

    # The Stem Convolution Group
    x = resnet_cifar10.stem(inputs)

    # The learner
    x = resnet_cifar10.learner(x, n_blocks)

    # The Classifier for 10 classes
    outputs = resnet_cifar10.classifier(x, 10)

    # Instantiate the Model
    model = Model(inputs, outputs)
    
    return model

## Construct data loaders

In [None]:
# Load the training set of CIFAR10
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

In [None]:
BATCH_SIZE = 128

def normalize(image, label):
    return tf.image.convert_image_dtype(image, tf.float32), label

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = (
    train_ds
    .shuffle(1024)
    .map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = (
    test_ds
    .map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

## Model sanity checks

In [None]:
model = get_training_model()
model.summary()

## Callbacks

In [None]:
# Custom LR schedule as mentioned in the LossLandscape paper
LR_SCHEDULE = [
    # (epoch to start, learning rate) tuples
    (0, 1.6*1e-3),
    (9, (1.6*1e-3)/2),
    (19, (1.6*1e-3)/4),
    (29, (1.6*1e-3)/8),
]

def lr_schedule(epoch):
    if (epoch >= 0) & (epoch < 9):
        return LR_SCHEDULE[0][1]
    elif (epoch >= 9) & (epoch < 19):
        return LR_SCHEDULE[1][1]
    elif (epoch >= 19) & (epoch < 29):
        return LR_SCHEDULE[2][1]
    else:
        return LR_SCHEDULE[3][1]

lr_callback = tf.keras.callbacks.LearningRateScheduler(lambda epoch: lr_schedule(epoch), verbose=True)

In [None]:
rng = rng = [i for i in range(40)]
plt.plot([lr_schedule(x) for x in rng])
plt.show()

In [None]:
SAVE_PATH = '/content/ResNet20v1_CIFAR10'

def save_model(epoch, logs):
    model.save(SAVE_PATH+'resnet20v1_checkpoint_{}.h5'.format(epoch))

save_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=save_model, verbose=True)

A custom callback to log confusion matrix batchwise (referred from this [tutorial](https://www.tensorflow.org/tensorboard/image_summaries)). 

In [None]:
from sklearn.metrics import confusion_matrix
import itertools
import io

In [None]:
CLASS_NAMES = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

In [None]:
def plot_confusion_matrix(cm, class_names):
    figure = plt.figure(figsize=(8, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title("Confusion matrix")
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)

    # Normalize the confusion matrix.
    cm = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)

    # Use white text if squares are dark; otherwise black.
    threshold = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        color = "white" if cm[i, j] > threshold else "black"
        plt.text(j, i, cm[i, j], horizontalalignment="center", color=color)

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    
    return figure

def plot_to_image(figure):
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
    returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    # Closing the figure prevents it from being displayed directly inside
    # the notebook.
    plt.close(figure)
    buf.seek(0)
    # Convert PNG buffer to TF image
    image = tf.image.decode_png(buf.getvalue(), channels=3)
    # Convert back to NumPy
    image = image.numpy()
    return image

def log_confusion_matrix(epoch, logs):
    # Use the model to predict the values from the validation dataset
    test_pred_raw = model.predict(x_test)
    test_pred = np.argmax(test_pred_raw, axis=1)

    # Calculate the confusion matrix
    cm = confusion_matrix(y_test, test_pred)
    # Log the confusion matrix as an image to wandb
    figure = plot_confusion_matrix(cm, class_names=CLASS_NAMES)
    cm_image = plot_to_image(figure)
    wandb.log({'confusion_matrix': wandb.Image(cm_image)})

cm_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=log_confusion_matrix, verbose=True)

## Model training

In [None]:
wandb.init(project='loss-landscape', id='resnet20v1-no-aug-1')

# Train model
model = get_training_model()
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
start = time.time()
h = model.fit(train_ds,
         validation_data=test_ds,
         epochs=40,
         callbacks=[lr_callback, WandbCallback(), cm_callback, save_callback])

end = time.time()
print("Network takes {:.3f} seconds to train".format(end - start))
wandb.log({'training_time': end - start})
wandb.log({'nb_model_params': model.count_params()})

In [None]:
# Serialize the model
model.save('resnet20v1_cifar10_40epochs.h5')

In [None]:
# Adding data augmentation
def augment(image,label):
    image = tf.image.resize_with_crop_or_pad(image, 40, 40) # Add 8 pixels of padding
    image = tf.image.random_crop(image, size=[32, 32, 3]) # Random crop back to 32x32
    image = tf.image.random_brightness(image, max_delta=0.5) # Random brightness
    image = tf.clip_by_value(image, 0., 1.)
    
    return image, label

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = (
    train_ds
    .shuffle(1024)
    .map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .map(augment, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = (
    test_ds
    .map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

In [None]:
SAVE_PATH = '/content/ResNet20v1_CIFAR10_Aug'

def save_model(epoch, logs):
    model.save(SAVE_PATH+'resnet20v1_checkpoint_{}.h5'.format(epoch))

save_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=save_model, verbose=True)

In [None]:
wandb.init(project='loss-landscape', id='resnet20v1-aug')

# Train model
model = get_training_model()
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
start = time.time()
h = model.fit(train_ds,
         validation_data=test_ds,
         epochs=40,
         callbacks=[lr_callback, WandbCallback(), cm_callback, save_callback])

end = time.time()
print("Network takes {:.3f} seconds to train".format(end - start))
wandb.log({'training_time': end - start})
wandb.log({'nb_model_params': model.count_params()})

In [None]:
# Serialize the model
model.save('resnet20v1_cifar10_40epochs_data_aug.h5')

## Put the model weights in a GCS bucket

In [None]:
from google.colab import auth as google_auth
google_auth.authenticate_user()

In [None]:
!gsutil -m cp -r ResNet20v1_CIFAR10resnet20v1_checkpoint_*.h5 gs://losslandscape/ResNet20v1_CIFAR10/
!gsutil cp resnet20v1_cifar10_40epochs.h5 gs://losslandscape/ResNet20v1_CIFAR10/

In [None]:
!gsutil -m cp -r ResNet20v1_CIFAR10_Augresnet20v1_checkpoint_*.h5 gs://losslandscape/ResNet20v1_CIFAR10_Aug/
!gsutil cp resnet20v1_cifar10_40epochs_data_aug.h5 gs://losslandscape/ResNet20v1_CIFAR10_Aug/