Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Visual entailment model code (#4822)
Browse files Browse the repository at this point in the history
* VE model code

* adding VE model

* misc minor updates

* update changelog
  • Loading branch information
AkshitaB committed Dec 1, 2020
1 parent 01f3a2d commit 52e9dd9
Show file tree
Hide file tree
Showing 11 changed files with 609 additions and 183 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -32,6 +32,7 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w
- Added abstraction and demo implementation for an image augmentation module.
- Added abstraction and concrete implementation for region detectors.
- Transformer toolkit to plug and play with modular components of transformer architectures.
- `VisionReader` and `VisionTextModel` base classes added. `VisualEntailment` and `VQA` inherit from these.

### Changed

Expand Down
1 change: 0 additions & 1 deletion allennlp/data/dataset_readers/vision_reader.py
Expand Up @@ -157,7 +157,6 @@ def _process_image_paths(
def yield_batch():
# process the images
paths = list(unprocessed_paths)
print(len(paths))
images, sizes = self.image_loader(paths)
with torch.no_grad():
images = images.to(self.cuda_device)
Expand Down
1 change: 1 addition & 0 deletions allennlp/models/__init__.py
Expand Up @@ -9,3 +9,4 @@
from allennlp.models.multitask import MultiTaskModel
from allennlp.models.simple_tagger import SimpleTagger
from allennlp.models.vilbert_vqa import VqaVilbert
from allennlp.models.visual_entailment import VisualEntailmentModel
215 changes: 38 additions & 177 deletions allennlp/models/vilbert_vqa.py
@@ -1,7 +1,6 @@
import collections
import logging
from copy import deepcopy
from typing import Dict, List, Optional
from typing import Dict, Optional

from overrides import overrides
import torch
Expand All @@ -12,24 +11,31 @@
TextEmbeddings,
ImageFeatureEmbeddings,
BiModalEncoder,
TransformerPooler,
)
from allennlp.nn import util

from transformers.modeling_auto import AutoModel
from allennlp.models.vision_text_model import VisionTextModel


logger = logging.getLogger(__name__)


@Model.register("vqa_vilbert")
@Model.register("vqa_vilbert_from_huggingface", constructor="from_huggingface_model_name")
class VqaVilbert(Model):
class VqaVilbert(VisionTextModel):
"""
Model for VQA task based on the VilBERT paper.
# Parameters
vocab : `Vocabulary`
text_embeddings : `TextEmbeddings`
image_embeddings : `ImageFeatureEmbeddings`
encoder : `BiModalEncoder`
pooled_output_dim : `int`
fusion_method : `str`, optional (default = `"sum"`)
dropout : `float`, optional (default = `0.1`)
label_namespace : `str`, optional (default = `answers`)
"""

def __init__(
Expand All @@ -43,7 +49,17 @@ def __init__(
dropout: float = 0.1,
label_namespace: str = "answers",
) -> None:
super().__init__(vocab)
super().__init__(
vocab,
text_embeddings,
image_embeddings,
encoder,
pooled_output_dim,
fusion_method,
dropout,
label_namespace,
)

self.loss = torch.nn.BCELoss()
self.consistency_wrong_map: Dict[str, int] = collections.Counter()
from allennlp.training.metrics import F1MultiLabelMeasure
Expand All @@ -52,114 +68,6 @@ def __init__(
from allennlp.training.metrics.vqa import VqaMeasure

self.vqa_metric = VqaMeasure()
self.fusion_method = fusion_method

self.embeddings = text_embeddings
self.image_embeddings = image_embeddings
self.encoder = encoder

self.t_pooler = TransformerPooler(encoder.hidden_size1, pooled_output_dim)
self.v_pooler = TransformerPooler(encoder.hidden_size2, pooled_output_dim)

num_labels = vocab.get_vocab_size(label_namespace)
self.label_namespace = label_namespace

self.classifier = torch.nn.Linear(pooled_output_dim, num_labels)
self.dropout = torch.nn.Dropout(dropout)

@classmethod
def from_huggingface_model_name(
cls,
vocab: Vocabulary,
model_name: str,
image_feature_dim: int,
image_num_hidden_layers: int,
image_hidden_size: int,
image_num_attention_heads: int,
image_intermediate_size: int,
image_attention_dropout: float,
image_hidden_dropout: float,
image_biattention_id: List[int],
image_fixed_layer: int,
text_biattention_id: List[int],
text_fixed_layer: int,
combined_hidden_size: int,
combined_num_attention_heads: int,
pooled_output_dim: int,
pooled_dropout: float = 0.1,
fusion_method: str = "sum",
):
transformer = AutoModel.from_pretrained(model_name)

# TODO(mattg): This call to `transformer.embeddings` works with some transformers, but I'm
# not sure it works for all of them, or what to do if it fails.
# We should probably pull everything up until the instantiation of the image feature
# embedding out into a central "transformers_util" module, or something, and just have a
# method that pulls an initialized embedding layer out of a huggingface model. One place
# for this somewhat hacky code to live, instead of having to duplicate it in various models.
text_embeddings = deepcopy(transformer.embeddings)

# Albert (and maybe others?) has this "embedding_size", that's different from "hidden_size".
# To get them to the same dimensionality, it uses a linear transform after the embedding
# layer, which we need to pull out and copy here.
if hasattr(transformer.config, "embedding_size"):
config = transformer.config

from transformers.modeling_albert import AlbertModel

if isinstance(transformer, AlbertModel):
linear_transform = deepcopy(transformer.encoder.embedding_hidden_mapping_in)
else:
logger.warning(
"Unknown model that uses separate embedding size; weights of the linear "
f"transform will not be initialized. Model type is: {transformer.__class__}"
)
linear_transform = torch.nn.Linear(config.embedding_dim, config.hidden_dim)

# We can't just use torch.nn.Sequential here, even though that's basically all this is,
# because Sequential doesn't accept *inputs, only a single argument.

class EmbeddingsShim(torch.nn.Module):
def __init__(self, embeddings: torch.nn.Module, linear_transform: torch.nn.Module):
super().__init__()
self.linear_transform = linear_transform
self.embeddings = embeddings

def forward(self, *inputs, **kwargs):
return self.linear_transform(self.embeddings(*inputs, **kwargs))

text_embeddings = EmbeddingsShim(text_embeddings, linear_transform)

image_embeddings = ImageFeatureEmbeddings(
feature_dim=image_feature_dim,
hidden_dim=image_hidden_size,
dropout=image_hidden_dropout,
)

encoder = BiModalEncoder.from_pretrained_module(
pretrained_module=transformer,
num_hidden_layers2=image_num_hidden_layers,
hidden_size2=image_hidden_size,
num_attention_heads2=image_num_attention_heads,
combined_hidden_size=combined_hidden_size,
combined_num_attention_heads=combined_num_attention_heads,
intermediate_size2=image_intermediate_size,
attention_dropout2=image_attention_dropout,
hidden_dropout2=image_hidden_dropout,
biattention_id1=text_biattention_id,
biattention_id2=image_biattention_id,
fixed_layer1=text_fixed_layer,
fixed_layer2=image_fixed_layer,
)
return cls(
vocab=vocab,
text_embeddings=text_embeddings,
image_embeddings=image_embeddings,
encoder=encoder,
pooled_output_dim=pooled_output_dim,
fusion_method=fusion_method,
dropout=pooled_dropout,
)

@overrides
def forward(
Expand All @@ -171,73 +79,25 @@ def forward(
label_weights: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:

batch_size, _, feature_size = box_features.size()

# TODO(mattg): have this make fewer assumptions.
input_ids = question["tokens"]["token_ids"]
token_type_ids = question["tokens"]["type_ids"]
attention_mask = question["tokens"]["mask"]

# All batch instances will always have the same number of images and boxes, so no masking
# is necessary, and this is just a tensor of ones.
image_attention_mask = torch.ones_like(box_coordinates[:, :, 0])

# (batch_size, num_tokens, embedding_dim)
embedding_output = self.embeddings(input_ids, token_type_ids)
num_tokens = embedding_output.size(1)

# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of
# causal attention used in OpenAI GPT, we just need to prepare the
# broadcast dimension here.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2).float().log()
extended_image_attention_mask = image_attention_mask.unsqueeze(1).unsqueeze(2).float().log()

# TODO(matt): it looks like the co-attention logic is all currently commented out; not sure
# that this is necessary.
extended_co_attention_mask = torch.zeros(
batch_size,
feature_size,
num_tokens,
dtype=extended_image_attention_mask.dtype,
return super().forward(
box_features, box_coordinates, text=question, label=labels, label_weights=label_weights
)

# (batch_size, num_boxes, image_embedding_dim)
v_embedding_output = self.image_embeddings(box_features, box_coordinates)

encoded_layers_t, encoded_layers_v = self.encoder(
embedding_output,
v_embedding_output,
extended_attention_mask,
extended_image_attention_mask,
extended_co_attention_mask,
)

sequence_output_t = encoded_layers_t[:, :, :, -1]
sequence_output_v = encoded_layers_v[:, :, :, -1]

pooled_output_t = self.t_pooler(sequence_output_t)
pooled_output_v = self.v_pooler(sequence_output_v)

if self.fusion_method == "sum":
pooled_output = self.dropout(pooled_output_t + pooled_output_v)
elif self.fusion_method == "mul":
pooled_output = self.dropout(pooled_output_t * pooled_output_v)
else:
raise ValueError(f"Fusion method '{self.fusion_method}' not supported")

logits = self.classifier(pooled_output)
probs = torch.sigmoid(logits)

outputs = {"logits": logits, "probs": probs}
if labels is not None and label_weights is not None:
label_mask = labels > 1 # 0 is padding, 1 is OOV, which we want to ignore
@overrides
def _compute_loss_and_metrics(
self,
batch_size: int,
outputs: torch.Tensor,
label: torch.Tensor,
label_weights: Optional[torch.Tensor] = None,
):
if label is not None and label_weights is not None:
logits = outputs["logits"]
label_mask = label > 1 # 0 is padding, 1 is OOV, which we want to ignore

weighted_labels = util.masked_index_replace(
logits.new_zeros(logits.size() + (1,)),
labels.clamp(min=0),
label.clamp(min=0),
label_mask,
label_weights.unsqueeze(-1),
).squeeze(-1)
Expand All @@ -258,7 +118,8 @@ def forward(
)

self.f1_metric(logits, weighted_labels, binary_label_mask.bool())
self.vqa_metric(logits, labels, label_weights)
self.vqa_metric(logits, label, label_weights)

return outputs

@overrides
Expand Down

0 comments on commit 52e9dd9

Please sign in to comment.