Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed Apr 9, 2020
2 parents 6223f4d + ff3c618 commit 3c7115e
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 8 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@ Most recent releases are shown at the top. Each release shows:
- **Fixed**: Bug fixes that don't change documented behaviour


## 0.13.2 (2020-04-09)

### New:
- N/A

### Changed
- `TransformerSummarizer` accepts BART `model_name` as parameter


### Fixed:
- N/A



## 0.13.1 including 0.13.0 (2020-04-09)

Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ learner.fit(0.01, 1, cycle_len=5)

*ktrain* is a lightweight wrapper for the deep learning library [TensorFlow Keras](https://www.tensorflow.org/guide/keras/overview) (and other libraries) to help build, train, and deploy neural networks. With only a few lines of code, ktrain allows you to easily and quickly:

- estimate an optimal learning rate for your model given your data using a Learning Rate Finder
- utilize learning rate schedules such as the [triangular policy](https://arxiv.org/abs/1506.01186), the [1cycle policy](https://arxiv.org/abs/1803.09820), and [SGDR](https://arxiv.org/abs/1608.03983) to effectively minimize loss and improve generalization
- employ fast and easy-to-use pre-canned models for `text`, `vision`, and `graph` data:
- employ fast, accurate, and easy-to-use pre-canned models for `text`, `vision`, and `graph` data:
- `text` data:
- **Text Classification**: [BERT](https://arxiv.org/abs/1810.04805), [DistilBERT](https://arxiv.org/abs/1910.01108), [NBSVM](https://www.aclweb.org/anthology/P12-2018), [fastText](https://arxiv.org/abs/1607.01759), and other models <sub><sup>[[example notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/master/examples/text/IMDb-BERT.ipynb)]</sup></sub>
- **Text Regression**: [BERT](https://arxiv.org/abs/1810.04805), [DistilBERT](https://arxiv.org/abs/1910.01108), Embedding-based linear text regression, [fastText](https://arxiv.org/abs/1607.01759), and other models <sub><sup>[[example notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/master/examples/text/text_regression_example.ipynb)]</sup></sub>
Expand All @@ -58,6 +56,8 @@ learner.fit(0.01, 1, cycle_len=5)
- `graph` data:
- **node classification** with graph neural networks ([GraphSAGE](https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf)) <sub><sup>[[example notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/master/examples/graphs/pubmed_node_classification-GraphSAGE.ipynb)]</sup></sub>
- **link prediction** with graph neural networks ([GraphSAGE](https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf)) <sub><sup>[[example notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/develop/examples/graphs/cora_link_prediction-GraphSAGE.ipynb)]</sup></sub>
- estimate an optimal learning rate for your model given your data using a Learning Rate Finder
- utilize learning rate schedules such as the [triangular policy](https://arxiv.org/abs/1506.01186), the [1cycle policy](https://arxiv.org/abs/1803.09820), and [SGDR](https://arxiv.org/abs/1608.03983) to effectively minimize loss and improve generalization
- build text classifiers for any language (e.g., [Chinese Sentiment Analysis with BERT](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/master/examples/text/ChineseHotelReviews-BERT.ipynb), [Arabic Sentiment Analysis with NBSVM](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/master/examples/text/ArabicHotelReviews-nbsvm.ipynb))
- easily train NER models for any language (e.g., [Dutch NER](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/master/examples/text/CoNLL2002_Dutch-BiLSTM.ipynb) )
- load and preprocess text and image data from a variety of formats
Expand Down
17 changes: 13 additions & 4 deletions ktrain/text/summarization/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,27 @@

class TransformerSummarizer():
"""
NER preprocessing base class
interface to Transformer-based text summarization
"""

def __init__(self):
def __init__(self, model_name='bart-large-cnn'):
"""
interface to BART-based text summarization using transformers library
Args:
model_name(str): name of BART model
"""
if model_name.split('-')[0] != 'bart':
raise ValueError('TransformerSummarizer currently only accepts BART models')
try:
import torch
except ImportError:
raise Exception('TransformerSummarizer requires PyTorch to be installed.')
from transformers import BartTokenizer, BartForConditionalGeneration
self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
self.model = BartForConditionalGeneration.from_pretrained('bart-large-cnn').to(self.torch_device)
self.tokenizer = BartTokenizer.from_pretrained(model_name)
self.model = BartForConditionalGeneration.from_pretrained(model_name).to(self.torch_device)


def summarize(self, doc):
"""
Expand Down
2 changes: 1 addition & 1 deletion ktrain/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__all__ = ['__version__']
__version__ = '0.13.1'
__version__ = '0.13.2'

0 comments on commit 3c7115e

Please sign in to comment.