Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neural Graphs #520

Merged
merged 106 commits into from
May 5, 2020
Merged
Show file tree
Hide file tree
Changes from 100 commits
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
de587bd
Manual rebase of dev-module-composition on top of current master
tkornuta-nvidia Mar 11, 2020
15b3b11
Adding neural graph examples
tkornuta-nvidia Mar 24, 2020
2756a66
formatting fixes
tkornuta-nvidia Mar 24, 2020
5bad9c9
new proposal
blisc Apr 1, 2020
e26a906
Merge branch 'master' of github.com:NVIDIA/NeMo into fix-sign
tkornuta-nvidia Apr 6, 2020
85bb729
Merge branch 'fix-sign' of github.com:NVIDIA/NeMo into fix-sign
tkornuta-nvidia Apr 6, 2020
0576ce7
format fix of Jason's example
tkornuta-nvidia Apr 6, 2020
3f1ac01
Added modules recording, with retrieval of module name
tkornuta-nvidia Apr 7, 2020
a57b34d
keyerror comment
tkornuta-nvidia Apr 7, 2020
fc3deac
Introduced name property of modules
tkornuta-nvidia Apr 7, 2020
5221fef
implemented NameRegistry basic functionality, test5 working, unit tes…
tkornuta-nvidia Apr 7, 2020
dd9af75
example 5 working, ObjectRegistry finished, app_state moved to utils
tkornuta-nvidia Apr 8, 2020
ab35621
test app_state fixed
tkornuta-nvidia Apr 8, 2020
bbd3589
TokenClassification self.name fix
tkornuta-nvidia Apr 8, 2020
56efff5
ObjectRegistry tests
tkornuta-nvidia Apr 8, 2020
c838f13
cleanup of neural graphs class, NG manager inheriting from ObjectRegi…
tkornuta-nvidia Apr 8, 2020
1153001
fixes to graph and manager, NeMo tests passing (locally)
tkornuta-nvidia Apr 8, 2020
9df6568
code formatting
tkornuta-nvidia Apr 8, 2020
868ec8a
turning examples into unit/integration tests, small fixes here and th…
tkornuta-nvidia Apr 9, 2020
51e852c
Graph and module nesting - operation mode injected, implemented funct…
tkornuta-nvidia Apr 9, 2020
e2ed33a
formatting fix
tkornuta-nvidia Apr 9, 2020
069dd10
merged with master
tkornuta-nvidia Apr 9, 2020
a054311
cleanup of output port logic in __call__, other cleanups, LGTM fixes
tkornuta-nvidia Apr 9, 2020
c0e5f0b
Further cleanup of neural_modules.py
tkornuta-nvidia Apr 9, 2020
98cfe43
LGTM cleanups
tkornuta-nvidia Apr 9, 2020
d109b5c
style fix:]
tkornuta-nvidia Apr 9, 2020
6487a65
micro cleanup
tkornuta-nvidia Apr 9, 2020
6c16c6e
micro cleanup2
tkornuta-nvidia Apr 9, 2020
b7d3f1c
punctuation module name fix
tkornuta-nvidia Apr 9, 2020
5548495
Working on NmTensor extensions
tkornuta-nvidia Apr 10, 2020
c835279
style fix
tkornuta-nvidia Apr 10, 2020
70320fe
Cleaned up the call output tensor logic in neural module and graph cl…
tkornuta-nvidia Apr 10, 2020
c9b2585
NmTensor test cleanup
tkornuta-nvidia Apr 10, 2020
adeeccc
style fix
tkornuta-nvidia Apr 10, 2020
127307b
LGTM cleanups
tkornuta-nvidia Apr 10, 2020
7b52fbe
style fixes
tkornuta-nvidia Apr 10, 2020
81b87ea
BoundOutputs class and unit tests
tkornuta-nvidia Apr 10, 2020
fbfc87e
style fixes
tkornuta-nvidia Apr 10, 2020
41ee935
manual output port binding operational
tkornuta-nvidia Apr 10, 2020
bce7e90
manual output port binding operational
tkornuta-nvidia Apr 10, 2020
752b810
Recording and updating all modules during connecting to a graph
tkornuta-nvidia Apr 14, 2020
f68c883
work in progress on output port binding, not operations
tkornuta-nvidia Apr 15, 2020
9c3f78b
bound output test fix
tkornuta-nvidia Apr 15, 2020
b8bf040
preparing ground for tensor copy
tkornuta-nvidia Apr 15, 2020
691e90b
Rewritten graph nesting - works by executing inner graph modules' cal…
tkornuta-nvidia Apr 18, 2020
ec2bfe6
fixed tests, formatted the code
tkornuta-nvidia Apr 18, 2020
cb250da
refactoring and cleanup of neural graphs
tkornuta-nvidia Apr 21, 2020
6907cd7
Refactoring, cleanups of NeuralGraphs nesting/binding
tkornuta-nvidia Apr 21, 2020
cb69367
reformating fix
tkornuta-nvidia Apr 21, 2020
8389f21
Fixes, unit tests for binding of nested graphs, automatic and manual,…
tkornuta-nvidia Apr 22, 2020
9235ca8
Added several tests for graph nesting, covered (and fixed) case of pa…
tkornuta-nvidia Apr 22, 2020
59b1d75
final tests for graph nesting
tkornuta-nvidia Apr 22, 2020
fc9d995
style fix
tkornuta-nvidia Apr 22, 2020
ece11c3
refactored the neural module serialization, made it more modular, so …
tkornuta-nvidia Apr 23, 2020
cd703de
style fix
tkornuta-nvidia Apr 23, 2020
465979d
graph serialization 80%
tkornuta-nvidia Apr 23, 2020
ba46466
graph serialization operational
tkornuta-nvidia Apr 23, 2020
fa1e4b3
work on deserialization, working up to the execution of graph
tkornuta-nvidia Apr 24, 2020
b810243
format fix
tkornuta-nvidia Apr 24, 2020
9f46e20
serialization and deserialization of JASPER (processor, encoder and d…
tkornuta-nvidia Apr 27, 2020
77ad78d
Merge remote-tracking branch 'origin/master' into merge
tkornuta-nvidia Apr 27, 2020
7b5b9d4
merged with master and reformatted
tkornuta-nvidia Apr 27, 2020
26a7621
Updated EN and ZH:] version of documentation for configuration/custom…
tkornuta-nvidia Apr 28, 2020
314cf10
Unification of singleton meta class
tkornuta-nvidia Apr 28, 2020
7f0ee86
Added to the train() function signatures
tkornuta-nvidia Apr 28, 2020
4076808
formatting fix
tkornuta-nvidia Apr 28, 2020
41bb68f
LGTM fixes
tkornuta-nvidia Apr 28, 2020
cb56b35
line numbers
tkornuta-nvidia Apr 28, 2020
5ff3896
Tweak of deserialize - moved name and overwrite_params
tkornuta-nvidia Apr 28, 2020
c955a7a
style fix
tkornuta-nvidia Apr 28, 2020
3b47bd9
Commeting the condition for both None: eval() seems to be calling :]
tkornuta-nvidia Apr 28, 2020
fd88c70
LGTM fix
tkornuta-nvidia Apr 28, 2020
98b03cf
method comment
tkornuta-nvidia Apr 28, 2020
18077fc
reorganization of directories in unit/core/utils
tkornuta-nvidia Apr 28, 2020
a7a04e3
Fix of NM init_params collection, made it much more robust and faster…
tkornuta-nvidia Apr 28, 2020
a932c8a
Renamed NG integration tests
tkornuta-nvidia Apr 28, 2020
7c57cdf
formatting
tkornuta-nvidia Apr 28, 2020
6d785f1
additional tests
tkornuta-nvidia Apr 28, 2020
5633963
style fix
tkornuta-nvidia Apr 28, 2020
8f6ddc5
A simple test for graph config import and export
tkornuta-nvidia Apr 28, 2020
1977db0
import/export/serialization/deserialization tests cnt, minor fixes he…
tkornuta-nvidia Apr 29, 2020
bd6c12e
minor cleanup, unit test: a simple graph with a loop (not passing)
tkornuta-nvidia Apr 29, 2020
7463dd8
style fix
tkornuta-nvidia Apr 29, 2020
17e4dd9
Refactored the whole solution, switching from module_name:port_name t…
tkornuta-nvidia Apr 29, 2020
7b94c57
format fix
tkornuta-nvidia Apr 29, 2020
e5917fe
Added neural type export to config file for both graph bound input an…
tkornuta-nvidia Apr 29, 2020
35b31c9
jasper polished
tkornuta-nvidia Apr 29, 2020
5375055
PR polish, added type hints to all classes, added graph summary
tkornuta-nvidia May 1, 2020
9f6ee98
removed 'name' in for in actions.py
tkornuta-nvidia May 1, 2020
45dacd3
NG-related integration tests
tkornuta-nvidia May 1, 2020
356e0c2
LGTM fix
tkornuta-nvidia May 1, 2020
cd1e23e
bind description updated
tkornuta-nvidia May 1, 2020
37ba923
removed app_state from jasper
tkornuta-nvidia May 1, 2020
e28979e
graph input_ports and output_ports are now immutable, added some unit…
tkornuta-nvidia May 1, 2020
693b744
removed a line - formatting fix :]
tkornuta-nvidia May 2, 2020
413c0fa
name fix in test
tkornuta-nvidia May 2, 2020
ccc1b21
Minor touches here and there
tkornuta-nvidia May 2, 2020
a191dbe
Unique names generates from names of classes
tkornuta-nvidia May 2, 2020
68bd7e2
moved graph managed to utils, removed it from init
tkornuta-nvidia May 4, 2020
bc216db
Moved neural_graph_manager to utils that enable to remove it (along w…
tkornuta-nvidia May 4, 2020
c54156c
removed graph examples
tkornuta-nvidia May 4, 2020
cf72cf8
Updated docstring in GraphInputs
tkornuta-nvidia May 4, 2020
327cc6d
Moved all NeuralGraph helper classes to nemo.utils.neural_graph. AppS…
tkornuta-nvidia May 4, 2020
831d269
inference -> evaluation, as requested
tkornuta-nvidia May 5, 2020
d2615e1
fix
tkornuta-nvidia May 5, 2020
da91ec0
changelog updated
tkornuta-nvidia May 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@

.. literalinclude:: ../../../../../examples/start_here/module_configuration.py
:language: python
:lines: 25-35
:lines: 24-34

现在我们可以导出任何一个已有模块的配置,调用 :meth:`export_to_config()`, 例如 \
我们可以导出 :class:`TaylorNet` 的配置,通过调用:

.. literalinclude:: ../../../../../examples/start_here/module_configuration.py
:language: python
:lines: 38
:lines: 37

导入配置
---------
Expand All @@ -37,7 +37,7 @@

.. literalinclude:: ../../../../../examples/start_here/module_configuration.py
:language: python
:lines: 41
:lines: 40

.. note::
:meth:`import_from_config()` 函数事实上是创建了在配置中的这个类的一个新的实例 \
Expand All @@ -49,7 +49,7 @@

.. literalinclude:: ../../../../../examples/start_here/module_configuration.py
:language: python
:lines: 43-
:lines: 42-


.. include:: module_custom_configuration.rst
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,21 @@

.. literalinclude:: ../../../../../examples/start_here/module_custom_configuration.py
:language: python
:lines: 33-35
:lines: 28-30

现在让我们定义 :class:`CustomTaylorNet` 神经模块类:

.. literalinclude:: ../../../../../examples/start_here/module_custom_configuration.py
:language: python
:lines: 38-43
:lines: 33-38


为了能处理好 :class:`Status` enum 的导出功能,我们必须实现自定义函数 \
:meth:`export_to_config()`:

.. literalinclude:: ../../../../../examples/start_here/module_custom_configuration.py
:language: python
:lines: 45-76
:lines: 40-61


注意配置实际上是一个字典,包含了两个部分:
Expand All @@ -40,35 +40,28 @@
这些参数存在保护域 ``self._init_params`` 中,它的基类是 :class:`NeuralModule` 类。
确保用户不能直接访问和使用它们。

类似地,我们必须重载方法 :meth:`import_from_config()` :
类似地,我们必须重载方法 :meth:`_deserialize_configuration()` :

.. literalinclude:: ../../../../../examples/start_here/module_custom_configuration.py
:language: python
:lines: 79-119

请注意,基类 :class:`NeuralModule` 提供了一些保护方法供我们使用, \
其中,最重要的是:

* :meth:`_create_config_header()` 生成合适的 header, 以及 \
* :meth:`_validate_config_file()` 验证加载的配置文件 (检查 header 内容)。

:lines: 63-86

.. note::
再强调一下 :meth:`import_from_config()` 是类的方法,实际上返回 \
再强调一下 :meth:`_deserialize_configuration()` 是类的方法,实际上返回 \
一个新的对象实例 - 在这个例子中就是 :class:`CustomTaylorNet` 类型。


现在我们可以简单的构建一个实例,并且导出它的配置,通过调用:

.. literalinclude:: ../../../../../examples/start_here/module_custom_configuration.py
:language: python
:lines: 128-129,134-135
:lines: 95-96,101-102

通过加载这个配置,初始化第二个实例:

.. literalinclude:: ../../../../../examples/start_here/module_custom_configuration.py
:language: python
:lines: 137-139
:lines: 104-106

从结果中我们可以看到新的对象把状态都设置成了原来那个对象的值:

Expand Down
10 changes: 5 additions & 5 deletions docs/sources/source/tutorials/module_configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@ In the following example we will once again train a model to learn Taylor's coef
However, we will extend the example by showing how to export configuration of the module to a YAML file and \
create a second instance having the same set of parameters.

Let us start by creating the :class:`NeuralFactory` object and instatiating the modules from the original example:
Let us start by creating the :class:`NeuralModuleFactory` object and instatiating the modules from the original example:

.. literalinclude:: ../../../../examples/start_here/module_configuration.py
:language: python
:lines: 25-35
:lines: 24-34

Now we can export the configuration of any of the existing modules by using the :meth:`export_to_config()`, for \
example we can export the configuration of the trainable :class:`TaylorNet` by calling:

.. literalinclude:: ../../../../examples/start_here/module_configuration.py
:language: python
:lines: 38
:lines: 37

Importing the configuration
---------------------------
Expand All @@ -37,7 +37,7 @@ There is an analogical function :meth:`import_from_config()` responsible for loa

.. literalinclude:: ../../../../examples/start_here/module_configuration.py
:language: python
:lines: 41
:lines: 40

.. note::
The :meth:`import_from_config()` function actually creates a new instance of object of the class that was stored \
Expand All @@ -49,7 +49,7 @@ For example, we can build a graph and train it with a NeMo trainer:

.. literalinclude:: ../../../../examples/start_here/module_configuration.py
:language: python
:lines: 43-
:lines: 42-


.. include:: module_custom_configuration.rst
Expand Down
28 changes: 11 additions & 17 deletions docs/sources/source/tutorials/module_custom_configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,21 @@ and extend it by those methods. But first, let us define a simple :class:`Status

.. literalinclude:: ../../../../examples/start_here/module_custom_configuration.py
:language: python
:lines: 33-35
:lines: 28-30

Now let us define the :class:`CustomTaylorNet` Neural Module class:

.. literalinclude:: ../../../../examples/start_here/module_custom_configuration.py
:language: python
:lines: 38-43
:lines: 33-38


In order to properly handle the export of the :class:`Status` enum we must implement a custom function \
:meth:`export_to_config()`:
:meth:`_serialize_configuration()`:

.. literalinclude:: ../../../../examples/start_here/module_custom_configuration.py
:language: python
:lines: 45-76
:lines: 49-61


Note that the configuration is actually a dictionary consisting of two sections:
Expand All @@ -40,35 +40,29 @@ Note that the configuration is actually a dictionary consisting of two sections:
Those parameters are stored in the protected ``self._init_params`` field of the base :class:`NeuralModule` class.
It is assumed that (aside of this use-case) the user won't access nor use them directly.

Analogically, we must overload the :meth:`import_from_config()` method:
Analogically, we must overload the :meth:`_deserialize_configuration()` method:

.. literalinclude:: ../../../../examples/start_here/module_custom_configuration.py
:language: python
:lines: 79-119

Please note that the base :class:`NeuralModule` class provides several protected methods that we used, \
with most important being:

* :meth:`_create_config_header()` generating the appropriate header, and \
* :meth:`_validate_config_file()` validating the loaded configuration file (checking the header content).

:lines: 63-86

.. note::
It is once again worth emphasizing that the :meth:`import_from_config()` is a class method, actually returning a \
new object instance - in this case of the hardcoded :class:`CustomTaylorNet` type.
It is worth emphasizing that the :meth:`_deserialize_configuration()` is a class method,
analogically to public :meth:`import_from_config()` and :meth:`deserialize()` methods
that return a new object instance - in this case of the hardcoded :class:`CustomTaylorNet` type.


Now we can simply create an instance and export its configuration by calling:

.. literalinclude:: ../../../../examples/start_here/module_custom_configuration.py
:language: python
:lines: 128-129,134-135
:lines: 95-96,101-102

And instantiate a second by loading that configuration:

.. literalinclude:: ../../../../examples/start_here/module_custom_configuration.py
:language: python
:lines: 137-139
:lines: 104-106

As a result we will see that the new object has set the status to the same value as the original one:

Expand Down
122 changes: 122 additions & 0 deletions examples/start_here/graph_composition_integration_tests0_jasper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# ! /usr/bin/python
blisc marked this conversation as resolved.
Show resolved Hide resolved
tkornuta-nvidia marked this conversation as resolved.
Show resolved Hide resolved
# -*- coding: utf-8 -*-

# =============================================================================
# Copyright (c) 2020 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

from functools import partial
from os.path import expanduser

from ruamel.yaml import YAML

import nemo
import nemo.collections.asr as nemo_asr
from nemo.collections.asr.helpers import monitor_asr_train_progress
from nemo.core import NeuralGraph, OperationMode
from nemo.utils import logging

nf = nemo.core.NeuralModuleFactory()

logging.info(
"This example shows how one can build a Jasper model using the explicit graph."
F" This approach works for applications containing a single graph."
)

# Set paths to "manifests" and model configuration files.
train_manifest = "~/TestData/an4_dataset/an4_train.json"
val_manifest = "~/TestData/an4_dataset/an4_val.json"
model_config_file = "~/workspace/nemo/examples/asr/configs/jasper_an4.yaml"

yaml = YAML(typ="safe")
with open(expanduser(model_config_file)) as f:
config = yaml.load(f)
# Get vocabulary.
vocab = config['labels']

# Create neural modules.
data_layer = nemo_asr.AudioToTextDataLayer.deserialize(
config["AudioToTextDataLayer_train"], overwrite_params={"manifest_filepath": train_manifest, "batch_size": 16},
)

data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor.deserialize(config["AudioToMelSpectrogramPreprocessor"])

jasper_encoder = nemo_asr.JasperEncoder.deserialize(config["JasperEncoder"])
jasper_decoder = nemo_asr.JasperDecoderForCTC.deserialize(
config["JasperDecoderForCTC"], overwrite_params={"num_classes": len(vocab)}
)
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
greedy_decoder = nemo_asr.GreedyCTCDecoder()

# Create the Jasper "model".
with NeuralGraph(operation_mode=OperationMode.both) as Jasper:
# Copy one input port definitions - using "user" port names.
Jasper.inputs["input"] = data_preprocessor.input_ports["input_signal"]
# Bind selected inputs - bind other using the default port name.
i_processed_signal, i_processed_signal_len = data_preprocessor(input_signal=Jasper.inputs["input"], length=Jasper)
okuchaiev marked this conversation as resolved.
Show resolved Hide resolved
i_encoded, i_encoded_len = jasper_encoder(audio_signal=i_processed_signal, length=i_processed_signal_len)
i_log_probs = jasper_decoder(encoder_output=i_encoded)
# Bind selected outputs - using "user" port names.
Jasper.outputs["log_probs"] = i_log_probs
Jasper.outputs["encoded_len"] = i_encoded_len

# Print the summary.
logging.info(Jasper.summary())

# Serialize graph
serialized_jasper = Jasper.serialize()
print("Serialized:\n", serialized_jasper)
blisc marked this conversation as resolved.
Show resolved Hide resolved

# Delete everything - aside of jasper encoder, just as a test to show that reusing work! ;)
del Jasper
del data_preprocessor
# del jasper_encoder #
del jasper_decoder

# Deserialize graph - copy of the JASPER "model".
jasper_copy = NeuralGraph.deserialize(serialized_jasper, reuse_existing_modules=True, name="jasper_copy")
serialized_jasper_copy = jasper_copy.serialize()
# print("Deserialized:\n", serialized_jasper_copy)
assert serialized_jasper == serialized_jasper_copy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these tests or examples? If these are tests, why aren't they following your own test conventions/classifications?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really - just wanted to show you that those two guys are "equivalent"...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(actually - the same)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, so once again:

  • every file "graph_composition_integration_tests_*" in examples/start_here is an example to show you what you can do with NeuralGraphs
  • those files will be removed after the PR is merged, or, more precisely, when i will finish the NG tutorials


# Create the "training" graph.
with NeuralGraph(name="training") as training_graph:
# Create the "implicit" training graph.
o_audio_signal, o_audio_signal_len, o_transcript, o_transcript_len = data_layer()
# Use Jasper module as any other neural module.
o_log_probs, o_encoded_len = jasper_copy(input=o_audio_signal, length=o_audio_signal_len)
o_predictions = greedy_decoder(log_probs=o_log_probs)
o_loss = ctc_loss(
log_probs=o_log_probs, targets=o_transcript, input_length=o_encoded_len, target_length=o_transcript_len
)
# Set graph output.
training_graph.outputs["o_loss"] = o_loss
# training_graph.outputs["o_predictions"] = o_predictions # DOESN'T WORK?!?

# Print the summary.
logging.info(training_graph.summary())

tensors_to_evaluate = [o_loss, o_predictions, o_transcript, o_transcript_len]
train_callback = nemo.core.SimpleLossLoggerCallback(
tensors=tensors_to_evaluate, print_func=partial(monitor_asr_train_progress, labels=vocab)
)
# import pdb;pdb.set_trace()
nf.train(
# tensors_to_optimize=[o_loss, o_predictions], # DOESN'T WORK?!?
training_graph=training_graph,
optimizer="novograd",
callbacks=[train_callback],
optimization_params={"num_epochs": 50, "lr": 0.01},
)
48 changes: 48 additions & 0 deletions examples/start_here/graph_composition_integration_tests1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# ! /usr/bin/python
# -*- coding: utf-8 -*-

# =============================================================================
# Copyright (c) 2020 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

from nemo.backends.pytorch.tutorials import MSELoss, RealFunctionDataLayer, TaylorNet
from nemo.core import DeviceType, NeuralGraph, NeuralModuleFactory, OperationMode, SimpleLossLoggerCallback
from nemo.utils import logging

nf = NeuralModuleFactory(placement=DeviceType.CPU)
# Instantiate the necessary neural modules.
dl = RealFunctionDataLayer(n=100, batch_size=32)
m2 = TaylorNet(dim=4)
loss = MSELoss()

logging.info("This example shows how one can build an `explicit` graph.")

with NeuralGraph(operation_mode=OperationMode.training) as g0:
x, t = dl()
p = m2(x=x)
lss = loss(predictions=p, target=t)
# Manual bind.
g0.outputs["output"] = lss

# Print the summary.
logging.info(g0.summary())

# SimpleLossLoggerCallback will print loss values to console.
callback = SimpleLossLoggerCallback(
tensors=[lss], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'),
)

# Invoke "train" action.
nf.train([lss], callbacks=[callback], optimization_params={"num_epochs": 2, "lr": 0.0003}, optimizer="sgd")