diff --git a/collections/nemo_asr/nemo_asr/audio_preprocessing.py b/collections/nemo_asr/nemo_asr/audio_preprocessing.py index 8e4e1676027c..7ba120e4880d 100644 --- a/collections/nemo_asr/nemo_asr/audio_preprocessing.py +++ b/collections/nemo_asr/nemo_asr/audio_preprocessing.py @@ -27,7 +27,7 @@ import torchaudio try: from apex import amp -except AttributeError: +except (AttributeError, ModuleNotFoundError) as e: print("Unable to import APEX. Mixed precision and distributed training " "will not work.") diff --git a/collections/nemo_asr/nemo_asr/jasper.py b/collections/nemo_asr/nemo_asr/jasper.py index e855e4319b59..ce6e2282c7fa 100644 --- a/collections/nemo_asr/nemo_asr/jasper.py +++ b/collections/nemo_asr/nemo_asr/jasper.py @@ -116,7 +116,6 @@ def __init__( self.dense_residual = True groups = lcfg.get('groups', 1) separable = lcfg.get('separable', False) - tied = lcfg.get('tied', False) heads = lcfg.get('heads', -1) encoder_layers.append( JasperBlock(feat_in, @@ -133,7 +132,6 @@ def __init__( residual_mode=residual_mode, normalization=normalization_mode, norm_groups=norm_groups, - tied=tied, activation=activation, residual_panes=dense_res, conv_mask=conv_mask)) diff --git a/collections/nemo_asr/nemo_asr/parts/jasper.py b/collections/nemo_asr/nemo_asr/parts/jasper.py index f28c30538760..464d938a64cb 100644 --- a/collections/nemo_asr/nemo_asr/parts/jasper.py +++ b/collections/nemo_asr/nemo_asr/parts/jasper.py @@ -1,7 +1,22 @@ -# Taken straight from Patter https://github.com/ryanleary/patter -# TODO: review, and copyright and fix/add comments +# Copyright (C) NVIDIA CORPORATION. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + import torch import torch.nn as nn +from torch import Tensor jasper_activations = { "hardtanh": nn.Hardtanh, @@ -11,7 +26,9 @@ def init_weights(m, mode='xavier_uniform'): - if isinstance(m, nn.Conv1d) or isinstance(m, MaskedConv1d): + if isinstance(m, MaskedConv1d): + init_weights(m.conv, mode) + if isinstance(m, nn.Conv1d): if mode == 'xavier_uniform': nn.init.xavier_uniform_(m.weight, gain=1.0) elif mode == 'xavier_normal': @@ -40,50 +57,52 @@ def get_same_padding(kernel_size, stride, dilation): return kernel_size // 2 -class MaskedConv1d(nn.Conv1d): +class MaskedConv1d(nn.Module): + __constants__ = ["use_conv_mask", "real_out_channels", "heads"] def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, heads=-1, bias=False, use_mask=True): + super(MaskedConv1d, self).__init__() if not (heads == -1 or groups == in_channels): raise ValueError("Only use heads for depthwise convolutions") + self.real_out_channels = out_channels if heads != -1: - self.real_out_channels = out_channels in_channels = heads out_channels = heads groups = heads - super(MaskedConv1d, self).__init__(in_channels, out_channels, - kernel_size, - stride=stride, - padding=padding, dilation=dilation, - groups=groups, bias=bias) + self.conv = nn.Conv1d(in_channels, out_channels, + kernel_size, + stride=stride, + padding=padding, dilation=dilation, + groups=groups, bias=bias) self.use_mask = use_mask self.heads = heads def get_seq_len(self, lens): - return ((lens + 2 * self.padding[0] - self.dilation[0] * ( - self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + return ((lens + 2 * self.conv.padding[0] - self.conv.dilation[0] * ( + self.conv.kernel_size[0] - 1) - 1) / self.conv.stride[0] + 1) def forward(self, x, lens): if self.use_mask: lens = lens.to(dtype=torch.long) max_len = x.size(2) - mask = torch.arange(max_len).to(lens.device)\ + mask = torch.arange(max_len).to(lens.device) \ .expand(len(lens), max_len) >= lens.unsqueeze(1) x = x.masked_fill( - mask.unsqueeze(1).type(torch.bool).to(device=x.device), 0 + mask.unsqueeze(1).to(device=x.device), 0 ) - del mask + # del mask lens = self.get_seq_len(lens) + sh = x.shape if self.heads != -1: - sh = x.shape x = x.view(-1, self.heads, sh[-1]) - out, lens = super(MaskedConv1d, self).forward(x), lens + out = self.conv(x) if self.heads != -1: out = out.view(sh[0], self.real_out_channels, -1) @@ -112,11 +131,12 @@ def forward(self, x): class JasperBlock(nn.Module): + __constants__ = ["conv_mask", "separable", "residual_mode", "res", "mconv"] def __init__(self, inplanes, planes, repeat=3, kernel_size=11, stride=1, dilation=1, padding='same', dropout=0.2, activation=None, residual=True, groups=1, separable=False, - heads=-1, tied=False, normalization="batch", + heads=-1, normalization="batch", norm_groups=1, residual_mode='add', residual_panes=[], conv_mask=False): super(JasperBlock, self).__init__() @@ -129,11 +149,11 @@ def __init__(self, inplanes, planes, repeat=3, kernel_size=11, stride=1, self.separable = separable self.residual_mode = residual_mode - self.conv = nn.ModuleList() inplanes_loop = inplanes + conv = nn.ModuleList() - if tied: - rep_layer = self._get_conv_bn_layer( + for _ in range(repeat - 1): + conv.extend(self._get_conv_bn_layer( inplanes_loop, planes, kernel_size=kernel_size, @@ -144,73 +164,70 @@ def __init__(self, inplanes, planes, repeat=3, kernel_size=11, stride=1, heads=heads, separable=separable, normalization=normalization, - norm_groups=norm_groups) + norm_groups=norm_groups)) - for _ in range(repeat - 1): - if tied: - self.conv.extend(rep_layer) - else: - self.conv.extend( - self._get_conv_bn_layer( - inplanes_loop, - planes, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - padding=padding_val, - groups=groups, - heads=heads, - separable=separable, - normalization=normalization, - norm_groups=norm_groups)) - - self.conv.extend( - self._get_act_dropout_layer( - drop_prob=dropout, - activation=activation)) + conv.extend(self._get_act_dropout_layer( + drop_prob=dropout, + activation=activation)) inplanes_loop = planes - if tied: - self.conv.extend(rep_layer) - else: - self.conv.extend( - self._get_conv_bn_layer( - inplanes_loop, - planes, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - padding=padding_val, - groups=groups, - heads=heads, - separable=separable, - normalization=normalization, - norm_groups=norm_groups)) + conv.extend(self._get_conv_bn_layer( + inplanes_loop, + planes, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding_val, + groups=groups, + heads=heads, + separable=separable, + normalization=normalization, + norm_groups=norm_groups)) + + self.mconv = conv - self.res = nn.ModuleList() if residual else None res_panes = residual_panes.copy() self.dense_residual = residual + if residual: + res_list = nn.ModuleList() if len(residual_panes) == 0: res_panes = [inplanes] self.dense_residual = False for ip in res_panes: - self.res.append( - nn.ModuleList( - modules=self._get_conv_bn_layer( - ip, - planes, - kernel_size=1, - normalization=normalization, - norm_groups=norm_groups))) - self.out = nn.Sequential( + res_list.append(nn.ModuleList(self._get_conv_bn_layer( + ip, + planes, + kernel_size=1, + normalization=normalization, + norm_groups=norm_groups))) + self.res = res_list + else: + self.res = None + + self.mout = nn.Sequential( *self._get_act_dropout_layer( drop_prob=dropout, - activation=activation - ) + activation=activation) ) + def _get_conv(self, in_channels, out_channels, kernel_size=11, + stride=1, dilation=1, padding=0, bias=False, + groups=1, heads=-1, separable=False): + use_mask = self.conv_mask + if use_mask: + return MaskedConv1d(in_channels, out_channels, kernel_size, + stride=stride, + dilation=dilation, padding=padding, bias=bias, + groups=groups, heads=heads, + use_mask=use_mask) + else: + return nn.Conv1d(in_channels, out_channels, kernel_size, + stride=stride, + dilation=dilation, padding=padding, bias=bias, + groups=groups) + def _get_conv_bn_layer(self, in_channels, out_channels, kernel_size=11, stride=1, dilation=1, padding=0, bias=False, groups=1, heads=-1, separable=False, @@ -220,23 +237,20 @@ def _get_conv_bn_layer(self, in_channels, out_channels, kernel_size=11, if separable: layers = [ - MaskedConv1d(in_channels, in_channels, kernel_size, - stride=stride, - dilation=dilation, padding=padding, bias=bias, - groups=in_channels, heads=heads, - use_mask=self.conv_mask), - MaskedConv1d(in_channels, out_channels, kernel_size=1, - stride=1, - dilation=1, padding=0, bias=bias, groups=groups, - use_mask=self.conv_mask) + self._get_conv(in_channels, in_channels, kernel_size, + stride=stride, + dilation=dilation, padding=padding, bias=bias, + groups=in_channels, heads=heads), + self._get_conv(in_channels, out_channels, kernel_size, + stride=1, + dilation=1, padding=0, bias=bias, groups=groups) ] else: layers = [ - MaskedConv1d(in_channels, out_channels, kernel_size, - stride=stride, - dilation=dilation, padding=padding, bias=bias, - groups=groups, - use_mask=self.conv_mask) + self._get_conv(in_channels, out_channels, kernel_size, + stride=stride, + dilation=dilation, padding=padding, bias=bias, + groups=groups) ] if normalization == "group": @@ -268,15 +282,13 @@ def _get_act_dropout_layer(self, drop_prob=0.2, activation=None): ] return layers - def forward(self, input_): - + def forward(self, input_: Tuple[Tensor, Tensor]): xs, lens_orig = input_ - # compute forward convolutions out = xs[-1] lens = lens_orig - for i, l in enumerate(self.conv): + for i, l in enumerate(self.mconv): # if we're doing masked convolutions, we need to pass in and # possibly update the sequence lengths # if (i % 4) == 0 and self.conv_mask: @@ -301,7 +313,7 @@ def forward(self, input_): out = torch.max(out, res_out) # compute the output - out = self.out(out) + out = self.mout(out) if self.res is not None and self.dense_residual: return xs + [out], lens diff --git a/collections/nemo_nlp/nemo_nlp/huggingface/bert.py b/collections/nemo_nlp/nemo_nlp/huggingface/bert.py index 735d072afd29..076bf46272f5 100644 --- a/collections/nemo_nlp/nemo_nlp/huggingface/bert.py +++ b/collections/nemo_nlp/nemo_nlp/huggingface/bert.py @@ -114,6 +114,8 @@ def __init__(self, *, self.add_module("bert", model) self.config = model.config + for key, value in self.config.to_dict().items(): + self._local_parameters[key] = value @staticmethod def list_pretrained_models() -> Optional[List[PretrainedModelInfo]]: diff --git a/nemo/nemo/backends/pytorch/actions.py b/nemo/nemo/backends/pytorch/actions.py index 6c5088cbdd2f..4fc4c9dba0bc 100644 --- a/nemo/nemo/backends/pytorch/actions.py +++ b/nemo/nemo/backends/pytorch/actions.py @@ -1,21 +1,25 @@ # Copyright (c) 2019 NVIDIA Corporation import importlib import itertools +import json import logging import os -from typing import List, Optional +from pathlib import Path +from typing import List, Optional, Dict +# import onnx +from collections import defaultdict import torch import torch.distributed as dist import torch.nn as nn import torch.optim as optim - from nemo.backends.pytorch.nm import TrainableNM from .module_wrapper import TrainableNeuralModuleWrapper from .nm import DataLayerNM from .optimizers import Novograd, AdamW, Lamb, master_params -from ...core import NmTensor, DeviceType, NeuralModule +from ...core import NmTensor, DeviceType, NeuralModule, DeploymentFormat +from ...core.neural_types import * from ...core.callbacks import (ActionCallback, EvaluatorCallback, SimpleLossLoggerCallback) @@ -60,7 +64,7 @@ def __init__(self, local_rank=None, tb_writer=None, parallel = importlib.import_module('apex.parallel') convert_syncbn = parallel.convert_syncbn_model create_syncbn_process_group = ( - parallel.create_syncbn_process_group) + parallel.create_syncbn_process_group) DDP = parallel.DistributedDataParallel LARC = parallel.LARC @@ -1003,6 +1007,129 @@ def _get_all_modules( for module in callchain: self.modules.add(module[0]) + @staticmethod + def __module_export(module, + output, + d_format: DeploymentFormat, + input_example=None, + output_example=None): + # Check if output already exists + destination = Path(output) + if destination.exists(): + raise FileExistsError(f"Destination {output} already exists. " + f"Aborting export.") + + input_names = list(module.input_ports.keys()) + output_names = list(module.output_ports.keys()) + dynamic_axes = defaultdict(list) + + def __extract_dynamic_axes(port_name: str, ntype: NeuralType, + dynamic_axes: defaultdict): + if ntype.axis2type: + for axis_id, axistype in ntype.axis2type.items(): + if issubclass(axistype.semantics, BatchTag) or issubclass( + axistype.semantics, TimeTag): + dynamic_axes[port_name].append(axis_id) + # for input_ports + for port_name, ntype in module.input_ports.items(): + __extract_dynamic_axes(port_name, ntype, dynamic_axes) + # for output_ports + for port_name, ntype in module.output_ports.items(): + __extract_dynamic_axes(port_name, ntype, dynamic_axes) + if len(dynamic_axes) == 0: + dynamic_axes = None + + local_parameters = {} + for key, value in module._local_parameters.items(): + local_parameters[key] = value + + # Remove NeMo-related things from the module + # We need to change __call__ method. Note that this will change the + # whole class, not just this object! Which is why we need to repair it + # in the finally block + type(module).__call__ = torch.nn.Module.__call__ + + module._local_parameters = None + module._logger = None + module._placement = None + module._factory = None + module._device = None + + module.eval() + try: + if d_format == DeploymentFormat.TORCHSCRIPT: + if input_example is None: + # Route 1 - via torch.jit.script + traced_m = torch.jit.script(module) + traced_m.save(output) + else: + # Route 2 - via tracing + traced_m = torch.jit.trace(module, input_example) + traced_m.save(output) + elif d_format == DeploymentFormat.ONNX: + if input_example is None: + raise ValueError( + f'Example input is None, but ONNX tracing was' + f' attempted') + torch.onnx.export(module, input_example, output, + input_names=input_names, + output_names=output_names, + verbose=True, + export_params=True, + do_constant_folding=True, + dynamic_axes=dynamic_axes, + opset_version=10) + # fn = output + ".readable" + # with open(fn, 'w') as f: + # tempModel = onnx.load(output) + # onnx.save(tempModel, output + ".copy") + # onnx.checker.check_model(tempModel) + # pgraph = onnx.helper.printable_graph(tempModel.graph) + # f.write(pgraph) + + elif d_format == DeploymentFormat.PYTORCH: + torch.save(module.state_dict(), output) + with open(output + ".json", 'w') as outfile: + json.dump(local_parameters, outfile) + + else: + raise NotImplemented( + f"Not supported deployment format: {d_format}") + except Exception as e: # nopep8 + print( + f'ERROR: module export failed for {module} with exception {e}') + finally: + def __old_call__(self, force_pt=False, *input, **kwargs): + pt_call = len(input) > 0 or force_pt + if pt_call: + return nn.Module.__call__(self, *input, **kwargs) + else: + return NeuralModule.__call__(self, **kwargs) + + type(module).__call__ = __old_call__ + + @staticmethod + def deployment_export(module, + output: str, + d_format: DeploymentFormat, + input_example=None, + output_example=None): + """Exports Neural Module instance for deployment. + + Args: + module: neural module to export + output (str): where export results should be saved + d_format (DeploymentFormat): which deployment format to use + input_example: sometimes tracing will require input examples + output_example: Should match inference on input_example + """ + PtActions.__module_export( + module=module, + output=output, + d_format=d_format, + input_example=input_example, + output_example=output_example) + def train(self, tensors_to_optimize, optimizer=None, @@ -1188,10 +1315,10 @@ def train(self, f" ({world_size})." ) process_group = create_syncbn_process_group( - synced_batchnorm_groupsize) + synced_batchnorm_groupsize) pmodule = convert_syncbn( - pmodule, - process_group=process_group) + pmodule, + process_group=process_group) self.module_reference_table[key] = ( self.module_reference_table[key][0], pmodule diff --git a/nemo/nemo/backends/pytorch/nm.py b/nemo/nemo/backends/pytorch/nm.py index 24b3a13e95d7..8f0d376e7413 100644 --- a/nemo/nemo/backends/pytorch/nm.py +++ b/nemo/nemo/backends/pytorch/nm.py @@ -43,12 +43,14 @@ def __call__(self, force_pt=False, *input, **kwargs): else: return NeuralModule.__call__(self, **kwargs) + @t.jit.ignore() def get_weights(self): result = dict() for name, parameter in self.named_parameters(): result[name] = (parameter, parameter.requires_grad) return result + @t.jit.ignore() def set_weights(self, name2weight, name2name_and_transform=None): if name2weight is not None and len(name2weight) > 0: if name2name_and_transform is None: @@ -60,6 +62,7 @@ def set_weights(self, name2weight, name2name_and_transform=None): {key: name2weight[key][0] for key in name2weight.keys()} ) + @t.jit.ignore() def tie_weights_with(self, module, weight_names, name2name_and_transform=None): if module is None: @@ -93,10 +96,12 @@ def tie_weights_with(self, module, weight_names, else: rsetattr(self, self_w_name, rgetattr(module, self_w_name)) + @t.jit.ignore() def save_to(self, path): # t.save(self._pt_module.state_dict(), path) t.save(self.state_dict(), path) + @t.jit.ignore() def restore_from(self, path, local_rank=0): # self._pt_module.load_state_dict(t.load(path)) if self.placement == DeviceType.AllGpu: @@ -105,6 +110,7 @@ def restore_from(self, path, local_rank=0): load_device = self._device self.load_state_dict(t.load(path, map_location=load_device)) + @t.jit.ignore() def freeze(self, weights=None): if hasattr(self, "_pt_module"): for name, param in self._pt_module.named_parameters(): @@ -115,6 +121,7 @@ def freeze(self, weights=None): if weights is None or name in weights: param.requires_grad = False + @t.jit.ignore() def unfreeze(self, weights=None): if hasattr(self, "_pt_module"): for name, param in self._pt_module.named_parameters(): diff --git a/nemo/nemo/core/neural_factory.py b/nemo/nemo/core/neural_factory.py index 6b39bdebfc1e..d1e1ac2c23d0 100644 --- a/nemo/nemo/core/neural_factory.py +++ b/nemo/nemo/core/neural_factory.py @@ -4,7 +4,8 @@ 'Optimization', 'DeviceType', 'Actions', - 'NeuralModuleFactory'] + 'NeuralModuleFactory', + 'DeploymentFormat'] from abc import ABC, abstractmethod import random @@ -18,6 +19,14 @@ from ..utils import ExpManager +class DeploymentFormat(Enum): + """Which format to use when exporting a Neural Module for deployment""" + AUTO = 0 + PYTORCH = 1 + TORCHSCRIPT = 2 + ONNX = 3 + + class Backend(Enum): """Supported backends. For now, it is only PyTorch.""" @@ -579,6 +588,29 @@ def eval(self, optimization_params={'num_epochs': 1} ) + def deployment_export(self, + module, + output: str, + d_format: DeploymentFormat, + input_example=None, + output_example=None): + """Exports Neural Module instance for deployment. + + Args: + module: neural module to export + output (str): where export results should be saved + d_format (DeploymentFormat): which deployment format to use + input_example: sometimes tracing will require input examples + output_example: Should match inference on input_example + """ + return self._trainer.deployment_export( + module=module, + output=output, + d_format=d_format, + input_example=input_example, + output_example=output_example + ) + def infer(self, tensors: List[NmTensor], checkpoint_dir=None, diff --git a/nemo/nemo/core/neural_modules.py b/nemo/nemo/core/neural_modules.py index 9b784911ea1b..5f67637a6bc5 100644 --- a/nemo/nemo/core/neural_modules.py +++ b/nemo/nemo/core/neural_modules.py @@ -23,7 +23,6 @@ class WeightShareTransform(Enum): """When sharing parameters, what kind of transform to apply.""" - SAME = 0 TRANSPOSE = 1 diff --git a/nemo/setup.py b/nemo/setup.py index 446cac6750f0..f6976489f592 100644 --- a/nemo/setup.py +++ b/nemo/setup.py @@ -23,7 +23,8 @@ 'torchvision', 'tensorboardX', 'pandas', - 'wget' + 'wget', + 'onnx' ] ) diff --git a/tests/common_setup.py b/tests/common_setup.py index 9ac8caa0cec1..8e2a7c27dc2a 100644 --- a/tests/common_setup.py +++ b/tests/common_setup.py @@ -2,6 +2,8 @@ import unittest import os from .context import nemo +# from nemo.backends.pytorch.nm import TrainableNM +import torch.nn as nn class NeMoUnitTest(unittest.TestCase): diff --git a/tests/test_deploy_export.py b/tests/test_deploy_export.py new file mode 100644 index 000000000000..ade5781de972 --- /dev/null +++ b/tests/test_deploy_export.py @@ -0,0 +1,108 @@ +# Copyright (c) 2019 NVIDIA Corporation +import os +from pathlib import Path + +import torch +from ruamel.yaml import YAML + +from .common_setup import NeMoUnitTest +from .context import nemo, nemo_asr, nemo_nlp + + +class TestDeployExport(NeMoUnitTest): + def setUp(self) -> None: + self.nf = nemo.core.NeuralModuleFactory( + placement=nemo.core.DeviceType.GPU) + + def __test_export_route(self, module, out_name, mode, + input_example=None): + out = Path(out_name) + if out.exists(): + os.remove(out) + + self.nf.deployment_export( + module=module, + output=out_name, + input_example=input_example, + d_format=mode) + + self.assertTrue(out.exists()) + if out.exists(): + os.remove(out) + + def test_simple_module_export(self): + simplest_module = \ + nemo.backends.pytorch.tutorials.TaylorNet(dim=4, factory=self.nf) + self.__test_export_route(module=simplest_module, + out_name="simple.pt", + mode=nemo.core.DeploymentFormat.TORCHSCRIPT, + input_example=None) + + def test_simple_module_onnx_export(self): + simplest_module = \ + nemo.backends.pytorch.tutorials.TaylorNet(dim=4, factory=self.nf) + self.__test_export_route(module=simplest_module, + out_name="simple.onnx", + mode=nemo.core.DeploymentFormat.ONNX, + input_example=torch.randn(16, 1).cuda()) + + def test_TokenClassifier_module_export(self): + t_class = nemo_nlp.TokenClassifier(hidden_size=512, num_classes=16, + use_transformer_pretrained=False) + self.__test_export_route(module=t_class, + out_name="t_class.pt", + mode=nemo.core.DeploymentFormat.TORCHSCRIPT, + input_example=torch.randn(16, 16, 512).cuda()) + + def test_TokenClassifier_module_onnx_export(self): + t_class = nemo_nlp.TokenClassifier(hidden_size=512, num_classes=16, + use_transformer_pretrained=False) + self.__test_export_route(module=t_class, + out_name="t_class.onnx", + mode=nemo.core.DeploymentFormat.ONNX, + input_example=torch.randn(16, 16, 512).cuda()) + + def test_jasper_decoder_export_ts(self): + j_decoder = nemo_asr.JasperDecoderForCTC(feat_in=1024, + num_classes=33) + self.__test_export_route(module=j_decoder, + out_name="j_decoder.ts", + mode=nemo.core.DeploymentFormat.TORCHSCRIPT, + input_example=None) + + def test_hf_bert_ts(self): + bert = nemo_nlp.huggingface.BERT( + pretrained_model_name="bert-base-uncased") + input_example = (torch.randint(low=0, high=16, size=(2, 16)).cuda(), + torch.randint(low=0, high=1, size=(2, 16)).cuda(), + torch.randint(low=0, high=1, size=(2, 16)).cuda()) + self.__test_export_route(module=bert, + out_name="bert.ts", + mode=nemo.core.DeploymentFormat.TORCHSCRIPT, + input_example=input_example) + + def test_hf_bert_pt(self): + bert = nemo_nlp.huggingface.BERT( + pretrained_model_name="bert-base-uncased") + self.__test_export_route(module=bert, + out_name="bert.pt", + mode=nemo.core.DeploymentFormat.PYTORCH) + + def test_jasper_encoder_to_onnx(self): + with open("tests/data/jasper_smaller.yaml") as file: + yaml = YAML(typ="safe") + jasper_model_definition = yaml.load(file) + + jasper_encoder = nemo_asr.JasperEncoder( + conv_mask=False, + feat_in=jasper_model_definition[ + 'AudioToMelSpectrogramPreprocessor']['features'], + **jasper_model_definition['JasperEncoder'] + ) + + self.__test_export_route(module=jasper_encoder, + out_name="jasper_encoder.onnx", + mode=nemo.core.DeploymentFormat.ONNX, + input_example=( + torch.randn(16, 64, 256).cuda(), + torch.randn(256).cuda()))