Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed Jun 7, 2020
2 parents 7ff72ec + 2c12611 commit c9d4a66
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 8 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@ Most recent releases are shown at the top. Each release shows:
- **Fixed**: Bug fixes that don't change documented behaviour


## 0.16.2 (2020-06-07)

### New:
- N/A

### Changed
- default model for summarization changed to `facebook/bart-large-cnn` due to breaking change in v2.11
- added `device` argument to `TransformerSummarizer` constructor to control PyTorch device

### Fixed:
- require `transformers>=2.11.0` due to breaking changes in 2.11 related to `bart` models


## 0.16.1 (2020-06-05)

### New:
Expand Down
2 changes: 1 addition & 1 deletion FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ smaller sample sizes (e.g., 500, 1000) may be sufficient for your use case.
Examples include:

- **medical informatics:** analyzing doctors' written analyses of patients and medical imagery
- **finance:** analyzing financial and stock-related news stories
- **finance:** financial crime analytics, mining stock-related news stories
- **insurance:** detecting fraud in insurance claims
- **social science:** making sense of text-based responses in surveys and emotion-classification from text data
- **linguistics:** detecting sarcasm in the news
Expand Down
2 changes: 1 addition & 1 deletion ktrain/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,7 +1335,7 @@ def load_predictor(fpath, batch_size=U.DEFAULT_BS):
#warnings.warn('could not load .preproc file as %s - attempting to load as %s' % (os.path.join(fpath, U.PREPROC_NAME), preproc_name))
with open(preproc_name, 'rb') as f: preproc = pickle.load(f)
except:
raise Exception('Could not find a .preproc file in either the post v0.16.x loction (%s) or pre v0.16.x location (%s)' % (os.path.join(fpath. U.PRERPC_NAME), fpath+'.preproc'))
raise Exception('Could not find a .preproc file in either the post v0.16.x loction (%s) or pre v0.16.x location (%s)' % (os.path.join(fpath. U.PREPROC_NAME), fpath+'.preproc'))

# load the model
model = _load_model(fpath, preproc=preproc)
Expand Down
10 changes: 6 additions & 4 deletions ktrain/text/summarization/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,23 @@ class TransformerSummarizer():
interface to Transformer-based text summarization
"""

def __init__(self, model_name='bart-large-cnn'):
def __init__(self, model_name='facebook/bart-large-cnn', device=None):
"""
interface to BART-based text summarization using transformers library
Args:
model_name(str): name of BART model
model_name(str): name of BART model for summarization
device(str): device to use (e.g., 'cuda', 'cpu')
"""
if model_name.split('-')[0] != 'bart':
if 'bart' not in model_name:
raise ValueError('TransformerSummarizer currently only accepts BART models')
try:
import torch
except ImportError:
raise Exception('TransformerSummarizer requires PyTorch to be installed.')
self.torch_device = device
if self.torch_device is None: self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
from transformers import BartTokenizer, BartForConditionalGeneration
self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.tokenizer = BartTokenizer.from_pretrained(model_name)
self.model = BartForConditionalGeneration.from_pretrained(model_name).to(self.torch_device)

Expand Down
2 changes: 1 addition & 1 deletion ktrain/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__all__ = ['__version__']
__version__ = '0.16.1'
__version__ = '0.16.2'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
'seqeval',
'packaging',
'tensorflow_datasets',
'transformers>=2.7.0',
'transformers>=2.11.0',
'ipython',
'syntok',
'whoosh'
Expand Down

0 comments on commit c9d4a66

Please sign in to comment.