# Training an Autoencoder for Image Segmentation

Organise data directories containing training data.

In [1]:
import os
import sys
import h5py  # !pip install pyyaml h5py
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt

from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

# Automatically reload imported programmes
%load_ext autoreload
%autoreload 2

# Locate data
data_file = 'tomograms2D/all'  # No leading/trailing `/`
exp_name = 'all-2D-patch-unet'

# Directories (ammend as necessary)
root_dir = '/content/gdrive/MyDrive/IDSAI/PROOF/filament-segmentation'
os.chdir(root_dir)  # Move to root_dir
sys.path.insert(0, root_dir)

# Add data to root directory and locate JSON file
data_dir = os.path.join(root_dir, 'data/' + data_file)
image_path = os.path.join(data_dir, 'png-original')
masks_path = os.path.join(data_dir, 'png-masks/semantic/*.png')

# New training and validation files
train_dir = os.path.join(root_dir, 'data/databases/' + exp_name + '/train')
valid_dir = os.path.join(root_dir, 'data/databases/' + exp_name + '/valid')

# Checkpoints
checkpoint_dir = os.path.join(root_dir, 'checkpoints/' + exp_name)
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
checkpoint_path = os.path.join(checkpoint_dir, 'cp-{epoch:04d}.h5')
best_weights_path = os.path.join(checkpoint_dir, 'unet-best-weights')

# Figure Outputs
fig_dir = os.path.join(root_dir, 'outputs/unet-train-' + exp_name)
os.makedirs(fig_dir, exist_ok=True)

Mounted at /content/gdrive


Assert GPU/TPU and RAM capability.

In [2]:
%%script false
# GPU info
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
   print(gpu_info)

In [3]:
%%script false
# TPU initialisation for tensorflow 2.X
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))

In [4]:
%%script false
## RAM availability
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

## Load data and model

In [5]:
batch_size = 10

print('\nLoading data...')
if not os.path.exists(train_dir) and not os.path.exists(valid_dir):

    from loader import augment_data, get_data
    train_imgs, train_msks, valid_imgs, valid_msks, _, _ = \
        get_data(path_train_imgs=image_path,
                 path_train_msks=masks_path,
                 path_valid_imgs='',
                 path_valid_msks='',
                 train_frac=0.8,
                 valid_frac=0.1,
                 image_size=[256, 256],
                 num_images_per_original=20,
                 num_duplicates_before_augmenting=1,
                 )
    train_set, valid_set = augment_data(
        train_imgs, train_msks, valid_imgs, valid_msks, batch_size, one_hot=True
    )


    tf.data.experimental.save(train_set, train_dir)
    tf.data.experimental.save(valid_set, valid_dir)
    print('Data processed, loaded and saved.')
else:
    train_set = tf.data.experimental.load(train_dir)
    valid_set = tf.data.experimental.load(valid_dir)
    print('Data loaded from file.')
print('Training set length: ', len(train_set))
print('Validation set length: ', len(valid_set))


Loading data...
Data loaded from file.
Training set length:  298
Validation set length:  38


## Iterate training

In [6]:
lr = 0.0001
num_epochs = 1000
batch_size = 10

In [None]:
from models import get_unet_model

# Instantiate model
model = get_unet_model((256, 256),
                       num_classes=2,
                       num_colour_channels=1,
                       )

# Optimiser
opt = keras.optimizers.Adam(learning_rate=lr)

# Compile model
model.compile(loss='binary_crossentropy',
              optimizer=opt,
              metrics=['mean_squared_error', 'mean_absolute_error'],
              )

# Checkpointing
cps = [keras.callbacks.ModelCheckpoint(best_weights_path, save_best_only=True)]

# Train the model, validating at the end of each epoch.
history = model.fit(
    train_set, epochs=num_epochs, validation_data=valid_set, callbacks=cps
)

Epoch 1/1000


  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 2/1000


  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 3/1000


  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 4/1000
Epoch 5/1000
Epoch 6/1000


  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 7/1000


  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 8/1000
Epoch 9/1000


  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 10/1000


  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Epoch 68/1000
Epoch 69/1000
Epoch 70/1000
Epoch 71/1000
Epoch 72/1000
Epoch 73/1000
Epoch 74/1000
Epoch 75/1000
Epoch 76/1000
Epoch 77/1000
Epoch 78/1000
Epoch 79/1000
Epoch 80/1000
Epoch 81/1000
Epoch 

Analysis of training sucess.

In [None]:
# Plot history: Model loss
plt.plot(
    history.history['loss'], label='Model loss (training data)'
)

plt.plot(
    history.history['val_loss'], label='Model loss (validation data)'
)

plt.title('Model loss for U-Net training.')
plt.ylabel('Model loss value')
plt.xlabel('No. epoch')
plt.legend(loc="upper left")
plt.show()

In [None]:
# Plot history: MSE
plt.plot(
    history.history['mean_squared_error'], label='MSE (training data)'

)

plt.plot(
    history.history['val_mean_squared_error'], label='MSE (validation data)'
)

plt.title('MSE for U-Net training.')
plt.ylabel('MSE value')
plt.xlabel('No. epoch')
plt.legend(loc="upper left")
plt.show()

In [None]:
# Plot history: MAE
plt.plot(
    history.history['mean_absolute_error'], label='MAE (training data)'

)

plt.plot(
    history.history['val_mean_absolute_error'], label='MAE (validation data)'
)

plt.title('MAE for U-Net training.')
plt.ylabel('MAE value')
plt.xlabel('No. epoch')
plt.legend(loc="upper left")
plt.show()