In [1]:
%matplotlib inline

In [2]:
import os
import sys

In [3]:
lmu_path = os.path.abspath("../lmu")
sys.path.append(lmu_path)


In [80]:
from lmu import LMUCell

import keras as K
from keras.applications import ResNet50
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.layers import Dense, Input, GlobalAveragePooling2D, Conv2DTranspose, Reshape, Lambda
from keras.layers import TimeDistributed
from keras.layers.recurrent import RNN
from keras.models import Sequential, Model
from keras.initializers import Constant
from keras.utils import multi_gpu_model, to_categorical

import numpy as np

In [122]:
def lmu_layer(return_sequences=False,**kwargs):
    return RNN(LMUCell(units=6,
                       order=6,
                       theta=15,
                       input_encoders_initializer=Constant(1),
                       hidden_encoders_initializer=Constant(0),
                       memory_encoders_initializer=Constant(0),
                       input_kernel_initializer=Constant(0),
                       hidden_kernel_initializer=Constant(0),
                       memory_kernel_initializer='glorot_normal',
                      ),
               return_sequences=return_sequences,
               **kwargs)

def Lmu_stack(input_tensor, return_sequences):
    t = []
    for i in range(K.backend.int_shape(input_tensor)[-1]):
        x = Lambda(lambda x: x[...,i])(input_tensor)
        x = Reshape([input_tensor.shape[1], -1])(x)
        t.append(lmu_layer(return_sequences=return_sequences)(x))
    
    return Lambda(lambda x: K.backend.stack(x, axis=-1))(t)

    

In [132]:
seq_len = 15

# resnet_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224,224,3))

# for layer in resnet_model.layers:
#     layer.trainable=False
    
# base_out = resnet_model.output
# base_out = GlobalAveragePooling2D()(base_out)
# base_model = Model(inputs=resnet_model.input, outputs=base_out)

# input_layer = Input(shape=(seq_len, 224, 224, 3))

# # x = TimeDistributed(base_model)(input_layer)
# x = TimeDistributed(resnet_model)(input_layer)

i = Input(shape=(seq_len, 7, 7, 2048))

# x = Reshape(target_shape=(seq_len, 2048))(x)
out = Lmu_stack(i, return_sequences=False)

v = Reshape([-1])(out)
v = Dense(2)(v)


#Deconvolution, this is huge, how to fix?
heat = Reshape([1,1,-1])(out)
heat = Conv2DTranspose(filters=1024,kernel_size=5,strides=2,activation='relu', padding = 'same')(heat)
heat = Conv2DTranspose(filters=512,kernel_size=5,strides=2,activation='relu', padding = 'same')(heat)
heat = Conv2DTranspose(filters=256,kernel_size=5,strides=2,activation='relu', padding = 'same')(heat)
heat = Conv2DTranspose(filters=128,kernel_size=5,strides=2,activation='relu', padding = 'same')(heat)

model = Model(inputs=i, outputs=heat)

In [133]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_135 (InputLayer)          (None, 15, 7, 7, 204 0                                            
__________________________________________________________________________________________________
lambda_246 (Lambda)             (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_247 (Lambda)             (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_248 (Lambda)             (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_249

lambda_599 (Lambda)             (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_600 (Lambda)             (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_601 (Lambda)             (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_602 (Lambda)             (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_603 (Lambda)             (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_604

lambda_734 (Lambda)             (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_735 (Lambda)             (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_736 (Lambda)             (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_737 (Lambda)             (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_738 (Lambda)             (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_739

lambda_984 (Lambda)             (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_985 (Lambda)             (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_986 (Lambda)             (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_987 (Lambda)             (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_988 (Lambda)             (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_989

__________________________________________________________________________________________________
lambda_1234 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_1235 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_1236 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_1237 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_1238 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________

__________________________________________________________________________________________________
lambda_1484 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_1485 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_1486 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_1487 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_1488 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________

lambda_1733 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_1734 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_1735 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_1736 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_1737 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_173

lambda_1983 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_1984 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_1985 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_1986 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_1987 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_198

__________________________________________________________________________________________________
lambda_2233 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_2234 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_2235 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_2236 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________________________________________________________________________________________________
lambda_2237 (Lambda)            (None, 15, 7, 7)     0           input_135[0][0]                  
__________

__________________________________________________________________________________________________
reshape_4386 (Reshape)          (None, Dimension(15) 0           lambda_435[0][0]                 
__________________________________________________________________________________________________
reshape_4387 (Reshape)          (None, Dimension(15) 0           lambda_436[0][0]                 
__________________________________________________________________________________________________
reshape_4388 (Reshape)          (None, Dimension(15) 0           lambda_437[0][0]                 
__________________________________________________________________________________________________
reshape_4389 (Reshape)          (None, Dimension(15) 0           lambda_438[0][0]                 
__________________________________________________________________________________________________
reshape_4390 (Reshape)          (None, Dimension(15) 0           lambda_439[0][0]                 
__________

reshape_4635 (Reshape)          (None, Dimension(15) 0           lambda_684[0][0]                 
__________________________________________________________________________________________________
reshape_4636 (Reshape)          (None, Dimension(15) 0           lambda_685[0][0]                 
__________________________________________________________________________________________________
reshape_4637 (Reshape)          (None, Dimension(15) 0           lambda_686[0][0]                 
__________________________________________________________________________________________________
reshape_4638 (Reshape)          (None, Dimension(15) 0           lambda_687[0][0]                 
__________________________________________________________________________________________________
reshape_4639 (Reshape)          (None, Dimension(15) 0           lambda_688[0][0]                 
__________________________________________________________________________________________________
reshape_46

reshape_4885 (Reshape)          (None, Dimension(15) 0           lambda_934[0][0]                 
__________________________________________________________________________________________________
reshape_4886 (Reshape)          (None, Dimension(15) 0           lambda_935[0][0]                 
__________________________________________________________________________________________________
reshape_4887 (Reshape)          (None, Dimension(15) 0           lambda_936[0][0]                 
__________________________________________________________________________________________________
reshape_4888 (Reshape)          (None, Dimension(15) 0           lambda_937[0][0]                 
__________________________________________________________________________________________________
reshape_4889 (Reshape)          (None, Dimension(15) 0           lambda_938[0][0]                 
__________________________________________________________________________________________________
reshape_48

__________________________________________________________________________________________________
reshape_5135 (Reshape)          (None, Dimension(15) 0           lambda_1184[0][0]                
__________________________________________________________________________________________________
reshape_5136 (Reshape)          (None, Dimension(15) 0           lambda_1185[0][0]                
__________________________________________________________________________________________________
reshape_5137 (Reshape)          (None, Dimension(15) 0           lambda_1186[0][0]                
__________________________________________________________________________________________________
reshape_5138 (Reshape)          (None, Dimension(15) 0           lambda_1187[0][0]                
__________________________________________________________________________________________________
reshape_5139 (Reshape)          (None, Dimension(15) 0           lambda_1188[0][0]                
__________

reshape_5444 (Reshape)          (None, Dimension(15) 0           lambda_1493[0][0]                
__________________________________________________________________________________________________
reshape_5445 (Reshape)          (None, Dimension(15) 0           lambda_1494[0][0]                
__________________________________________________________________________________________________
reshape_5446 (Reshape)          (None, Dimension(15) 0           lambda_1495[0][0]                
__________________________________________________________________________________________________
reshape_5447 (Reshape)          (None, Dimension(15) 0           lambda_1496[0][0]                
__________________________________________________________________________________________________
reshape_5448 (Reshape)          (None, Dimension(15) 0           lambda_1497[0][0]                
__________________________________________________________________________________________________
reshape_54

reshape_5801 (Reshape)          (None, Dimension(15) 0           lambda_1850[0][0]                
__________________________________________________________________________________________________
reshape_5802 (Reshape)          (None, Dimension(15) 0           lambda_1851[0][0]                
__________________________________________________________________________________________________
reshape_5803 (Reshape)          (None, Dimension(15) 0           lambda_1852[0][0]                
__________________________________________________________________________________________________
reshape_5804 (Reshape)          (None, Dimension(15) 0           lambda_1853[0][0]                
__________________________________________________________________________________________________
reshape_5805 (Reshape)          (None, Dimension(15) 0           lambda_1854[0][0]                
__________________________________________________________________________________________________
reshape_58

reshape_6134 (Reshape)          (None, Dimension(15) 0           lambda_2183[0][0]                
__________________________________________________________________________________________________
reshape_6135 (Reshape)          (None, Dimension(15) 0           lambda_2184[0][0]                
__________________________________________________________________________________________________
reshape_6136 (Reshape)          (None, Dimension(15) 0           lambda_2185[0][0]                
__________________________________________________________________________________________________
reshape_6137 (Reshape)          (None, Dimension(15) 0           lambda_2186[0][0]                
__________________________________________________________________________________________________
reshape_6138 (Reshape)          (None, Dimension(15) 0           lambda_2187[0][0]                
__________________________________________________________________________________________________
reshape_61

rnn_4478 (RNN)                  (None, 6)            469         reshape_4555[0][0]               
__________________________________________________________________________________________________
rnn_4479 (RNN)                  (None, 6)            469         reshape_4556[0][0]               
__________________________________________________________________________________________________
rnn_4480 (RNN)                  (None, 6)            469         reshape_4557[0][0]               
__________________________________________________________________________________________________
rnn_4481 (RNN)                  (None, 6)            469         reshape_4558[0][0]               
__________________________________________________________________________________________________
rnn_4482 (RNN)                  (None, 6)            469         reshape_4559[0][0]               
__________________________________________________________________________________________________
rnn_4483 (

__________________________________________________________________________________________________
rnn_4759 (RNN)                  (None, 6)            469         reshape_4836[0][0]               
__________________________________________________________________________________________________
rnn_4760 (RNN)                  (None, 6)            469         reshape_4837[0][0]               
__________________________________________________________________________________________________
rnn_4761 (RNN)                  (None, 6)            469         reshape_4838[0][0]               
__________________________________________________________________________________________________
rnn_4762 (RNN)                  (None, 6)            469         reshape_4839[0][0]               
__________________________________________________________________________________________________
rnn_4763 (RNN)                  (None, 6)            469         reshape_4840[0][0]               
__________

rnn_5220 (RNN)                  (None, 6)            469         reshape_5297[0][0]               
__________________________________________________________________________________________________
rnn_5221 (RNN)                  (None, 6)            469         reshape_5298[0][0]               
__________________________________________________________________________________________________
rnn_5222 (RNN)                  (None, 6)            469         reshape_5299[0][0]               
__________________________________________________________________________________________________
rnn_5223 (RNN)                  (None, 6)            469         reshape_5300[0][0]               
__________________________________________________________________________________________________
rnn_5224 (RNN)                  (None, 6)            469         reshape_5301[0][0]               
__________________________________________________________________________________________________
rnn_5225 (

rnn_5508 (RNN)                  (None, 6)            469         reshape_5585[0][0]               
__________________________________________________________________________________________________
rnn_5509 (RNN)                  (None, 6)            469         reshape_5586[0][0]               
__________________________________________________________________________________________________
rnn_5510 (RNN)                  (None, 6)            469         reshape_5587[0][0]               
__________________________________________________________________________________________________
rnn_5511 (RNN)                  (None, 6)            469         reshape_5588[0][0]               
__________________________________________________________________________________________________
rnn_5512 (RNN)                  (None, 6)            469         reshape_5589[0][0]               
__________________________________________________________________________________________________
rnn_5513 (

__________________________________________________________________________________________________
rnn_6008 (RNN)                  (None, 6)            469         reshape_6085[0][0]               
__________________________________________________________________________________________________
rnn_6009 (RNN)                  (None, 6)            469         reshape_6086[0][0]               
__________________________________________________________________________________________________
rnn_6010 (RNN)                  (None, 6)            469         reshape_6087[0][0]               
__________________________________________________________________________________________________
rnn_6011 (RNN)                  (None, 6)            469         reshape_6088[0][0]               
__________________________________________________________________________________________________
rnn_6012 (RNN)                  (None, 6)            469         reshape_6089[0][0]               
__________

                                                                 rnn_4799[0][0]                   
                                                                 rnn_4800[0][0]                   
                                                                 rnn_4801[0][0]                   
                                                                 rnn_4802[0][0]                   
                                                                 rnn_4803[0][0]                   
                                                                 rnn_4804[0][0]                   
                                                                 rnn_4805[0][0]                   
                                                                 rnn_4806[0][0]                   
                                                                 rnn_4807[0][0]                   
                                                                 rnn_4808[0][0]                   
          

                                                                 rnn_5798[0][0]                   
                                                                 rnn_5799[0][0]                   
                                                                 rnn_5800[0][0]                   
                                                                 rnn_5801[0][0]                   
                                                                 rnn_5802[0][0]                   
                                                                 rnn_5803[0][0]                   
                                                                 rnn_5804[0][0]                   
                                                                 rnn_5805[0][0]                   
                                                                 rnn_5806[0][0]                   
                                                                 rnn_5807[0][0]                   
          