In [1]:
import tensorflow as tf

from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
from rnnconv import RnnConv

(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images / 255.0, test_images / 255.0

In [2]:
model = tf.keras.Sequential() 

for i in range(15):

    model.add(layers.Conv2D(filters=64, kernel_size=3, input_shape=(3, 32, 32), activation = "relu", strides = (2, 2), padding='same', data_format='channels_first'))


    model.add(layers.Lambda(lambda x: tf.expand_dims(x,axis=1)))
    model.add(layers.ConvLSTM2D(filters=256, kernel_size=3, strides=(2, 2),  activation = "relu", padding='same', data_format='channels_first'))

    model.add(layers.Lambda(lambda x: tf.expand_dims(x,axis=1)))
    model.add(layers.ConvLSTM2D(filters=512, kernel_size=3, strides=(2, 2),  activation = "relu", padding='same', data_format='channels_first'))

    model.add(layers.Lambda(lambda x: tf.expand_dims(x,axis=1)))
    model.add(layers.ConvLSTM2D(filters=512, kernel_size=3, strides=(2, 2),  activation = "relu", padding='same', data_format='channels_first'))


    model.add(layers.Conv2D(filters=32, kernel_size=1, activation = "relu", strides = (1,1), padding='same', data_format='channels_first'))


    model.add(layers.Conv2D(filters=512, kernel_size=1, activation = "relu", strides = (1,1), padding='same', data_format='channels_first'))

    model.add(layers.Lambda(lambda x: tf.expand_dims(x,axis=1)))
    model.add(layers.ConvLSTM2D(filters=512, kernel_size=2, strides=(1, 1),  activation = "relu", padding='same', data_format='channels_first'))
    model.add(layers.Lambda(lambda x:tf.nn.depth_to_space(x,2,data_format='NCHW')))

    model.add(layers.Lambda(lambda x: tf.expand_dims(x,axis=1)))
    model.add(layers.ConvLSTM2D(filters=512, kernel_size=3, strides=(1, 1),  activation = "relu", padding='same', data_format='channels_first'))
    model.add(layers.Lambda(lambda x:tf.nn.depth_to_space(x,2,data_format='NCHW')))

    model.add(layers.Lambda(lambda x: tf.expand_dims(x,axis=1)))
    model.add(layers.ConvLSTM2D(filters=256, kernel_size=3, strides=(1, 1),  activation = "relu", padding='same', data_format='channels_first'))
    model.add(layers.Lambda(lambda x:tf.nn.depth_to_space(x,2,data_format='NCHW')))

    model.add(layers.Lambda(lambda x: tf.expand_dims(x,axis=1)))
    model.add(layers.ConvLSTM2D(filters=128, kernel_size=3, strides=(1, 1),  activation = "relu", padding='same', data_format='channels_first'))
    model.add(layers.Lambda(lambda x:tf.nn.depth_to_space(x,2,data_format='NCHW')))

    model.add(layers.Conv2D(filters=3, kernel_size=1, activation = "relu", strides = (1,1), padding='same', data_format='channels_first'))

In [3]:
model.compile(optimizer='adam', loss='mse', metrics=['mse','mae'])
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 64, 16, 16)        1792      
_________________________________________________________________
lambda (Lambda)              (None, 1, 64, 16, 16)     0         
_________________________________________________________________
conv_lst_m2d (ConvLSTM2D)    (None, 256, 8, 8)         2950144   
_________________________________________________________________
lambda_1 (Lambda)            (None, 1, 256, 8, 8)      0         
_________________________________________________________________
conv_lst_m2d_1 (ConvLSTM2D)  (None, 512, 4, 4)         14157824  
_________________________________________________________________
lambda_2 (Lambda)            (None, 1, 512, 4, 4)      0         
_________________________________________________________________
conv_lst_m2d_2 (ConvLSTM2D)  (None, 512, 2, 2)         1

In [None]:
train_images_ = tf.convert_to_tensor(train_images.reshape((-1,3,32,32)))
model3 = model.fit(train_images_[0:100], train_images_[0:100], batch_size=64, epochs=1)