In [9]:
# disable warnings
import warnings
warnings.filterwarnings("ignore")

# basic imports
import torch
import timm
import tfimm
import os
import numpy as np
import datetime
import tensorflow as tf 
# import tensorflow_hub as hub

from keras.utils import dataset_utils
import matplotlib.pyplot as plt

import baseline_config

In [10]:
def paths_and_labels_to_dataset(image_paths,labels,num_classes):
    path_ds = tf.data.Dataset.from_tensor_slices(image_paths)
    img_ds = path_ds.map(
        lambda path: tf.io.read_file(path), 
        num_parallel_calls=tf.data.AUTOTUNE
    )
    label_ds = dataset_utils.labels_to_dataset(
        labels, 
        'categorical', 
        num_classes)
    img_ds = tf.data.Dataset.zip((img_ds, label_ds))
    return img_ds

def create_dataset(subset):
    image_paths, labels, class_names = dataset_utils.index_directory(
            baseline_config.dataset_path + subset,
            labels="inferred",
            formats=('.pt'),
            class_names=None,
            shuffle=False,
            seed=42,
            follow_links=False)

    dataset = paths_and_labels_to_dataset(
        image_paths=image_paths,
        labels=labels,
        num_classes=len(class_names))
    
    return dataset, class_names

train_dataset, class_names = create_dataset('TRAIN/')
test_dataset, _            = create_dataset('TEST/')
validation_dataset, _      = create_dataset('VALIDATION/')
print("class names: ", class_names)

Found 12384 files belonging to 5 classes.
Found 487 files belonging to 5 classes.
Found 384 files belonging to 5 classes.
class names:  ['brant', 'jabwar', 'sheowl', 'spodov', 'wiltur']


In [11]:
def dataset_transforms(image,label):
  
    # reshape into standard 3 channels
    image = tf.io.parse_tensor(image, tf.float32)
    image = tf.expand_dims(image, -1)
    
    # most pre-trained model expect 3 color channels
    image = tf.repeat(image, 3, 2)
    
    # rescale to range [0,1]
    image = image / tf.reduce_max(image)
    
    image = tf.ensure_shape(image, [313, 128, 3])
    
    return image,label

train_dataset_b = ( 
                  train_dataset       
                  .shuffle(20000)
                  .map(dataset_transforms)
                  .batch(baseline_config.batch_size)
                  .cache()
                  #.repeat()            
                )

validation_dataset_b = ( 
                  validation_dataset
                  .map(dataset_transforms)
                  .batch(baseline_config.batch_size)
                  .cache()
                )

test_dataset_b = ( 
                  test_dataset
                  .map(dataset_transforms)
                  .batch(baseline_config.batch_size)
                  .cache()
                )

In [12]:
for item,lbl in train_dataset_b.take(1):
    print(item.shape, lbl.shape)
    print(np.min(item), np.max(item))

(32, 313, 128, 3) (32, 5)
0.0 1.0


In [13]:
# uncomment to list all possible timm models
# tfimm.list_models()

In [14]:
class CustomTimmModel(tf.keras.Model):
    
    def __init__(self, *args, **kwargs):  
        super(CustomTimmModel, self).__init__(*args, **kwargs)
        self.rs = tf.keras.layers.Resizing(224, 224, interpolation="lanczos5", crop_to_aspect_ratio=False)
        self.fm = tfimm.create_model("swin_tiny_patch4_window7_224", pretrained=True)
        self.flat = tf.keras.layers.Flatten()
        self.fc1 = tf.keras.layers.Dense(len(class_names)*20, activation="relu")
        self.fc2 = tf.keras.layers.Dense(len(class_names)*1, activation="relu")
        self.do1 = tf.keras.layers.Dropout(0.5)

    def call(self, inputs, training=False):
        x = self.rs(inputs)
        x = self.fm.forward_features(x) 
        x = self.flat(x)
        x = self.fc1(x)                   
        x = self.do1(x)
        x = self.fc2(x)
        return x
    
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}
        
    def compile(self, *args, **kwargs):
        super(CustomTimmModel, self).compile(*args, **kwargs)  

    def build(self, *args, **kwargs):
        super(CustomTimmModel, self).build(*args, **kwargs) 

    #def metrics(self, *args, **kwargs):
    #    super(CustomTimmModel, self).metrics(*args, **kwargs) 

    #def reset_metrics(self):
    #    super(CustomTimmModel, self).reset_metrics() 
        
# construct the model
model = CustomTimmModel()

# need to tell the model what the input shape is
model.build([None, 313, 128, 3])

All PyTorch model weights were used when initializing SwinTransformer.
All the weights of SwinTransformer were initialized from the PyTorch model.



In [15]:
if not os.path.exists('models/'):
    os.mkdir('models/')
    
# the form_logits means the loss function has the 'softmax' buillt in.  This approach is numerically more stable
# than including the softmax activation on the last layer of the classifier
model.compile(loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), 
              optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-3), 
              metrics=[tf.keras.metrics.Accuracy()],
              )

# tensorboard for visualisation of results
log_dir = "tensorboard_logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, 
                                                      histogram_freq=1)

# reduce learning rate to avoid overshooting local minima
lr_reduce_plateau = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', 
                                                      factor=0.5,
                                                      patience=4, 
                                                      verbose=1,
                                                      mode='min',
                                                      cooldown=0, 
                                                      min_lr=1e-8)

# end the training if no improvement for 16 epochs in a row, then restore best model weights
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0,
    patience=8,
    verbose=0,
    mode="min",
    baseline=None,
    restore_best_weights=True,
)

# save the best model as it trains..
mcp_save = tf.keras.callbacks.ModelCheckpoint('models/checkpoint_timm_model.hdf5', 
                                            save_weights_only=True,
                                           save_best_only=True, 
                                           monitor='val_loss', 
                                           mode='min')

# fit the model to the training set
model.fit(train_dataset_b, 
          validation_data=validation_dataset_b,
          steps_per_epoch=10,
          callbacks=[lr_reduce_plateau, early_stopping, tensorboard_callback, mcp_save],
          epochs=1000)

Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 5: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 9: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.


<keras.callbacks.History at 0x26118a43580>

In [18]:
# save the model weights
model.save_weights('models/baseline_timm_model.hdf5', save_format='h5')

# test load the model for inference
test_model = CustomTimmModel()
test_model.build([None, 313, 128, 3])
test_model.load_weights('models/baseline_timm_model.hdf5')

All PyTorch model weights were used when initializing SwinTransformer.
All the weights of SwinTransformer were initialized from the PyTorch model.

