Skip to content

Commit

Permalink
[feat] Add ViLT image text processors
Browse files Browse the repository at this point in the history
Add ViT image processor and bert text processor.

ghstack-source-id: c7f7fe2b89fe90298909dbf363c396806e859ab8
Pull Request resolved: facebookresearch#1097
  • Loading branch information
Ryan-Qiyu-Jiang committed Sep 22, 2021
1 parent b0cc2f0 commit 2435279
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 0 deletions.
12 changes: 12 additions & 0 deletions mmf/datasets/processors/bert_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,3 +353,15 @@ def __init__(self, config, *args, **kwargs):
self.fusion_strategy = config.get("fusion", "concat")
self.tokenizer = RobertaTokenizer(config, *args, **kwargs)
self._probability = config.get("mask_probability", 0)


@registry.register_processor("vilt_text_processor")
class VILTTextProcessor(BertTokenizer):
def __init__(self, config, *args, **kwargs):
from omegaconf import OmegaConf, open_dict

with open_dict(config):
config.tokenizer_config = OmegaConf.create(
{"type": "bert-base-uncased", "params": {"do_lower_case": True}}
)
super().__init__(config, *args, **kwargs)
25 changes: 25 additions & 0 deletions mmf/datasets/processors/image_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,28 @@ def __call__(self, image):

return padded_image
return image


@registry.register_processor("vilt_image_processor")
class VILTImageProcessor(BaseProcessor):
from torchvision.transforms import Resize, ToTensor, Normalize, Compose

def __init__(self, config, *args, **kwargs):
image_size = getattr(config, "size", [224, 224])
transforms_list = []
transforms_list.append(self.Resize(image_size))
transforms_list.append(self.ToTensor())
transforms_list.append(GrayScaleTo3Channels())

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
transforms_list.append(self.Normalize(mean, std))
self.transform = self.Compose(transforms_list)

def __call__(self, x):
# Support both dict and normal mode
if isinstance(x, collections.abc.Mapping):
x = x["image"]
return {"image": self.transform(x)}
else:
return self.transform(x)
18 changes: 18 additions & 0 deletions tests/datasets/test_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import unittest

import torch
from mmf.datasets.processors.image_processors import VILTImageProcessor
from mmf.datasets.processors.processors import (
CaptionProcessor,
EvalAIAnswerProcessor,
Expand Down Expand Up @@ -173,3 +174,20 @@ def test_multi_class_from_file(self):

self.assertRaises(AssertionError, processor, {"label": "UNK"})
os.unlink(f.name)

def test_vilt_image_processor(self):
from torchvision.transforms import ToPILImage

size = 384
config = OmegaConf.create({"size": [size, size]})
image_processor = VILTImageProcessor(config)

expected_size = torch.Size([size, size])

image = ToPILImage()(torch.ones(3, 300, 500))
processed_image = image_processor(image)
self.assertEqual(processed_image.size(), expected_size)

image = ToPILImage()(torch.ones(1, 224, 224))
processed_image = image_processor(image)
self.assertEqual(processed_image.size(), expected_size)

0 comments on commit 2435279

Please sign in to comment.