In [None]:
import numpy as np
import matplotlib.pyplot as plt
 
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [None]:
fpath = '/content/drive/MyDrive/Dataset/syn_occ_sh.npy'

dataset = np.load(fpath)
print(dataset.shape)

(872, 5, 80, 30, 3)


In [None]:
indexes = np.arange(dataset.shape[0])
np.random.shuffle(indexes)
train_index = indexes[: int(0.90 * dataset.shape[0])]
val_index = indexes[int(0.90 * dataset.shape[0]) :]
len(train_index), len(val_index)

(784, 88)

In [None]:
train_dataset = dataset[train_index]
val_dataset = dataset[val_index]

In [None]:
# Normalize the data to the 0-1 range.
train_dataset = train_dataset / 255.0
val_dataset = val_dataset / 255.0

In [None]:
def create_shifted_frames(data):
    x = data[:, 0 : data.shape[1] - 2, :, :]
    y = data[:, -1, :, :]
    return x, y


x_train, y_train = create_shifted_frames(train_dataset)
x_val, y_val = create_shifted_frames(val_dataset)

In [None]:
print("Training Dataset Shapes: " + str(x_train.shape) + ", " + str(y_train.shape))
print("Validation Dataset Shapes: " + str(x_val.shape) + ", " + str(y_val.shape))

Training Dataset Shapes: (784, 3, 80, 30, 3), (784, 80, 30, 3)
Validation Dataset Shapes: (88, 3, 80, 30, 3), (88, 80, 30, 3)


In [None]:
inp = layers.Input(shape=(None, *x_train.shape[2:]))


x = layers.ConvLSTM2D(
    filters=128,
    kernel_size=(5, 5),
    padding="same",
    return_sequences=True,
    activation="relu",
)(inp)
x = layers.BatchNormalization()(x)
x = layers.ConvLSTM2D(
    filters=128,
    kernel_size=(3, 3),
    padding="same",
    return_sequences=True,
    activation="relu",
)(x)
x = layers.BatchNormalization()(x)
x = layers.ConvLSTM2D(
    filters=64,
    kernel_size=(3, 3),
    padding="same",
    return_sequences=True,
    activation="relu",
)(x)
x = layers.BatchNormalization()(x)
x = layers.ConvLSTM2D(
    filters=64,
    kernel_size=(3, 3),
    padding="same",
    return_sequences=True,
    activation="relu",
)(x)
x = layers.BatchNormalization()(x)
x = layers.ConvLSTM2D(
    filters=32,
    kernel_size=(3, 3),
    padding="same",
    return_sequences=True,
    activation="relu",
)(x)
x = layers.BatchNormalization()(x)
x = layers.ConvLSTM2D(
    filters=32,
    kernel_size=(3, 3),
    padding="same",
    return_sequences=True,
    activation="relu",
)(x)
x = layers.BatchNormalization()(x)
x = layers.ConvLSTM2D(
    filters=32,
    kernel_size=(3, 3),
    padding="same",
    return_sequences=True,
    activation="relu",
)(x)
x = layers.BatchNormalization()(x)
x = layers.ConvLSTM2D(
    filters=16,
    kernel_size=(3, 3),
    padding="same",
    return_sequences=True,
    activation="relu",
)(x)
x = layers.BatchNormalization()(x)
x = layers.ConvLSTM2D(
    filters=3,
    kernel_size=(1, 1),
    padding="same",
    return_sequences=False,
    activation="relu",
)(x)
x = layers.Conv2D(
    filters=3, kernel_size=(3, 3), activation="sigmoid", padding="same"
)(x)


model = keras.models.Model(inp, x)
model.compile(
    loss=keras.losses.binary_crossentropy, optimizer=keras.optimizers.Adam(),
)

In [None]:
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, None, 80, 30, 3)] 0         
_________________________________________________________________
conv_lst_m2d (ConvLSTM2D)    (None, None, 80, 30, 128) 1677312   
_________________________________________________________________
batch_normalization (BatchNo (None, None, 80, 30, 128) 512       
_________________________________________________________________
conv_lst_m2d_1 (ConvLSTM2D)  (None, None, 80, 30, 128) 1180160   
_________________________________________________________________
batch_normalization_1 (Batch (None, None, 80, 30, 128) 512       
_________________________________________________________________
conv_lst_m2d_2 (ConvLSTM2D)  (None, None, 80, 30, 64)  442624    
_________________________________________________________________
batch_normalization_2 (Batch (None, None, 80, 30, 64)  256   

In [None]:
early_stopping = keras.callbacks.EarlyStopping(monitor="val_loss", patience=10)
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=5)

# hyperparameters.
epochs = 150
batch_size = 64

model.fit(
    x_train,
    y_train,
    batch_size=batch_size,
    epochs=epochs,
    validation_data=(x_val, y_val),
    callbacks=[early_stopping, reduce_lr],
)

Epoch 1/150
Epoch 2/150
Epoch 3/150
Epoch 4/150
Epoch 5/150
Epoch 6/150
Epoch 7/150
Epoch 8/150
Epoch 9/150
Epoch 10/150
Epoch 11/150
Epoch 12/150
Epoch 13/150
Epoch 14/150
Epoch 15/150
Epoch 16/150
Epoch 17/150
Epoch 18/150
Epoch 19/150
Epoch 20/150
Epoch 21/150
Epoch 22/150
Epoch 23/150
Epoch 24/150
Epoch 25/150
Epoch 26/150
Epoch 27/150
Epoch 28/150
Epoch 29/150
Epoch 30/150
Epoch 31/150
Epoch 32/150
Epoch 33/150
Epoch 34/150
Epoch 35/150
Epoch 36/150
Epoch 37/150
Epoch 38/150
Epoch 39/150
Epoch 40/150
Epoch 41/150
Epoch 42/150
Epoch 43/150
Epoch 44/150
Epoch 45/150
Epoch 46/150
Epoch 47/150
Epoch 48/150
Epoch 49/150
Epoch 50/150
Epoch 51/150
Epoch 52/150
Epoch 53/150
Epoch 54/150
Epoch 55/150
Epoch 56/150
Epoch 57/150
Epoch 58/150
Epoch 59/150
Epoch 60/150
Epoch 61/150
Epoch 62/150
Epoch 63/150
Epoch 64/150
Epoch 65/150
Epoch 66/150
Epoch 67/150
Epoch 68/150
Epoch 69/150
Epoch 70/150
Epoch 71/150
Epoch 72/150
Epoch 73/150
Epoch 74/150
Epoch 75/150
Epoch 76/150
Epoch 77/150
Epoch 78

<tensorflow.python.keras.callbacks.History at 0x7f0b503b5d90>

In [None]:
from keras.models import load_model
m=load_model('/content/drive/MyDrive/Dataset/model_syn_occ1.h5')

In [None]:
for _ in range(10):
  example = val_dataset[np.random.choice(range(len(val_dataset)), size=1)[0]]

  frames = example[:3, ...]
  original_frames = example[-1, ...]
  new_prediction = m.predict(np.expand_dims(frames, axis=0))
  s = new_prediction.reshape(80, 30, 3)
  plt.figure(figsize=(20, 5))
  plt.subplot(1, 5, 1),plt.imshow(frames[0]),plt.title("Frame 1")
  plt.subplot(1, 5, 2),plt.imshow(frames[1]),plt.title("Frame 2")
  plt.subplot(1, 5, 3),plt.imshow(original_frames),plt.title("Frame 3")
  plt.subplot(1, 5, 4),plt.imshow(frames[2]),plt.title("Occluded Frame 3")
  plt.subplot(1, 5, 5),plt.imshow(s),plt.title("LSTM Model Output")
  plt.show()