Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed Jun 25, 2020
2 parents 2cdc795 + 8bc1766 commit 5b00b5a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@ Most recent releases are shown at the top. Each release shows:
- **Fixed**: Bug fixes that don't change documented behaviour


## 0.17.1 (2020-06-24)

### New:
- N/A

### Changed
- N/A

### Fixed:
- Properly set device in `text.Translator` and use cuda when available


## 0.17.0 (2020-06-24)

### New:
Expand Down
4 changes: 2 additions & 2 deletions ktrain/text/translation/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, model_name=None, device=None):
if self.torch_device is None: self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
from transformers import MarianMTModel, MarianTokenizer
self.tokenizer = MarianTokenizer.from_pretrained(model_name)
self.model = MarianMTModel.from_pretrained(model_name)
self.model = MarianMTModel.from_pretrained(model_name).to(self.torch_device)


def translate(self, src_text, join_with='\n'):
Expand All @@ -48,7 +48,7 @@ def translate(self, src_text, join_with='\n'):
"""
# tokenize text into sentences:
sentences = TU.sent_tokenize(src_text)
translated = self.model.generate(**self.tokenizer.prepare_translation_batch(sentences))
translated = self.model.generate(**self.tokenizer.prepare_translation_batch(sentences).to(self.torch_device))
tgt_sentences = [self.tokenizer.decode(t, skip_special_tokens=True) for t in translated]
return join_with.join(tgt_sentences)

Expand Down
2 changes: 1 addition & 1 deletion ktrain/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__all__ = ['__version__']
__version__ = '0.17.0'
__version__ = '0.17.1'

0 comments on commit 5b00b5a

Please sign in to comment.