Skip to content

Commit

Permalink
add summarize arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed Jun 2, 2023
1 parent fe29089 commit e5aadf1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@ Most recent releases are shown at the top. Each release shows:
- **Changed**: Additional parameters, changes to inputs or outputs, etc
- **Fixed**: Bug fixes that don't change documented behaviour

## 0.37.1 (TBD)

### new:
- Supply arguments to `generate` in `TransformerSummarizer.summarize`

### changed
- N/A

### fixed:
- N/A


## 0.37.0 (2023-05-11)

### new:
Expand Down
24 changes: 17 additions & 7 deletions ktrain/text/summarization/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,19 @@ def __init__(self, model_name="facebook/bart-large-cnn", device=None):
self.torch_device
)

def summarize(self, doc, **kwargs):
def summarize(
self,
doc,
max_length=150,
min_length=56,
no_repeat_ngram_size=3,
length_penalty=2.0,
num_beams=4,
**kwargs,
):
"""
```
summarize document text
Summarize document text. Extra arguments are fed to generate method
Args:
doc(str): text of document
Returns:
Expand All @@ -44,11 +53,12 @@ def summarize(self, doc, **kwargs):
)["input_ids"].to(self.torch_device)
summary_ids = self.model.generate(
answers_input_ids,
num_beams=4,
length_penalty=2.0,
max_length=142,
min_length=56,
no_repeat_ngram_size=3,
num_beams=num_beams,
length_penalty=length_penalty,
max_length=max_length,
min_length=min_length,
no_repeat_ngram_size=no_repeat_ngram_size,
**kwargs,
)

exec_sum = self.tokenizer.decode(
Expand Down

0 comments on commit e5aadf1

Please sign in to comment.