## Saving keras model checkpoints directly to Google Cloud Storage during training

One of the biggest problems i have faced while running my keras model training jobs on Google Cloud AI platform is that model checkpoint callback cannot write checkpoints directly to GCS.

Lets see how this problem can be solved by creating a custom tensorflow keras callback. 

#### We will use the very popular IRIS dataset in this notebook and build a super simple model to classify the flower types.

In [1]:
import os

import pandas as pd
import numpy as np

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from sklearn.utils import shuffle
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split

from tensorflow.keras.callbacks import Callback
from google.cloud import storage

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, CSVLogger

In [2]:
# Disable tensorflow debugging information
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ["KMP_SETTINGS"] = "false"

In [3]:
# Read the csv file
data = pd.read_csv('/kaggle/input/iris-flower-dataset/IRIS.csv')
data.head(n=5)

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,species
0,5.1,3.5,1.4,0.2,Iris-setosa
1,4.9,3.0,1.4,0.2,Iris-setosa
2,4.7,3.2,1.3,0.2,Iris-setosa
3,4.6,3.1,1.5,0.2,Iris-setosa
4,5.0,3.6,1.4,0.2,Iris-setosa


In [4]:
# Separate input and labels
X, y = data[['sepal_length', 'sepal_width', 'petal_length', 'petal_width']].values, data['species'].values.reshape(-1, 1)

In [5]:
# One hot encode the labels
ohc = OneHotEncoder(handle_unknown='ignore')
ohc.fit(y)
y = ohc.transform(y).toarray()

In [6]:
# Train test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

### Define a simple DNN architecture

We will keep the model very simple since we are not concentrating on accuracy of the model.

In [7]:
model = Sequential()
model.add(layers.Dense(4, input_shape=(None, 4), activation='relu'))
model.add(layers.Dense(64, activation='relu'))                                                                                                                                                                
model.add(layers.Dense(3, activation='softmax'))

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=["accuracy"])
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, None, 4)           20        
_________________________________________________________________
dense_1 (Dense)              (None, None, 64)          320       
_________________________________________________________________
dense_2 (Dense)              (None, None, 3)           195       
Total params: 535
Trainable params: 535
Non-trainable params: 0
_________________________________________________________________


## General callbacks

In [8]:
# Create directories for tensorboard and checkpoints if it doesn't exist
os.makedirs('/tmp/tensorboard', exist_ok=True)
os.makedirs('/tmp/checkpoints', exist_ok=True)

# Tensorboard is a tool to monitor training.
tensorboard = TensorBoard(log_dir='/tmp/tensorboard')

# Callback to save a copy of model after every epoch.
cp_callback = ModelCheckpoint(filepath='/tmp/checkpoints/model.{epoch:02d}-{val_loss:.2f}.hdf5',
                             monitor='val_accuracy',
                             save_freq='epoch',
                             verbose=1,
                             period=1,
                             save_best_only=False,
                             save_weights_only=True)

# Keeps monitoring training and terminates it if model starts to overfit or value of specified monitoring metric is stationary.
es_callback = EarlyStopping(monitor='val_accuracy',
                            mode='min',
                            verbose=1,
                            patience=5)

callbacks = [tensorboard, cp_callback, es_callback]

## Creating a custom GCS callback

Model checkpoint callback will write checkpoints to /tmp/checkpoints at the end of every epoch. We will create our own custom callback which also runs at the end of each epoch after model checkpoint callback runs, and uploads the newly created checkpoint to GCS.

In [9]:
# Google Cloud Credentials

# Project ID
gcp_project_id = 'gcp-demos-341918'

# GCS bucket name
bucket_name = "callback-demo-bucket"

### Create a bucket to store the checkpoints

In [10]:
storage_client = storage.Client(project=gcp_project_id)

buckets = [b.name for b in storage_client.list_buckets()]

if bucket_name not in buckets:
    bucket = storage_client.bucket(bucket_name)
    bucket.storage_class = "COLDLINE"
    new_bucket = storage_client.create_bucket(bucket, location="us", )
    print(f"Created new bucket: {new_bucket.name}")

Created new bucket: callback-demo-bucket


### Implement the GCS callback

Custom callbacks can be implemented by inherting the Callback class from tensorflow.keras.callbacks module. There are multiple functions which can be overridden to achieve desired tasks in multiple stages during training like:

* Beginning or ending of training|evaluation|prediction.
* Beginning or ending of epochs.
* Beginning or ending of batches.

To solve our problem, we will override the on_epoch_end function, which will copy the checkpoint to the specified GCS bucket at end of each epoch using Google Cloud Storage python client.

In [11]:
class GCSCallback(Callback):
    """ A custom callback to copy checkpoints from local file system directory to Google Cloud Storage directory"""
    
    def __init__(self, cp_path: str, bucket_name: str):
        """ init method
        Args:
            cp_path (str): gcs directory path to store checkpoints
            bucket_name (str): name of GCS bucket
        """
        super(GCSCallback, self).__init__()
        self.checkpoint_path = cp_path
        self.bucket_name = bucket_name
        
        client = storage.Client(project=gcp_project_id)
        self.bucket = client.get_bucket(bucket_name)
        
    def upload_file_to_gcs(self, src_path: str, dest_path: str):
        """ Uploads file to Google Cloud Storage
        Args:
            src_path (str): absolute path of source file
            dest_path (str): gcs directory path beginning with 'gs://<bucket-name>'
        Returns:
        """
        # blob needs only the path inside the bucket. we need to remove gs://<bucket-name> part
        dest_path = dest_path.split(f'{self.bucket_name}/')[1]
        
        # Create a complete destination path. This is basically self.cp_path + file_name.
        dest_path = os.path.join(dest_path, os.path.basename(src_path))
        
        blob = self.bucket.blob(dest_path)
        blob.upload_from_filename(src_path)
        
    def on_epoch_end(self, epoch, logs=None):
        
        # ModelCheckpoint callback will write checkpoints to /checkpoints directory
        for cp_file in os.listdir('/tmp/checkpoints'):
            src_path = os.path.join('/tmp/checkpoints', cp_file)
            self.upload_file_to_gcs(src_path=src_path, dest_path=self.checkpoint_path)
        print(f"Epoch {str(epoch+1).zfill(5)}: Uploaded saved model to {self.checkpoint_path}\n")


In [12]:
# Create the callback object and append it to the callback list
gcs_callback = GCSCallback(cp_path=f'gs://{bucket_name}/checkpoints', bucket_name=bucket_name)

callbacks.append(gcs_callback)

## Training

In [13]:
history = model.fit(X_train, y_train, validation_split=0.1, epochs=10, batch_size=8, callbacks=callbacks)

Epoch 1/10

Epoch 00001: saving model to /tmp/checkpoints/model.01-0.96.hdf5
Epoch 00001: Uploaded saved model to gs://callback-demo-bucket/checkpoints

Epoch 2/10

Epoch 00002: saving model to /tmp/checkpoints/model.02-0.86.hdf5
Epoch 00002: Uploaded saved model to gs://callback-demo-bucket/checkpoints

Epoch 3/10

Epoch 00003: saving model to /tmp/checkpoints/model.03-0.78.hdf5
Epoch 00003: Uploaded saved model to gs://callback-demo-bucket/checkpoints

Epoch 4/10

Epoch 00004: saving model to /tmp/checkpoints/model.04-0.71.hdf5
Epoch 00004: Uploaded saved model to gs://callback-demo-bucket/checkpoints

Epoch 5/10

Epoch 00005: saving model to /tmp/checkpoints/model.05-0.62.hdf5
Epoch 00005: Uploaded saved model to gs://callback-demo-bucket/checkpoints

Epoch 6/10

Epoch 00006: saving model to /tmp/checkpoints/model.06-0.55.hdf5
Epoch 00006: Uploaded saved model to gs://callback-demo-bucket/checkpoints

Epoch 00006: early stopping


### Verifying that checkpoints are saved to GCS

In [14]:
# Get GCS bucket
bucket = storage_client.get_bucket(bucket_name)

elements = list(bucket.list_blobs(prefix='checkpoints'))

for element in elements:
    print(element.name)

checkpoints/model.01-0.96.hdf5
checkpoints/model.02-0.86.hdf5
checkpoints/model.03-0.78.hdf5
checkpoints/model.04-0.71.hdf5
checkpoints/model.05-0.62.hdf5
checkpoints/model.06-0.55.hdf5


#### The checkpoints were copied to GCS bucket ! This can also be seen at the cloud storage UI in google cloud console to verify that checkpoints are getting copied at the end of every epoch.

## Resources

* https://www.tensorflow.org/guide/keras/custom_callback
* https://cloud.google.com/storage/docs/how-to