In [None]:
!pip install -U keras-tuner

# Load prerequisites

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
# Import packages
import os, numpy as np, matplotlib.pyplot as plt, tensorflow as tf, tensorflow_datasets as tfds, kerastuner as kt
from tensorflow.keras.layers import InputLayer, Conv2D, MaxPooling2D, Flatten, Dense
print('Num GPUs Available: ', len(tf.config.experimental.list_physical_devices('GPU')))

In [None]:
# Locate custom dataset directory
%cd /content/gdrive/My Drive/Colab/plankton/data/tfds_dataset
# Create custom dataset
#!tfds new plankton_images_tfds

In [None]:
# Import and initialize dataset
import plankton_images_tfds
# load dataset the first time
ds, info = tfds.load('plankton_images_tfds', as_supervised=True, with_info=True)

# Train prep

In [None]:
# Image preprocess method
img_height = 180
img_width = 180

def preprocess(img, label):
  image = tf.image.resize(img, [img_height, img_width]) / 255
  return tf.image.rgb_to_grayscale(image), label

# Ensure shuffle consistency
seed = 42 #@param {type:"integer"}
tf.random.set_seed(seed)
read_config = tfds.ReadConfig(shuffle_seed=seed)

In [None]:
#@title Training hyperparams
k = 10 #@param {type:"slider", min:0, max:100, step:1}
epochs = 10 #@param {type:"integer"}
# Ensure shuffle consistency
seed = 7 #@param {type:"integer"}
tf.random.set_seed(seed)
read_config = tfds.ReadConfig(shuffle_seed=seed)

# Configure dataset for performance
AUTOTUNE = tf.data.AUTOTUNE

train_ds = tfds.load(name='plankton_images_tfds',
                     split='train[0%:80%]',
                     shuffle_files=True,
                     as_supervised=True,
                     read_config=read_config)
train_ds = train_ds.map(preprocess).batch(32)
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)

val_ds = tfds.load(name='plankton_images_tfds',
                   split='train[80%:90%]',
                   shuffle_files=True,
                   as_supervised=True,
                   read_config=read_config)
val_ds = val_ds.map(preprocess).batch(32)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# Model

In [None]:
# Check backend
!cat ~/.keras/keras.json

In [None]:
# Hyperband optimization for model archtecture+
class CNNHyperModel(kt.HyperModel):
  def __init__(self, input_shape, num_classes):
    self.input_shape = input_shape
    self.num_classes = num_classes

  def build(self, hp):
    model = tf.keras.Sequential()
    model.add(InputLayer(input_shape=(img_height, img_width, 1)))

    for i in range(hp.Int('num_blocks', min_value=1, max_value=10, step=1)):

      model.add(Conv2D(
          hp.Int('num_kernel', min_value=1, max_value=64, step=1),
          (hp.Int('kernel_width', min_value=1, max_value=3, step=1),hp.Int('kernel_height', min_value=1, max_value=3, step=1)),
          (hp.Int('conv_stride_width', min_value=1, max_value=3, step=1),hp.Int('conv_stride_height', min_value=1, max_value=3, step=1)),
          padding='same',
          activation='relu'))

      model.add(MaxPooling2D(
          (hp.Int('pool_width', min_value=1, max_value=3, step=1),hp.Int('pool_height', min_value=1, max_value=3, step=1)),
          (hp.Int('pool_stride_width', min_value=1, max_value=3, step=1),hp.Int('pool_stride_height', min_value=1, max_value=3, step=1)),
          padding='same'))

    model.add(Flatten())

    model.add(Flatten())

    model.add(Dense(1, activation='sigmoid'))

    model.compile(
        optimizer=tf.keras.optimizers.Adam(
            learning_rate=hp.Float(
                'learning_rate',
                min_value=1e-4,
                max_value=1e-2,
            )
        ),
        loss='binary_crossentropy',
        metrics=['accuracy']
    )

    return model

hypermodel = CNNHyperModel(input_shape=(180, 180, 1), num_classes=2)

In [None]:
tuner = kt.tuners.hyperband.Hyperband(
    hypermodel,
    objective='accuracy',
    max_epochs=epochs,
)

In [None]:
tuner.search_space_summary()

In [None]:
tuner.search(
    train_ds,
    epochs=epochs,
    validation_data=val_ds,
    #class_weight=class_weight,
)

In [None]:
model = tuner.get_best_models(num_models=1)[0]

In [None]:
# Saved tuned model
%cd /content/gdrive/My Drive/Colab/plankton/model/
model.save('tuned_model')

In [None]:
model.summary()

In [None]:
tf.keras.utils.plot_model(model,
                          to_file='model.png',
                          show_shapes=True,
                          show_dtype=True,
                          show_layer_names=True,
                          )

# Train and evaluate

## Re-randomize model weights

In [None]:
# Get hypertuned model in case
%cd /content/gdrive/My Drive/Colab/plankton/model/
model = tf.keras.models.load_model('tuned_model')

# Check architecture
model.summary()

In [None]:
# Clone and compare weights
original_weights = model.get_weights()
print("Original weights", original_weights[0])
print("========================================================")
print("========================================================")
print("========================================================")
model_cloned = tf.keras.models.clone_model(model)
new_weights = model_cloned.get_weights()
print("New weights", new_weights[0])

In [None]:
print(model.optimizer.learning_rate)

In [None]:
model_cloned.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=model.optimizer.learning_rate),
    loss='binary_crossentropy',
    metrics=['accuracy']
)
# Learning rate for saved model
print(model_cloned.optimizer.learning_rate)

In [None]:
model = model_cloned
model_weights = model.get_weights()
print(model_weights[0])

In [None]:
# make sure weights are randomized
test_ds = tfds.load(name='plankton_images_tfds',
                    split='train[90%:100%]',
                    shuffle_files=True,
                    as_supervised=True,
                    read_config=read_config)
test_ds = test_ds.map(preprocess).batch(32)
test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)

test_history = model.evaluate(
    test_ds
)

## Training

### loss convergence

In [None]:
# Pickle in case of RAM outage
import pickle

#@title Training hyperparams
k = 10 #@param {type:"slider", min:0, max:100, step:1}
batch = 16 #@param {type:"slider", min:0, max:128, step:1}
epochs =  100000#@param {type:"integer"}

# Configure dataset for performance
AUTOTUNE = tf.data.AUTOTUNE

# Record accuracy histories for each fold
acc_histories = []

# Reset model weights
reset_model = lambda model : model.set_weights(model_weights)
reset_model(model)

# Early stopping for convergence
convergence_callback = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    min_delta=1e-20,
    patience=3,
    verbose=0,
    mode='auto',
    baseline=None,
    restore_best_weights=False,
)

dataset = 'plankton_images_tfds'
crs_vld_begin = 0
crs_vld_end = 90
test_begin = 90
test_end = 100


def cross_validate(dataset, train_range, val_range):
  train_ds = tfds.load(name=dataset,
                       split=train_range,
                       shuffle_files=True,
                       as_supervised=True,
                       read_config=read_config)
  train_ds = train_ds.map(preprocess).batch(batch)
  train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
  val_ds = tfds.load(name=dataset,
                     split=val_range,
                     shuffle_files=True,
                     as_supervised=True,
                     read_config=read_config)
  val_ds = val_ds.map(preprocess).batch(batch)
  val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
  acc_history = model.fit(
      train_ds,
      epochs=epochs,
      callbacks=[convergence_callback],
      validation_data=val_ds,
  )
  acc_histories.append(acc_history)
  with open('acc_histories.pkl', 'wb') as f:
    pickle.dump(acc_histories, f)
  reset_model(model)


for fold in range(k):
  print(f"FOLD {fold + 1}", end=": ")
  
  val_end = crs_vld_end - fold*(100/k-1)
  val_start = val_end - 100/k + 1
  
  val_range = f'train[{val_start}%:{val_end}%]'
  
  # Validate ending
  if val_end == crs_vld_end:
    train_start = crs_vld_begin
    train_end = val_start
    train_range = f'train[{train_start}%:{train_end}%]'
  # Validate beginning
  elif val_start == crs_vld_begin:
    train_start = val_end
    train_end = crs_vld_end
    train_range = f'train[{train_start}%:{train_end}%]'
  # Middle
  else:
    train_start = crs_vld_begin
    train_mid1 = val_start
    train_mid2 = val_end
    train_end = crs_vld_end
    train_range = f'train[{train_start}%:{train_mid1}%]+train[{train_mid2}%:{train_end}%]'
  
  cross_validate(dataset, train_range, val_range)

In [None]:
# Final training before testing
print('Fold',10)
train_ds = tfds.load(name='plankton_images_tfds',
                     split='train[0%:90%]',
                     shuffle_files=True,
                     as_supervised=True,
                     read_config=read_config)
train_ds = train_ds.map(preprocess).batch(batch)
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = tfds.load(name='plankton_images_tfds',
                   split='train[90%:100%]',
                   shuffle_files=True,
                   as_supervised=True,
                   read_config=read_config)
val_ds = val_ds.map(preprocess).batch(batch)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
acc_history = model.fit(
    train_ds,
    epochs=epochs,
    callbacks=[convergence_callback],
    validation_data=val_ds,
    #class_weight=class_weight,
)

In [None]:
test_ds = tfds.load(name='plankton_images_tfds',
                    split='train[90%:100%]',
                    shuffle_files=True,
                    as_supervised=True,
                    read_config=read_config)
test_ds = test_ds.map(preprocess).batch(batch)
test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)

test_history = model.evaluate(
    test_ds
)

In [None]:
with open('acc_histories.pkl', 'rb') as f:
  acc = pickle.load(f)

In [None]:
# list all data in history
for acc in acc_histories:
  print(acc.history.keys())

fold_num = 1
for acc in acc_histories:
  # summarize history for accuracy
  plt.plot(acc.history['accuracy'])
  plt.plot(acc.history['val_accuracy'])
  title = 'Fold ' + str(fold_num) + ': model accuracy'
  plt.title(title)
  plt.ylabel('accuracy')
  plt.xlabel('epoch')
  plt.legend(['train', 'validation'], loc='upper left')
  plt.grid()
  plt.show()
  # summarize history for loss
  plt.plot(acc.history['loss'])
  plt.plot(acc.history['val_loss'])
  title = 'Fold ' + str(fold_num) + ': model loss'
  plt.title(title)
  plt.ylabel('loss')
  plt.xlabel('epoch')
  plt.legend(['train', 'validation'], loc='upper left')
  plt.grid()
  plt.show()
  fold_num += 1