# Training an Autoencoder for Image Segmentation

Organise data directories containing training data.

In [None]:
import os
import sys
import h5py  # !pip install pyyaml h5py
import tensorflow as tf
from tensorflow import keras

from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

# Automatically reload imported programmes
%load_ext autoreload
%autoreload 2

# Locate data
data_file = 'tomograms2D/tf1'  # No leading/trailing `/`

# Directories (ammend as necessary)
root_dir = '/content/gdrive/MyDrive/IDSAI/PROOF/filament-segmentation'
os.chdir(root_dir)  # Move to root_dir
sys.path.insert(0, root_dir)

# Add data to root directory and locate JSON file
data_dir = os.path.join(root_dir, 'data/' + data_file)
image_path = os.path.join(data_dir, 'png-original/*.png')
masks_path = os.path.join(data_dir, 'png-masks/semantic')

# New training and validation files
train_dir = os.path.join(root_dir, 'databases/' + data_file + '/train')
valid_dir = os.path.join(root_dir, 'databases/' + data_file + '/valid')

# Checkpoints
checkpoint_dir = os.path.join(root_dir, 'checkpoints')
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
ckpt_path = os.path.join(checkpoint_dir, 'cp-{epoch:04d}.h5')

Choose network parameters.

In [None]:
lr = 0.0001
batch_size = 10
num_epochs = 1000
num_internal_layers = 3

Assert GPU and RAM capability.

In [None]:
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
    raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [None]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

## Data Loading

In [None]:
from loader import augment_data, get_data

print('\nLoading data...')
if not os.path.exists(train_dir) and not os.path.exists(valid_dir):
    train_imgs, train_msks, valid_imgs, valid_msks, _, _ = \
        get_data(path_train_imgs=image_path,
                 path_train_msks=masks_path,
                 path_valid_imgs='',
                 path_valid_msks='',
                 train_frac=0.8,
                 valid_frac=0.1,
                 image_size=[256, 256],
                 num_images_per_original=20,
                 )
    train_set, valid_set = augment_data(
        train_imgs, train_msks, valid_imgs, valid_msks, batch_size,
    )

    

    tf.data.experimental.save(train_set, train_dir)
    tf.data.experimental.save(valid_set, valid_dir)
    print('Data processed, loaded and saved.')
else:
    train_set = tf.data.experimental.load(train_dir)
    valid_set = tf.data.experimental.load(valid_dir)
    print('Data loaded from file.')
print('Training set length: ', len(train_set))
print('Validation set length: ', len(valid_set))

## Model instantiation

In [None]:
class Autoencoder(keras.Model):
    """Defines the encoding-decoding paradigm."""
    
    def __init__(self, autoencoder_lr: float=0.001):
        super(Autoencoder, self).__init__()
        self.encoder_tup = (self.encoding_conv(1, 256, 1),
                            keras.layers.BatchNormalization(),
                            self.encoding_conv(256, 128, 2),
                            keras.layers.BatchNormalization(),
                            self.encoding_conv(128, 64, 2),
                            keras.layers.BatchNormalization(),
                            self.encoding_conv(64, 32, 2),
                            keras.layers.BatchNormalization(),
                            self.encoding_conv(32, 8, 2),
                            keras.layers.BatchNormalization(),
                            )
        self.decoder_tup = (self.decoding_conv(8, 8, 2),
                            keras.layers.BatchNormalization(),
                            self.decoding_conv(8, 16, 2),
                            keras.layers.BatchNormalization(),
                            self.decoding_conv(16, 32, 2),
                            keras.layers.BatchNormalization(),
                            self.decoding_conv(32, 64, 2),
                            keras.layers.BatchNormalization(),
                            self.decoding_conv(64, 128, 2),
                            keras.layers.BatchNormalization(),
                            self.decoding_conv(128, 256, 2),
                            keras.layers.BatchNormalization(),
                            self.decoding_conv(256, 1, 1, activation='sigmoid'),
                            keras.layers.BatchNormalization(),
                            )

        self.lr = keras.optimizers.schedules.ExponentialDecay(
            autoencoder_lr, decay_steps=1000, decay_rate=0.75, staircase=True
        )
        self.optimiser = keras.optimizers.Adam(learning_rate=self.lr)

    def encoding_conv(self, in_feats, out_feats, stride):
        """Define bespoke encoding convolution."""
        return keras.layers.Conv2D(filters=out_feats,
                                  kernel_size=(5, 5),
                                  input_shape=(None, None, in_feats),
                                  strides=(stride, stride),
                                  padding='same',
                                  data_format='channels_last',
                                  activation='relu',
                                  )

    def decoding_conv(self, in_feats, out_feats, stride, activation='relu'):
        """Define bespoke decoding convolution."""
        return keras.layers.Conv2DTranspose(filters=out_feats,
                                            kernel_size=(5, 5),
                                            input_shape=(None, None, in_feats),
                                            strides=(stride, stride),
                                            padding='same',
                                            data_format='channels_last',
                                            activation=activation,
                                            )

    def call_encoder(self, input, training=False):
        state = input
        for i in range(len(self.encoder_tup)):
            #if i % 2 == 0: print('in '+str(i)+' enc_state: ', state.shape)
            state = self.encoder_tup[i](state, training=training)
            #if i % 2 == 0: print('out '+str(i)+' enc_state: ', state.shape)
        return state

    def call_decoder(self, input, training=False):
        state = input
        for i in range(len(self.decoder_tup)):
            #if i % 2 == 0: print('in '+str(i)+' dec_state: ', state.shape)
            state = self.decoder_tup[i](state, training=training)
            #if i % 2 == 0: print('out '+str(i)+' dec_state: ', state.shape)
        return state


    def call_both(self, input, training=False):
        encoded = self.call_encoder(input, training=training)
        decoded = self.call_decoder(encoded, training=training)
        return decoded

    def optimise_autoencoder(self, x, y, loss_func):
        with tf.GradientTape() as g:
            enc = self.call_encoder(x, training=True)
            x_predict = self.call_decoder(enc, training=True)
            loss = tf.math.reduce_mean(loss_func(y, x_predict))
        self.optimiser.minimize(
            loss, [layer.trainable_weights for layer in
                   self.encoder_tup + self.decoder_tup], tape=g
                   )
        return loss


In [None]:
# Instantiate model
model = Autoencoder(autoencoder_lr=lr)

@tf.function
def train_step(data):
    images, masks = data
    return model.optimise_autoencoder(images, masks, keras.losses.MSE)

@tf.function
def test_step(data):
    images, masks = data
    evals = model.call_decoder(model.call_encoder(images))
    loss = tf.math.reduce_mean(keras.losses.MSE(evals, masks), axis = 1)
    loss = tf.math.reduce_mean(loss, axis = 1)
    print ('evals.shape: ', evals.shape)
    print ('images.shape: ', images.shape)
    print ('loss: ', loss)
    return evals, loss, images, masks


In [None]:
# Plot model architechture
tf.keras.utils.plot_model(model, to_file='graphics/architecture_ae.png')

## Iterate training

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

# Iterate training session
losses = []    
eval_x = []
eval_y = []
for epoch in range(num_epochs):
    print ('Epoch: ', epoch)

    for data in train_set:
        losses.extend(np.array(train_step(data)).flatten())
        end_data = data

    if epoch % 10 == 0:
        model.save_weights(ckpt_path.format(epoch=epoch),
                           save_format='h5',
                           )
        eval_x.append(len(losses))
        evls_train, _, imgs_train, msks_train = test_step(end_data)
        accum_valid_loss = []
        for x_valid in valid_set:
            evls_valid, valid_loss, imgs_valid, msks_valid = test_step(x_valid)
            valid_loss = np.array(valid_loss).flatten()
            accum_valid_loss.extend(valid_loss)
        valid_loss = np.mean(accum_valid_loss)
        eval_y.append(valid_loss)
        print('Checkpoint at epoch ', epoch)
        print('losses[-1]: ', losses[-1])
        print('\n')

        # Graphical Output
        for i in range(min(batch_size, 1)):

            print('\nTraining Results:')
            print ('np.max(evals_train[i, :, :, 0]): ',
                   np.max(evls_train[i, :, :, 0]),
                   )
            print ('np.min(evals_train[i, :, :, 0]): ',
                   np.min(evls_train[i, :, :, 0]),
                   )
            fig, ax = plt.subplots(1, 3, constrained_layout=True)
            ax[0].imshow(imgs_train[i, :, :, 0])
            ax[0].set_title('Input')
            ax[1].imshow(msks_train[i, :, :, 0])
            ax[1].set_title('Mask')
            ax[2].imshow(evls_train[i, :, :, 0])
            ax[2].set_title('Prediction')
            #fig.suptitle('Training Set Results', fontsize=16)
            plt.show()
            fig.savefig(graphics_dir +
                        '/training_output_bnum{:02d}_epoch{:04d}'
                        .format(i, epoch)
                        )

            print('\nValidation Results:')
            print ('np.max(evls_valid[i, :, :, 0]): ',
                   np.max(evls_valid[i, :, :, 0]),
                   )
            print ('np.min(evls_valid[i, :, :, 0]): ',
                   np.min(evls_valid[i, :, :, 0]),
                   )
            fig, ax = plt.subplots(1, 3, constrained_layout=True)
            ax[0].imshow(imgs_valid[i, :, :, 0])
            ax[0].set_title('Input')
            ax[1].imshow(msks_valid[i, :, :, 0])
            ax[1].set_title('Mask')
            ax[2].imshow(evls_valid[i, :, :, 0])
            ax[2].set_title('Prediction')
            #fig.suptitle('Validation Results', fontsize=16)
            plt.show()
            fig.savefig(graphics_dir +
                        '/training_output_bnum{:02d}_epoch{:04d}'
                        .format(i, epoch)
                        )
    
fig, ax = plt.subplots(1, 1)
ax.plot(losses, color='g', label='Training loss')
ax.plot(eval_x, eval_y, color='r', label='Validation loss')
ax.legend()
plt.show()
fig.savefig(graphics_dir + '/loss_record')
