This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
visual_entailment_head.py
55 lines (43 loc) · 1.86 KB
/
visual_entailment_head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from typing import Dict, Optional
import torch
from overrides import overrides
from allennlp.data.vocabulary import Vocabulary
from allennlp.models.heads.head import Head
@Head.register("visual_entailment")
class VisualEntailmentHead(Head):
def __init__(self, vocab: Vocabulary, embedding_dim: int, label_namespace: str = "labels"):
super().__init__(vocab)
num_labels = vocab.get_vocab_size(label_namespace)
self.label_namespace = label_namespace
self.classifier = torch.nn.Linear(embedding_dim, num_labels)
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.training.metrics import FBetaMeasure
self.accuracy = CategoricalAccuracy()
self.fbeta = FBetaMeasure(beta=1.0, average="macro")
@overrides
def forward(
self, # type: ignore
encoded_boxes: torch.Tensor,
encoded_boxes_mask: torch.Tensor,
encoded_boxes_pooled: torch.Tensor,
encoded_text: torch.Tensor,
encoded_text_mask: torch.Tensor,
encoded_text_pooled: torch.Tensor,
pooled_boxes_and_text: torch.Tensor,
labels: Optional[torch.Tensor] = None,
label_weights: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
logits = self.classifier(pooled_boxes_and_text)
probs = torch.softmax(logits, dim=-1)
output = {"logits": logits, "probs": probs}
assert label_weights is None
if labels is not None:
output["loss"] = torch.nn.functional.cross_entropy(logits, labels) / logits.size(0)
self.accuracy(logits, labels)
self.fbeta(probs, labels)
return output
@overrides
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
result = self.fbeta.get_metric(reset)
result["acc"] = self.accuracy.get_metric(reset)
return result