# Notebook for training Dev-ResNet with Triplet Semi-Hard loss

A notebook outlining the training and inference process for Dev-ResNet with Triplet Semi-Hard loss, for the purpose of visualising the differences between detected features in high-dimensional space between developmental events. Note that the shape of the continuum will look slightly different each run because the random seeds used internally are not set.

In [None]:
import glob
import vuba
import cv2
import numpy as np
import re
from tensorflow import keras
import tensorflow as tf
import pandas as pd
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import multiprocessing as mp
from tqdm import tqdm
from typing import Tuple
import atexit
import time
import os
import ujson
import math
import seaborn as sns
from mpl_toolkits.axes_grid1 import ImageGrid

from dev_resnet import DevResNet

from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')

physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
    
# Parameters ----------------------------------------------------------
batch_size = 16
input_shape = (12, 128, 128, 1)
epochs = 50
model_save_dir = './trained_models'
events = ['pre_gastrula', 'gastrula', 'trocophore', 'veliger', 'eye', 'heart', 'crawling', 'radula', 'hatch', 'dead']

# Labels used for plotting below
act_events = ['Pre-Gastrula', 'Gastrula', 'Trochophore', 'Veliger', 'Eye spots', 'Heart beat', 'Crawling', 'Radula', 'Hatch', 'Dead']

train_data_path = './annotations_train_aug.csv'
val_data_path = './annotations_val.csv'
test_data_path = './annotations_test.csv'

model_save_name = 'Dev-Resnet_lymnaea_TripLet'

# ---------------------------------------------------------------------

# Dataset handling

Here we use the same dataset pipelines as in the training notebook for training the original 3D-CNN, Dev-ResNet.

In [None]:
import tensorflow_addons as tfa

def read_data(fn, label):
    gif = tf.io.read_file(fn)
    gif = tf.image.decode_gif(gif)
    gif = tf.image.resize_with_pad(gif, 128, 128)
    gif = tf.image.rgb_to_grayscale(gif)
    return gif, label

def dataset(images, labels, batch_size): 
    data = tf.data.Dataset.from_tensor_slices((images, labels))
    data = data.map(read_data, num_parallel_calls=tf.data.AUTOTUNE)
    data = data.batch(batch_size, drop_remainder=False)
    return data

annotations_train = pd.read_csv(train_data_path)
annotations_train = annotations_train.sample(frac=1).reset_index(drop=True)
annotations_train['categorical'] = [events.index(e) for e in annotations_train.single_event]

annotations_val = pd.read_csv(val_data_path)
annotations_val = annotations_val.sample(frac=1).reset_index(drop=True)
annotations_val['categorical'] = [events.index(e) for e in annotations_val.single_event]

annotations_test = pd.read_csv(test_data_path)
annotations_test['categorical'] = [events.index(e) for e in annotations_test.single_event]

# Training data pipeline
train_files = list(annotations_train.out_file)
train_labels = list(annotations_train.categorical)

val_files = list(annotations_val.out_file)
val_labels = list(annotations_val.categorical)

# Test data pipeline
test_files = list(annotations_test.out_file)
test_labels = list(annotations_test.categorical)

train_data = dataset(train_files, train_labels, batch_size)
val_data = dataset(val_files, val_labels, batch_size)   
test_data = dataset(test_files, test_labels, batch_size)

for b in train_data:
    images, labels = b
    print(images.shape)
    print(labels)
    break

# 4x4 grid for batch size of 32
fig = plt.figure(figsize=(8., 8.))
grid = ImageGrid(fig, 111,
             nrows_ncols=(4, 4),
             axes_pad=0.3,
)

for i, (v, f, ax) in enumerate(zip(images, labels, grid)):

    im = v[0,:,:,0]
    event = events[f]
    
    ax.set_title(event)
    ax.imshow(im, cmap='gray')

plt.show()

# Instantiate and compile model for training with Triplet Semi-Hard loss

Here we replace the final classification block in Dev-ResNet with a fully connected layer with L2 normalization for training with triplet semi-hard loss. We only train the model 20 epochs rather than the original 50, since 20 is sufficient to reach conversion for this task. 

In [None]:
# Set seed for reproducible results
np.random.seed(1)
tf.random.set_seed(1)

# Instantiate and compile modified Dev-ResNet network for training with Triplet loss
inputs = keras.Input(input_shape)
model = DevResNet(input_tensor=inputs, include_top=False)

# Add fully connected final layer without classification head
x = layers.GlobalAveragePooling3D()(model.output)
x = layers.BatchNormalization()(x)  
x = layers.Dense(512)(x)   
x = tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))(x)

model = keras.Model(inputs=inputs, outputs=x)

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.0001),
    loss=tfa.losses.TripletSemiHardLoss()
)

callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath=f'{model_save_dir}/{model_save_name}.h5',
        save_best_only=True,
        save_weights_only=True
    )
]

history = model.fit(
    train_data,
    epochs=20, 
    callbacks=callbacks,
    validation_data=val_data)        


# Evaluate on test data and visualise using UMAP

Here we evaluate the trained model on the testing data before performing dimesionality reduction using UMAP for 2D visualisation.

In [None]:
# Evaluate the network on unseen testing data
results = model.predict(test_data)

In [None]:
from sklearn.preprocessing import StandardScaler
import umap

reducer = umap.UMAP()

results_scaled = StandardScaler().fit_transform(results)
trans = reducer.fit(results_scaled)
embedding = trans.transform(results_scaled).
embedding.shape

In [None]:
import matplotlib.pyplot as plt
import matplotlib
plt.style.use('default')

cmap = matplotlib.colormaps.get_cmap('plasma')
 
fig, ax1 = plt.subplots(dpi=150, figsize=(9,9))

for i in sorted(pd.unique(test_labels)):
    ax1.plot(-2.5, 0, 'o', markersize=2, label=act_events[i])

for i,e in enumerate(embedding):
    ax1.plot(e[0], e[1], 'o', markersize=2, color=f'C{test_labels[i]}')

ax1.legend(title='Developmental event:')
plt.show()