Skip to content

Commit

Permalink
PunctuationCapitalizationModel exportable (#1378)
Browse files Browse the repository at this point in the history
Signed-off-by: Sergei Nikolaev <snikolaev@nvidia.com>
  • Loading branch information
drnikolaev committed Oct 31, 2020
1 parent 6107566 commit 75774cf
Show file tree
Hide file tree
Showing 3 changed files with 476 additions and 185 deletions.
Expand Up @@ -15,6 +15,7 @@
import os
from typing import Dict, List, Optional

import onnx
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
Expand All @@ -30,14 +31,16 @@
from nemo.collections.nlp.modules.common.tokenizer_utils import get_tokenizer
from nemo.collections.nlp.parts.utils_funcs import tensor2list
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.classes.exportable import Exportable
from nemo.core.classes.modelPT import ModelPT
from nemo.core.neural_types import LogitsType, NeuralType
from nemo.utils import logging
from nemo.utils.export_utils import attach_onnx_to_onnx

__all__ = ['PunctuationCapitalizationModel']


class PunctuationCapitalizationModel(ModelPT):
class PunctuationCapitalizationModel(ModelPT, Exportable):
@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
return self.bert_model.input_types
Expand Down Expand Up @@ -408,3 +411,86 @@ def list_available_models(cls) -> Optional[Dict[str, str]]:
)
)
return result

def _prepare_for_export(self):
return self.bert_model._prepare_for_export()

def export(
self,
output: str,
input_example=None,
output_example=None,
verbose=False,
export_params=True,
do_constant_folding=True,
keep_initializers_as_inputs=False,
onnx_opset_version: int = 12,
try_script: bool = False,
set_eval: bool = True,
check_trace: bool = True,
use_dynamic_axes: bool = True,
):
"""
Unlike other models' export() this one creates 5 output files, not 3:
punct_<output> - fused punctuation model (BERT+PunctuationClassifier)
capit_<output> - fused capitalization model (BERT+CapitalizationClassifier)
bert_<output> - common BERT neural net
punct_classifier_<output> - Punctuation Classifier neural net
capt_classifier_<output> - Capitalization Classifier neural net
"""
if input_example is not None or output_example is not None:
logging.warning(
"Passed input and output examples will be ignored and recomputed since"
" PunctuationCapitalizationModel consists of three separate models with different"
" inputs and outputs."
)

bert_model_onnx = self.bert_model.export(
os.path.join(os.path.dirname(output), 'bert_' + os.path.basename(output)),
None, # computed by input_example()
None,
verbose,
export_params,
do_constant_folding,
keep_initializers_as_inputs,
onnx_opset_version,
try_script,
set_eval,
check_trace,
use_dynamic_axes,
)

punct_classifier_onnx = self.punct_classifier.export(
os.path.join(os.path.dirname(output), 'punct_classifier_' + os.path.basename(output)),
None, # computed by input_example()
None,
verbose,
export_params,
do_constant_folding,
keep_initializers_as_inputs,
onnx_opset_version,
try_script,
set_eval,
check_trace,
use_dynamic_axes,
)

capit_classifier_onnx = self.capit_classifier.export(
os.path.join(os.path.dirname(output), 'capit_classifier_' + os.path.basename(output)),
None, # computed by input_example()
None,
verbose,
export_params,
do_constant_folding,
keep_initializers_as_inputs,
onnx_opset_version,
try_script,
set_eval,
check_trace,
use_dynamic_axes,
)

punct_output_model = attach_onnx_to_onnx(bert_model_onnx, punct_classifier_onnx, "PTCL")
onnx.save(punct_output_model, os.path.join(os.path.dirname(output), 'punct_' + os.path.basename(output)))
capit_output_model = attach_onnx_to_onnx(bert_model_onnx, capit_classifier_onnx, "CPCL")
onnx.save(capit_output_model, os.path.join(os.path.dirname(output), 'capit_' + os.path.basename(output)))
28 changes: 28 additions & 0 deletions tests/collections/nlp/test_nlp_exportables.py
Expand Up @@ -116,6 +116,34 @@ def test_TokenClassificationModel_export_to_onnx(self):
assert onnx_model.graph.input[0].name == 'input_ids'
assert onnx_model.graph.output[0].name == 'logits'

def test_PunctuationCapitalizationModel_export_to_onnx(self):
model = nemo_nlp.models.PunctuationCapitalizationModel.from_pretrained(
model_name="Punctuation_Capitalization_with_BERT"
)
with tempfile.TemporaryDirectory() as tmpdir:
filename = os.path.join(tmpdir, 'puncap.onnx')
punct_filename = os.path.join(tmpdir, 'punct_puncap.onnx')
capit_filename = os.path.join(tmpdir, 'capit_puncap.onnx')
model.export(output=filename)
onnx_model = onnx.load(punct_filename)
onnx.checker.check_model(onnx_model, full_check=True) # throws when failed
assert len(onnx_model.graph.node) == 1160
assert onnx_model.graph.node[0].name == 'Unsqueeze_0'
assert onnx_model.graph.node[1159].name == 'PTCLLogSoftmax_2'
assert onnx_model.graph.node[30].name == 'Add_30'
assert onnx_model.graph.input[0].name == 'input_ids'
assert onnx_model.graph.input[2].name == 'token_type_ids'
assert onnx_model.graph.output[0].name == 'logits'
onnx_model = onnx.load(capit_filename)
onnx.checker.check_model(onnx_model, full_check=True) # throws when failed
assert len(onnx_model.graph.node) == 1160
assert onnx_model.graph.node[0].name == 'Unsqueeze_0'
assert onnx_model.graph.node[1159].name == 'CPCLLogSoftmax_2'
assert onnx_model.graph.node[30].name == 'Add_30'
assert onnx_model.graph.input[0].name == 'input_ids'
assert onnx_model.graph.input[2].name == 'token_type_ids'
assert onnx_model.graph.output[0].name == 'logits'


@pytest.fixture()
def dummy_data(test_data_dir):
Expand Down

0 comments on commit 75774cf

Please sign in to comment.