Skip to content

Commit

Permalink
[feat] Add ViLT image text processors
Browse files Browse the repository at this point in the history
Add ViT image processor and bert text processor.

ghstack-source-id: bff7e216a41a02c5f36ed0e2cb6cc65edbf67818
Pull Request resolved: facebookresearch#1097
  • Loading branch information
Ryan-Qiyu-Jiang committed Oct 1, 2021
1 parent bc49510 commit 28d44a9
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 6 deletions.
3 changes: 3 additions & 0 deletions mmf/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,6 @@
}

DOWNLOAD_CHUNK_SIZE = 1024 * 1024

IMAGE_COLOR_MEAN = [0.485, 0.456, 0.406]
IMAGE_COLOR_STD = [0.229, 0.224, 0.225]
44 changes: 44 additions & 0 deletions mmf/datasets/processors/bert_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,3 +353,47 @@ def __init__(self, config, *args, **kwargs):
self.fusion_strategy = config.get("fusion", "concat")
self.tokenizer = RobertaTokenizer(config, *args, **kwargs)
self._probability = config.get("mask_probability", 0)


@registry.register_processor("vilt_text_tokenizer")
class VILTTextTokenizer(MaskedTokenProcessor):
def __init__(self, config, *args, **kwargs):
from transformers import BertTokenizer

if isinstance(config, str):
config = {"from_pretrained": config}

from_pretrained_name = config.get("from_pretrained", "bert-base-uncased")
kwargs_dict = dict(kwargs, do_lower_case="uncased" in from_pretrained_name)
self._tokenizer = BertTokenizer.from_pretrained(
from_pretrained_name, **kwargs_dict
)
self._max_seq_length = config.get("max_seq_length", 25)
self._probability = config.get("mask_probability", 0)

def __call__(self, item):
if "text" in item:
text_a = item["text"]
elif "text_a" in item:
text_a = item["text_a"]
else:
text_a = " ".join(item["tokens"])

if isinstance(text_a, list):
text_a = " ".join(text_a)

tokens_a = self.tokenize(text_a)

# 'text_b' can be defined in the dataset preparation
tokens_b = None
if "text_b" in item:
text_b = item["text_b"]
if text_b:
tokens_b = self.tokenize(text_b)

self._truncate_seq_pair(tokens_a, tokens_b, self._max_seq_length)
output = self._convert_to_indices(
tokens_a, tokens_b, probability=self._probability
)
output["text"] = output["tokens"]
return output
22 changes: 22 additions & 0 deletions mmf/datasets/processors/image_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
import warnings

import torch
from mmf.common.constants import IMAGE_COLOR_MEAN, IMAGE_COLOR_STD
from mmf.common.registry import registry
from mmf.datasets.processors.processors import BaseProcessor
from omegaconf import OmegaConf
from torchvision import transforms
from torchvision.transforms import Compose, Normalize, Resize, ToTensor


@registry.register_processor("torchvision_transforms")
Expand Down Expand Up @@ -162,3 +164,23 @@ def __call__(self, image):

return padded_image
return image


@registry.register_processor("vilt_image_processor")
class VILTImageProcessor(BaseProcessor):
def __init__(self, config, *args, **kwargs):
image_size = config.get("size", [224, 224])
transforms_list = []
transforms_list.append(Resize(image_size))
transforms_list.append(ToTensor())
transforms_list.append(GrayScaleTo3Channels())
transforms_list.append(Normalize(IMAGE_COLOR_MEAN, IMAGE_COLOR_STD))
self.transform = Compose(transforms_list)

def __call__(self, x):
# Support both dict and normal mode
if isinstance(x, collections.abc.Mapping):
x = x["image"]
return {"image": self.transform(x)}
else:
return self.transform(x)
8 changes: 5 additions & 3 deletions mmf/datasets/processors/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __call__(self, item, *args, **kwargs):

import numpy as np
import torch
from mmf.common.constants import IMAGE_COLOR_MEAN, IMAGE_COLOR_STD
from mmf.common.registry import registry
from mmf.common.typings import ProcessorConfigType
from mmf.utils.configuration import get_mmf_cache_dir, get_mmf_env
Expand Down Expand Up @@ -726,7 +727,8 @@ class GraphVQAAnswerProcessor(BaseProcessor):
"answers" or "answers_tokens". "answers" are preprocessed to generate
"answers_tokens" if passed.
This version also takes a graph vocab and predicts a main and graph stream simultanously
This version also takes a graph vocab and predicts a main and graph
stream simultanously
Args:
config (DictConfig): Configuration for the processor
Expand Down Expand Up @@ -1751,14 +1753,14 @@ def __init__(self, config, *args, **kwargs):
),
),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
T.Normalize(IMAGE_COLOR_MEAN, IMAGE_COLOR_STD),
]
)
self.inference_transform = T.Compose(
[
T.RandomResize([config.test_image_size], max_size=config.max_size),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
T.Normalize(IMAGE_COLOR_MEAN, IMAGE_COLOR_STD),
]
)

Expand Down
83 changes: 83 additions & 0 deletions tests/datasets/test_bert_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,86 @@ def test_bert_tokenizer(self):

# Test [MASK] token is present
self.assertTrue(103 in results["input_ids"])

def test_vilt_tokenizer(self):
from mmf.datasets.processors.bert_processors import VILTTextTokenizer

test_utils.setup_proxy()
processor = VILTTextTokenizer(self.config)

# Test normal caption
arg = {"text": "This will be a test of tokens?"}
results = processor(arg)
expected_input_ids = torch.zeros(128, dtype=torch.long)
expected_input_ids[:11] = torch.tensor(
[101, 2023, 2097, 2022, 1037, 3231, 1997, 19204, 2015, 1029, 102],
dtype=torch.long,
)
expected_segment_ids = torch.zeros(128, dtype=torch.long)
expected_masks = torch.zeros(128, dtype=torch.long)
expected_masks[:11] = 1
self.assertTrue(torch.equal(results["input_ids"], expected_input_ids))
self.assertTrue(torch.equal(results["segment_ids"], expected_segment_ids))
self.assertTrue(torch.equal(results["input_mask"], expected_masks))

# Test empty caption
arg = {"text": ""}
results = processor(arg)
expected_input_ids = torch.zeros(128, dtype=torch.long)
expected_input_ids[:2] = torch.tensor([101, 102], dtype=torch.long)
expected_segment_ids = torch.zeros(128, dtype=torch.long)
expected_masks = torch.zeros(128, dtype=torch.long)
expected_masks[:2] = 1
self.assertTrue(torch.equal(results["input_ids"], expected_input_ids))
self.assertTrue(torch.equal(results["segment_ids"], expected_segment_ids))
self.assertTrue(torch.equal(results["input_mask"], expected_masks))

# Test long caption
arg = {"text": "I am working for facebook " * 100} # make a long sentence
results = processor(arg)
expected_input_ids = [1045, 2572, 2551, 2005, 9130] * 100
expected_input_ids.insert(0, 101) # [CLS]
expected_input_ids = expected_input_ids[:128]
expected_input_ids[-1] = 102 # [SEP]
expected_input_ids = torch.tensor(expected_input_ids, dtype=torch.long)
expected_segment_ids = torch.zeros(128, dtype=torch.long)
expected_masks = torch.ones(128, dtype=torch.long)
self.assertTrue(torch.equal(results["input_ids"], expected_input_ids))
self.assertTrue(torch.equal(results["segment_ids"], expected_segment_ids))
self.assertTrue(torch.equal(results["input_mask"], expected_masks))

# Test two captions
arg = {
"text_a": "This will be a test of tokens?",
"text_b": "I am working for facebook",
}
results = processor(arg)
expected_input_ids = torch.zeros(128, dtype=torch.long)
expected_input_ids[:17] = torch.tensor(
[101, 2023, 2097, 2022, 1037, 3231, 1997, 19204, 2015, 1029, 102]
+ [1045, 2572, 2551, 2005, 9130, 102],
dtype=torch.long,
)
expected_segment_ids = torch.zeros(128, dtype=torch.long)
expected_segment_ids[11:17] = 1
expected_masks = torch.zeros(128, dtype=torch.long)
expected_masks[:17] = 1
self.assertTrue(torch.equal(results["input_ids"], expected_input_ids))
self.assertTrue(torch.equal(results["segment_ids"], expected_segment_ids))
self.assertTrue(torch.equal(results["input_mask"], expected_masks))

# Test masked caption
processor._probability = 1.0
arg = {"text": "This will be a test of tokens?"}
results = processor(arg)
expected_input_ids = torch.zeros(128, dtype=torch.long)
expected_input_ids[:11] = torch.tensor(
[101, 2023, 2097, 2022, 1037, 3231, 1997, 19204, 2015, 1029, 102],
dtype=torch.long,
)
expected_segment_ids = torch.zeros(128, dtype=torch.long)
self.assertFalse(torch.equal(results["input_ids"], expected_input_ids))
self.assertTrue(torch.equal(results["segment_ids"], expected_segment_ids))

# Test [MASK] token is present
self.assertTrue(103 in results["input_ids"])
18 changes: 18 additions & 0 deletions tests/datasets/test_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import unittest

import torch
from mmf.datasets.processors.image_processors import VILTImageProcessor
from mmf.datasets.processors.processors import (
CaptionProcessor,
EvalAIAnswerProcessor,
Expand Down Expand Up @@ -173,3 +174,20 @@ def test_multi_class_from_file(self):

self.assertRaises(AssertionError, processor, {"label": "UNK"})
os.unlink(f.name)

def test_vilt_image_processor(self):
from torchvision.transforms import ToPILImage

size = 384
config = OmegaConf.create({"size": [size, size]})
image_processor = VILTImageProcessor(config)

expected_size = torch.Size([3, size, size])

image = ToPILImage()(torch.ones(3, 300, 500))
processed_image = image_processor(image)
self.assertEqual(processed_image.size(), expected_size)

image = ToPILImage()(torch.ones(1, 224, 224))
processed_image = image_processor(image)
self.assertEqual(processed_image.size(), expected_size)
5 changes: 2 additions & 3 deletions tools/scripts/features/extract_resnet152_feat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,17 @@
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from mmf.common.constants import IMAGE_COLOR_MEAN, IMAGE_COLOR_STD
from PIL import Image
from torch.autograd import Variable


TARGET_IMAGE_SIZE = [448, 448]
CHANNEL_MEAN = [0.485, 0.456, 0.406]
CHANNEL_STD = [0.229, 0.224, 0.225]
data_transforms = transforms.Compose(
[
transforms.Resize(TARGET_IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(CHANNEL_MEAN, CHANNEL_STD),
transforms.Normalize(IMAGE_COLOR_MEAN, IMAGE_COLOR_STD),
]
)

Expand Down

0 comments on commit 28d44a9

Please sign in to comment.