In [None]:
import tensorflow as tf
import os
import random
import numpy as np
from tqdm import tqdm 
from skimage.io import imread, imshow
from skimage.transform import resize
import matplotlib.pyplot as plt
from timeit import default_timer as timer
import skimage.measure
from keras.preprocessing import image

from model_definition import unet

from prediction import predict_image

from datetime import datetime


### Input image for prediction

* This assumes that an image is taken once a day 
* It is saved in a new folder titled DD-MM-YY
* The image is called 'image.jpg'

The image is read, resized, and formatted as required by the model

We then show the input image, the non-thresholded image, and the thresholded image

And finally make a prediction of cell count based on the thresholded image

In [None]:
IMG_WIDTH = 128
IMG_HEIGHT = 128
IMG_CHANNELS = 3

input_path = 'images/'

# load all folder in the image folder
# the folders are named by date dd-mm-yy
# if the folder name is not equal to todays ate, remove from list
input_ids = next (os.walk(input_path))[1]
for l in input_ids:
    if l != datetime.today().strftime('%Y-%m-%d'):
        input_ids.remove(l)

# format model input based on folder list - should only be 1
X_input = np.zeros((1, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)

sizes_test = []

# resize input image, assuming image is called 'image.jpg'
for n, id_ in tqdm(enumerate(input_ids), total=len(input_ids)): 
    path = input_path + id_
    print(path)
    img = imread(path + '/image.jpg')[:,:,:IMG_CHANNELS]
    sizes_test.append([img.shape[0], img.shape[1]])
    img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
    X_input[n] = img

end = timer()

# define model parameters
kernel_size = 8

model = unet(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS, kernel_size)

#Load saved UNET model
unet_model_name = 'checkpoint_unet.h5'
checkpoint_filepath = unet_model_name
model.load_weights(checkpoint_filepath);

# prediction 
preds_input = model.predict(X_input, verbose=1)
preds_input_t = (preds_input > 0.5).astype(np.uint8)

ix = 0
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(10,10))

# show input image
ax[0].set_title("Input")
ax[0].imshow(X_input[0])

# show non-thresholded image
ax[1].set_title("Predicted without Threshold")
ax[1].imshow(np.squeeze(preds_input[0]), cmap='gray')

# show thresholded image
ax[2].set_title("Predicted with Threshold")
ax[2].imshow(np.squeeze(preds_input_t[0]), cmap='gray')

for a in ax:
  a.axis("off")

plt.tight_layout()
plt.show()

# predict cell count and display
limg = skimage.measure.label(preds_input_t[ix], connectivity=2, return_num=True)
print("Cell count: ", np.max(limg[0]))

#TODO
# generate report of prediction data and results

### Model training

In [None]:
%reload_ext tensorboard
%tensorboard --logdir logs

In [None]:
# Define common variables:
IMG_WIDTH = 128
IMG_HEIGHT = 128
IMG_CHANNELS = 3

TRAIN_PATH = 'train/'
TEST_PATH = 'test/'

train_ids = next(os.walk(TRAIN_PATH))[1]
test_ids = next(os.walk(TEST_PATH))[1]

In [None]:
# Define X train and Y train tensors:
X_train = np.zeros((len(train_ids), IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
Y_train = np.zeros((len(train_ids), IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.bool_)

In [None]:
for n, id_ in tqdm(enumerate(train_ids), total=len(train_ids)):   
    path = TRAIN_PATH + id_
    img = imread(path + '/images/' + id_ + '.png')

In [None]:
# Data Cleaning:

start = timer()

for n, id_ in tqdm(enumerate(train_ids), total=len(train_ids)):   
    path = TRAIN_PATH + id_
    img = imread(path + '/images/' + id_ + '.png')[:,:,:IMG_CHANNELS]  
    img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
    X_train[n] = img  # Fill empty X_train with values from img
    mask = np.zeros((IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.bool_)
    for mask_file in next(os.walk(path + '/masks/'))[2]:
        mask_ = imread(path + '/masks/' + mask_file)
        mask_ = np.expand_dims(resize(mask_, (IMG_HEIGHT, IMG_WIDTH), mode='constant',  
                                      preserve_range=True), axis=-1)
        mask = np.maximum(mask, mask_)  
            
    Y_train[n] = mask   

In [None]:
# Test images:
X_test = np.zeros((len(test_ids), IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)

In [None]:
# Data Cleaning:

sizes_test = []

for n, id_ in tqdm(enumerate(test_ids), total=len(test_ids)):
    path = TEST_PATH + id_
    img = imread(path + '/images/' + id_ + '.png')[:,:,:IMG_CHANNELS]
    sizes_test.append([img.shape[0], img.shape[1]])
    img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
    X_test[n] = img

end = timer()

In [None]:
# Check tensor shapes:
X_train = tf.random.shuffle(X_train, seed=101).numpy()
Y_train = tf.random.shuffle(Y_train, seed=101).numpy()
print(X_train.shape)
print(Y_train.shape)
print(X_test.shape)

In [None]:
# Display random x_train and y_train image:

l=1
while l <= 5:
    ix = random.randint(0, len(train_ids))

    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(8,8))

    ax[0].set_title("Input")
    ax[0].imshow(X_train[ix])

    ax[1].set_title("Ground Truth")
    ax[1].imshow(np.squeeze(Y_train[ix]))

    for a in ax:
      a.axis("off")

    plt.tight_layout()
    
    l+=1
plt.show()

In [None]:
kernel_size = 8

model = unet(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS, kernel_size)

In [None]:
callbacks = [
    #tf.keras.callbacks.EarlyStopping(patience=5, monitor='val_loss'),
    tf.keras.callbacks.EarlyStopping(patience=10, monitor='accuracy'),
    tf.keras.callbacks.TensorBoard(log_dir='logs', histogram_freq=1)
]

mcp_save = tf.keras.callbacks.ModelCheckpoint('.mdl_wts.hdf5', save_best_only=True, monitor='val_loss', mode='min')
reduce_lr_loss = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=7, verbose=1, epsilon=1e-4, mode='min')


start = timer()

results = model.fit(X_train, Y_train, validation_split=0.2, batch_size=32, epochs=200, callbacks=[callbacks,mcp_save,reduce_lr_loss])
# results = model.fit(X_train[:100], Y_train[:100], epochs=250, callbacks=callbacks)

end = timer()
print("\nTime taken for Model to Run: ", end - start, "seconds\n") 

In [None]:
from keras.utils.vis_utils import plot_model
plot_model(model, to_file='model_plot.png', show_shapes=True, show_layer_names=True)

In [None]:
# Plot the training and validation accuracy and loss at each epoch
loss = results.history['loss']
val_loss = results.history['val_loss']
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, 'y', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

acc = results.history['accuracy']
val_acc = results.history['val_accuracy']
plt.plot(epochs, acc, 'y', label='Training acc')
plt.plot(epochs, val_acc, 'r', label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

In [None]:
preds_train = model.predict(X_train, verbose=1)
preds_test = model.predict(X_test, verbose=1)
 
preds_train_t = (preds_train > 0.5).astype(np.uint8)
preds_test_t = (preds_test > 0.5).astype(np.uint8)

In [None]:
# Save trained model
fcn8_model_name = 'checkpoint_fcn8.h5'
unet_model_name = 'checkpoint_unet.h5'
model.save(unet_model_name)

In [None]:
#Load saved model
checkpoint_filepath = 'unet_model_name'
%model.load_weights(checkpoint_filepath)

In [None]:
# Plot model
tf.keras.utils.plot_model(
     model, to_file='model.png', show_shapes=False, show_layer_names=True,
     rankdir='TB', expand_nested=False, dpi=96
 )

In [None]:
predict_image()