In [1]:
# disable warnings
import warnings
warnings.filterwarnings("ignore")

# basic imports
import tensorflow as tf 
import torch
import io
import numpy as np
import datetime
import tensorflow_hub as hub
from tensorflow import keras
from tensorflow.keras import layers

from keras.utils import dataset_utils
import matplotlib.pyplot as plt

import baseline_config

In [2]:
def paths_and_labels_to_dataset(image_paths,labels,num_classes):
    path_ds = tf.data.Dataset.from_tensor_slices(image_paths)
    img_ds = path_ds.map(
        lambda path: tf.io.read_file(path), 
        num_parallel_calls=tf.data.AUTOTUNE
    )
    label_ds = dataset_utils.labels_to_dataset(
        labels, 
        'categorical', 
        num_classes)
    img_ds = tf.data.Dataset.zip((img_ds, label_ds))
    return img_ds

def create_dataset(subset):
    image_paths, labels, class_names = dataset_utils.index_directory(
            baseline_config.dataset_path + subset,
            labels="inferred",
            formats=('.pt'),
            class_names=None,
            shuffle=False,
            seed=42,
            follow_links=False)

    dataset = paths_and_labels_to_dataset(
        image_paths=image_paths,
        labels=labels,
        num_classes=len(class_names))
    
    return dataset, class_names

train_dataset, class_names = create_dataset('TRAIN/')
test_dataset, _            = create_dataset('TEST/')
validation_dataset, _      = create_dataset('VALIDATION/')
print("class names: ", class_names)

Found 12384 files belonging to 5 classes.
Found 487 files belonging to 5 classes.
Found 384 files belonging to 5 classes.
class names:  ['brant', 'jabwar', 'sheowl', 'spodov', 'wiltur']


In [3]:
def dataset_transforms(image,label):
  
    # reshape into standard 3 channels
    image = tf.io.parse_tensor(image, tf.float32)
    image = tf.expand_dims(image, -1)
    image = tf.repeat(image, 3, 2)
    
    # rescale to range [0,1]
    image = image / tf.reduce_max(image)
    
    return image,label

train_dataset_b = ( 
                  train_dataset
                  .cache()
                  .shuffle(20000)
                  .map(dataset_transforms)
                  .batch(baseline_config.batch_size)
                  .repeat()            
                )

validation_dataset_b = ( 
                  validation_dataset
                  .map(dataset_transforms)
                  .batch(baseline_config.batch_size)
                  .cache()
                )

test_dataset_b = ( 
                  test_dataset
                  .map(dataset_transforms)
                  .batch(baseline_config.batch_size)
                  .cache()
                )

In [4]:
for item,lbl in train_dataset_b.take(1):
    print(item.shape, lbl.shape)
    print(np.min(item), np.max(item))

(32, 313, 128, 3) (32, 5)
0.0 1.0


In [5]:
def build_model(trainable):
    # build a really simple classification model using a pre-training Efficientnet V2
    model = keras.Sequential(
        [
            # use the model as a feature generator only
            # need to resize here, as the efficientnet_v2_imagenet21k_s model expects it
            tf.keras.layers.InputLayer(input_shape=(313,128,3)),
            tf.keras.layers.Resizing(224, 224, interpolation="lanczos5", crop_to_aspect_ratio=False),
            
            # use pre-trained mobilenet v2 as the feature layer (less parameters, more accessible)
            # hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4", trainable),  
            
            hub.KerasLayer("https://tfhub.dev/google/imagenet/resnet_v2_152/feature_vector/5", trainable),  
                  
            # add the classification layer here       
            layers.Flatten(),
            layers.Dense(128, activation="relu", kernel_regularizer=keras.regularizers.L1L2(l1=1e-3, l2=1e-5)),
            layers.Dense(64, activation="relu", kernel_regularizer=keras.regularizers.L1L2(l1=1e-3, l2=1e-5)),
            layers.Dense(32, activation="relu", kernel_regularizer=keras.regularizers.L1L2(l1=1e-3, l2=1e-5)),
            layers.Dropout(0.50),
            layers.Dense(len(class_names), activation=None),
        ]
    )
    # need to tell the model what the input shape is
    model.build([None, 313, 128, 3])

    # show the model
    model.summary()
    
    return model

In [6]:
# allow all the weights to be trained
model = build_model(True)

# the form_logits means the loss function has the 'softmax' buillt in.  This approach is numerically more stable
# than including the softmax activation on the last layer of the classifier
model.compile(loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), 
              optimizer=tf.keras.optimizers.Adam(learning_rate=2e-4), 
              metrics=["accuracy"],
              )

# tensorboard for visualisation of results
log_dir = "logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

# reduce learning rate to avoid overshooting local minima
lr_reduce_plateau = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1,
                                                      patience=4, verbose=1, mode='min',
                                                      cooldown=0, min_lr=1e-8)

# save the best model as it trains..
mcp_save = keras.callbacks.ModelCheckpoint('models/model_wts.hdf5', save_best_only=True, monitor='val_loss', mode='min')

# fit the model to the training set
model.fit(train_dataset_b, 
          validation_data=validation_dataset_b,
          steps_per_epoch=50,
          callbacks=[lr_reduce_plateau, tensorboard_callback, mcp_save],
          epochs=10000)

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 resizing (Resizing)         (None, 224, 224, 3)       0         
                                                                 
 keras_layer (KerasLayer)    (None, 2048)              58331648  
                                                                 
 flatten (Flatten)           (None, 2048)              0         
                                                                 
 dense (Dense)               (None, 128)               262272    
                                                                 
 dense_1 (Dense)             (None, 64)                8256      
                                                                 
 dense_2 (Dense)             (None, 32)                2080      
                                                                 
 dropout (Dropout)           (None, 32)                0