Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed Jul 29, 2020
2 parents 16b843e + 4535e3f commit bfbf85f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 23 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@ Most recent releases are shown at the top. Each release shows:
- **Changed**: Additional parameters, changes to inputs or outputs, etc
- **Fixed**: Bug fixes that don't change documented behaviour

## 0.19.1 (2020-07-29)

### New:
- N/A

### Changed
- Adjusted `no_grad` scope in `ZeroShotClassifier.predict`

### Fixed:
- N/A


## 0.19.0 (2020-07-29)

### New:
Expand Down
44 changes: 22 additions & 22 deletions ktrain/text/zsl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,27 @@ def predict(self, doc, topic_strings=[], include_labels=False, max_length=512, b
inferred probabilities
"""
import torch
if topic_strings is None or len(topic_strings) == 0:
raise ValueError('topic_strings must be a list of strings')
if batch_size > len(topic_strings): batch_size = len(topic_strings)
topic_chunks = list(U.list2chunks(topic_strings, n=math.ceil(len(topic_strings)/batch_size)))
if len(topic_strings) >= 100 and batch_size==8:
warnings.warn('TIP: Try increasing batch_size to speedup ZeroShotClassifier predictions')
result = []
for topics in topic_chunks:
pairs = []
for topic_string in topics:
premise = doc
hypothesis = 'This text is about %s.' % (topic_string)
pairs.append( (premise, hypothesis) )
batch = self.tokenizer.batch_encode_plus(pairs, return_tensors='pt', max_length=max_length, truncation='only_first', padding=True).to(self.torch_device)
with torch.no_grad():
with torch.no_grad():
if topic_strings is None or len(topic_strings) == 0:
raise ValueError('topic_strings must be a list of strings')
if batch_size > len(topic_strings): batch_size = len(topic_strings)
topic_chunks = list(U.list2chunks(topic_strings, n=math.ceil(len(topic_strings)/batch_size)))
if len(topic_strings) >= 100 and batch_size==8:
warnings.warn('TIP: Try increasing batch_size to speedup ZeroShotClassifier predictions')
result = []
for topics in topic_chunks:
pairs = []
for topic_string in topics:
premise = doc
hypothesis = 'This text is about %s.' % (topic_string)
pairs.append( (premise, hypothesis) )
batch = self.tokenizer.batch_encode_plus(pairs, return_tensors='pt', max_length=max_length, truncation='only_first', padding=True).to(self.torch_device)
logits = self.model(batch['input_ids'], attention_mask=batch['attention_mask'])[0]
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
true_probs = list(probs[:,1].cpu().detach().numpy())
if include_labels:
true_probs = list(zip(topics, true_probs))
result.extend(true_probs)
return result
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
true_probs = list(probs[:,1].cpu().detach().numpy())
if include_labels:
true_probs = list(zip(topics, true_probs))
result.extend(true_probs)
return result

2 changes: 1 addition & 1 deletion ktrain/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__all__ = ['__version__']
__version__ = '0.19.0'
__version__ = '0.19.1'

0 comments on commit bfbf85f

Please sign in to comment.