In [None]:
# =============================================================================
# Copyright (c) 2020 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
from functools import partial
from os.path import expanduser

from ruamel.yaml import YAML

import nemo
import nemo.collections.asr as nemo_asr
from nemo.collections.asr.helpers import monitor_asr_train_progress
from nemo.core import NeuralGraph, OperationMode, DeviceType
from nemo.utils import logging
from nemo.utils.app_state import AppState

# Create Neural(Module)Factory, use CPU.
nf = nemo.core.NeuralModuleFactory(placement=DeviceType.CPU)

### Tutorial II: The advanced functionality

In this first part of the Neural Graphs (NGs) tutorial we will focus on a more complex example: training of an End-to-End Convolutional Neural Acoustic Model called JASPER. We will build a "model graph" and show how we can nest it into another graphs, how we can freeze/unfreeze modules, use graph configuration and save/load graph checkpoints.

#### This part covers the following:
 * how to nest one graph into another
 * how to serialize and deserialize a graph 
 * how to export and import configuration to/from YAML files
 * how to save and load graph checkpoints
 * how to freeze/unfreeze modules in a graph

In order to learn more about graph nesting and input/output binding please refer to the first part of the tutorial.


In [None]:
# Set paths to "manifests" and model configuration files.
train_manifest = "~/TestData/an4_dataset/an4_train.json"
val_manifest = "~/TestData/an4_dataset/an4_val.json"
model_config_file = "~/workspace/nemo/examples/asr/configs/jasper_an4.yaml"

yaml = YAML(typ="safe")
with open(expanduser(model_config_file)) as f:
    config = yaml.load(f)
# Get vocabulary.
vocab = config['labels']

In [None]:
# Create neural modules using the Neural Module deserialization feature.
data_layer = nemo_asr.AudioToTextDataLayer.deserialize(
    config["AudioToTextDataLayer_train"], overwrite_params={"manifest_filepath": train_manifest, "batch_size": 16},
)

data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor.deserialize(config["AudioToMelSpectrogramPreprocessor"])

jasper_encoder = nemo_asr.JasperEncoder.deserialize(config["JasperEncoder"])
jasper_decoder = nemo_asr.JasperDecoderForCTC.deserialize(
    config["JasperDecoderForCTC"], overwrite_params={"num_classes": len(vocab)}
)
ctc_loss = nemo_asr.CTCLossNM(num_classes=len(vocab))
greedy_decoder = nemo_asr.GreedyCTCDecoder()

In [None]:
# Create the Jasper "model" graph.
with NeuralGraph(operation_mode=OperationMode.both, name="jasper") as Jasper:
    # Copy one input port definitions - using "user" port names.
    Jasper.inputs["input"] = data_preprocessor.input_ports["input_signal"]
    # Bind selected inputs - bind other using the default port name.
    i_processed_signal, i_processed_signal_len = data_preprocessor(input_signal=Jasper.inputs["input"], length=Jasper)
    i_encoded, i_encoded_len = jasper_encoder(audio_signal=i_processed_signal, length=i_processed_signal_len)
    i_log_probs = jasper_decoder(encoder_output=i_encoded)
    # Bind selected outputs - using "user" port names.
    Jasper.outputs["log_probs"] = i_log_probs
    Jasper.outputs["encoded_len"] = i_encoded_len

# Print the summary.
logging.info(Jasper.summary())

In [None]:
# Serialize graph
serialized_jasper = Jasper.serialize()
logging.info("Serialized JasperNet:\n {}".format(serialized_jasper))

In [None]:
# Serialize decoder.
logging.info("Serialized Jasper Decoder:\n {}".format(jasper_decoder.serialize()))

In [None]:
# We can also export the serialized configuration to a file.
Jasper.export_to_config("my_jasper.yml")

In [None]:
# Display the lists of graph and modules
logging.info(AppState().graphs.summary())
logging.info(AppState().modules.summary())

In [None]:
# Delete everything - aside of jasper encoder, just as a test to show that reusing work! ;)
del Jasper
del data_preprocessor
del jasper_encoder #
del jasper_decoder

# In "pure" python - that will remove ALL existing references (bot registries are Dicts with weak references!)

In [None]:
# Display list of graph and modules
logging.info(AppState().graphs.summary())
logging.info(AppState().modules.summary())

In [None]:
# Deserialize graph - create a copy of the JASPER "model".
# Please note that the modules exist, so we must enable the graph to "reuse" them.
# (Commenting out reuse_existing_modules will raise a KeyError.)
jasper_copy = NeuralGraph.deserialize(serialized_jasper, reuse_existing_modules=True)
serialized_jasper_copy = jasper_copy.serialize()
assert serialized_jasper == serialized_jasper_copy # THE SAME! Please note name of the graph is not exported.

# Print the summary.
logging.info(jasper_copy.summary())

# Display list of graph and modules
logging.info(AppState().graphs.summary())
logging.info(AppState().modules.summary())

In [None]:
# Alternativelly, import a copy of the JASPER "model" from config.
jasper_copy = NeuralGraph.import_from_config("my_jasper.yml", reuse_existing_modules=True, name="jasper_copy2")

# Print the summary.
logging.info(jasper_copy.summary())

# Display list of graph and modules
logging.info(AppState().graphs.summary())
logging.info(AppState().modules.summary())

In [None]:
# Create the "training" graph.
with NeuralGraph(operation_mode=OperationMode.training) as training_graph:
    # Create the "implicit" training graph.
    o_audio_signal, o_audio_signal_len, o_transcript, o_transcript_len = data_layer()
    # Use Jasper module as any other neural module.
    o_log_probs, o_encoded_len = jasper_copy(input=o_audio_signal, length=o_audio_signal_len)
    o_predictions = greedy_decoder(log_probs=o_log_probs)
    o_loss = ctc_loss(
        log_probs=o_log_probs, targets=o_transcript, input_length=o_encoded_len, target_length=o_transcript_len
    )
    # Set graph output.
    training_graph.outputs["o_loss"] = o_loss
    # training_graph.outputs["o_predictions"] = o_predictions # DOESN'T WORK!

# Print the summary.
logging.info(training_graph.summary())

In [None]:
# Create training callback.
tensors_to_evaluate = [o_loss, o_predictions, o_transcript, o_transcript_len]
train_callback = nemo.core.SimpleLossLoggerCallback(
    tensors=tensors_to_evaluate, print_func=partial(monitor_asr_train_progress, labels=vocab), step_freq=1
)

# Train the graph.
nf.train(
    # tensors_to_optimize=[o_loss, o_predictions], # DOESN'T WORK!
    # tensors_to_optimize=[o_loss],
    training_graph=training_graph,
    optimizer="novograd",
    callbacks=[train_callback],
    optimization_params={"max_steps": 5, "lr": 0.01},
)

In [None]:
# Finally, I can save the graph checkpoint!
jasper_copy.save_to("my_jasper.chkpt")#, module_names=["jasperencoder0"])
# Please note only "trainable" modules will be saved.

In [None]:
# In this case saving the whole graph should result in the same checkpoint...
training_graph.export_to_config("my_whole_graph.yml")
training_graph.save_to("my_whole_graph.chkpt")

# BUT !! class GreedyCTCDecoder(TrainableNM) !! so:

In [None]:
# Finally, I can load everything and continue training.
new_training_graph = NeuralGraph.import_from_config("my_whole_graph.yml", reuse_existing_modules=True)

# Let's restore only the encoder
new_training_graph.restore_from("my_whole_graph.chkpt", module_names=["jasperencoder0"])

In [None]:
# Or maybe not...
# Let's restore only the encoder
new_training_graph.restore_from("my_whole_graph.chkpt")

In [None]:
# Analogically - create a loss callback.
loss_callback = nemo.core.SimpleLossLoggerCallback(
    tensors=[new_training_graph.output_tensors["o_loss"]],
    print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), step_freq=1
)

In [None]:
# And  what will happen if we will freeze our graph?
training_graph.freeze() #we can also freeze a subset, using "module_names=[]""
# Let us finetune only the decoder.
training_graph.unfreeze(module_names=["jasperdecoderforctc0"])

# Ok, let us see what the graph looks like now.
logging.info(training_graph.summary())

In [None]:
# And continue training...
nf.reset_trainer() # I do not understand why do I have to "reset the trainer" when calling train() function again :]
nf.train(
    training_graph=new_training_graph,
    optimizer="novograd",
    callbacks=[loss_callback],
    optimization_params={"max_steps": 5, "lr": 0.01},
)

# This will throw an error as all trainable modules are frozen!

# Neural Graph plans and extensions

## 1. Long-term goal: "training with graphs" (since November 2019;])

### training with training/evaluation graphs
 * train(training_graph=graph1, evaluation_graph=graph2 [OPTIONAL], ...)

### Expanded: training with callbacks 
 * train(training_graph=graph1, training_callbacks=callbacks1 [OPTIONAL], evaluation_graph=graph2 [OPTIONAL], evaluation_callbacks=callbacks2 [OPTIONAL], ...)

### Inference/evaluation
 * infer(evaluation_graph=graph2, ...)

### Expanded: inference with callbacks 
 * infer(evaluation_graph=graph2, evaluation_callbacks=callbacks2 [OPTIONAL], ...)
 

## 2. "Other main" graph actions

 * inputs/outputs binding [DONE]
 * graph nesting [DONE]
 * import_from_config()/export_to_config() [DONE]
 * serialize()/deserialize() [DONE]
 * save_to()/restore_from() [DONE]
 
 
## 3. "Partial" graph actions
### (will be used in the "main actions", but also could be called by the user directly)

 * freeze()/unfreeze() [DONE]
 * is_valid()
 * to(device)
 * graph nesting "with duplication" (@duplicate)
 * get_batch() -> batch
 * forward(batch) # Evelina's "infer with user input" (Complete Dialog Pipeline)
 * backward() (?)
