### Model architecture demonstration

#### Dataset used: https://www.kaggle.com/datasets/iandutoit/crustacea-zooscan-image-database

In [None]:
import numpy as np
import pathlib

import tensorflow as tf
from tensorflow import keras

from keras.models import Sequential
from keras import layers

In [None]:
data_dir = pathlib.Path('/zooplankton/train')

In [None]:
# Extract class labels

images = []
lables = []

import os
for root, dirs, files in os.walk('/zooplankton/train'):
  for dir in dirs:
    lst = os.listdir('/zooplankton/train/' + str(dir))
    templabels = [str(dir)] * len(lst)
    lables.extend(templabels)
  for img in files:
    images.append(img)

In [None]:
# Constuct training and validation sets

batch_size = 32
img_height = 90
img_width = 90

train_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  #color_mode = "grayscale",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

val_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  #color_mode = "grayscale",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

In [None]:
# Image augemntation

data_augmentation = tf.keras.Sequential([
  layers.Reshape(target_shape = (90,90,3))
])

'''
  #layers.Rescaling(1./255),
  layers.RandomFlip("horizontal",
                      input_shape=(28,
                                  28,
                                  1)),
  layers.RandomRotation(0.3),
  layers.RandomZoom(0.2, 0.5),
'''

In [None]:
model = Sequential([
  data_augmentation,
  layers.Rescaling(1./255, input_shape=(img_height, img_width, 3)),

  layers.Conv2D(16, 4, strides= 1, padding='same', activation='relu', input_shape=(img_height, img_width, 3)),
  layers.BatchNormalization(),
  
  layers.MaxPooling2D(2),
  
  layers.Conv2D(32, 3, strides= 1, padding='same', activation='relu', input_shape=(img_height, img_width, 3)),
  layers.BatchNormalization(),
  layers.Conv2D(64, 3, strides= 1, padding='same', activation='relu', input_shape=(img_height, img_width, 3)),
  layers.BatchNormalization(),

  layers.Dropout(0.1),
  layers.MaxPooling2D(2),

  layers.Conv2D(128, 3, strides= 1, padding='same', activation='relu', input_shape=(img_height, img_width, 3)),
  layers.BatchNormalization(),
  layers.Dropout(0.2),

  layers.Conv2D(64, 3, strides= 1, padding='same', activation='relu', input_shape=(img_height, img_width, 3)),
  layers.BatchNormalization(),
  layers.Conv2D(32, 3, strides= 1, padding='same', activation='relu', input_shape=(img_height, img_width, 3)),
  layers.BatchNormalization(),

  layers.Flatten(),
  layers.Dense(4096), 
  layers.Dense(24)
])

model.summary()

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [None]:
# Callback, stops training if no progress is being made and loads best epoch weights

class zooplankton_callback(keras.callbacks.Callback):
    def __init__ (self, model, epochs):
        super(zooplankton_callback, self).__init__()
        self.model=model               
        self.epochs=epochs
        self.lowest_vloss=np.inf
        self.best_weights=self.model.get_weights()
        self.best_epoch=1
                
    def on_train_end(self, logs=None):  
        self.model.set_weights(self.best_weights) 
        
    def on_epoch_end(self, epoch, logs=None):  
        v_loss=logs.get('val_loss')  
        if v_loss< self.lowest_vloss:
            self.lowest_vloss=v_loss
            self.best_weights=self.model.get_weights() 
            self.best_epoch=epoch + 1

In [None]:
epochs = 10

callbacks=[zooplankton_callback(model, epochs)]

model.fit(x=train_ds,  
            epochs=epochs,
            verbose=1,
            callbacks=callbacks,  
            validation_data=val_ds,
            validation_steps=None,
            shuffle=False,  
            initial_epoch=0)

In [None]:
# Export model

io_option = tf.saved_model.SaveOptions(experimental_io_device="/job:localhost")
model.save("model", options=io_option)