Skip to content

Commit

Permalink
Add basic grpc MT server
Browse files Browse the repository at this point in the history
Add readme, server updates

Signed-off-by: Ryan Leary <rleary@nvidia.com>
  • Loading branch information
ryanleary committed Feb 26, 2021
1 parent db5a787 commit 619df56
Show file tree
Hide file tree
Showing 6 changed files with 470 additions and 0 deletions.
11 changes: 11 additions & 0 deletions examples/nlp/machine_translation/nmt_server/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# NMT Server Getting Started
1. Install latest grpc (pip install grpc grpcio-tools)
2. Create a models directory and copy nemo models there. Note that models should be named `<source>-<target>.nemo`
where `<source>` and `<target>` are both two letter language codes, e.g. `en-es.nemo`.
3. Start the server, explicitly loading each model:
```
python server.py --model models/en-es.nemo --model models/es-en.nemo --model models/En-Ja.nemo
```

## Notes
Port can be overridden with `--port` flag. Default is 50052. Beam decoder parameters can also be set at server start time. See `--help` for more details.
199 changes: 199 additions & 0 deletions examples/nlp/machine_translation/nmt_server/api/nmt_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

72 changes: 72 additions & 0 deletions examples/nlp/machine_translation/nmt_server/api/nmt_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc

import nmt_pb2 as nmt__pb2


class JarvisTranslateStub(object):
"""Jarvis NLP Services implement task-specific APIs for popular NLP tasks including
intent recognition (as well as slot filling), and entity extraction.
"""

def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.TranslateText = channel.unary_unary(
'/nvidia.jarvis.nmt.JarvisTranslate/TranslateText',
request_serializer=nmt__pb2.TranslateTextRequest.SerializeToString,
response_deserializer=nmt__pb2.TranslateTextResponse.FromString,
)


class JarvisTranslateServicer(object):
"""Jarvis NLP Services implement task-specific APIs for popular NLP tasks including
intent recognition (as well as slot filling), and entity extraction.
"""

def TranslateText(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_JarvisTranslateServicer_to_server(servicer, server):
rpc_method_handlers = {
'TranslateText': grpc.unary_unary_rpc_method_handler(
servicer.TranslateText,
request_deserializer=nmt__pb2.TranslateTextRequest.FromString,
response_serializer=nmt__pb2.TranslateTextResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'nvidia.jarvis.nmt.JarvisTranslate', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))


# This class is part of an EXPERIMENTAL API.
class JarvisTranslate(object):
"""Jarvis NLP Services implement task-specific APIs for popular NLP tasks including
intent recognition (as well as slot filling), and entity extraction.
"""

@staticmethod
def TranslateText(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/nvidia.jarvis.nmt.JarvisTranslate/TranslateText',
nmt__pb2.TranslateTextRequest.SerializeToString,
nmt__pb2.TranslateTextResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
47 changes: 47 additions & 0 deletions examples/nlp/machine_translation/nmt_server/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@

from concurrent import futures
from time import time
import math
import logging

import grpc
import argparse
import api.nmt_pb2 as nmt
import api.nmt_pb2_grpc as nmtsrv

def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--target", default="es", type=str)
parser.add_argument("--source", default="en", type=str)
parser.add_argument("--text", default="", type=str)
parser.add_argument("--port", default=50052, type=int, required=False)

args = parser.parse_args()
return args

if __name__ == '__main__':
args = get_args()
with grpc.insecure_channel(f'localhost:{args.port}') as channel:
stub = nmtsrv.JarvisTranslateStub(channel)

iterations = 1
start_time = time()
for _ in range(iterations):
req = nmt.TranslateTextRequest(texts=[args.text], source_language=args.source, target_language=args.target)
result = stub.TranslateText(req)
end_time = time()
print(f"Time to complete {iterations} synchronous requests: {end_time-start_time}")
print(result)
print(result.translations[0].translation)

# iterations = 1
# start_time = time()
# futures = []
# for _ in range(iterations):
# req = nmt.TranslateTextRequest(texts=["Hello, can you hear me?"])
# futures.append(stub.TranslateText.future(req))
# for f in futures:
# f.result()
# end_time = time()
# print(f"Time to complete {iterations} asynchronous requests: {end_time-start_time}")
# print(futures[0].result())
39 changes: 39 additions & 0 deletions examples/nlp/machine_translation/nmt_server/nmt.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.

syntax = "proto3";

package nvidia.jarvis.nmt;

option cc_enable_arenas = true;
option go_package = "nvidia.com/jarvis_speech";

// Jarvis NLP Services implement task-specific APIs for popular NLP tasks including
// intent recognition (as well as slot filling), and entity extraction.
service JarvisTranslate {
rpc TranslateText(TranslateTextRequest) returns (TranslateTextResponse) {}

}

message TranslateTextRequest {
repeated string texts = 1;

string source_language = 3;
string target_language = 4;
}

message Translation {
string translation = 1;

string language = 2;
}

message TranslateTextResponse {
repeated Translation translations = 1;

}
Loading

0 comments on commit 619df56

Please sign in to comment.