# Training a U-Net for Image Segmentation

Organise data directories containing training data.

In [1]:
import os
import sys
import h5py  # !pip install pyyaml h5py
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt

from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)


# Automatically reload imported programmes
%load_ext autoreload
%autoreload 2


# Database choice
batch_size = 10
num_patches = 1  # Subsample taining data
num_duplicates = 30  # Repeats of subsamples to augment
apply_augmentation = True
shuffle_on = True


# Locate data
data_file = 'tomograms2D/all'  # No leading/trailing `/`
database_name = 'all-2D-augment30'
exp_name = 'unet-' + database_name


# Directories (ammend as necessary)
root_dir = '/content/gdrive/MyDrive/IDSAI/PROOF/filament-segmentation'
os.chdir(root_dir)  # Move to root_dir
sys.path.insert(0, root_dir)


# Add data to root directory and locate JSON file
data_dir = os.path.join(root_dir, 'data/' + data_file)
image_path = os.path.join(data_dir, 'png-original')
masks_path = os.path.join(data_dir, 'png-masks/semantic/*.png')


# New training and validation files
train_dir = os.path.join(root_dir, 'data/databases/' + database_name + '/train')
valid_dir = os.path.join(root_dir, 'data/databases/' + database_name + '/valid')


# Checkpoints
checkpoint_dir = os.path.join(root_dir, 'checkpoints/' + exp_name)
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
checkpoint_path = os.path.join(checkpoint_dir, 'cp-{epoch:04d}.h5')
best_weights_path = os.path.join(checkpoint_dir, 'unet-best-weights')


# Figure Outputs
fig_dir = os.path.join(root_dir, 'outputs/unet-train-' + database_name)
os.makedirs(fig_dir, exist_ok=True)

Mounted at /content/gdrive


Assert GPU/TPU and RAM capability.

In [2]:
%%script false
# GPU info
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
   print(gpu_info)

In [3]:
%%script false
# TPU initialisation for tensorflow 2.X
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))

In [4]:
%%script false
## RAM availability
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

## Load data

In [5]:
print('\nLoading data...')
if not os.path.exists(train_dir) and not os.path.exists(valid_dir):

    from loader import augment_data, get_data
    train_imgs, train_msks, valid_imgs, valid_msks, _, _ = \
        get_data(path_train_imgs=image_path,
                 path_train_msks=masks_path,
                 path_valid_imgs='',
                 path_valid_msks='',
                 train_frac=0.8,
                 valid_frac=0.1,
                 image_size=[256, 256],
                 num_patches_per_image=num_patches,
                 num_duplicates_per_image=num_duplicates,
                 )
        
    train_set, valid_set = augment_data(train_imgs,
                                        train_msks,
                                        valid_imgs,
                                        valid_msks,
                                        batch_size,
                                        one_hot=False,
                                        augment_on=apply_augmentation,
                                        shuffle_on=True,
                                        )

    tf.data.experimental.save(train_set, train_dir)
    tf.data.experimental.save(valid_set, valid_dir)
    print('Data processed, loaded and saved.')

else:

    train_set = tf.data.experimental.load(train_dir)
    valid_set = tf.data.experimental.load(valid_dir)
    print('Data loaded from file.')
    
print('Training set length: ', len(train_set))
print('Validation set length: ', len(valid_set))


Loading data...


 91%|█████████▏| 170/186 [11:53<01:07,  4.20s/it]


KeyboardInterrupt: ignored

## Load model

In [None]:
from models import get_unet_model

# Instantiate model
model = get_unet_model((256, 256),
                       num_classes=2,
                       num_colour_channels=1,
                       )

model.summary()

## Iterate training

In [None]:
unet_lr = 0.0001
num_epochs = 500
batch_size = 10

In [None]:

# Optimiser
lr = keras.optimizers.schedules.ExponentialDecay(
            unet_lr, decay_steps=1000, decay_rate=0.75, staircase=True
        )
opt = keras.optimizers.RMSprop(learning_rate=lr)

# Compile model
model.compile(loss='sparse_categorical_crossentropy',
              optimizer=opt,
              metrics=['mean_squared_error', 'mean_absolute_error'],
              )

# Checkpointing
cps = [keras.callbacks.ModelCheckpoint(best_weights_path, save_best_only=True)]

# Surpress `CustomMaskWarning`, see: stackoverflow.com/questions/68384466
import logging, os
logging.disable(logging.WARNING)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

def int_valued_masks(image, mask):
    return image, tf.cast(mask, tf.int16)

# Train the model, validating at the end of each epoch.
history = model.fit(train_set.map(int_valued_masks),
                    epochs=num_epochs,
                    validation_data=valid_set.map(int_valued_masks),
                    callbacks=cps,
                    )


Analysis of training sucess.

In [None]:
# Plot history: Model loss
plt.plot(
    history.history['loss'], label='Model loss (training data)'
)

plt.plot(
    history.history['val_loss'], label='Model loss (validation data)'
)

plt.title('Model loss for U-Net training.')
plt.ylabel('Model loss value')
plt.xlabel('No. epoch')
plt.legend(loc="upper left")
plt.show()

In [None]:
# Plot history: MSE
plt.plot(
    history.history['mean_squared_error'], label='MSE (training data)'

)

plt.plot(
    history.history['val_mean_squared_error'], label='MSE (validation data)'
)

plt.title('MSE for U-Net training.')
plt.ylabel('MSE value')
plt.xlabel('No. epoch')
plt.legend(loc="upper left")
plt.show()

In [None]:
# Plot history: MAE
plt.plot(
    history.history['mean_absolute_error'], label='MAE (training data)'

)

plt.plot(
    history.history['val_mean_absolute_error'], label='MAE (validation data)'
)

plt.title('MAE for U-Net training.')
plt.ylabel('MAE value')
plt.xlabel('No. epoch')
plt.legend(loc="upper left")
plt.show()