Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed Sep 3, 2020
2 parents 27b54c2 + 0fddb11 commit 2bd430e
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 16 deletions.
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@ Most recent releases are shown at the top. Each release shows:
- **Changed**: Additional parameters, changes to inputs or outputs, etc
- **Fixed**: Bug fixes that don't change documented behaviour

## 0.21.1 (2020-09-03)

### New:
- N/A

### Changed
- added `num_beams` and `early_stopping` arguments to `translate` methods in `translation` module that can be set to improve translation speed
- added `half` parameter to `Translator` construcor

### Fixed:
- N/A



## 0.21.0 (2020-09-03)

### New:
Expand Down
42 changes: 32 additions & 10 deletions ktrain/text/translation/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ class Translator():
Translator: basic wrapper around MarianMT model for language translation
"""

def __init__(self, model_name=None, device=None):
def __init__(self, model_name=None, device=None, half=False):
"""
basic wrapper around MarianMT model for language translation
Args:
model_name(str): Helsinki-NLP model
device(str): device to use (e.g., 'cuda', 'cpu')
half(bool): If True, use half precision.
"""
if 'Helsinki-NLP' not in model_name:
raise ValueError('Translator requires a Helsinki-NLP model: https://huggingface.co/Helsinki-NLP')
Expand All @@ -28,11 +29,13 @@ def __init__(self, model_name=None, device=None):
from transformers import MarianMTModel, MarianTokenizer
self.tokenizer = MarianTokenizer.from_pretrained(model_name)
self.model = MarianMTModel.from_pretrained(model_name).to(self.torch_device)
if half: self.model = self.model.half()


def translate(self, src_text, join_with='\n'):
def translate(self, src_text, join_with='\n', num_beams=None, early_stopping=None):
"""
translate sentence using model_name as model
Translate document (src_text).
To speed up translations, you can set num_beams and early_stopping (e.g., num_beams=4, early_stopping=True).
Args:
src_text(str): source text.
The source text can either be a single sentence or an entire document with multiple sentences
Expand All @@ -43,29 +46,41 @@ def translate(self, src_text, join_with='\n'):
feed each chunk separately into translate to avoid out-of-memory issues.
join_with(str): list of translated sentences will be delimited with this character.
default: each sentence on separate line
num_beams(int): Number of beams for beam search. Defaults to None. If None, the transformers library defaults this to 1,
whicn means no beam search.
early_stopping(bool): Whether to stop the beam search when at least ``num_beams`` sentences
are finished per batch or not. Defaults to None. If None, the transformers library
sets this to False.
Returns:
str: translated text
"""
sentences = TU.sent_tokenize(src_text)
tgt_sentences = self.translate_sentences(sentences)
tgt_sentences = self.translate_sentences(sentences, num_beams=num_beams, early_stopping=early_stopping)
return join_with.join(tgt_sentences)


def translate_sentences(self, sentences):
def translate_sentences(self, sentences, num_beams=None, early_stopping=None):
"""
translate sentence using model_name as model
Translate sentences using model_name as model.
To speed up translations, you can set num_beams and early_stopping (e.g., num_beams=4, early_stopping=True).
Args:
sentences(list): list of strings representing sentences that need to be translated
IMPORTANT NOTE: Sentences are joined together and fed to model as single batch.
If the input text is very large (e.g., an entire book), you should
break it up into reasonbly-sized chunks (e.g., pages, paragraphs, or sentences) and
feed each chunk separately into translate to avoid out-of-memory issues.
num_beams(int): Number of beams for beam search. Defaults to None. If None, the transformers library defaults this to 1,
whicn means no beam search.
early_stopping(bool): Whether to stop the beam search when at least ``num_beams`` sentences
are finished per batch or not. Defaults to None. If None, the transformers library
sets this to False.
Returns:
str: translated sentences
"""
import torch
with torch.no_grad():
translated = self.model.generate(**self.tokenizer.prepare_seq2seq_batch(sentences).to(self.torch_device))
translated = self.model.generate(**self.tokenizer.prepare_seq2seq_batch(sentences).to(self.torch_device),
num_beams=num_beams, early_stopping=early_stopping)
tgt_sentences = [self.tokenizer.decode(t, skip_special_tokens=True) for t in translated]
return tgt_sentences

Expand Down Expand Up @@ -116,9 +131,11 @@ def __init__(self, src_lang=None, device=None):
raise ValueError('lang:%s is currently not supported.' % (src_lang))


def translate(self, src_text, join_with='\n'):
def translate(self, src_text, join_with='\n', num_beams=None, early_stopping=None):
"""
translate source sentence to English.
Translate source document to English.
To speed up translations, you can set num_beams and early_stopping (e.g., num_beams=4, early_stopping=True).
Args:
src_text(str): source text. Must be in language specified by src_lang (language code) supplied to constructor
The source text can either be a single sentence or an entire document with multiple sentences
Expand All @@ -129,12 +146,17 @@ def translate(self, src_text, join_with='\n'):
feed each chunk separately into translate to avoid out-of-memory issues.
join_with(str): list of translated sentences will be delimited with this character.
default: each sentence on separate line
num_beams(int): Number of beams for beam search. Defaults to None. If None, the transformers library defaults this to 1,
whicn means no beam search.
early_stopping(bool): Whether to stop the beam search when at least ``num_beams`` sentences
are finished per batch or not. Defaults to None. If None, the transformers library
sets this to False.
Returns:
str: translated text
"""
text = src_text
for t in self.translators:
text = t.translate(text, join_with=join_with)
text = t.translate(text, join_with=join_with, num_beams=num_beams, early_stopping=early_stopping)
return text


Expand Down
2 changes: 1 addition & 1 deletion ktrain/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__all__ = ['__version__']
__version__ = '0.21.0'
__version__ = '0.21.1'
5 changes: 0 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import sys
if sys.version_info.major != 3: raise Exception('ktrain requires Python 3')
tf_version_str = 'tensorflow==2.1.0'
if sys.version_info.minor == 8:
tf_version_str = 'tensorflow>=2.2.0'

from distutils.core import setup
import setuptools
Expand All @@ -28,8 +25,6 @@
url = 'https://github.com/amaiya/ktrain',
keywords = ['tensorflow', 'keras', 'deep learning', 'machine learning'],
install_requires=[
#tf_version_str, # removed TensorFlow dependency in v0.20.3
#'tensorflow_datasets',
#'scipy==1.4.1', # removed due to https://github.com/tensorflow/tensorflow/commit/78026d6a66f7f0fc
#'pillow'
'scikit-learn>=0.21.3', # previously pinned to 0.21.3 due to TextPredictor.explain, but no longer needed as of 0.19.7
Expand Down

0 comments on commit 2bd430e

Please sign in to comment.