Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LSTM layer support #8

Closed
anbac opened this issue Jan 31, 2018 · 71 comments
Closed

LSTM layer support #8

anbac opened this issue Jan 31, 2018 · 71 comments

Comments

@anbac
Copy link

anbac commented Jan 31, 2018

It would be very nice to also support LSTM layers

@Dobiasd
Copy link
Owner

Dobiasd commented Feb 1, 2018

Totally agree. Are you interested in implementing it? :-)

edit: RNNs have also been asked for here.

@rcshubhadeep
Copy link

rcshubhadeep commented Mar 12, 2018

I was about to write about the same and then found this issue thread. I think it will be a great addition. If you agree I can try my luck. DId not do C++ in a while but would be great to go back.

@Dobiasd
Copy link
Owner

Dobiasd commented Mar 12, 2018

@rcshubhadeep That would be awesome! Once the algorithm works, we can adjust the C++ style afterwards.

So if you'd like to start implementing, here are the steps that are needed:

  • use the layer type in question in get_test_model_small in keras_export/generate_test_models.py
  • add an export function for the layer (if needed) to keras_export/convert_model.py
  • add the new layer class to include/fdeep/layers (actual implementation)
  • add a create function (deserialization from JSON) to include/fdeep/import_model.hpp
  • also in include/fdeep/import_model.hpp, extend the unordered_map creators in create_layer to also dispatch to your new create function.

When you have all this, your new layer type will automatically be tested (i.e. its results compared with the Keras implementation) in the unit tests. You can run the tests locally as described in README.md or by creating a pull request here in this repo and letting Travis CI do it.

If any questions arise, just let me know. :)

@rcshubhadeep
Copy link

Here to say, I am sorry but a newborn kid and the pressure of my present job presently preventing me from taking up this venture. Would have loved to

@Dobiasd
Copy link
Owner

Dobiasd commented Aug 9, 2018

No problem at all. Having kids is one of the few things that is even more awesome than implementing deep-learning layers. 😉 So I totally understand.

@chammika
Copy link

@Dobiasd I briefly looked into implementing SimpleRNN recurrent layer, weight layout of model seems simple and exporting them wouldn't be hard. I did the forward pass within python using numpy and it agrees with the Keras predictions.
https://gist.github.com/chammika/0448fbb0e96d1365326721137ff86da6

Next steps would be to follow your outline above to have a functional layer. I am not clear yet how to handle the extra time dimension in the data though...

Regarding recurrent layers in general, input shape has another dimension for time ie. seq_length How would you envision it fit in to the fdeep ? should we define tensor4 and shape4 with (seq_length, depth, height, width)
And if you count the batch_size dimension (which is needed to implement stateful RNN layers) it's a 5D tensor. Would be great if you can give some ideas on it.

@Dobiasd
Copy link
Owner

Dobiasd commented Aug 17, 2018

@chammika Wow, good work on implementing it in Python already!

Yes, it seems like we need a Tensor4, but I'm not yet sure if it makes sense to let all fdeep-layers just work with Tensor4 instead of Tensor3to unify things. The seq_length dimension might be confused with a batch-size dimension. What do you think?

Do stateful RNN layers need a fifth dimension (batch_size) also when leaving out training and just implementing forward passes?

@chammika
Copy link

If we are to add Tensor4 it only apply to the following to cases to my knowledge:

  • inputs of RNN layers
  • outputs of RNN layers if the return_sequences=True (next layer must be a RNN layer)

When return_sequences=False of a RNN layer output is Tensor3 which matches inputs with the non-recurrent layers. To illustrate this I updated the above example code to have two SimpleRNN layers and Dense layer and simplified the calculations a bit.
https://gist.github.com/chammika/0448fbb0e96d1365326721137ff86da6

Note that First SimpleRNN return a sequence which can be fed to a RNN where as second one does not/ it cannot if we are to connect a Dense layer next. I could replicate the Kras predictions in numpy as before.

but I'm not yet sure if it makes sense to let all fdeep-layers just work with Tensor4 instead of Tensor3to unify things

There is no need Only the above cases will have Tensor4

Do stateful RNN layers need a fifth dimension (batch_size) also when leaving out training and just implementing forward passes?

I am not clear about this yet. According to above example it doesn't. Changing the stateful = True or False Keras will give same predictions. So I am guessing it's only option during the training ? According to this article however, we should have a difference

https://fairyonice.github.io/Understand-Keras's-RNN-behind-the-scenes-with-a-sin-wave-example.html

Enlighten me if you figure this out 😃

@Dobiasd
Copy link
Owner

Dobiasd commented Aug 17, 2018

There is no need Only the above cases will have Tensor4

Yeah, but if one layer type takes Tensor4 and the other layer type take Tensor3, what type should model::predict (and layer::apply) take?

I guess we need some kind of unification, or do you see a loophole somewhere? 🙂

So I am guessing it's only option during the training ? According to this article however, we should have a difference

😕

@Dobiasd
Copy link
Owner

Dobiasd commented Aug 18, 2018

Leaving the stateful question aside, perhaps there is no problem with "tensor4". Currently model::predict and layer::apply take fdeep::tensor3s (i.e., std::vector<fplus::tensor3>, because that's just how models are implemented in Keras. They can take multiple input tensors and some layer types also take multiple tensors as input (e.g., merge layers). Layers like conv2d of course only can work when this tensor3 vector has size 1. Perhaps LSTM is just like for example the concatenate layer taking a tensor3 vector with more than 1 element. In that case we could just add LSTM without changing the overall structure and types of the lib. What do you think?

@chammika-become
Copy link
Contributor

chammika-become commented Aug 20, 2018

Yes that would be the best solution keep the current layer api intact. In that case size(tensor3s) == seq_len. According to an answer on below question

https://stackoverflow.com/questions/42763928/how-to-use-model-reset-states-in-keras

stateful=True is usually used when you want to treat consecutive batches as consequtive inputs. In this case model is treating consequtive batches the same as it were in the same batch.

If we flatten batches into one sequence we can have the same effect as having the stateful predictions. In that case we might want to update the api to

apply(const tensor3s& inputs, batch_size=1)

Pass the appropriate batch_size so that the outputs will be generated from the layer at batch_size intervals.
Since I haven't seen any difference in stateful predictions form my simulations of Kera models, I would first implement layer without stateful and see if the predictions really need this extra batch_size parameter to work.

@Dobiasd
Copy link
Owner

Dobiasd commented Aug 20, 2018

Sounds good. 👍

One question though: You mean changing the interface of fdeep::layer::apply from virtual tensor3s apply(const tensor3s& input) const final to virtual tensor3s apply(const tensor3s& input, batch_size=1) const final?

Since frugally-deep does not support any kind of batching (not counting parallel predictions) I'd like to keep the word batch out. Do you think it is possible to leave the interface as it, and just assume (or calculate) the batch size if needed inside the LSTM-layer implementation?

@chammika-become
Copy link
Contributor

It's possible to keep the interface as it is for stateful=False layers. It might be necessary to change the signature of the interface as above. Yes you are right about my suggestion it's virtual layer::apply method.
Still, I am not sure if it's necessary because I am not fully clear on stateful predictions, sorry for confusing you with suggestions that fully don't understand.
I will first implement SimpleRNN layer (with stateful=Fale) and see if it can predict stateful layer without modifying the interface.

@Dobiasd
Copy link
Owner

Dobiasd commented Aug 21, 2018

Sounds like a very good plan. 👍

Let me know if you have any questions regarding the integrations or if I can be of any other kind of help. 🙂

@n-Guard
Copy link
Contributor

n-Guard commented Aug 29, 2018

@Dobiasd First of all: keep up the good work, I think this project is really useful!

I recently got my first job at an audio software company and it's my task to implement the LSTM, TimeDistributed, and Bidirectional layers for production.

I took a shot at implementing the lstm_layer class with the inner function that computes the LSTM output, but only for my special use case where I basically just have a 2-dimensional input with (seq_length, 1D_data). The third dimension (first in Tensor3) is ignored right now because it seemed to me that this would be a little bit harder to implement since we don't work with Eigen-unsupported where n-dimensional matrices can be used.

For testing, I manually copied the weights from the keras model and fed them into the function. Output values are reasonably close to keras output.
I know the implementation is probably far from ideal because I'm relatively new at C++ but maybe it's a starting point.

Actual implementation: https://gist.github.com/n-Guard/50a64f4ab837b06777263758b15e6118
Tests: https://gist.github.com/n-Guard/afe8a288ac54a32880ceb3f1e37d95d8

Next thing would obviously be implementing the model conversion to be able to actually test random cases.

I'm looking forward to get some feedback!

@Dobiasd
Copy link
Owner

Dobiasd commented Aug 29, 2018

@n-Guard Hi, and thanks for the nice feedback and this good contribution. Code cleanup regarding C++ style can be done later. :)

The more important thing from my POV is that we now seem to have two people working on the same feature. Maybe to further improve quality (and to not waste your time too) a collaboration of you two would be reasonable? @chammika / @chammika-become What do you think?

@chammika-become
Copy link
Contributor

chammika-become commented Aug 30, 2018

@n-Guard Great work! 👍
@Dobiasd I have been working on SimpleRNN managed only so far to export weights from model. On the otherhand @n-Guard have already implemented a working version LSTM so let's go ahead integrating his solution.
Only thing I would suggest now is to keep bias in RowMajorMatrixXf from the beginning so you don't have to regenerate the b here for every prediction.
https://gist.github.com/n-Guard/50a64f4ab837b06777263758b15e6118#file-lstm_layer-hpp-L123

Could you also post a link to the python model that comes up with the same predictions for easy of comparison. I particularly want to see how the seq_len of the input vectors is handled. My idea for SimpleRNN is len(tensor3s) == seq_len where as in your case it looks like x_width

@Dobiasd
Copy link
Owner

Dobiasd commented Aug 30, 2018

OK, cool. @n-Guard now that this part is clear, feel free to open a PR with your work if you like, so we can view it and so that you have automatic tests in the CI etc.

@n-Guard
Copy link
Contributor

n-Guard commented Aug 30, 2018

@chammika-become Yes, it is indeed better to keep bias as RowMajorMatrixXf and I changed it accordingly.
Here is my python code that I used for reference:
https://gist.github.com/n-Guard/cedd69b50803c82f6ce437ea771a620c
I chose x_width as seq_len because in Tensor3 the first dimension is the number of channels. But I think your suggestion with seq_len being the first dimension is more consistent, so let's go with that!

@Dobiasd Yes, I will soon open a PR :)
Btw is there a specific reason why you didn't use Eigen-unsupported for having n-dimensional matrices? I'm interested because tensorflow seems to use it.

@Dobiasd
Copy link
Owner

Dobiasd commented Aug 30, 2018

But I think your suggestion with seq_len being the first dimension is more consistent, so let's go with that!

fdeep::layer::apply takes tensor3s (i.e., vector<tensor3>). Would it make sense to use the size of this vector as seq_len instead of reusing "channels, "height" or "width"?


Btw is there a specific reason why you didn't use Eigen-unsupported for having n-dimensional matrices? I'm interested because tensorflow seems to use it.

I did not yet look into it. The main reason might be historical. For quite a long time fdeep did not use Eigen at all. I then only added it to speed up the matrix multiplications. ;)

@n-Guard
Copy link
Contributor

n-Guard commented Aug 30, 2018

fdeep::layer::apply takes tensor3s (i.e., vector). Would it make sense to use the size of this vector as seq_len instead of reusing "channels, "height" or "width"?

Ah, so if I understand it correctly you would have a vector with elements representing multiple tensor3 instances (i.e. tensor3s), where each tensor3 element represents one time step with three dimensions (channels, height, width)? If so, I think this would make sense.

@Dobiasd
Copy link
Owner

Dobiasd commented Aug 30, 2018

Yes, exactly. And to me this sounds more intuitive. On the other hand, I don't know much about LSTMs. So if this poses a problem, feel free to use an alternative approach.

@n-Guard
Copy link
Contributor

n-Guard commented Aug 30, 2018

After doing some research, I don't think there is a problem specific to LSTMs since keras LSTM-layer (and RNN-layers in general) only accept a 3D Tensor as input (batch_size, timesteps, input_dim), so basically we only have 2 dimensions left when leaving out batch dimension. This also implies that we don't need n-dimensional matrices at this point (for LSTMs).
So the problem of having more than 3 dimensions only occurs when the input layer is not an RNN-layer.

@Dobiasd
Copy link
Owner

Dobiasd commented Aug 30, 2018

If I understand correctly we basically can choose which one of our four dimensions ([tensor_count, channels, height, widths] from tensor3s aka. vector<tensor3>) we want to use for seq_len, right?

So we might take the option that translates the simplest from what Keras does, i.e., the way that avoids conversions (swapaxes stuff) when we export the validation input data for a model from Keras to C++.

Naturally this should then automatically be what would be most intuitive for our users to create when calling tensor3s model::predict(const tensor3s& inputs) const.

@n-Guard
Copy link
Contributor

n-Guard commented Sep 5, 2018

Ok, I think we can definitely agree on having seq_len as tensor3s.size() since time dimension is always first dimension in keras.
In my lastest commit to my PR, I changed the input accordingly but with the feature dimension being tensor3.height() because I thought this would be most appropriate (having no channel-dimension), so the input to the LSTM layer would be (seq_len, 1, n_features, 1).

But: Since the flatten layer returns a tensor3 with shape (n, 1, 1) I think it would be better to have tensor3.depth() as the feature dimension because when we have a greater than 2-dim input to an RNN layer (e.g. from a Conv2D layer) the input has to be flattened to use in the RNN layer.

@Dobiasd
Copy link
Owner

Dobiasd commented Sep 5, 2018

Good catch! flatten (in Keras and in fdeep) only makes the depth dimension != 1. So if it's a normal use case in Keras to do flatten -> LSTM, how do the LSTM layers there work with this output if they expect their seq_len as "number of tensors" (instead of depth)?

Is it possible for us at all to "just" do it like Keras does, or is this in conflict with how we handle tensors in general?

(It might be that I don't understand the problem fully, so please correct me if needed. 🙂)

@Dobiasd
Copy link
Owner

Dobiasd commented Nov 14, 2018

Right. Up to now we only forward the raw tensors. It probably would result in some modifications in get_layer_output/layer::get_output/node::get_output.

The basic idea of these functions is to not push the data trough the model from front to end, but instead pull it out from the end. This "pull" then propagates through the computational graph up to the input layer(s).

One advantage of this is the following.

Consider we have such a graph (A is our only input layer, H is our only output layer):

         +-->C---->D
         |
A---->B--+                 +-->H
         |                 |
         +-->E---->F--->G--+

Pushing from A would also invoke C and D. But actually computing these is not needed. Pulling from H solves this issue. The calculations in C and D will not be executed.

@Dobiasd
Copy link
Owner

Dobiasd commented Nov 22, 2018

Moved our conversation to a new issue. :)

@marco-monforte
Copy link

Hi guys,
First and foremost, thanks for the great job you're doing!

For the project I'm working on I need stateful LSTM (and a way to reset the internal cell state). Is this implemented in any of your functions? Unfortunately, my knowledge of C++ is limited and I don't feel I can help with it.

@Dobiasd
Copy link
Owner

Dobiasd commented Feb 18, 2019

@marco-monforte So you are looking for something like this in frugally-deep?

const auto result_1 = model.predict(some_input_tensors);
model.reset_some_hidden_LSTM_state();
const auto result_2 = model.predict(some_other:input_tensors);

@marco-monforte
Copy link

@Dobiasd yes! And if I input two consecutive tensors, the hidden LSTM state is preserved from the first call to the second

@Dobiasd
Copy link
Owner

Dobiasd commented Feb 19, 2019

Currently something like that is not supported. fdeep::model::predict is even const, meaning it is not intended to change internal state. What you are trying to achieve can not be done with a TimeDistributed layer, right?

@marco-monforte
Copy link

No, it's a different concept from the TimeDistributed layer.

When we run a prediction, the LSTM network output is built upon an internal state of the cells, obtained from the gates. What the model outputs, however, it's just the first information and the state is resetted after each prediction. In some cases, it's useful to keep this cell state intact so that the next prediction will depend also from the previous one, but then we have to manually control this resetting of the memory. It is something used in particular cases, also because of the "danger" associated to this memory preservation.

Thanks anyway, the library is really great! Hopefully someone will add this feature soon :)

@keithchugg
Copy link
Collaborator

keithchugg commented Sep 4, 2019

@Dobiasd - what @marco-monforte is asking for is very useful. A common case is that an RNN is trained with input: (batch_size, sequence_length, features_dim) and stateful=False, return_sequence=True. Then you can convert the trained model to another "streaming" keras model that has input: (1,1, features_dim) and stateful=True, return_sequence=False. This streaming model then runs with an indefinite sequence length, remembering its past. If you want to start over, you reset the states explicitly.

Since for many cases, the state is the previous output (eg, simpleRNN, GRU), perhaps a solution could be passing the input and the state (last input)? Sorry, I also am not a C++ expert...

@Dobiasd
Copy link
Owner

Dobiasd commented Sep 4, 2019

Thanks for the explanation. Could you give a minimal Keras code example that does it?

@keithchugg
Copy link
Collaborator

keithchugg commented Sep 4, 2019

Below is an example. The feature you added to the LSTM layer to set the state could be use to make this work. The state would have to be returned and then sent back as an input for the next step. It would be cleaner if the recursive layers had a "stateful" (remember state) and reset_state() options.

Thanks for frugally-deep!

import h5py
import keras
import numpy as np
from keras.layers import Input, Dense, GRU

##### generate toy data
train_seq_length = 4
feature_dim = 2
num_seqs = 8
x =  np.random.randint(0, high=2, size = (num_seqs * train_seq_length, feature_dim) )
x = np.sign( x - 0.5 )
y = np.sum( ( x == np.roll(x, 1, axis = 0) ), axis = 1 )
### y[n] = number of agreements between x[n], x[n-1]
x = x.reshape( (num_seqs, train_seq_length, feature_dim) )
y = y.reshape( (num_seqs, train_seq_length, 1) )


######  Define/Build/Train Training Model
training_in_shape = x.shape[1:]
training_in = Input(shape=training_in_shape)
# training_in = Input(batch_shape=(None,train_seq_length,feature_dim)) this works too
foo = GRU(4, return_sequences=True, stateful=False)(training_in)
training_pred = Dense(1)(foo)

training_model = keras.Model(inputs=training_in, outputs=training_pred)
training_model.compile(loss='mean_squared_error', optimizer='adam')
training_model.summary()

training_model.fit(x, y, batch_size=2, epochs=10)

##### define the streaming-infernece model
streaming_in = Input(batch_shape=(1,1,feature_dim))  ## stateful ==> needs batch_shape specified
foo = GRU(4, return_sequences=False, stateful=True )(streaming_in)
streaming_pred = Dense(1)(foo)
streaming_model = keras.Model(inputs=streaming_in, outputs=streaming_pred)

streaming_model.compile(loss='mean_squared_error', optimizer='adam')
streaming_model.summary()

##### copy the weights from trained model to streaming-inference model
training_model.save_weights('weights.hd5', overwrite=True)
streaming_model.load_weights('weights.hd5')

##### demo the behaivor
print('\n\n******the streaming-inference model can replicate the sequence-based trained model:\n')
for s in range(num_seqs):
    print(f'\n\nRunning Sequence {s} with STATE RESET:\n')
    in_seq = x[s].reshape( (1, train_seq_length, feature_dim) )
    seq_pred = training_model.predict(in_seq)
    seq_pred = seq_pred.reshape(train_seq_length)
    for n in range(train_seq_length):
        in_feature_vector = x[s][n].reshape(1,1,feature_dim)
        single_pred = streaming_model.predict(in_feature_vector)[0][0]
        print(f'Seq-model Prediction, Streaming-Model Prediction, difference [{n}]: {seq_pred[n] : 3.2f}, {single_pred : 3.2f}, {seq_pred[n] - single_pred: 3.2f}')
    streaming_model.reset_states()

print('\n\n******streaming-inference state needs reset between sequences to replicate sequence-based trained model:\n')
for s in range(num_seqs):
    print(f'\n\nRunning Sequence {s} with NO STATE RESET:\n')
    in_seq = x[s].reshape( (1, train_seq_length, feature_dim) )
    seq_pred = training_model.predict(in_seq)
    seq_pred = seq_pred.reshape(train_seq_length)
    for n in range(train_seq_length):
        in_feature_vector = x[s][n].reshape(1,1,feature_dim)
        single_pred = streaming_model.predict(in_feature_vector)[0][0]
        print(f'Seq-model Prediction, Streaming-Model Prediction, difference [{n}]: {seq_pred[n] : 3.2f}, {single_pred : 3.2f}, {seq_pred[n] - single_pred: 3.2f}')
    #### NO STATE RESET HERE: streaming model will treat multiples sequences as one long sequence, 
    #### so after first sequence, the streaming output will differ, difference will decay with time from start up as effect of intial state fades

for s in range(2):
    N = np.random.randint(1, 10)
    print(f'\n\n******streaming-inference can work on an sequences of indefinite length -- running length {N}:\n')
    for n in range(N):
        x_sample =  np.random.randint(0, high=2, size = ( 1, 1, feature_dim) )
        x_sample = np.sign( x_sample - 0.5 )
        single_pred = streaming_model.predict(x_sample)[0][0]
        print(f'Streaming-Model Prediction[{n}]:  {single_pred : 3.2f}')
    streaming_model.reset_states()

@Dobiasd
Copy link
Owner

Dobiasd commented Sep 5, 2019

Thanks a lot. I'm still in the process of trying to understand it, but you seem to already do. Also, you write good code. Would you be interested in trying to implement it in frugally-deep in a PR?

I'd prefer the cleaner solution, you proposed. If I understand correctly, we would need to give up the const-ness of model::predict. But to still have this class thread-safe, we could use thread_local storage for the mutating state. What do you think?

@keithchugg
Copy link
Collaborator

keithchugg commented Sep 5, 2019

Regarding the code sample, the main point is that the training_model is trained using sequences and w/o statefulness, so each sequence is separate training sample and starts from the zero-state. The streaming model is the same as the training model, except: (i) the input shape is just one time sample (sequence length 1) and (ii) it is stateful, meaning that each call picks up from where the last left off. You need to use sequences of a fixed length to train, but you may want to run the trained model on a time series of indefinite length and the streaming model does that.

I am happy to help with adding this feature to the recurrent layers, but I am just starting to familiarize myself with frugally-deep and am a C++ novice -- I don't really get all of the headerless and lamba stuff. Despite you nice comment, I am not much of a programmer. BTW, if you have any good resources to get up to speed on the C++ approaches used in fd, please share (I am starting from ansi C and some C++).

I understand the recurrent layer functionality and math pretty well though, so can help on that front...

If the recurrent layer was a class in C++, then the feature could be implemented with (i) a private vector that is the state (in LSTMs, this is represented as two state vectors typically) (ii) a public function that would allow you to set/reset the state, and (iii) another private variable defining if the layer is stateful (ie., if the state is reset to 0 at each call or if it retains the previous state).

I think what you are saying is that your current implementation does not allow for state in model::predict (for thread-safe reasons?) and that you could accomplish the vanilla-C++ approach above using the thread_local storage method. I don't really understand any of that.. ;-) -- but yeah, if that accomplishes the same thing with the benefits you mention, great!

BTW, in keras, the reset_states() for a model resets the states of all recurrent layers and the state of a given layer can be set to a specific value using:

my_model.layers[i].reset_states() # sets to zero
my_model.layers[i].reset_states(my_desired_state) # sets to my_desired_state

@Dobiasd
Copy link
Owner

Dobiasd commented Sep 12, 2019

I am happy to help with adding this feature to the recurrent layers, but I am just starting to familiarize myself with frugally-deep and am a C++ novice -- I don't really get all of the headerless and lamba stuff. Despite you nice comment, I am not much of a programmer. BTW, if you have any good resources to get up to speed on the C++ approaches used in fd, please share (I am starting from ansi C and some C++).

Ah, OK, understood.
With "headerless" I guess you mean "header-only". :)
Regarding the usage of FunctionalPlus in frugally-deep's code, the readme file of it might be a good starting point.


If the recurrent layer was a class in C++, then the feature could be implemented with (i) a private vector that is the state (in LSTMs, this is represented as two state vectors typically) (ii) a public function that would allow you to set/reset the state, and (iii) another private variable defining if the layer is stateful (ie., if the state is reset to 0 at each call or if it retains the previous state).

BTW, in keras, the reset_states() for a model resets the states of all recurrent layers and the state of a given layer can be set to a specific value using:
my_model.layers[i].reset_states() # sets to zero
my_model.layers[i].reset_states(my_desired_state) # sets to my_desired_state

Yes, it is a class in C++.

Currently, the model class does not provide public access to the layers. Also, it only stores them as base-class (fdeep::layer) pointers, so it does not even know the type of any of the layers it consists of.

One way, that allows us to provide a supple way for the user of fdeep::model to set the LSTM's state, might be: We add a reset_states member function to fdeep::layer in general. For stateless layers (then everything except LSTM), this would just do nothing (or raise an exception if more appropriate). For the LSTM layer it would be implemented to do whatever it needs to do there. No new private is_stateful flag would be needed. The model itself would expose a new reset_states function, which might be better than exposing the layers itself completely.

Would fdeep::model::reset_states();, i.e. resetting all LSTM states in the model, suffice for your use-case?
Or would you need a way to only reset specifig layers, i.e., fdeep::model::reset_states(std::size_t layer_index);?
Or even fdeep::model::reset_states(std::size_t layer_index, const whatever_type& new_state); to set specific states instead of just empty everything?


I think what you are saying is that your current implementation does not allow for state in model::predict (for thread-safe reasons?) and that you could accomplish the vanilla-C++ approach above using the thread_local storage method. I don't really understand any of that.. ;-) -- but yeah, if that accomplishes the same thing with the benefits you mention, great!

Yeah, basically the approach with the state variables in the LSTM-layer class would become threadsafe when we simply declare them as thread_local. :)


Once we are clear about the questions from above, maybe we could share the work like the following:

  • I adjust the architecture, i.e., fdeep::model and fdeep::layer and add a reset_states interface.
  • You then could take care of adding the actual functionality to the LSTM layer. A pull request would be a good place for that. Adding test cases if possible (generate_test_models.py) would be nice too.
  • Once it's working (tests are green), we can take care of the thread-safety part.

@keithchugg
Copy link
Collaborator

keithchugg commented Sep 12, 2019

Thanks. I am ok to write the tests in generate_test_models.py and to do the detailed layer implementation as needed (I think I can figure that out).

I am starting to understand your code a little bit better -- thanks for the pointers! I don't think that just having a reset_states() functionality is what we need. Essentially the default (stateful=False) behavior in keras and your code is to reset_state() at the start of each call to model::predict. What we need is a method to remember_states between model.predict() calls. Having a reset_states() function in addition to that would be good too.

My understanding is that it is not simple to add a private variable to the LSTM/GRU layers that would capture the state because the model::predict is const, meaning that it cannot set any variables in the classes/subclasses.

However, following your lead, can we make the entire model have: (i) a stateful bool variable that would apply to all layers and (ii) a reset_states() that would apply to all layers? Having granularity to the states of each layer individually would be good, but it is not needed to cover the case that started this conversation and I think it is not needed for 99% of applications.

I see that your lstm_impl() function takes in initial states (for h, c) and has a return_state option. This should be enough to implement the desired functionality if state variables can be held in the model class. We should add consistent functionality to all recurrent layers (looks like you just have GRU and LSTM, but we could add simpleRNN). Below is pseudo code for what I mean. I am not clear on how the state and sequence are stored in the return tensor for LSTM when return_state is true, so this is not precise. The basic idea is to pass and return the state for each layer when either the model is stateful or if the layer has return_state = True, but pass the state to the next layer (or output) only if return_state = True. In this approach, there are state variables in the model class, but not in the layer classes.

Not sure that this accomplished a path to thread-safe implementation because the new predict function needs to update the state variable of model...

class model

private:

stateful_ // model property, initialize false
states_  // layer property initialize to all zeros
recurrent_ // layer property initialize to all false
return_states_internally_ // layer property initialize all false
return_states_keras_ // layer property initialize all false

inline model read_model()
	for layer in keras_layers_list:
	   if layer.name in ['GRU', 'LSTM']:
	   		recurrent_[later.idx] = True
		   	if layer.stateful:
	      		stateful_[layer.idx] = True
	      		return_states_internally_[layer.idx] = True
	   		if layer.return_states:
	   			return_states_internally_[layer.idx] = True
	   			return_states_keras_[layer.idx] = True
...


tensor5s original_predict(const tensor5s& inputs) const
	activation_signal = input
	new_states // same shape as states_
	for l in layers:
		if recurrent_[l]:
			out_sig = l.apply(activation_signal, state_[l])
		 	if return_states_keras_[l] == False:
		 		activation_signal, new_states[l] = split_outputs(out_sig)
		 	else:
		 		new_states[l] = get_state_from_outputs(out_sig)
		else:
			activation_signal = l.apply(activation_signal)
	return activation_signal, new_states

tensor5s predict(const tensor5s& inputs)
	if stateful_ == False:
		reset_states()
	y, states_ = original_predict(inputs)
	return y

public:

reset_states()
	zerofill(states_)

Below is some more detail about what is going on with the "stateful" setting in keras:


Think of a RNN (LSTM, GRU, etc) as black box that takes in one input time sample at time n: x[n]. The RNN is in state s[n] at time n. Then, the output at time n is y[n] and it is a function of both the state s[n] and input x[n]. The next state is also a function of these two variables, so one time-step of the RNN is:

y[n] = next_output(s[n], x[n])
s[n+1] = next_state(s[n], x[n])

In this sense, RNNs are always "stateful". However, the way keras and your GRU/LSTM code work, they take in a sequence of inputs {x[n]} for n=0...N-1. During one call with this input sequence, the steps are run by the above equations, but the initial state is set to 0. In keras this is stateful=False behavior.

Stateful=True behavior in keras is that the initial state is remembered from the last call. Suppose, you have a x[0]...x[199] and your RNN takes in length 100 sequences. If you make two calls to RNN.predict, first using x[0:100] and then x[100:200], then if keras stateful=True, the second call has initial state set to s[100] -- i.e., the final state from the first call. If stateful=False, then the second call starts with zero state value.

The use case I highlighted is just using a sequence length of 1 with stateful=True, which is useful in practice.

So, your current implementation effectively runs a "reset_states()" at the start of each call.


@Dobiasd
Copy link
Owner

Dobiasd commented Sep 13, 2019

I am starting to understand your code a little bit better -- thanks for the pointers! I don't think that just having a reset_states() functionality is what we need. Essentially the default (stateful=False) behavior in keras and your code is to reset_state() at the start of each call to model::predict. What we need is a method to remember_states between model.predict() calls. Having a reset_states() function in addition to that would be good too.

Yeah, I understand that. But that's an implementation detail of an individual layer class. The architecture is not converned about that, except:

My understanding is that it is not simple to add a private variable to the LSTM/GRU layers that would capture the state because the model::predict is const, meaning that it cannot set any variables in the classes/subclasses.

Yes, we would drop the const property of model::predict or declare the state-holding member variables as mutable.

However, following your lead, can we make the entire model have: (i) a stateful bool variable that would apply to all layers and (ii) a reset_states() that would apply to all layers?

I don't yet see why we need this bool. The base class layer could have a reset_states() function, and every layer is responsible on its own if they override it with an implementation that actually does something or not. model::reset_states() would just call layer::reset_states() for every layer it has, and only a few of those calls might actually do something. The details of what is happening (if something is happening) would be hidden in the classical OOP style.

Having granularity to the states of each layer individually would be good, but it is not needed to cover the case that started this conversation and I think it is not needed for 99% of applications.

Cool, having only model::reset_states() and no model::reset_states(std::size_t layer_index) or model::reset_states(std::size_t layer_index, const whatever_type& new_state) does make things simpler. :)

I am not clear on how the state and sequence are stored in the return tensor for LSTM when return_state is true

If return_state is true, it just returns additional tensors (see here, the same as Keras does.

Not sure that this accomplished a path to thread-safe implementation because the new predict function needs to update the state variable of model...

We just make this stateful private variables thread_local and we should be fine. :)

Stateful=True behavior in keras is that the initial state is remembered from the last call.
So, your current implementation effectively runs a "reset_states()" at the start of each call.

Yes, I understand now. :)


Maybe I'm missing something, but right now I don't see an actual problem/blocker here. I think we could simply start implementing it.

Suggestion: I open a new branch called stateful. I commit the needed changes in model and layer there and push them. You checkout this branch, start working on the internal layer stuff, and regularly push to a pull request, so I can review the code. What do you think?

@keithchugg
Copy link
Collaborator

I don't yet see why we need this bool. The base class layer could have a reset_states() function, and every layer is responsible on its own if they override it with an implementation that actually does something or not. model::reset_states() would just call layer::reset_states() for every layer it has, and only a few of those calls might actually do something. The details of what is happening (if something is happening) would be hidden in the classical OOP style.

Yes, this is the simplest. Just keep the states a private state variable in each (recurrent) layer and then if stateful=False, call reset_states(). reset_states() can also be available at the model level. This is perfect and the most simple and direct way to do it. I just did not understand how big of a deal it would be to have state in the layers.

We just make this stateful private variables thread_local and we should be fine. :)

great!

Suggestion: I open a new branch called stateful. I commit the needed changes in model and layer there and push them. You checkout this branch, start working on the internal layer stuff, and regularly push to a pull request, so I can review the code. What do you think?

This sounds good. The changes should be small in the layers -- we can do LSTM and GRU.
One question: you have initial states passed to the LSTM layer. I am not sure how these get set from the model object. It could be cleaner to have reset_states() take optional arguments that are the state values to be set.

@Dobiasd
Copy link
Owner

Dobiasd commented Sep 13, 2019

I just did not understand how big of a deal it would be to have state in the layers.
I think no big deal at all, if we manage to encapsulate it nicely

One question: you have initial states passed to the LSTM layer. I am not sure how these get set from the model object.

The same as Keras does. If an LSTM layer is invoked with only one input tensor, that means no initial states. If there are 3 input tensors, then the last two ones represent the initial state. This happens here in the code.

It could be cleaner to have reset_states() take optional arguments that are the state values to be set.

I thought "99% of applications" don't require that, and simply set to all zeros is enough. Or did I misunderstand that?


Maybe we just start with that simple approach and then see how it goes.

I suggest the following game plan.

Tobias:

  • Open feature branch stateful.
  • Add layer::reset_states() (dummy implementation that does nothing).
  • Add model::reset_states() (calls every reset_states on every layer, also recursively for model_layers).
  • Call model::reset_states() after loading the model (to remove test-case state).
  • Allow stateful == true in convert_model.py
  • Pass stateful flag in import_model.hpp to concerned layer constructors.
  • Add example or overriding this function to lstm_layer and gru_layer.
  • Add some todo comments to the layer.
  • Commit and push.

Keith:

  • Pull/checkout stateful branch
  • Add new model test_model_lstm_stateful to generate_test_models.py.
  • Fork, commit, push, and open (WIP) pull request (so Tobias can review).
  • Add the new test model to CMakeLists.txt.
  • Add a new cpp file for it in the tests directory.
  • Add the new test model to applications_performance.cpp (for possible local debugging).
  • Implement statefulness in the LSTM layer.
    • Use a mutable private member variable for this, so we can leave model::predict const for now, i.e. not break the existing API.
  • Commit and push not only when completely done but also regularly in between.

@Dobiasd
Copy link
Owner

Dobiasd commented Sep 13, 2019

OK, done. I just created the new branch and pushed a commit with the architectural skeleton.

Let me know if you feel blocked at some point or something is missing, etc. 🙂

@keithchugg
Copy link
Collaborator

OK. Thanks. I will start on it...

@keithchugg
Copy link
Collaborator

It could be cleaner to have reset_states() take optional arguments that are the state values to be set.

I thought "99% of applications" don't require that, and simply setting to all zeros is enough. Or did I misunderstand that?

Yes, I don't have a use to set the states to a specific value, I just though that this functionality and reset_states() are so similar, you may want to combine.

The plan you posted is good. I may be a little slow on the uptake, but I will work on it and let you know...

@Dobiasd
Copy link
Owner

Dobiasd commented Sep 14, 2019

Yes, I don't have a use to set the states to a specific value, I just though that this functionality and reset_states() are so similar, you may want to combine.

Similar, but, maybe surprisingly, way more complex. Having something like model::reset_states(std::size_t layer_index, const tensor5s& new_state); is only straight forward to simple sequential models, because of two cases:

  • Models can be nested.
  • In one model (outer or one of the nested ones), the computational graph can have parallel branches.

Currently, it's not clear to me how we could come up with meaningful indexing. We might need to expose the whole model architecture (graph + nesting) to the public API so that all layers can be traversed or found by some other means in order to set the state of a particular layer. If possible, I'd like to avoid increasing the "surface" of the library that drastically.

@Dobiasd
Copy link
Owner

Dobiasd commented Oct 29, 2019

Thanks to the awesome work of @keithchugg, stateful models are now fully supported with the latest release, i.e. v0.10.0-p0. 🎉

@marco-monforte
Copy link

Hey guys! I've been finally able to test the stateful LSTM in my C++ code and it works amazingly! Thank you very much! Very appreciated 😃

@marco-monforte
Copy link

marco-monforte commented Oct 15, 2020

Hey @Dobiasd ! I'm sorry to come back again on this issue, but I'm working on my old code with LSTMs after long time and I'm not able to adjust these two lines of code to the new tensors definitions:

deep::tensor5s result = model_stateful->predict_stateful({fdeep::tensor5(fdeep::shape5(1, 1, 1, 1, 3), {inputX, inputY, inputZ})});
std::vector<float> vec = *result.front().as_vector();

Could you help me?

@Dobiasd
Copy link
Owner

Dobiasd commented Oct 15, 2020

I'm afk right now. What error message do you get?
Have you tried removing the *?

@marco-monforte
Copy link

I'm getting that 'tensor5' in the first line is not a member of fdeep. The issues apparently are there, in the first line. I don't know if I should use 'tensor' or 'tensors' as input, and especially how to substitute now shape5, which doesn't exists anymore in the library.

Removing the * doesn't help

@Dobiasd
Copy link
Owner

Dobiasd commented Oct 15, 2020

@marco-monforte
Copy link

Thanks! Apparently, I've been able to compile by doing the following:

fdeep::tensor t(fdeep::tensor_shape(1,1,1,1,3),0);
t.set(fdeep::tensor_pos(0,0,0,0,0), x);
t.set(fdeep::tensor_pos(0,0,0,0,1), y);
t.set(fdeep::tensor_pos(0,0,0,0,2), z);
auto result = model_stateful->predict_stateful({t});
std::vector<float> vec = result[0].to_vector();

but then I get this error at runtime:

terminate called after throwing an instance of 'std::runtime_error'
  what():  Invalid inputs shape.
The model takes [2(Nothing, Just 3)] but provided was: [5(1, 1, 1, 1, 3)]

I don't get the meaning of [2(Nothing, Just 3)], while the second pair of brackets should be right, given my t tensor.

@Dobiasd
Copy link
Owner

Dobiasd commented Oct 15, 2020

Since your problem is not related to LSTM layers, I guess this is not the right place to discuss this. Also, since it might spam the other participants of this thread with notifications, that are of no interest to them.

Basically, you're not providing the right input shape for your model. Try to only provide sizes for the dimensions, that are actually used, i.e., the last two.

In case of further problems, please open a separate issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

10 participants