This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
vqa_head.py
93 lines (74 loc) · 3.35 KB
/
vqa_head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from typing import Dict, Optional
import torch
from overrides import overrides
from allennlp.data.vocabulary import Vocabulary
from allennlp.models.heads.head import Head
@Head.register("vqa")
class VqaHead(Head):
def __init__(self, vocab: Vocabulary, embedding_dim: int, label_namespace: str = "answers"):
from allennlp_models.vision.metrics.vqa import VqaMeasure
from allennlp.training.metrics import F1MultiLabelMeasure
super().__init__(vocab)
num_labels = vocab.get_vocab_size(label_namespace)
self.label_namespace = label_namespace
self.classifier = torch.nn.Linear(embedding_dim, num_labels)
self.f1_metric = F1MultiLabelMeasure(average="micro")
self.vqa_metric = VqaMeasure()
@overrides
def forward(
self, # type: ignore
encoded_boxes: torch.Tensor,
encoded_boxes_mask: torch.Tensor,
encoded_boxes_pooled: torch.Tensor,
encoded_text: torch.Tensor,
encoded_text_mask: torch.Tensor,
encoded_text_pooled: torch.Tensor,
pooled_boxes_and_text: torch.Tensor,
labels: Optional[torch.Tensor] = None,
label_weights: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
logits = self.classifier(pooled_boxes_and_text)
output = {
"logits": logits,
"probs": torch.sigmoid(logits),
}
if labels is not None and label_weights is not None:
label_mask = labels > 1 # 0 is padding, 1 is OOV, which we want to ignore
from allennlp.nn import util
weighted_labels = util.masked_index_replace(
logits.new_zeros(logits.size() + (1,)),
labels.clamp(min=0),
label_mask,
label_weights.unsqueeze(-1),
).squeeze(-1)
# weighted_labels now has shape (batch_size, num_labels). We need to ignore the first
# two columns of this in our loss function and accuracy metric. The first column is a
# padding label, and the second column is an OOV label. We want the loss function to
# be computed on every other label.
binary_label_mask = weighted_labels.new_ones(logits.size())
binary_label_mask[:, 0] = 0
binary_label_mask[:, 1] = 0
output["loss"] = torch.nn.functional.binary_cross_entropy_with_logits(
logits, weighted_labels, weight=binary_label_mask, reduction="sum"
) / logits.size(0)
self.f1_metric(logits, weighted_labels, binary_label_mask.bool())
self.vqa_metric(logits, labels, label_weights)
return output
@overrides
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
result = self.f1_metric.get_metric(reset)
result["vqa"] = self.vqa_metric.get_metric(reset)["score"]
return result
def make_output_human_readable(
self, output_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
if len(output_dict) <= 0:
return output_dict
logits = output_dict["logits"]
best_answer_index = logits.argmax(-1)
best_answer = [
self.vocab.get_token_from_index(int(i), "answers") for i in best_answer_index
]
output_dict["best_answer"] = best_answer
return output_dict
default_predictor = "vilbert_vqa"