In [1]:
%matplotlib inline

In [2]:
import os
import sys

In [3]:
lmu_path = os.path.abspath("../lmu")
sys.path.append(lmu_path)


In [4]:
from lmu import LMUCell

import keras as K
from keras.applications import ResNet50
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.layers import Dense, Input, GlobalAveragePooling2D, Conv2DTranspose, Reshape, Lambda
from keras.layers import TimeDistributed
from keras.layers.recurrent import RNN
from keras.models import Sequential, Model
from keras.initializers import Constant
from keras.utils import multi_gpu_model, to_categorical

import numpy as np

Using TensorFlow backend.


In [15]:
def lmu_layer(return_sequences=False,**kwargs):
    return RNN(LMUCell(units=1,
                       order=6,
                       theta=15,
                       input_encoders_initializer=Constant(1),
                       hidden_encoders_initializer=Constant(0),
                       memory_encoders_initializer=Constant(0),
                       input_kernel_initializer=Constant(0),
                       hidden_kernel_initializer=Constant(0),
                       memory_kernel_initializer='glorot_normal',
                      ),
               return_sequences=return_sequences,
               **kwargs)

def Lmu_stack(input_tensor, return_sequences):
    t = []
    shape = K.backend.int_shape(input_tensor)
    input_tensor = Reshape([shape[1], -1, shape[-1]])(input_tensor)
    for i in range(shape[-1]):
        x = Lambda(lambda x: x[...,i])(input_tensor)
        t.append(lmu_layer(return_sequences=return_sequences)(x))
    
    return Lambda(lambda x: K.backend.stack(x, axis=-1))(t)

    

In [16]:
seq_len = 15

# resnet_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224,224,3))

# for layer in resnet_model.layers:
#     layer.trainable=False
    
# base_out = resnet_model.output
# base_out = GlobalAveragePooling2D()(base_out)
# base_model = Model(inputs=resnet_model.input, outputs=base_out)

# input_layer = Input(shape=(seq_len, 224, 224, 3))

# # x = TimeDistributed(base_model)(input_layer)
# x = TimeDistributed(resnet_model)(input_layer)

i = Input(shape=(seq_len, 7, 7, 2048))

# x = Reshape(target_shape=(seq_len, 2048))(x)
out = Lmu_stack(i, return_sequences=False)

v = Reshape([-1])(out)
v = Dense(2)(v)


#Deconvolution, this is huge, how to fix?
heat = Reshape([1,1,-1])(out)
# heat = Conv2DTranspose(filters=1024,kernel_size=5,strides=2,activation='relu', padding = 'same')(heat)
# heat = Conv2DTranspose(filters=512,kernel_size=5,strides=2,activation='relu', padding = 'same')(heat)
# heat = Conv2DTranspose(filters=256,kernel_size=5,strides=2,activation='relu', padding = 'same')(heat)
# heat = Conv2DTranspose(filters=128,kernel_size=5,strides=2,activation='relu', padding = 'same')(heat)

model = Model(inputs=i, outputs=heat)

In [17]:
model.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            (None, 15, 7, 7, 204 0                                            
__________________________________________________________________________________________________
reshape_4 (Reshape)             (None, 15, 49, 2048) 0           input_5[0][0]                    
__________________________________________________________________________________________________
lambda_2050 (Lambda)            (None, 15, 49)       0           reshape_4[0][0]                  
__________________________________________________________________________________________________
lambda_2051 (Lambda)            (None, 15, 49)       0           reshape_4[0][0]                  
____________________________________________________________________________________________

lambda_2707 (Lambda)            (None, 15, 49)       0           reshape_4[0][0]                  
__________________________________________________________________________________________________
lambda_2708 (Lambda)            (None, 15, 49)       0           reshape_4[0][0]                  
__________________________________________________________________________________________________
lambda_2709 (Lambda)            (None, 15, 49)       0           reshape_4[0][0]                  
__________________________________________________________________________________________________
lambda_2710 (Lambda)            (None, 15, 49)       0           reshape_4[0][0]                  
__________________________________________________________________________________________________
lambda_2711 (Lambda)            (None, 15, 49)       0           reshape_4[0][0]                  
__________________________________________________________________________________________________
lambda_271

lambda_3591 (Lambda)            (None, 15, 49)       0           reshape_4[0][0]                  
__________________________________________________________________________________________________
lambda_3592 (Lambda)            (None, 15, 49)       0           reshape_4[0][0]                  
__________________________________________________________________________________________________
lambda_3593 (Lambda)            (None, 15, 49)       0           reshape_4[0][0]                  
__________________________________________________________________________________________________
lambda_3594 (Lambda)            (None, 15, 49)       0           reshape_4[0][0]                  
__________________________________________________________________________________________________
lambda_3595 (Lambda)            (None, 15, 49)       0           reshape_4[0][0]                  
__________________________________________________________________________________________________
lambda_359

__________________________________________________________________________________________________
rnn_2310 (RNN)                  (None, 1)            154         lambda_2311[0][0]                
__________________________________________________________________________________________________
rnn_2311 (RNN)                  (None, 1)            154         lambda_2312[0][0]                
__________________________________________________________________________________________________
rnn_2312 (RNN)                  (None, 1)            154         lambda_2313[0][0]                
__________________________________________________________________________________________________
rnn_2313 (RNN)                  (None, 1)            154         lambda_2314[0][0]                
__________________________________________________________________________________________________
rnn_2314 (RNN)                  (None, 1)            154         lambda_2315[0][0]                
__________

rnn_2900 (RNN)                  (None, 1)            154         lambda_2901[0][0]                
__________________________________________________________________________________________________
rnn_2901 (RNN)                  (None, 1)            154         lambda_2902[0][0]                
__________________________________________________________________________________________________
rnn_2902 (RNN)                  (None, 1)            154         lambda_2903[0][0]                
__________________________________________________________________________________________________
rnn_2903 (RNN)                  (None, 1)            154         lambda_2904[0][0]                
__________________________________________________________________________________________________
rnn_2904 (RNN)                  (None, 1)            154         lambda_2905[0][0]                
__________________________________________________________________________________________________
rnn_2905 (

rnn_3499 (RNN)                  (None, 1)            154         lambda_3500[0][0]                
__________________________________________________________________________________________________
rnn_3500 (RNN)                  (None, 1)            154         lambda_3501[0][0]                
__________________________________________________________________________________________________
rnn_3501 (RNN)                  (None, 1)            154         lambda_3502[0][0]                
__________________________________________________________________________________________________
rnn_3502 (RNN)                  (None, 1)            154         lambda_3503[0][0]                
__________________________________________________________________________________________________
rnn_3503 (RNN)                  (None, 1)            154         lambda_3504[0][0]                
__________________________________________________________________________________________________
rnn_3504 (

rnn_4048 (RNN)                  (None, 1)            154         lambda_4049[0][0]                
__________________________________________________________________________________________________
rnn_4049 (RNN)                  (None, 1)            154         lambda_4050[0][0]                
__________________________________________________________________________________________________
rnn_4050 (RNN)                  (None, 1)            154         lambda_4051[0][0]                
__________________________________________________________________________________________________
rnn_4051 (RNN)                  (None, 1)            154         lambda_4052[0][0]                
__________________________________________________________________________________________________
rnn_4052 (RNN)                  (None, 1)            154         lambda_4053[0][0]                
__________________________________________________________________________________________________
rnn_4053 (

Total params: 315,392
Trainable params: 229,376
Non-trainable params: 86,016
__________________________________________________________________________________________________
