Skip to content

Commit

Permalink
VisionTextDualEncoder (huggingface#13511)
Browse files Browse the repository at this point in the history
* init vision_text_dual_encoder

* fix merge

* remove extra heads

* fix tests

* remove VISION_TEXT_DUAL_ENCODER_PRETRAINED_CONFIG_ARCHIVE_MAP

* remove archive map

* fix imports

* fix more imports

* fix init

* delete tokenizers

* fix imports

* clean

* support clip's vision model

* handle None config

* begin tests

* more test and few fixes

* warn about newly init weights

* more tests

* add loss to model

* remove extra classes from doc

* add processor

* doc and small fixes

* add start docstr

* update flax model

* flax tests

* more flax tests

* doc

* quality

* doc and quality

* fix doc

* doc

* remove comments

* update warning

* quality

* fix docs

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* replace asserts, fix imports

* update imports

* fix import

* address some review comments

* fix check

* reduce tolerance

* fix test

* add flax integration test

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* address Sylvain's comments

* fix style

* add pt_flax_equivalence test in PT tests

* add pt integration test

* update test

* use pre-trained checkpoint in examples

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
3 people authored and Alberto Bégué committed Jan 27, 2022
1 parent 95c16d9 commit 5a754ad
Show file tree
Hide file tree
Showing 19 changed files with 2,643 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,8 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Vision Encoder decoder ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| VisionTextDualEncoder ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| VisualBert ||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ViT ||||||
Expand Down Expand Up @@ -686,6 +688,7 @@ Flax), PyTorch, and/or TensorFlow.
model_doc/unispeech
model_doc/unispeech_sat
model_doc/visionencoderdecoder
model_doc/vision_text_dual_encoder
model_doc/vit
model_doc/visual_bert
model_doc/wav2vec2
Expand Down
56 changes: 56 additions & 0 deletions docs/source/model_doc/vision_text_dual_encoder.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
..
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

VisionTextDualEncoder
-----------------------------------------------------------------------------------------------------------------------

Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The :class:`~transformers.VisionTextDualEncoderModel` can be used to initialize a vision-text dual encoder model with
any pretrained vision autoencoding model as the vision encoder (*e.g.* :doc:`ViT <vit>`, :doc:`BEiT <beit>`, :doc:`DeiT
<deit>`) and any pretrained text autoencoding model as the text encoder (*e.g.* :doc:`RoBERTa <roberta>`, :doc:`BERT
<bert>`). Two projection layers are added on top of both the vision and text encoder to project the output embeddings
to a shared latent space. The projection layers are randomly initialized so the model should be fine-tuned on a
downstream task. This model can be used to align the vision-text embeddings using CLIP like contrastive image-text
training and then can be used for zero-shot vision tasks such image-classification or retrieval.

In `LiT: Zero-Shot Transfer with Locked-image Text Tuning <https://arxiv.org/abs/2111.07991>`__ it is shown how
leveraging pre-trained (locked/frozen) image and text model for contrastive learning yields significant improvment on
new zero-shot vision tasks such as image classification or retrieval.

VisionTextDualEncoderConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.VisionTextDualEncoderConfig
:members:


VisionTextDualEncoderProcessor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.VisionTextDualEncoderProcessor
:members:


VisionTextDualEncoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.VisionTextDualEncoderModel
:members: forward


FlaxVisionTextDualEncoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.FlaxVisionTextDualEncoderModel
:members: __call__
7 changes: 7 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@
"UniSpeechSatConfig",
],
"models.vision_encoder_decoder": ["VisionEncoderDecoderConfig"],
"models.vision_text_dual_encoder": ["VisionTextDualEncoderConfig", "VisionTextDualEncoderProcessor"],
"models.visual_bert": ["VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VisualBertConfig"],
"models.vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
"models.wav2vec2": [
Expand Down Expand Up @@ -1306,6 +1307,7 @@
]
)
_import_structure["models.vision_encoder_decoder"].extend(["VisionEncoderDecoderModel"])
_import_structure["models.vision_text_dual_encoder"].extend(["VisionTextDualEncoderModel"])
_import_structure["models.visual_bert"].extend(
[
"VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -1907,6 +1909,7 @@
)

# Flax models structure

_import_structure["models.bart"].extend(
[
"FlaxBartForConditionalGeneration",
Expand Down Expand Up @@ -2027,6 +2030,7 @@
)
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
_import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
_import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"])
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
_import_structure["models.wav2vec2"].extend(
["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"]
Expand Down Expand Up @@ -2267,6 +2271,7 @@
from .models.unispeech import UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechConfig
from .models.unispeech_sat import UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechSatConfig
from .models.vision_encoder_decoder import VisionEncoderDecoderConfig
from .models.vision_text_dual_encoder import VisionTextDualEncoderConfig, VisionTextDualEncoderProcessor
from .models.visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig
from .models.vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
from .models.wav2vec2 import (
Expand Down Expand Up @@ -3109,6 +3114,7 @@
UniSpeechSatPreTrainedModel,
)
from .models.vision_encoder_decoder import VisionEncoderDecoderModel
from .models.vision_text_dual_encoder import VisionTextDualEncoderModel
from .models.visual_bert import (
VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
VisualBertForMultipleChoice,
Expand Down Expand Up @@ -3704,6 +3710,7 @@
)
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel
from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel
from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
from .models.wav2vec2 import (
FlaxWav2Vec2ForCTC,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
unispeech,
unispeech_sat,
vision_encoder_decoder,
vision_text_dual_encoder,
visual_bert,
vit,
wav2vec2,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
("trocr", "TrOCRConfig"),
("fnet", "FNetConfig"),
("segformer", "SegformerConfig"),
("vision-text-dual-encoder", "VisionTextDualEncoderConfig"),
("gptj", "GPTJConfig"),
("layoutlmv2", "LayoutLMv2Config"),
("beit", "BeitConfig"),
Expand Down Expand Up @@ -192,6 +193,7 @@
("trocr", "TrOCR"),
("fnet", "FNet"),
("segformer", "SegFormer"),
("vision-text-dual-encoder", "VisionTextDualEncoder"),
("gptj", "GPT-J"),
("beit", "BEiT"),
("rembert", "RemBERT"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
("qdqbert", "QDQBertModel"),
("fnet", "FNetModel"),
("segformer", "SegformerModel"),
("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
("gptj", "GPTJModel"),
("layoutlmv2", "LayoutLMv2Model"),
("beit", "BeitModel"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
[
# Base model mapping
("pegasus", "FlaxPegasusModel"),
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
("distilbert", "FlaxDistilBertModel"),
("albert", "FlaxAlbertModel"),
("roberta", "FlaxRobertaModel"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
("speech_to_text_2", "Speech2Text2Processor"),
("trocr", "TrOCRProcessor"),
("wav2vec2", "Wav2Vec2Processor"),
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
]
)

Expand Down
52 changes: 52 additions & 0 deletions src/transformers/models/vision_text_dual_encoder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

# rely on isort to merge the imports
from ...file_utils import _LazyModule, is_flax_available, is_torch_available


_import_structure = {
"configuration_vision_text_dual_encoder": ["VisionTextDualEncoderConfig"],
"processing_vision_text_dual_encoder": ["VisionTextDualEncoderProcessor"],
}


if is_torch_available():
_import_structure["modeling_vision_text_dual_encoder"] = ["VisionTextDualEncoderModel"]


if is_flax_available():
_import_structure["modeling_flax_vision_text_dual_encoder"] = ["FlaxVisionTextDualEncoderModel"]


if TYPE_CHECKING:
from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig
from .processing_visiotn_text_dual_encoder import VisionTextDualEncoderProcessor

if is_torch_available():
from .modeling_vision_text_dual_encoder import VisionTextDualEncoderModel

if is_flax_available():
from .modeling_vision_text_dual_encoder import FlaxVisionTextDualEncoderModel


else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# coding=utf-8
# Copyright The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" VisionTextDualEncoder model configuration """

import copy

from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
from ..clip.configuration_clip import CLIPVisionConfig


logger = logging.get_logger(__name__)


class VisionTextDualEncoderConfig(PretrainedConfig):
r"""
:class:`~transformers.VisionTextDualEncoderConfig` is the configuration class to store the configuration of a
:class:`~transformers.VisionTextDualEncoderModel`. It is used to instantiate
:class:`~transformers.VisionTextDualEncoderModel` model according to the specified arguments, defining the text
model and vision model configs.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Args:
text_config_dict (:obj:`dict`):
Dictionary of configuration options that defines text model config.
vision_config_dict (:obj:`dict`):
Dictionary of configuration options that defines vison model config.
projection_dim (:obj:`int`, `optional`, defaults to 512):
Dimentionality of text and vision projection layers.
logit_scale_init_value (:obj:`float`, `optional`, defaults to 2.6592):
The inital value of the `logit_scale` paramter. Default is used as per the original CLIP implementation.
kwargs (`optional`):
Dictionary of keyword arguments.
Examples::
>>> from transformers import ViTConfig, BertConfig, VisionTextDualEncoderConfig, VisionTextDualEncoderModel
>>> # Initializing a BERT and ViT configuration
>>> config_vision = ViTConfig()
>>> config_text = BertConfig()
>>> config = VisionTextDualEncoderConfig.from_vision_text_configs(config_vision, config_text, projection_dim=512)
>>> # Initializing a BERT and ViT model
>>> model = VisionTextDualEncoderModel(config=config)
>>> # Accessing the model configuration
>>> config_vision = model.config.vision_config
>>> config_text = model.config.text_config
>>> # Saving the model, including its configuration
>>> model.save_pretrained('my-model')
>>> # loading model and config from pretrained folder
>>> vision_text_config = VisionTextDualEncoderConfig.from_pretrained('vit-bert')
>>> model = VisionTextDualEncoderModel.from_pretrained('vit-bert', config=vision_text_config)
"""

model_type = "vision-text-dual-encoder"
is_composition = True

def __init__(self, projection_dim=512, logit_scale_init_value=2.6592, **kwargs):
super().__init__(**kwargs)

if "vision_config" not in kwargs:
raise ValueError("`vision_config` can not be `None`.")

if "text_config" not in kwargs:
raise ValueError("`text_config` can not be `None`.")

vision_config = kwargs.pop("vision_config")
text_config = kwargs.pop("text_config")

vision_model_type = vision_config.pop("model_type")
text_model_type = text_config.pop("model_type")

if vision_model_type == "clip":
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
elif vision_model_type == "clip_vision_model":
self.vision_config = CLIPVisionConfig(**vision_config)
else:
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)

self.text_config = AutoConfig.for_model(text_model_type, **text_config)

self.projection_dim = projection_dim
self.logit_scale_init_value = logit_scale_init_value

@classmethod
def from_vision_text_configs(cls, vision_config: PretrainedConfig, text_config: PretrainedConfig, **kwargs):
r"""
Instantiate a :class:`VisionTextDualEncoderConfig` (or a derived class) from text model configuration and
vision model configuration.
Returns:
:class:`VisionTextDualEncoderConfig`: An instance of a configuration object
"""

return cls(vision_config=vision_config.to_dict(), text_config=text_config.to_dict(), **kwargs)

def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default
:meth:`~transformers.PretrainedConfig.to_dict`.
Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
Loading

0 comments on commit 5a754ad

Please sign in to comment.