Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorflow 1 and 2 Step Savers, And Base Classes #5

Merged
merged 30 commits into from
Jan 16, 2020
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
bb20f6b
Create Setup.py, And Folder Structure
alexbrillant Dec 2, 2019
6c943dd
Add TensorflowModelWrapperMixin
alexbrillant Dec 2, 2019
8c06467
Move Usage Example Link In ReadMe
alexbrillant Dec 2, 2019
fd5b1fa
Add Docstrings To Tensorflow Model V1 Wrapper Mixin
alexbrillant Dec 2, 2019
d5ccc5d
Add Tensorflow V2 Step Saver
alexbrillant Dec 30, 2019
d623bdb
Add New Line In file Headers before copyright notice
alexbrillant Dec 30, 2019
8337d09
Simplify Tensorflow 1 Model And Saver Wip
alexbrillant Dec 30, 2019
536b4dd
Simplify Tensorflow 2 Models, And Saver
alexbrillant Dec 31, 2019
84680d6
Simplify Tensorflow 2 Models, And Saver
alexbrillant Dec 31, 2019
0efda6b
Add Docstrings To Base Tensorflow Steps
alexbrillant Dec 31, 2019
bb41bee
Fix Tensorflow 1 And Tensorflow 2 Test / Examples
alexbrillant Dec 31, 2019
a98b150
Apply Review Comments #5
alexbrillant Jan 6, 2020
c6baca8
Apply Review Comments Pr #5
alexbrillant Jan 6, 2020
333945d
Apply Review Comments #5
alexbrillant Jan 6, 2020
d68b876
Fix README
alexbrillant Jan 6, 2020
1d72294
Fix Tensorflow V1 Model Graph Argument And Loss
alexbrillant Jan 7, 2020
16846bf
Add the possibility to return 2 outputs on the create graph function …
alexbrillant Jan 8, 2020
d562cc9
Save Loss array in Tensorflow v1 Model step, And Add Print Loss func
alexbrillant Jan 11, 2020
2333f12
Apply Review Comments #5
alexbrillant Jan 13, 2020
6d4f34e
Apply Review Comments #5
alexbrillant Jan 13, 2020
50bcd7d
Fix tensorflow v1 create graph return in test
alexbrillant Jan 13, 2020
40f156d
Add create_inputs to tensorflow models, and add data_inputs, expected…
alexbrillant Jan 14, 2020
ffbf749
Add Train Loss, And Test Loss To BaseTensorflow Model
alexbrillant Jan 14, 2020
1c68f89
Remove Dead Code
alexbrillant Jan 14, 2020
7824c5d
Fix Test Losses
alexbrillant Jan 14, 2020
7d26a91
Fix Transform Data Container For Tensorflow 2
alexbrillant Jan 14, 2020
7db0771
Fix Tensorflow 2 Transform
alexbrillant Jan 14, 2020
4864bb9
Fix Transform Data Container In Tensorflow 2
alexbrillant Jan 14, 2020
d41e46e
Fit Transform Data Container In Tensorflow 2
alexbrillant Jan 14, 2020
1cbbdd1
Add device support in tensorflow 2 (GPU, CPU, etc.)
alexbrillant Jan 15, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,6 @@ venv.bak/

# mypy
.mypy_cache/

# custom
.idea/*
74 changes: 73 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,74 @@
# Neuraxle-TensorFlow
TensorFlow steps, savers, and utilities for Neuraxle.

guillaume-chevalier marked this conversation as resolved.
Show resolved Hide resolved
TensorFlow steps, savers, and utilities for [Neuraxle](https://github.com/Neuraxio/Neuraxle).

Neuraxle is a Machine Learning (ML) library for building neat pipelines, providing the right abstractions to both ease research, development, and deployment of your ML applications.

## Usage example

[See also a complete example](https://github.com/Neuraxio/LSTM-Human-Activity-Recognition/blob/neuraxle-refactor/steps/lstm_rnn_tensorflow_model_wrapper.py)

### Tensorflow 1

Create a tensorflow 1 model step by giving it a graph, an optimizer, and a loss function.

```python
def create_graph(step: TensorflowV1ModelStep):
tf.placeholder('float', name='data_inputs')
tf.placeholder('float', name='expected_outputs')

tf.Variable(np.random.rand(), name='weight')
tf.Variable(np.random.rand(), name='bias')

return tf.add(tf.multiply(step['data_inputs'], step['weight']), step['bias'])

"""
# Note: you can also return a tuple containing two elements : tensor for training (fit), tensor for inference (transform)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was indentation lost during auto-format? where should this comment go?

def create_graph(step: TensorflowV1ModelStep)
# ...
decoder_outputs_training = create_training_decoder(step, encoder_state, decoder_cell)
decoder_outputs_inference = create_inference_decoder(step, encoder_state, decoder_cell)

return decoder_outputs_training, decoder_outputs_inference
"""


def create_loss(step: TensorflowV1ModelStep):
return tf.reduce_sum(tf.pow(step['output'] - step['expected_outputs'], 2)) / (2 * N_SAMPLES)

def create_optimizer(step: TensorflowV1ModelStep):
return tf.train.GradientDescentOptimizer(step.hyperparams['learning_rate'])

model_step = TensorflowV1ModelStep(
create_grah=create_graph,
create_loss=create_loss,
create_optimizer=create_optimizer,
has_expected_outputs=True
).set_hyperparams(HyperparameterSamples({
'learning_rate': 0.01
})).set_hyperparams_space(HyperparameterSpace({
'learning_rate': LogUniform(0.0001, 0.01)
}))
```

### Tensorflow 2

Create a tensorflow 2 model step by giving it a model, an optimizer, and a loss function.

```python
def create_model(step: Tensorflow2ModelStep):
return LinearModel()

def create_optimizer(step: Tensorflow2ModelStep):
return tf.keras.optimizers.Adam(0.1)

def create_loss(step: Tensorflow2ModelStep, expected_outputs, predicted_outputs):
return tf.reduce_mean(tf.abs(predicted_outputs - expected_outputs))

model_step = Tensorflow2ModelStep(
create_model=create_model,
create_optimizer=create_optimizer,
create_loss=create_loss,
tf_model_checkpoint_folder=os.path.join(tmpdir, 'tf_checkpoints')
)
```
1 change: 1 addition & 0 deletions neuraxle_tensorflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.0.1"
17 changes: 17 additions & 0 deletions neuraxle_tensorflow/tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from neuraxle.base import BaseStep


class BaseTensorflowModelStep(BaseStep):
def __init__(self, create_model, create_loss, create_optimizer, step_saver):
self.create_model = create_model
self.create_loss = create_loss
self.create_optimizer = create_optimizer

self.set_hyperparams(self.__class__.HYPERPARAMS)
self.set_hyperparams_space(self.__class__.HYPERPARAMS_SPACE)

BaseStep.__init__(
self,
savers=[step_saver],
hyperparams=self.HYPERPARAMS
)
289 changes: 289 additions & 0 deletions neuraxle_tensorflow/tensorflow_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
"""
Neuraxle Tensorflow V1 Utility classes
=========================================
Neuraxle utility classes for tensorflow v1.

..
Copyright 2019, Neuraxio Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

"""
import os

import tensorflow as tf
from neuraxle.base import BaseSaver, BaseStep, ExecutionContext
from neuraxle.hyperparams.space import HyperparameterSamples, HyperparameterSpace

from neuraxle_tensorflow.tensorflow import BaseTensorflowModelStep


class TensorflowV1ModelStep(BaseTensorflowModelStep):
"""
Base class for tensorflow 1 steps.
It uses :class:`TensorflowV1StepSaver` for saving the model.

.. seealso::
`Using the saved model format <https://www.tensorflow.org/guide/checkpoint>`_,
:class:`~neuraxle.base.BaseStep`
"""
HYPERPARAMS = HyperparameterSamples({})
HYPERPARAMS_SPACE = HyperparameterSpace({})

def __init__(
self,
create_graph,
create_loss,
create_optimizer,
create_feed_dict=None,
variable_scope=None,
has_expected_outputs=True,
print_loss=False,
print_func=None
):
BaseTensorflowModelStep.__init__(
self,
create_model=create_graph,
create_loss=create_loss,
create_optimizer=create_optimizer,
step_saver=TensorflowV1StepSaver()
)

if variable_scope is None:
variable_scope = self.name
self.variable_scope = variable_scope
alexbrillant marked this conversation as resolved.
Show resolved Hide resolved
self.has_expected_outputs = has_expected_outputs
self.create_feed_dict = create_feed_dict
self.losses = []
self.print_loss = print_loss
if print_func is None:
print_func = print
self.print_func = print_func

def setup(self) -> BaseStep:
"""
Setup tensorflow 1 graph, and session using a variable scope.

:return: self
:rtype: BaseStep
"""
if self.is_initialized:
return self

self.graph = tf.Graph()
with self.graph.as_default():
with tf.variable_scope(self.variable_scope, reuse=tf.AUTO_REUSE):
self.session = tf.Session(config=tf.ConfigProto(log_device_placement=True), graph=self.graph)

model = self.create_model(self)
if not isinstance(model, tuple):
tf.identity(model, name='output')
else:
tf.identity(model[0], name='output')
tf.identity(model[1], name='inference_output')
alexbrillant marked this conversation as resolved.
Show resolved Hide resolved

tf.identity(self.create_loss(self), name='loss')
self.create_optimizer(self).minimize(self['loss'], name='optimizer')

init = tf.global_variables_initializer()
self.session.run(init)
self.is_initialized = True

def teardown(self) -> BaseStep:
"""
Close session on teardown.

:return:
"""
if self.session is not None:
self.session.close()
self.is_initialized = False

return self

def strip(self):
"""
Strip tensorflow 1 properties from to step to make the step serializable.

:return: stripped step
:rtype: BaseStep
"""
self.graph = None
self.session = None

return self

def fit(self, data_inputs, expected_outputs=None) -> 'BaseStep':
with tf.variable_scope(self.variable_scope, reuse=tf.AUTO_REUSE):
return self.fit_model(data_inputs, expected_outputs)

def fit_model(self, data_inputs, expected_outputs=None) -> BaseStep:
"""
Fit tensorflow model using the variable scope.

:param data_inputs: data inputs
:param expected_outputs: expected outputs to fit on
:return: fitted self
:rtype: BaseStep
"""
feed_dict = {
self['data_inputs']: data_inputs
}

if self.has_expected_outputs:
feed_dict.update({
self['expected_outputs']: expected_outputs
})

if self.create_feed_dict is not None:
additional_feed_dict_arguments = self.create_feed_dict(self, data_inputs, expected_outputs)
alexbrillant marked this conversation as resolved.
Show resolved Hide resolved
feed_dict.update(additional_feed_dict_arguments)

results = self.session.run([self['optimizer'], self['loss']], feed_dict=feed_dict)
self.losses.append(results[1])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

srry for the double comment here. I think this could even be train_losses or test_losses. Or have two variables for those losses if we compute both at some point. Just saying. It's okay to not rename for now, although in the doc this will need to be clear when doing a documentation pass.

if self.print_loss:
self.print_func(self.losses[-1])

return self

def transform(self, data_inputs, expected_outputs=None) -> 'BaseStep':
with tf.variable_scope(self.variable_scope, reuse=tf.AUTO_REUSE):
return self.transform_model(data_inputs)

def transform_model(self, data_inputs):
"""
Transform tensorflow model using the variable scope.

:param data_inputs:
:return:
"""
inference_output_name = self._get_inference_output_name()

feed_dict = {
self['data_inputs']: data_inputs
}

results = self.session.run(self[inference_output_name], feed_dict=feed_dict)

return results

def _get_inference_output_name(self):
"""
Return the output tensor name for inference (transform).
In create_graph, the user can return a tuple of two elements : the output tensor for training, and the output tensor for inference.

:return:
"""
inference_output_name = 'output'
if len(self['inference_output'].get_shape().as_list()) > 0:
inference_output_name = 'inference_output'

return inference_output_name

def __getitem__(self, item):
alexbrillant marked this conversation as resolved.
Show resolved Hide resolved
"""
Get a graph tensor by name using get item.

:param item: tensor name
:type item: str

:return: tensor
:rtype: tf.Tensor
"""
if ":" in item:
split = item.split(":")
tensor_name = split[0]
device = split[1]
alexbrillant marked this conversation as resolved.
Show resolved Hide resolved
else:
tensor_name = item
device = "0"

try:
result = self.graph.get_tensor_by_name("{0}/{1}:{2}".format(self.variable_scope, tensor_name, device))
except KeyError:
result = None

if result is None:
try:
result = self.graph.get_operation_by_name("{0}/{1}".format(self.variable_scope, tensor_name))
except KeyError:
result = tf.get_variable(tensor_name, [])

return result


class TensorflowV1StepSaver(BaseSaver):
"""
Step saver for a tensorflow Session using tf.train.Saver().
It saves, or restores the tf.Session() checkpoint at the context path using the step name as file name.

.. seealso::
`Using the saved model format <https://www.tensorflow.org/guide/saved_model>`_,
:class:`~neuraxle.base.BaseSaver`
"""

def save_step(self, step: 'TensorflowV1ModelStep', context: 'ExecutionContext') -> 'BaseStep':
"""
Save a step that is using tf.train.Saver().
:param step: step to save
:type step: BaseStep
:param context: execution context to save from
:type context: ExecutionContext
:return: saved step
"""
with step.graph.as_default():
saver = tf.train.Saver()
saver.save(step.session, self._get_saved_model_path(context, step))
step.strip()

return step

def load_step(self, step: 'TensorflowV1ModelStep', context: 'ExecutionContext') -> 'BaseStep':
"""
Load a step that is using tensorflow using tf.train.Saver().
:param step: step to load
:type step: BaseStep
:param context: execution context to load from
:type context: ExecutionContext
:return: loaded step
"""
step.is_initialized = False
step.setup()

with step.graph.as_default():
saver = tf.train.Saver()
saver.restore(step.session, self._get_saved_model_path(context, step))

return step

def can_load(self, step: 'TensorflowV1ModelStep', context: 'ExecutionContext'):
"""
Returns whether or not we can load.
:param step: step to load
:type step: BaseStep
:param context: execution context to load from
:type context: ExecutionContext
:return: loaded step
"""
meta_exists = os.path.exists(os.path.join(context.get_path(), "{0}.ckpt.meta".format(step.get_name())))
index_exists = os.path.exists(os.path.join(context.get_path(), "{0}.ckpt.index".format(step.get_name())))

return meta_exists and index_exists

def _get_saved_model_path(self, context: ExecutionContext, step: BaseStep):
"""
Returns the saved model path using the given execution context, and step name.
:param step: step to load
:type step: BaseStep
:param context: execution context to load from
:type context: ExecutionContext
:return: loaded step
"""
return os.path.join(context.get_path(), "{0}.ckpt".format(step.get_name()))
Loading