# Training notebook for training and evaluating the Dev-ResNet model 

A notebook outlining the training and evaluation procedure for Dev-ResNet. Note that this is specifically for a dataset comprising developmental sequences of the great pond snail, Lymnaea stagnalis.

## Dependencies

The following are required dependencies for this script. We also set up mixed precision training for the speedup it provides in training time.

In [None]:
import glob
import vuba
import cv2
import numpy as np
import re
from tensorflow import keras
import tensorflow as tf
import pandas as pd
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import multiprocessing as mp
from tqdm import tqdm
from typing import Tuple
import atexit
import time
import os
import ujson
import math
import seaborn as sns
from mpl_toolkits.axes_grid1 import ImageGrid

from dev_resnet import DevResNet

from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')

physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
 
# Parameters ----------------------------------------------------------
batch_size = 32
input_shape = (12,128,128,1)
epochs = 50
model_save_dir = './trained_models'
model_name = 'Dev-Resnet_lymnaea'
events = ['pre_gastrula', 'gastrula', 'trocophore', 'veliger', 'eye', 'heart', 'crawling', 'radula', 'hatch', 'dead']

train_data_path = './annotations_train_aug.csv'
val_data_path = './annotations_val.csv'
test_data_path = './annotations_test.csv'

# ---------------------------------------------------------------------

## Dataset pipeline

The following dataset pipeline is for an augmented dataset generated from manually annotated developmental sequences of *Lymnaea stagnalis*. Note that images are rescaled by default in the model so images can be supplied in uint8 format.

If you wish to train Dev-ResNet on this video dataset, please download and extract the following dataset into the same folder as this notebook: https://zenodo.org/record/8214975

In [None]:
def read_data(fn, label):
    gif = tf.io.read_file(fn)
    gif = tf.image.decode_gif(gif)
    gif = tf.image.resize_with_pad(gif, 128, 128)
    gif = tf.image.rgb_to_grayscale(gif)
    return gif, label

def dataset(images, labels, batch_size): 
    data = tf.data.Dataset.from_tensor_slices((images, labels))
    data = data.map(read_data, num_parallel_calls=tf.data.AUTOTUNE)
    data = data.batch(batch_size, drop_remainder=True)
    return data

annotations_train = pd.read_csv(train_data_path)
annotations_train = annotations_train.sample(frac=1).reset_index(drop=True)
annotations_train['categorical'] = [events.index(e) for e in annotations_train.single_event]

annotations_val = pd.read_csv(val_data_path)
annotations_val = annotations_val.sample(frac=1).reset_index(drop=True)
annotations_val['categorical'] = [events.index(e) for e in annotations_val.single_event]

annotations_test = pd.read_csv(test_data_path)
annotations_test['categorical'] = [events.index(e) for e in annotations_test.single_event]

# Training data pipeline
train_files = list(annotations_train.out_file)
train_labels = list(annotations_train.categorical)

val_files = list(annotations_val.out_file)
val_labels = list(annotations_val.categorical)

# Test data pipeline
test_files = list(annotations_test.out_file)
test_labels = list(annotations_test.categorical)

train_data = dataset(train_files, train_labels, batch_size)
val_data = dataset(val_files, val_labels, batch_size)   
test_data = dataset(test_files, test_labels, batch_size)

for b in train_data:
    images, labels = b
    print(images.shape)
    print(labels)
    break

# 4x4 grid for batch size of 32
fig = plt.figure(figsize=(8., 8.))
grid = ImageGrid(fig, 111,
             nrows_ncols=(4, 4),
             axes_pad=0.3,
)

for i, (v, f, ax) in enumerate(zip(images, labels, grid)):

    im = v[0,:,:,0]
    event = events[f]
    
    ax.set_title(event)
    ax.imshow(im, cmap='gray')

plt.show()

## Training and evaluation

This is the main training loop for constructing, training and computing summary metrics for Dev-ResNet. 

In [None]:
# Train and evaluate with three different seeds for computing metrics
for i in range(3):
    np.random.seed(i)
    tf.random.set_seed(i)
    
    model = DevResNet(input_shape, n_classes=len(events))

    model.compile(
        # Fixed learning rate of 1e-6 works particularly well for convergence after 50 epochs
        optimizer=keras.optimizers.Adam(learning_rate=0.000001), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    class EvaluateCallback(keras.callbacks.Callback):
        def __init__(self):
            super().__init__()
            self.loss = []
            self.accuracy = []

        def on_epoch_end(self, epoch, log=None):
            loss, acc = self.model.evaluate(test_data, verbose=0)
            print('-', 'test_loss:', round(loss, 4), 'test_accuracy:', round(acc, 4))
            self.loss.append(loss)
            self.accuracy.append(acc)

    evaluate_callback = EvaluateCallback()
    callbacks = [
        keras.callbacks.ModelCheckpoint(
            filepath=f'{model_save_dir}/{model_name}_{i}.h5',
            save_best_only=True,
            monitor='val_accuracy',
            save_weights_only=True
        ),
        evaluate_callback
    ]

    start = time.time()
    history = model.fit(
        train_data,
        epochs=epochs, 
        callbacks=callbacks,
        validation_data=val_data)        
    end = time.time()

    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.plot(evaluate_callback.loss)
    plt.show()

    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.plot(evaluate_callback.accuracy)
    plt.show()

    model.load_weights(f'{model_save_dir}/{model_name}_{i}.h5')
    test_loss, test_accuracy = model.evaluate(test_data)

    fig = plt.figure(dpi=150)

    counter = 0
    for batch in test_data:
        ims, labels = batch
        preds = model.predict_on_batch(ims)

        preds = tf.argmax(preds, 1)
        at_cfm = tf.math.confusion_matrix(labels, preds, num_classes=len(events))

        if counter == 0:
            cfm = at_cfm
        else:
            cfm += at_cfm

        counter += 1

    cfm = cfm.numpy()
    cfm = cfm.astype('float') / cfm.sum(axis=1)[:, np.newaxis]
    sns.heatmap(cfm, annot=True, fmt='.2f')
    plt.xticks([])
    plt.yticks([])
    plt.show()
    
    del model
    keras.backend.clear_session()
