In [1]:
# disable warnings
import warnings
warnings.filterwarnings("ignore")

# basic imports
import torch
import timm
import tfimm
import os
import numpy as np
import datetime
import tensorflow as tf 
# import tensorflow_hub as hub

from keras.utils import dataset_utils
import matplotlib.pyplot as plt

import baseline_config

In [2]:
########################################################################################
# MODEL PARAMETERS
########################################################################################
MODEL_INPUT_IMAGE_WIDTH = 256
MODEL_INPUT_IMAGE_HEIGHT = 256

In [3]:
def paths_and_labels_to_dataset(image_paths,labels,num_classes):
    path_ds = tf.data.Dataset.from_tensor_slices(image_paths)
    img_ds = path_ds.map(
        lambda path: tf.io.read_file(path), 
        num_parallel_calls=tf.data.AUTOTUNE
    )
    label_ds = dataset_utils.labels_to_dataset(
        labels, 
        'categorical', 
        num_classes)
    img_ds = tf.data.Dataset.zip((img_ds, label_ds))
    return img_ds

def create_dataset(subset):
    image_paths, labels, class_names = dataset_utils.index_directory(
            baseline_config.dataset_path + subset,
            labels="inferred",
            formats=('.pt'),
            class_names=None,
            shuffle=False,
            seed=42,
            follow_links=False)

    dataset = paths_and_labels_to_dataset(
        image_paths=image_paths,
        labels=labels,
        num_classes=len(class_names))
    
    return dataset, class_names

train_dataset, class_names = create_dataset('TRAIN/')
test_dataset, _            = create_dataset('TEST/')
validation_dataset, _      = create_dataset('VALIDATION/')
print("class names: ", class_names)

Found 12384 files belonging to 5 classes.
Found 487 files belonging to 5 classes.
Found 384 files belonging to 5 classes.
class names:  ['brant', 'jabwar', 'sheowl', 'spodov', 'wiltur']


In [4]:
def dataset_transforms(image,label):
  
    # reshape into standard 3 channels
    image = tf.io.parse_tensor(image, tf.float32)
    image = tf.expand_dims(image, -1)
    
    # most pre-trained model expect 3 color channels
    image = tf.repeat(image, 3, 2)
    
    image = tf.ensure_shape(image, [313, 128, 3])
    image = tf.image.resize(image, (MODEL_INPUT_IMAGE_WIDTH,MODEL_INPUT_IMAGE_HEIGHT), method=tf.image.ResizeMethod.LANCZOS5)
    
    # rescale to range [0,1]
    image = image - tf.reduce_min(image) 
    image = image / tf.reduce_max(image)
    
    return image,label

train_dataset_b = ( 
                  train_dataset       
                  .shuffle(20000)
                  .map(dataset_transforms)
                  .batch(baseline_config.batch_size)
                  .cache()          
                )

validation_dataset_b = ( 
                  validation_dataset
                  .map(dataset_transforms)
                  .batch(baseline_config.batch_size)
                  .cache()
                )

test_dataset_b = ( 
                  test_dataset
                  .map(dataset_transforms)
                  .batch(baseline_config.batch_size)
                  .cache()
                )

In [5]:
for item,lbl in train_dataset_b.take(1):
    print(item.shape, lbl.shape)
    print(np.min(item), np.max(item))

(64, 256, 256, 3) (64, 5)
0.0 1.0


In [6]:
# uncomment to list all possible timm models
tfimm.list_models()

['cait_m36_384',
 'cait_m48_448',
 'cait_s24_224',
 'cait_s24_384',
 'cait_s36_384',
 'cait_xs24_384',
 'cait_xxs24_224',
 'cait_xxs24_384',
 'cait_xxs36_224',
 'cait_xxs36_384',
 'convmixer_768_32',
 'convmixer_1024_20_ks9_p14',
 'convmixer_1536_20',
 'convnext_base',
 'convnext_base_384_in22ft1k',
 'convnext_base_in22ft1k',
 'convnext_base_in22k',
 'convnext_large',
 'convnext_large_384_in22ft1k',
 'convnext_large_in22ft1k',
 'convnext_large_in22k',
 'convnext_small',
 'convnext_small_384_in22ft1k',
 'convnext_small_in22ft1k',
 'convnext_small_in22k',
 'convnext_tiny',
 'convnext_tiny_384_in22ft1k',
 'convnext_tiny_in22ft1k',
 'convnext_tiny_in22k',
 'convnext_xlarge_384_in22ft1k',
 'convnext_xlarge_in22ft1k',
 'convnext_xlarge_in22k',
 'deit_base_distilled_patch16_224',
 'deit_base_distilled_patch16_384',
 'deit_base_patch16_224',
 'deit_base_patch16_384',
 'deit_small_distilled_patch16_224',
 'deit_small_patch16_224',
 'deit_tiny_distilled_patch16_224',
 'deit_tiny_patch16_224',
 '

In [7]:
class CustomTimmModel(tf.keras.Model):
    
    def __init__(self, *args, **kwargs):  
        super(CustomTimmModel, self).__init__(*args, **kwargs)
        self.fm = tfimm.create_model("efficientnet_v2_l_in21ft1k", pretrained=True)
        self.flat = tf.keras.layers.Flatten()
        self.fc1 = tf.keras.layers.Dense(len(class_names)*20, activation=tf.keras.activations.swish)
        self.fc2 = tf.keras.layers.Dense(len(class_names)*20, activation=tf.keras.activations.swish)
        self.fc3 = tf.keras.layers.Dense(len(class_names)*1, activation=tf.keras.activations.linear)
        self.do1 = tf.keras.layers.Dropout(0.50)

    def call(self, inputs, training=False):
        x = self.fm.forward_features(inputs) 
        x = self.flat(x)
        x = self.fc1(x)
        x = self.fc2(x)               
        x = self.do1(x)
        x = self.fc3(x)
        return x
        
# construct the model
model = CustomTimmModel()

model.fm.trainable=False

# need to tell the model what the input shape is
model.build([None, MODEL_INPUT_IMAGE_WIDTH, MODEL_INPUT_IMAGE_HEIGHT, 3])

All PyTorch model weights were used when initializing EfficientNet.
All the weights of EfficientNet were initialized from the PyTorch model.



In [8]:
if not os.path.exists('models/'):
    os.mkdir('models/')
    
# the form_logits means the loss function has the 'softmax' buillt in.  This approach is numerically more stable
# than including the softmax activation on the last layer of the classifier
model.compile(loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), 
              optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-4), 
              metrics=["accuracy"],
              )

# tensorboard for visualisation of results
log_dir = "tensorboard_logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, 
                                                      histogram_freq=1)

# reduce learning rate to avoid overshooting local minima
lr_reduce_plateau = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', 
                                                      factor=0.5,
                                                      patience=4, 
                                                      verbose=1,
                                                      mode='min',
                                                      cooldown=0, 
                                                      min_lr=1e-8)

# end the training if no improvement for 8 epochs in a row, then restore best model weights
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0,
    patience=8,
    verbose=0,
    mode="min",
    baseline=None,
    restore_best_weights=True,
)

# save the best model as it trains..
mcp_save = tf.keras.callbacks.ModelCheckpoint('models/checkpoint_timm_model.hdf5', 
                                            save_weights_only=True,
                                           save_best_only=True, 
                                           monitor='val_loss', 
                                           mode='min')

# fit the model to the training set
history = model.fit(train_dataset_b, 
          validation_data=validation_dataset_b,
          callbacks=[lr_reduce_plateau, early_stopping, tensorboard_callback, mcp_save],
          epochs=10)

Epoch 1/10
Epoch 2/10

In [None]:
# fine tune
model.fm.trainable=True
model.make_train_function(force=True)

model.optimizer.learning_rate=1e-6
history = model.fit(train_dataset_b, 
          validation_data=validation_dataset_b,
          callbacks=[lr_reduce_plateau, early_stopping, tensorboard_callback, mcp_save],
          epochs=30)

Epoch 1/30


In [None]:
# save the model weights
model.save_weights('models/baseline_timm_model.hdf5', save_format='h5')

# test load the model for inference
test_model = CustomTimmModel()
test_model.build([None, 313, 128, 3])
test_model.load_weights('models/baseline_timm_model.hdf5')