Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Fix Invalid Index Reference for labels in Vocabulary (#2926)
Browse files Browse the repository at this point in the history
* Fix bug with invalid reference to index when labels dict is empty

* Simplify if-then to get() and make same change in basic_classifier

* Changes per PR comments

* Update bert_for_classification.py

* Update basic_classifier.py
  • Loading branch information
sbhaktha authored and joelgrus committed Jun 7, 2019
1 parent c629093 commit 5b2066b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
7 changes: 5 additions & 2 deletions allennlp/models/basic_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ def __init__(self,
else:
self._dropout = None

self._label_namespace = label_namespace

if num_labels:
self._num_labels = num_labels
else:
self._num_labels = vocab.get_vocab_size(namespace=label_namespace)
self._num_labels = vocab.get_vocab_size(namespace=self._label_namespace)
self._classification_layer = torch.nn.Linear(self._classifier_input_dim, self._num_labels)
self._accuracy = CategoricalAccuracy()
self._loss = torch.nn.CrossEntropyLoss()
Expand Down Expand Up @@ -139,7 +141,8 @@ def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor
classes = []
for prediction in predictions_list:
label_idx = prediction.argmax(dim=-1).item()
label_str = self.vocab.get_token_from_index(label_idx, namespace="labels")
label_str = (self.vocab.get_index_to_token_vocabulary(self._label_namespace)
.get(label_idx, str(label_idx)))
classes.append(label_str)
output_dict["label"] = classes
return output_dict
Expand Down
7 changes: 5 additions & 2 deletions allennlp/models/bert_for_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,12 @@ def __init__(self,

in_features = self.bert_model.config.hidden_size

self._label_namespace = label_namespace

if num_labels:
out_features = num_labels
else:
out_features = vocab.get_vocab_size(label_namespace)
out_features = vocab.get_vocab_size(namespace=self._label_namespace)

self._dropout = torch.nn.Dropout(p=dropout)

Expand Down Expand Up @@ -139,7 +141,8 @@ def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor
classes = []
for prediction in predictions_list:
label_idx = prediction.argmax(dim=-1).item()
label_str = self.vocab.get_token_from_index(label_idx, namespace="labels")
label_str = (self.vocab.get_index_to_token_vocabulary(self._label_namespace)
.get(label_idx, str(label_idx)))
classes.append(label_str)
output_dict["label"] = classes
return output_dict
Expand Down

0 comments on commit 5b2066b

Please sign in to comment.