In [None]:
import keras
import PIL as pil
import os
import numpy as np
import random
import matplotlib.pyplot as plt

In [None]:
DATA_PATH = '/Users/gil-arnaudcoche/Documents/ijoutaku/data/UCSDped1/'
TRAIN_PATH = f'{DATA_PATH}Train/'
TEST_PATH = f'{DATA_PATH}Test/'

In [None]:
class DataGenerator(keras.utils.Sequence):
    
    _IMAGE_WIDTH = 128
    _IMAGE_HEIGHT = 88
    _TIF_EXTENSION = '.tif'
    _STRIDE_WINDOW = 5
    _BATCH_SIZE = 3
    _SEQUENCE_SIZE = 200

    def __init__(self, data_path):
        self._data_path = data_path
        self._sequences = list()
        self._batches = list()
        self._len = 0
        self.__load__()

    def __make_batches__(self, ):
        random.shuffle(self._sequences)
        strided = list()
        for sequence in self._sequences:
            image_files = [ f'{sequence}{image_file}' for image_file in sorted(os.listdir(sequence)) ]
            for stride in range(self._STRIDE_WINDOW):
                strided.append(image_files[stride::self._STRIDE_WINDOW])
        self._len = int(len(strided)/self._BATCH_SIZE)
        self._batches = [ strided[b:b+self._BATCH_SIZE] for b in range(self._len) ]

    def __load__(self, ):
        self._sequences = sorted([ f'{self._data_path}{data_folder}/' for data_folder in os.listdir(self._data_path) ])
        self.__make_batches__()

    def __len__(self, ):
        return self._len

    def __getitem__(self, index):
        X = np.zeros((self._BATCH_SIZE, int(self._SEQUENCE_SIZE/self._STRIDE_WINDOW), self._IMAGE_HEIGHT, self._IMAGE_WIDTH, 1), dtype=np.float16)
        batch = self._batches[index]
        for b, image_paths in enumerate(batch):
            for t, image_path in enumerate(image_paths):
                image = np.array(pil.Image.open(image_path).resize((self._IMAGE_WIDTH, self._IMAGE_HEIGHT)), dtype=np.float16)
                X[b, t, :, :, 0] = image
        return X, X

    def on_epoch_end(self, ):
        self.__make_batches__()

In [None]:
train_set = DataGenerator(TRAIN_PATH)

In [5]:
EPOCHS = 3

seq = keras.models.Sequential()
seq.add(keras.layers.TimeDistributed(keras.layers.Conv2D(128, (11, 11), strides=4, padding="same"), batch_input_shape=(DataGenerator._BATCH_SIZE, int(DataGenerator._SEQUENCE_SIZE/DataGenerator._STRIDE_WINDOW), DataGenerator._IMAGE_HEIGHT, DataGenerator._IMAGE_WIDTH, 1)))
seq.add(keras.layers.LayerNormalization())
# # # # #
seq.add(keras.layers.ConvLSTM2D(64, (3, 3), padding="same", return_sequences=True))
seq.add(keras.layers.LayerNormalization())
# seq.add(keras.layers.ConvLSTM2D(32, (3, 3), padding="same", return_sequences=True))
# seq.add(keras.layers.LayerNormalization())
# seq.add(keras.layers.ConvLSTM2D(64, (3, 3), padding="same", return_sequences=True))
# seq.add(keras.layers.LayerNormalization())
# # # # #
seq.add(keras.layers.TimeDistributed(keras.layers.Conv2DTranspose(128, (11, 11), strides=4, padding="same")))
seq.add(keras.layers.LayerNormalization())
seq.add(keras.layers.TimeDistributed(keras.layers.Conv2D(1, (11, 11), activation="sigmoid", padding="same")))
print(seq.summary())
seq.compile(loss='mse', optimizer=keras.optimizers.legacy.Adam(learning_rate=1e-4, epsilon=1e-6))
seq.fit(train_set, batch_size=DataGenerator._BATCH_SIZE, epochs=EPOCHS, shuffle=False)
