Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
fix failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Oct 27, 2020
1 parent 98edd25 commit 81892db
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion allennlp/models/heads/classifier_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ClassifierHead(Head):
vocab : `Vocabulary`
Used to get the number of labels, if `num_labels` is not provided, and to translate label
indices to strings in `make_output_human_readable`.
seq2vec_encoder : `Seq2VecEncoder`, optional (default = `ClsPooler`)
seq2vec_encoder : `Seq2VecEncoder`
The input to this module is assumed to be a sequence of encoded vectors. We use a
`Seq2VecEncoder` to compress this into a single vector on which we can perform
classification.
Expand Down
Binary file not shown.
5 changes: 3 additions & 2 deletions tests/models/multitask_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from allennlp.models.heads import ClassifierHead
from allennlp.models import MultiTaskModel
from allennlp.modules.backbones import PretrainedTransformerBackbone
from allennlp.modules.seq2vec_encoders import ClsPooler


class TestMultiTaskModel(ModelTestCase):
Expand All @@ -16,8 +17,8 @@ def test_forward_works(self):
transformer_name = "epwalsh/bert-xsmall-dummy"
vocab = Vocabulary()
backbone = PretrainedTransformerBackbone(vocab, transformer_name)
head1 = ClassifierHead(vocab, input_dim=20, num_labels=3)
head2 = ClassifierHead(vocab, input_dim=20, num_labels=4)
head1 = ClassifierHead(vocab, seq2vec_encoder=ClsPooler(20), input_dim=20, num_labels=3)
head2 = ClassifierHead(vocab, seq2vec_encoder=ClsPooler(20), input_dim=20, num_labels=4)
# We'll start with one head, and add another later.
model = MultiTaskModel(vocab, backbone, {"cls": head1})

Expand Down

0 comments on commit 81892db

Please sign in to comment.