Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
adding multilabel option (#4843)
Browse files Browse the repository at this point in the history
  • Loading branch information
AkshitaB committed Dec 5, 2020
1 parent 7887119 commit 52fdd75
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
1 change: 1 addition & 0 deletions allennlp/models/vilbert_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
fusion_method,
dropout,
label_namespace,
is_multilabel=True,
)

self.loss = torch.nn.BCELoss()
Expand Down
11 changes: 10 additions & 1 deletion allennlp/models/vision_text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class VisionTextModel(Model):
fusion_method : `str`, optional (default = `"sum"`)
dropout : `float`, optional (default = `0.1`)
label_namespace : `str`, optional (default = `"labels"`)
is_multilabel: `bool`, optional (default = `False`)
Whether the output classification is multilabel.
(i.e., can have multiple correct answers)
"""

def __init__(
Expand All @@ -50,6 +53,7 @@ def __init__(
fusion_method: str = "sum",
dropout: float = 0.1,
label_namespace: str = "labels",
is_multilabel: bool = False,
) -> None:

super().__init__(vocab)
Expand All @@ -69,6 +73,8 @@ def __init__(
self.classifier = torch.nn.Linear(pooled_output_dim, num_labels)
self.dropout = torch.nn.Dropout(dropout)

self.is_multilabel = is_multilabel

@classmethod
def from_huggingface_model_name(
cls,
Expand Down Expand Up @@ -230,7 +236,10 @@ def forward(
raise ValueError(f"Fusion method '{self.fusion_method}' not supported")

logits = self.classifier(pooled_output)
probs = torch.sigmoid(logits)
if self.is_multilabel:
probs = torch.sigmoid(logits)
else:
probs = torch.softmax(logits, dim=-1)

outputs = {"logits": logits, "probs": probs}
outputs = self._compute_loss_and_metrics(batch_size, outputs, label, label_weights)
Expand Down
1 change: 1 addition & 0 deletions allennlp/models/visual_entailment.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
fusion_method,
dropout,
label_namespace,
is_multilabel=False,
)

self.accuracy = CategoricalAccuracy()
Expand Down

0 comments on commit 52fdd75

Please sign in to comment.