In [None]:
# =============================================================================
# Copyright (c) 2020 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

from functools import partial
from os.path import expanduser, join, abspath, dirname, exists
import tarfile

from ruamel.yaml import YAML

import nemo
import nemo.collections.asr as nemo_asr
from nemo.collections.asr.helpers import monitor_asr_train_progress
from nemo.core import NeuralGraph, OperationMode, DeviceType, SimpleLossLoggerCallback
from nemo.utils import logging
from nemo.utils.app_state import AppState

# Create Neural(Module)Factory, use CPU.
nf = nemo.core.NeuralModuleFactory(placement=DeviceType.CPU)

### Tutorial II: The advanced functionality

In this first part of the Neural Graphs (NGs) tutorial we will focus on a more complex example: training of an End-to-End Convolutional Neural Acoustic Model called JASPER. We will build a "model graph" and show how we can nest it into another graphs, how we can freeze/unfreeze modules, use graph configuration and save/load graph checkpoints.

#### This part covers the following:
 * how to nest one graph into another
 * how to serialize and deserialize a graph
 * how to export and import serialized graph configuration to/from YAML files
 * how to save and load graph checkpoints (containing weights of the Trainable NMs)
 * how to freeze/unfreeze modules in a graph
 
Additionally, we will show how use `AppState` to list all the modules and graphs we have created in the scope of our application.
In order to learn more about graph nesting and input/output binding please refer to the first part of the tutorial.


In [None]:
# Prepare the samples for training JASPER - we will use the data available in NeMo tests.
data_folder = abspath("../../tests/data/")
logging.info("Looking up for test ASR data")
if not exists(join(data_folder, "asr")):
    logging.info("Extracting ASR data to: {0}".format(join(data_folder, "asr")))
    tar = tarfile.open(join(data_folder, "asr.tar.gz"), "r:gz")
    tar.extractall(path=data_folder)
    tar.close()
else:
    logging.info("ASR data found in: {0}".format(join(data_folder, "asr")))

In [None]:
# Set paths to model configuration, manifest and sample files.
model_config_file = abspath("../asr/configs/jasper_an4.yaml")
manifest_path = join(data_folder, 'asr/tarred_an4/tarred_audio_manifest.json')
tarpath = join(data_folder, 'asr/tarred_an4/audio_1.tar')

# Open the model config file and get vocabulary.
yaml = YAML(typ="safe")
with open(expanduser(model_config_file)) as f:
    config = yaml.load(f)
    
# Get labels (vocabulary).
vocab = config['labels']
vocab_len = len(vocab)

In [None]:
# Instantiate DataLayer that can load the tarred samples.
data_layer = nemo_asr.TarredAudioToTextDataLayer(
    audio_tar_filepaths=tarpath, manifest_filepath=manifest_path, labels=vocab, batch_size=16)
logging.info("Loaded {} samples that we will use for training".format(len(data_layer)))

# Create rest of the modules using the Neural Module deserialization feature.
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor.deserialize(config["AudioToMelSpectrogramPreprocessor"])

jasper_encoder = nemo_asr.JasperEncoder.deserialize(config["JasperEncoder"])
jasper_decoder = nemo_asr.JasperDecoderForCTC.deserialize(
    config["JasperDecoderForCTC"], overwrite_params={"num_classes": vocab_len}
)
ctc_loss = nemo_asr.CTCLossNM(num_classes=vocab_len)
greedy_decoder = nemo_asr.GreedyCTCDecoder()

In [None]:
# Create the Jasper "model" graph.
with NeuralGraph(operation_mode=OperationMode.both, name="jasper_model") as jasper_model:
    # Copy one input port definitions - using "user" port names.
    jasper_model.inputs["input"] = data_preprocessor.input_ports["input_signal"]
    # Bind selected inputs - bind other using the default port name.
    i_processed_signal, i_processed_signal_len = data_preprocessor(input_signal=jasper_model.inputs["input"], length=jasper_model)
    i_encoded, i_encoded_len = jasper_encoder(audio_signal=i_processed_signal, length=i_processed_signal_len)
    i_log_probs = jasper_decoder(encoder_output=i_encoded)
    # Bind selected outputs - using "user" port names.
    jasper_model.outputs["log_probs"] = i_log_probs
    jasper_model.outputs["encoded_len"] = i_encoded_len

# Print the summary.
logging.info(jasper_model.summary())

In [None]:
# Serialize the whole graph.
serialized_jasper = jasper_model.serialize()
logging.info("Serialized JASPER model:\n {}".format(serialized_jasper))

In [None]:
# You can also serialize/deserialize a single NeuralModule, e.g. a decoder.
logging.info("Serialized JASPER Decoder:\n {}".format(jasper_decoder.serialize()))

In [None]:
# We can also export the serialized configuration to a file.
jasper_model.export_to_config("my_jasper.yml")

In [None]:
# Display the lists of graph and modules.
logging.info(AppState().graphs.summary())
logging.info(AppState().modules.summary())

In [None]:
# Deserialize graph - create a copy of the JASPER "model".
# Please note that the modules exist, so we must enable the graph to "reuse" them.
# (Commenting out reuse_existing_modules will raise a KeyError.)
jasper_copy = NeuralGraph.deserialize(serialized_jasper, reuse_existing_modules=True)
serialized_jasper_copy = jasper_copy.serialize()
assert serialized_jasper == serialized_jasper_copy # THE SAME! Please note name of the graph is not exported.

In [None]:
# Alternativelly, import a copy of the JASPER "model" from config.
jasper_copy = NeuralGraph.import_from_config("my_jasper.yml", reuse_existing_modules=True, name="jasper_copy")

# Print the summary.
logging.info(jasper_copy.summary())

# Display list of graph and modules
logging.info(AppState().graphs.summary())
logging.info(AppState().modules.summary())

Note that there are two graphs in the "Graph Registry", yet the list of modules haven't changed. This means that both graphs are spanned on the same list of modules.

In [None]:
# Create the "training" graph.
with NeuralGraph(operation_mode=OperationMode.training) as training_graph:
    # Create the "implicit" training graph.
    o_audio_signal, o_audio_signal_len, o_transcript, o_transcript_len = data_layer()
    # Use Jasper module as any other neural module.
    o_log_probs, o_encoded_len = jasper_copy(input=o_audio_signal, length=o_audio_signal_len)
    o_predictions = greedy_decoder(log_probs=o_log_probs)
    o_loss = ctc_loss(
        log_probs=o_log_probs, targets=o_transcript, input_length=o_encoded_len, target_length=o_transcript_len
    )
    # Set the graph output.
    training_graph.outputs["o_loss"] = o_loss

# Print the summary.
logging.info(training_graph.summary())

In [None]:
# Create a simple loss callback.
loss_callback = nemo.core.SimpleLossLoggerCallback(
    tensors=[training_graph.output_tensors["o_loss"]],
    print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), step_freq=1
)
# Train the graph.
nf.train(
    training_graph=training_graph,
    optimizer="novograd",
    callbacks=[loss_callback],
    optimization_params={"max_steps": 5, "lr": 0.01},
)

Please note that the loss is going down. Still, we use only 65 samples, so we cannot really expect the model to be useful;)

In [None]:
# Finally, I can save the graph checkpoint!
# Note that optionally you can indicate the names of the modules to be saved.
jasper_copy.save_to("my_jasper.chkpt")#, module_names=["jasperencoder0"])
# Please note only "trainable" modules will be saved.

In [None]:
# We can also save the whole training graph - which in this case will result in the same checkpoint...
training_graph.export_to_config("my_whole_graph.yml")
training_graph.save_to("my_whole_graph.chkpt")

In [None]:
# Finally, I can load everything and continue training.
new_training_graph = NeuralGraph.import_from_config("my_whole_graph.yml", reuse_existing_modules=True)

# Let's restore only the encoder
new_training_graph.restore_from("my_whole_graph.chkpt", module_names=["jasperencoder0"])

In [None]:
# So let us freeze the whole graph...
training_graph.freeze() #we can also freeze a subset, using "module_names=[]""
# ... and finetune only the decoder.
training_graph.unfreeze(module_names=["jasperdecoderforctc0"])

# Ok, let us see what the graph looks like now.
logging.info(training_graph.summary())

In [None]:
# Create a new simple callback using graph outputs "o_loss".
loss_callback = nemo.core.SimpleLossLoggerCallback(
    tensors=[new_training_graph.output_tensors["o_loss"]],
    print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), step_freq=1
)

# And continue training...
nf.reset_trainer()
nf.train(
    training_graph=new_training_graph,
    optimizer="novograd",
    callbacks=[loss_callback],
    optimization_params={"max_steps": 5, "lr": 0.01},
)
# Please note that this will throw an error if you will freeze all the trainable modules!