# Titlt-Series Augmentation via ConvLSTM Neural Networks
Author: Xiaolei Chu @UCDavis

This notebook takes a brain mri projection tilt series (100 projections) and splits them into a 80-projection train set and 20-projection test set. The goal is to build a neural net work which takes a consecutive N-projection sequence and predict the next sequence. 

In [2]:
# Import everything
import tensorflow as tf
from skimage.transform import warp, AffineTransform
import matplotlib.pyplot as plt
from ipywidgets import interact, interactive, IntSlider, ToggleButtons
!pip install mrcfile
import mrcfile
import numpy as np


Collecting mrcfile
  Downloading mrcfile-1.3.0-py2.py3-none-any.whl (40 kB)
[?25l[K     |████████                        | 10 kB 24.1 MB/s eta 0:00:01[K     |████████████████                | 20 kB 29.3 MB/s eta 0:00:01[K     |████████████████████████        | 30 kB 12.3 MB/s eta 0:00:01[K     |████████████████████████████████| 40 kB 9.5 MB/s eta 0:00:01[K     |████████████████████████████████| 40 kB 24 kB/s 
Installing collected packages: mrcfile
Successfully installed mrcfile-1.3.0


## Define some helper functions

In [76]:
# Helper funtions
def interactVol(vol, title, axis = 2):
  if axis == 0:
    vol = np.transpose(vol, [1,2,0])
  def explore_3d(layer):
      plt.figure(figsize=(10,5))
      channel = 1
      plt.imshow(vol[:,:,layer], cmap = 'gray')
      plt.title(title, fontsize = 10)
      plt.axis(False)
  interact(explore_3d, layer=(0, vol.shape[2]-1))

def resizeVol(vol, resize_shape):
  resized = np.zeros(resize_shape)
  for i in range(vol.shape[2]):
    resized[:,:,i] = tf.image.resize(vol[:,:,i], resize_shape[:2]),
  return resized

def get_RandomAffine(img_shape, rotation = 0.01, translate = (0.01,0.01), shear = 0.01, scale = (1.02, 1.02)):
  rotation = (np.random.uniform()*2-1)*rotation
  translate = ((np.random.uniform()*2-1)*translate[0]*img_shape[1], (np.random.uniform()*2-1)*translate[1]*img_shape[0])
  shear = (np.random.uniform()*2-1)*shear
  scale = ((np.random.uniform()*2-1)*(scale[0]-1) + 1, (np.random.uniform()*2-1)*(scale[0]-1) + 1)
  tform = AffineTransform(scale = scale, translation = translate, shear = shear, rotation = rotation)
  return tform

def getWarpedTiltSeries(tilt_series, sequence_axis = 2, scaled_tform = None):
  if sequence_axis == 2:

    mis_aligned = np.zeros((tilt_series.shape[0],tilt_series.shape[1],1))
  else:
    mis_aligned = np.zeros((1,)+(tilt_series.shape[1],tilt_series.shape[2]))
  tform_array = np.zeros((6, tilt_series.shape[2]))
  for i in range(tilt_series.shape[sequence_axis]):
    temp_image = np.take(tilt_series, i, axis=sequence_axis)
    temp_tform = get_RandomAffine(img_shape = temp_image.shape)
    temp_tform = temp_tform.params
    if scaled_tform != None:
      temp_tform[0,0] = temp_tform[0,0] - 1
      temp_tform[1,1] = temp_tform[1,1] - 1
      temp_tform[0,2] = temp_tform[0,2]/0.02/tilt_series.shape[0]
      temp_tform[1,2] = temp_tform[1,2]/0.02/tilt_series.shape[0]
    temp_warped = warp(temp_image, temp_tform)
    temp_warped = np.expand_dims(temp_warped, sequence_axis)
    mis_aligned = np.concatenate([mis_aligned, temp_warped], axis=sequence_axis)
    tform_array[:,i] = np.reshape(temp_tform[:2,:], (6,))
  return np.take(mis_aligned, indices=range(1, mis_aligned.shape[sequence_axis]), axis=sequence_axis), tform_array

def warpback(mis_aligned, tform_array):
  warp_back = np.zeros(mis_aligned.shape)
  for i in range(mis_aligned.shape[2]):
      warp_back[:,:,i] = warp(mis_aligned[:,:,i], np.linalg.inv(tform_array[:,:,i]))
  return warp_back

def custom_mae(y_true, y_pred):
  mae = tf.math.reduce_mean(tf.math.abs(y_true[:2,:,:] - y_pred[:2,:,:]))
  return mae
def addback001(tform, shape):
  full_tform = np.zeros((3,3))
  full_tform[:2,:] = tform
  full_tform[2,:] = [0, 0, 1]
  return full_tform

## Read in the projection file

In [6]:
with mrcfile.open("brain_mri_proj.mrc") as projmrc:
  proj_volume = projmrc.data
projmrc.close()
proj_volume = np.array(proj_volume)

### Resize it to (128,128) just to save some memory

In [74]:
proj_volume_resized = tf.image.resize(proj_volume, [128,128]).numpy()
interactVol(proj_volume_resized)

interactive(children=(IntSlider(value=49, description='layer', max=99), Output()), _dom_classes=('widget-inter…

In [23]:
# Split training and test sets
train_proj = proj_volume_resized[:,:,:-20]
test_proj = proj_volume_resized[:,:,-20:]
train_proj.shape, test_proj.shape

((128, 128, 80), (128, 128, 20))

In [24]:
# Normalize the data
MAX_INTENSITY = np.max(proj_volume_resized)
train_proj = np.transpose(train_proj, [2,0,1]) / MAX_INTENSITY
test_proj = np.transpose(test_proj, [2,0,1]) / MAX_INTENSITY

13749.5

Next we will take the training set and make inputs and labels. Here a window_size = 7 and horizon = 1 is used. 

In [27]:
# Make sequence (X) and labels(Y) split from train tilt-series
WINDOW_SIZE = 7
HORIZON = 1

def make_sequence_label(vol, window_size, horizon):
  i = 0
  sequence_data = np.zeros((1,)+(window_size,vol.shape[1],vol.shape[2]))
  label_data = np.zeros((1,) + (vol.shape[1],vol.shape[2]))

  for _ in range(window_size, vol.shape[0]):
    sequence = np.expand_dims(vol[i:i+WINDOW_SIZE,:,:], axis=0)
    sequence_data = np.concatenate([sequence_data, sequence], axis=0)

    label = np.expand_dims(vol[i+WINDOW_SIZE,:,:], 0)
    label_data = np.concatenate([label_data, label], axis=0)
    i+=1
  return sequence_data[1:], label_data[1:]

In [28]:
train_sequence, train_label = make_sequence_label(train_proj, window_size=WINDOW_SIZE, horizon=HORIZON)
test_sequence, test_label = make_sequence_label(test_proj, window_size=WINDOW_SIZE, horizon=HORIZON)
train_sequence.shape, train_label.shape, test_sequence.shape, test_label.shape

((73, 7, 128, 128), (73, 128, 128), (13, 7, 128, 128), (13, 128, 128))

In [31]:
## Create a modified random warping function which takes sequence and label as tuples and returns warped sequence and unchanged label
def getWarpedTiltSeries_forTuple(tilt_series):
  mis_aligned = np.zeros((1,)+(tilt_series.shape[1],tilt_series.shape[2]))
  tform_array = np.zeros((6, tilt_series.shape[2]))
  for i in range(tilt_series.shape[0]):
    temp_image = np.take(tilt_series, i, axis=0)
    temp_tform = get_RandomAffine(img_shape = temp_image.shape)
    temp_tform = temp_tform.params
    temp_warped = warp(temp_image, temp_tform)
    temp_warped = np.expand_dims(temp_warped, 0)
    mis_aligned = np.concatenate([mis_aligned, temp_warped], axis=0)
  mis_aligned = np.take(mis_aligned, indices=range(1, mis_aligned.shape[0]), axis=0)
  return mis_aligned
def tf_function_warping(inputs):
  y = tf.numpy_function(getWarpedTiltSeries_forTuple,[inputs], tf.float64)
  return y

Here we create the tf.dataset

In [15]:
## Create performace dataset
# train_sequence_dataset = tf.data.Dataset.from_tensor_slices(train_sequence)
# train_sequence_dataset = train_sequence_dataset.map(tf_function_warping, num_parallel_calls=tf.data.AUTOTUNE)
# train_label_dataset = tf.data.Dataset.from_tensor_slices(train_label)
# train_dataset = tf.data.Dataset.zip((train_sequence_dataset, train_label_dataset))
# train_dataset = train_dataset.batch(1).prefetch(tf.data.AUTOTUNE)

train_dataset = tf.data.Dataset.from_tensor_slices((train_sequence, train_label))
train_dataset = train_dataset.batch(1).prefetch(tf.data.AUTOTUNE)
test_dataset = tf.data.Dataset.from_tensor_slices((test_sequence, test_label))
test_dataset = test_dataset.batch(1)

## Model building
Now we build a simple ConvLSTM2D model to predict the next projection from a sequence of 7 projections

In [59]:
from tensorflow.python.keras.layers.normalization.batch_normalization import BatchNormalization
# Model building
import tensorflow.keras.layers as layers

model_0 = tf.keras.Sequential([
                               layers.Input(shape=(7,128,128)),
                               layers.Reshape((7,128,128,1)),
                               layers.Bidirectional(
                               layers.ConvLSTM2D(filters=64,
                                        kernel_size=(3,3),
                                        data_format='channels_last',
                                        return_sequences= False,
                                        padding = 'same')),
                               layers.BatchNormalization(),
                               tf.keras.layers.Dense(64, activation='relu'),
                               tf.keras.layers.Dense(1, activation='relu')

], name="model_0_ConvLSTM2D")
model_0.summary()

Model: "model_0_ConvLSTM2D"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 reshape_5 (Reshape)         (None, 7, 128, 128, 1)    0         
                                                                 
 bidirectional_5 (Bidirectio  (None, 128, 128, 128)    300032    
 nal)                                                            
                                                                 
 batch_normalization_2 (Batc  (None, 128, 128, 128)    512       
 hNormalization)                                                 
                                                                 
 dense_10 (Dense)            (None, 128, 128, 64)      8256      
                                                                 
 dense_11 (Dense)            (None, 128, 128, 1)       65        
                                                                 
Total params: 308,865
Trainable params: 308,609


In [60]:
model_0.compile(loss = tf.keras.losses.mae,
        optimizer = tf.keras.optimizers.Adam(),
        metrics = tf.keras.metrics.MSE)
checkpoint_func = tf.keras.callbacks.ModelCheckpoint("model_3_as2_add_batchNorm/model_0.ckpt", monitor='val_loss', verbose=0, save_best_only=True, save_weights_only = "True")

history_0 = model_0.fit(train_dataset,
             validation_data = test_dataset,
             epochs = 100,
             verbose = 1,
             callbacks=[checkpoint_func,
                  tf.keras.callbacks.ReduceLROnPlateau(patience=20,
                                    factor = 0.2,
                                     monitor = "val_loss")])

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

In [63]:
model_0.load_weights('/content/model_3_as2_add_batchNorm/model_0.ckpt')
pred = model_0.predict(test_dataset)
pred.shape

(13, 128, 128, 1)

In [64]:
interactVol(test_label, axis=0)

interactVol(tf.squeeze(pred), axis=0)

interactive(children=(IntSlider(value=6, description='layer', max=12), Output()), _dom_classes=('widget-intera…

interactive(children=(IntSlider(value=6, description='layer', max=12), Output()), _dom_classes=('widget-intera…

In [71]:
## What about use the predction as input to predict further into later suquence
last_sequence = train_proj[-7:, :, :] #Take the last 7 sequences as the starting sequence, to predict the first proj in test set.
def continous_prediction(model, beginning_sequence, num_of_predictions):
  last_sequence = beginning_sequence
  init_shape = (1, beginning_sequence.shape[1], beginning_sequence.shape[2])
  preds_list = np.zeros(init_shape)
  for i in range(num_of_predictions):
    preds = model.predict(tf.expand_dims(last_sequence, 0))
    preds = np.reshape(tf.squeeze(preds), init_shape)
    preds_list = np.concatenate([preds_list, preds], axis=0)
    new_sequence = np.concatenate([last_sequence, preds], axis = 0)
    last_sequence = new_sequence[1:]
  return preds_list[1:]



In [72]:
prediction_20 = continous_prediction(model_0, last_sequence, 20)

In [73]:
interactVol(prediction_20, axis = 0)

interactive(children=(IntSlider(value=9, description='layer', max=19), Output()), _dom_classes=('widget-intera…