Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nmt service update #1909

Merged
merged 10 commits into from
Mar 24, 2021
6 changes: 2 additions & 4 deletions examples/nlp/machine_translation/nmt_webapp/README.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
**STEPS 2 RUN**
===============
0. (Make sure you have Flask installed - ``pip install flask``)
1. Train NMT model derived from ``nemo.collections.nlp.models.machine_translation.mt_enc_dec_model.MTEncDecModel``
2. Download resulting .nemo file and store it locally at location PATH2NEMO_FILE
3. In ``nmt_service.py`` file set PATH2NEMO_FILE to correct location
0. (Make sure you have Flask installed - ``pip install flask flask-cors``)
1. Edit "config.json" file to only contain models you need. If model's location starts with "NGC/" - it will load this model from NVIDIA's NGC. Otherwise, specify full path to .nemo file.
4. To run: ``python nmt_service.py``
5. To translate: ``http://127.0.0.1:5000/translate?text=Frohe%20Weihnachten`` (here %20 means space)
14 changes: 11 additions & 3 deletions examples/nlp/machine_translation/nmt_webapp/config.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
{
"en-de": "/home/okuchaiev/Workspace/MTModels/DeEn/en-de.nemo",
"de-en": "/home/okuchaiev/Workspace/MTModels/DeEn/de-en.nemo"
}
"en-de": "NGC/nmt_en_de_transformer12x2",
"de-en": "NGC/nmt_de_en_transformer12x2",
"en-es": "NGC/nmt_en_es_transformer12x2",
"es-en": "NGC/nmt_es_en_transformer12x2",
"en-ru": "NGC/nmt_en_ru_transformer6x6",
"ru-en": "NGC/nmt_ru_en_transformer6x6",
"en-fr": "NGC/nmt_en_fr_transformer12x2",
"fr-en": "NGC/nmt_fr_en_transformer12x2",
"en-zh": "NGC/nmt_en_zh_transformer6x6",
"zh-en": "NGC/nmt_zh_en_transformer6x6"
}
56 changes: 51 additions & 5 deletions examples/nlp/machine_translation/nmt_webapp/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
$( "#tgt_text" ).val(result.translation);
},
error: function (err) {
alert('No');
alert('Sorry, can not reach translation service');
},
})
} );
Expand All @@ -51,6 +51,11 @@
<div class="contact-us">
<h1>Neural Machine Translation Demo</h1>
<p>
<br/>
<h4>Running NeMo commit id: r1.0.0rc1 </h4>
<br/>
</p>
<p>
<fieldset>
<legend>Select translation direction: </legend>
<table>
Expand All @@ -61,17 +66,58 @@ <h1>Neural Machine Translation Demo</h1>
</td>
<td>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</td>
<td>
<label for="de-en" selected="selected">German to English</label>
<input type="radio" name="langpair" id="de-en" checked>
<label for="de-en">German to English</label>
<input type="radio" name="langpair" id="de-en">
</td>
<td>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</td>
<td>
<label for="es-en">Spanish to English</label>
<input type="radio" name="langpair" id="es-en">
</td>
<td>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</td>
<td>
<label for="en-es" selected="selected">English to Spanish</label>
<input type="radio" name="langpair" id="en-es" checked>
</td>
</tr>
<tr>
<td>
<label for="en-ru">English to Russian</label>
<input type="radio" name="langpair" id="en-ru">
</td>
<td>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</td>
<td>
<label for="ru-en">Russian to English</label>
<input type="radio" name="langpair" id="ru-en">
</td>
<td>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</td>
<td>
<label for="en-fr">English to French</label>
<input type="radio" name="langpair" id="en-fr">
</td>
<td>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</td>
<td>
<label for="fr-en">French to English</label>
<input type="radio" name="langpair" id="fr-en">
</td>
<td>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</td>
<td>
<label for="en-zh">English to Chinese</label>
<input type="radio" name="langpair" id="en-zh">
</td>
<td>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</td>
<td>
<label for="zh-en">Chinese to English</label>
<input type="radio" name="langpair" id="zh-en">
</td>
</tr>
</table>
</fieldset>
<div>
<table>
<tr>
<td>
<textarea id="src_text" rows="12" cols="60"></textarea>
<textarea id="src_text" rows="12" cols="60">Type something here, chose language pair and click "Translate" button.</textarea>
</td>
</tr>
<tr>
Expand All @@ -91,4 +137,4 @@ <h1>Neural Machine Translation Demo</h1>


</body>
</html>
</html>
12 changes: 9 additions & 3 deletions examples/nlp/machine_translation/nmt_webapp/nmt_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@
import flask
import torch
from flask import Flask, json, request
from flask_cors import CORS

import nemo.collections.nlp as nemo_nlp
from nemo.utils import logging

PATH2NEMO_FILE = '[PATH TO YOUR NMT MODEL .nemo FILE]'
MODELS_DICT = {}

model = None
api = Flask(__name__)
CORS(api)


def initialize(config_file_path: str):
Expand All @@ -35,6 +36,8 @@ def initialize(config_file_path: str):
__MODELS_DICT = None

logging.info("Starting NMT service")
logging.info(f"I will attempt to load all the models listed in {config_file_path}.")
logging.info(f"Edit {config_file_path} to disable models you don't need.")
if torch.cuda.is_available():
logging.info("CUDA is available. Running on GPU")
else:
Expand All @@ -47,7 +50,10 @@ def initialize(config_file_path: str):
if __MODELS_DICT is not None:
for key, value in __MODELS_DICT.items():
logging.info(f"Loading model for {key} from file: {value}")
model = nemo_nlp.models.machine_translation.MTEncDecModel.restore_from(restore_path=value)
if value.startswith("NGC/"):
model = nemo_nlp.models.machine_translation.MTEncDecModel.from_pretrained(model_name=value[4:])
else:
model = nemo_nlp.models.machine_translation.MTEncDecModel.restore_from(restore_path=value)
if torch.cuda.is_available():
model = model.cuda()
MODELS_DICT[key] = model
Expand Down Expand Up @@ -91,4 +97,4 @@ def get_translation():

if __name__ == '__main__':
initialize('config.json')
api.run()
api.run(host='0.0.0.0')