Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/NVIDIA/NeMo into add_confor…
Browse files Browse the repository at this point in the history
…mer3
  • Loading branch information
VahidooX committed Oct 22, 2020
2 parents 9420578 + 53b31f7 commit ef7a236
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Jenkinsfile
Expand Up @@ -61,7 +61,7 @@ pipeline {
}
}
steps {
sh 'pytest -m "unit and not pleasefixme" --cpu'
sh 'CUDA_VISIBLE_DEVICES="" pytest -m "unit and not pleasefixme" --cpu'
}
}

Expand Down
60 changes: 59 additions & 1 deletion nemo/collections/asr/models/label_models.py
Expand Up @@ -18,6 +18,7 @@
import pickle as pkl
from typing import Dict, List, Optional, Union

import onnx
import torch
from omegaconf import DictConfig
from omegaconf.omegaconf import open_dict
Expand All @@ -31,13 +32,15 @@
from nemo.collections.common.metrics import TopKClassificationAccuracy
from nemo.core.classes import ModelPT
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.classes.exportable import Exportable
from nemo.core.neural_types import *
from nemo.utils import logging
from nemo.utils.export_utils import attach_onnx_to_onnx

__all__ = ['EncDecSpeakerLabelModel', 'ExtractSpeakerEmbeddingsModel']


class EncDecSpeakerLabelModel(ModelPT):
class EncDecSpeakerLabelModel(ModelPT, Exportable):
"""Encoder decoder class for speaker label models.
Model class creates training, validation methods for setting up data
performing model forward pass.
Expand Down Expand Up @@ -309,6 +312,61 @@ def setup_finetune_model(self, model_config: DictConfig):

logging.info(f"Changed decoder output to # {self.decoder._num_classes} classes.")

def export(
self,
output: str,
input_example=None,
output_example=None,
verbose=False,
export_params=True,
do_constant_folding=True,
keep_initializers_as_inputs=False,
onnx_opset_version: int = 12,
try_script: bool = False,
set_eval: bool = True,
check_trace: bool = True,
use_dynamic_axes: bool = True,
):
if input_example is not None or output_example is not None:
logging.warning(
"Passed input and output examples will be ignored and recomputed since"
" EncDecSpeakerModel consists of two separate models (encoder and decoder) with different"
" inputs and outputs."
)

encoder_onnx = self.encoder.export(
os.path.join(os.path.dirname(output), 'encoder_' + os.path.basename(output)),
None, # computed by input_example()
None,
verbose,
export_params,
do_constant_folding,
keep_initializers_as_inputs,
onnx_opset_version,
try_script,
set_eval,
check_trace,
use_dynamic_axes,
)

decoder_onnx = self.decoder.export(
os.path.join(os.path.dirname(output), 'decoder_' + os.path.basename(output)),
None, # computed by input_example()
None,
verbose,
export_params,
do_constant_folding,
keep_initializers_as_inputs,
onnx_opset_version,
try_script,
set_eval,
check_trace,
use_dynamic_axes,
)

output_model = attach_onnx_to_onnx(encoder_onnx, decoder_onnx, "SL")
onnx.save(output_model, output)


class ExtractSpeakerEmbeddingsModel(EncDecSpeakerLabelModel):
"""
Expand Down
2 changes: 1 addition & 1 deletion nemo/core/classes/modelPT.py
Expand Up @@ -88,7 +88,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self._trainer = trainer

# Set device_id in AppState
if torch.cuda.current_device() is not None:
if torch.cuda.is_available() and torch.cuda.current_device() is not None:
app_state = AppState()
app_state.device_id = torch.cuda.current_device()

Expand Down
21 changes: 13 additions & 8 deletions nemo/utils/export_utils.py
Expand Up @@ -191,25 +191,30 @@ def replace_for_export(model: nn.Module, replace_1D_2D: bool = False) -> nn.Modu

def attach_onnx_to_onnx(model1: onnx.ModelProto, model2: onnx.ModelProto, prefix2: str):

if len(model1.graph.output) < 1 or len(model1.graph.output) != len(model2.graph.output):
if len(model1.graph.output) < 1 or len(model1.graph.output) != len(model2.graph.input):
raise ValueError(
'Incompatible input/output dimensions: {} != {}'.format(len(model1.graph.output), len(model2.graph.output))
'Incompatible input/output dimensions: {} != {}'.format(len(model1.graph.output), len(model2.graph.input))
)
for i in range(len(model2.graph.initializer)):
model2.graph.initializer[i].name = prefix2 + model2.graph.initializer[i].name
for i in range(len(model2.graph.node)):
model2.graph.node[i].name = prefix2 + model2.graph.node[i].name

for o in range(len(model1.graph.output)):
for i in range(len(model2.graph.node)):
for j in range(len(model2.graph.node[i].input)):
for i in range(len(model2.graph.node)):
for j in range(len(model2.graph.node[i].input)):
for o in range(len(model1.graph.output)):
if model2.graph.node[i].input[j] == model2.graph.input[o].name:
model2.graph.node[i].input[j] = model1.graph.output[o].name
else:
model2.graph.node[i].input[j] = prefix2 + model2.graph.node[i].input[j]
for j in range(len(model2.graph.node[i].output)):
if model2.graph.node[i].output[j] != model2.graph.output[o].name:
model2.graph.node[i].output[j] = prefix2 + model2.graph.node[i].output[j]
for j in range(len(model2.graph.node[i].output)):
inner_output = True
for p in range(len(model2.graph.output)):
if model2.graph.node[i].output[j] == model2.graph.output[p].name:
inner_output = False
break
if inner_output:
model2.graph.node[i].output[j] = prefix2 + model2.graph.node[i].output[j]

graph = onnx.GraphProto()
graph.node.extend(model1.graph.node)
Expand Down
52 changes: 51 additions & 1 deletion tests/collections/asr/test_asr_exportables.py
Expand Up @@ -18,7 +18,7 @@
import pytest
from omegaconf import DictConfig, ListConfig

from nemo.collections.asr.models import EncDecClassificationModel, EncDecCTCModel
from nemo.collections.asr.models import EncDecClassificationModel, EncDecCTCModel, EncDecSpeakerLabelModel
from nemo.collections.asr.modules import ConvASRDecoder, ConvASREncoder


Expand Down Expand Up @@ -84,6 +84,20 @@ def test_EncDecClassificationModel_export_to_onnx(self, speech_classification_mo
assert onnx_model.graph.input[0].name == 'audio_signal'
assert onnx_model.graph.output[0].name == 'logits'

def test_EncDecSpeakerLabelModel_export_to_onnx(self, speaker_label_model):
model = speaker_label_model.train()
with tempfile.TemporaryDirectory() as tmpdir:
filename = os.path.join(tmpdir, 'sl.onnx')
model.export(output=filename)
onnx_model = onnx.load(filename)
onnx.checker.check_model(onnx_model, full_check=True) # throws when failed
assert len(onnx_model.graph.node) == 31
assert onnx_model.graph.node[0].name == 'Conv_0'
assert onnx_model.graph.node[12].name == 'SLConstant_9'
assert onnx_model.graph.node[30].name == 'SLGemm_27'
assert onnx_model.graph.input[0].name == 'audio_signal'
assert onnx_model.graph.output[0].name == 'logits'

def setup_method(self):
self.preprocessor = {
'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor',
Expand Down Expand Up @@ -193,3 +207,39 @@ def speech_classification_model():
)
model = EncDecClassificationModel(cfg=modelConfig)
return model


@pytest.fixture()
def speaker_label_model():
preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})}
encoder = {
'cls': 'nemo.collections.asr.modules.ConvASREncoder',
'params': {
'feat_in': 64,
'activation': 'relu',
'conv_mask': True,
'jasper': [
{
'filters': 512,
'repeat': 1,
'kernel': [1],
'stride': [1],
'dilation': [1],
'dropout': 0.0,
'residual': False,
'separable': False,
}
],
},
}

decoder = {
'cls': 'nemo.collections.asr.modules.SpeakerDecoder',
'params': {'feat_in': 512, 'num_classes': 2, 'pool_mode': 'xvector', 'emb_sizes': [1024]},
}

modelConfig = DictConfig(
{'preprocessor': DictConfig(preprocessor), 'encoder': DictConfig(encoder), 'decoder': DictConfig(decoder)}
)
speaker_model = EncDecSpeakerLabelModel(cfg=modelConfig)
return speaker_model

0 comments on commit ef7a236

Please sign in to comment.