New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensorflow 1 and 2 Step Savers, And Base Classes #5
Changes from 3 commits
bb20f6b
6c943dd
8c06467
fd5b1fa
d5ccc5d
d623bdb
8337d09
536b4dd
84680d6
0efda6b
bb41bee
a98b150
c6baca8
333945d
d68b876
1d72294
16846bf
d562cc9
2333f12
6d4f34e
50bcd7d
40f156d
ffbf749
1c68f89
7824c5d
7d26a91
7db0771
4864bb9
d41e46e
1cbbdd1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,7 +61,7 @@ def __init__( | |
self.variable_scope = variable_scope | ||
alexbrillant marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.has_expected_outputs = has_expected_outputs | ||
self.create_feed_dict = create_feed_dict | ||
self.loss = [] | ||
self.losses = [] | ||
self.print_loss = print_loss | ||
if print_func is None: | ||
print_func = print | ||
|
@@ -147,9 +147,9 @@ def fit_model(self, data_inputs, expected_outputs=None) -> BaseStep: | |
feed_dict.update(additional_feed_dict_arguments) | ||
|
||
results = self.session.run([self['optimizer'], self['loss']], feed_dict=feed_dict) | ||
self.loss.append(results[1]) | ||
self.losses.append(results[1]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. srry for the double comment here. I think this could even be |
||
if self.print_loss: | ||
self.print_func(self.loss[-1]) | ||
self.print_func(self.losses[-1]) | ||
|
||
return self | ||
|
||
|
@@ -164,9 +164,7 @@ def transform_model(self, data_inputs): | |
:param data_inputs: | ||
:return: | ||
""" | ||
inference_output_name = 'output' | ||
if len(self['inference_output'].get_shape().as_list()) > 0: | ||
inference_output_name = 'inference_output' | ||
inference_output_name = self._get_inference_output_name() | ||
|
||
feed_dict = { | ||
self['data_inputs']: data_inputs | ||
|
@@ -176,6 +174,19 @@ def transform_model(self, data_inputs): | |
|
||
return results | ||
|
||
def _get_inference_output_name(self): | ||
""" | ||
Return the output tensor name for inference (transform). | ||
In create_graph, the user can return a tuple of two elements : the output tensor for training, and the output tensor for inference. | ||
|
||
:return: | ||
""" | ||
inference_output_name = 'output' | ||
if len(self['inference_output'].get_shape().as_list()) > 0: | ||
inference_output_name = 'inference_output' | ||
|
||
return inference_output_name | ||
|
||
def __getitem__(self, item): | ||
alexbrillant marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Get a graph tensor by name using get item. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
Neuraxle | ||
neuraxle>=0.3.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was indentation lost during auto-format? where should this comment go?