forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
VisionTextDualEncoder (huggingface#13511)
* init vision_text_dual_encoder * fix merge * remove extra heads * fix tests * remove VISION_TEXT_DUAL_ENCODER_PRETRAINED_CONFIG_ARCHIVE_MAP * remove archive map * fix imports * fix more imports * fix init * delete tokenizers * fix imports * clean * support clip's vision model * handle None config * begin tests * more test and few fixes * warn about newly init weights * more tests * add loss to model * remove extra classes from doc * add processor * doc and small fixes * add start docstr * update flax model * flax tests * more flax tests * doc * quality * doc and quality * fix doc * doc * remove comments * update warning * quality * fix docs * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * replace asserts, fix imports * update imports * fix import * address some review comments * fix check * reduce tolerance * fix test * add flax integration test * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * address Sylvain's comments * fix style * add pt_flax_equivalence test in PT tests * add pt integration test * update test * use pre-trained checkpoint in examples Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
- Loading branch information
1 parent
95c16d9
commit 5a754ad
Showing
19 changed files
with
2,643 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
.. | ||
Copyright 2021 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
|
||
VisionTextDualEncoder | ||
----------------------------------------------------------------------------------------------------------------------- | ||
|
||
Overview | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
The :class:`~transformers.VisionTextDualEncoderModel` can be used to initialize a vision-text dual encoder model with | ||
any pretrained vision autoencoding model as the vision encoder (*e.g.* :doc:`ViT <vit>`, :doc:`BEiT <beit>`, :doc:`DeiT | ||
<deit>`) and any pretrained text autoencoding model as the text encoder (*e.g.* :doc:`RoBERTa <roberta>`, :doc:`BERT | ||
<bert>`). Two projection layers are added on top of both the vision and text encoder to project the output embeddings | ||
to a shared latent space. The projection layers are randomly initialized so the model should be fine-tuned on a | ||
downstream task. This model can be used to align the vision-text embeddings using CLIP like contrastive image-text | ||
training and then can be used for zero-shot vision tasks such image-classification or retrieval. | ||
|
||
In `LiT: Zero-Shot Transfer with Locked-image Text Tuning <https://arxiv.org/abs/2111.07991>`__ it is shown how | ||
leveraging pre-trained (locked/frozen) image and text model for contrastive learning yields significant improvment on | ||
new zero-shot vision tasks such as image classification or retrieval. | ||
|
||
VisionTextDualEncoderConfig | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.VisionTextDualEncoderConfig | ||
:members: | ||
|
||
|
||
VisionTextDualEncoderProcessor | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.VisionTextDualEncoderProcessor | ||
:members: | ||
|
||
|
||
VisionTextDualEncoderModel | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.VisionTextDualEncoderModel | ||
:members: forward | ||
|
||
|
||
FlaxVisionTextDualEncoderModel | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.FlaxVisionTextDualEncoderModel | ||
:members: __call__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
52 changes: 52 additions & 0 deletions
52
src/transformers/models/vision_text_dual_encoder/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# flake8: noqa | ||
# There's no way to ignore "F401 '...' imported but unused" warnings in this | ||
# module, but to preserve other warnings. So, don't check this module at all. | ||
|
||
# Copyright 2021 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import TYPE_CHECKING | ||
|
||
# rely on isort to merge the imports | ||
from ...file_utils import _LazyModule, is_flax_available, is_torch_available | ||
|
||
|
||
_import_structure = { | ||
"configuration_vision_text_dual_encoder": ["VisionTextDualEncoderConfig"], | ||
"processing_vision_text_dual_encoder": ["VisionTextDualEncoderProcessor"], | ||
} | ||
|
||
|
||
if is_torch_available(): | ||
_import_structure["modeling_vision_text_dual_encoder"] = ["VisionTextDualEncoderModel"] | ||
|
||
|
||
if is_flax_available(): | ||
_import_structure["modeling_flax_vision_text_dual_encoder"] = ["FlaxVisionTextDualEncoderModel"] | ||
|
||
|
||
if TYPE_CHECKING: | ||
from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig | ||
from .processing_visiotn_text_dual_encoder import VisionTextDualEncoderProcessor | ||
|
||
if is_torch_available(): | ||
from .modeling_vision_text_dual_encoder import VisionTextDualEncoderModel | ||
|
||
if is_flax_available(): | ||
from .modeling_vision_text_dual_encoder import FlaxVisionTextDualEncoderModel | ||
|
||
|
||
else: | ||
import sys | ||
|
||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) |
129 changes: 129 additions & 0 deletions
129
src/transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# coding=utf-8 | ||
# Copyright The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" VisionTextDualEncoder model configuration """ | ||
|
||
import copy | ||
|
||
from ...configuration_utils import PretrainedConfig | ||
from ...utils import logging | ||
from ..auto.configuration_auto import AutoConfig | ||
from ..clip.configuration_clip import CLIPVisionConfig | ||
|
||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
class VisionTextDualEncoderConfig(PretrainedConfig): | ||
r""" | ||
:class:`~transformers.VisionTextDualEncoderConfig` is the configuration class to store the configuration of a | ||
:class:`~transformers.VisionTextDualEncoderModel`. It is used to instantiate | ||
:class:`~transformers.VisionTextDualEncoderModel` model according to the specified arguments, defining the text | ||
model and vision model configs. | ||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model | ||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. | ||
Args: | ||
text_config_dict (:obj:`dict`): | ||
Dictionary of configuration options that defines text model config. | ||
vision_config_dict (:obj:`dict`): | ||
Dictionary of configuration options that defines vison model config. | ||
projection_dim (:obj:`int`, `optional`, defaults to 512): | ||
Dimentionality of text and vision projection layers. | ||
logit_scale_init_value (:obj:`float`, `optional`, defaults to 2.6592): | ||
The inital value of the `logit_scale` paramter. Default is used as per the original CLIP implementation. | ||
kwargs (`optional`): | ||
Dictionary of keyword arguments. | ||
Examples:: | ||
>>> from transformers import ViTConfig, BertConfig, VisionTextDualEncoderConfig, VisionTextDualEncoderModel | ||
>>> # Initializing a BERT and ViT configuration | ||
>>> config_vision = ViTConfig() | ||
>>> config_text = BertConfig() | ||
>>> config = VisionTextDualEncoderConfig.from_vision_text_configs(config_vision, config_text, projection_dim=512) | ||
>>> # Initializing a BERT and ViT model | ||
>>> model = VisionTextDualEncoderModel(config=config) | ||
>>> # Accessing the model configuration | ||
>>> config_vision = model.config.vision_config | ||
>>> config_text = model.config.text_config | ||
>>> # Saving the model, including its configuration | ||
>>> model.save_pretrained('my-model') | ||
>>> # loading model and config from pretrained folder | ||
>>> vision_text_config = VisionTextDualEncoderConfig.from_pretrained('vit-bert') | ||
>>> model = VisionTextDualEncoderModel.from_pretrained('vit-bert', config=vision_text_config) | ||
""" | ||
|
||
model_type = "vision-text-dual-encoder" | ||
is_composition = True | ||
|
||
def __init__(self, projection_dim=512, logit_scale_init_value=2.6592, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
if "vision_config" not in kwargs: | ||
raise ValueError("`vision_config` can not be `None`.") | ||
|
||
if "text_config" not in kwargs: | ||
raise ValueError("`text_config` can not be `None`.") | ||
|
||
vision_config = kwargs.pop("vision_config") | ||
text_config = kwargs.pop("text_config") | ||
|
||
vision_model_type = vision_config.pop("model_type") | ||
text_model_type = text_config.pop("model_type") | ||
|
||
if vision_model_type == "clip": | ||
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config | ||
elif vision_model_type == "clip_vision_model": | ||
self.vision_config = CLIPVisionConfig(**vision_config) | ||
else: | ||
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config) | ||
|
||
self.text_config = AutoConfig.for_model(text_model_type, **text_config) | ||
|
||
self.projection_dim = projection_dim | ||
self.logit_scale_init_value = logit_scale_init_value | ||
|
||
@classmethod | ||
def from_vision_text_configs(cls, vision_config: PretrainedConfig, text_config: PretrainedConfig, **kwargs): | ||
r""" | ||
Instantiate a :class:`VisionTextDualEncoderConfig` (or a derived class) from text model configuration and | ||
vision model configuration. | ||
Returns: | ||
:class:`VisionTextDualEncoderConfig`: An instance of a configuration object | ||
""" | ||
|
||
return cls(vision_config=vision_config.to_dict(), text_config=text_config.to_dict(), **kwargs) | ||
|
||
def to_dict(self): | ||
""" | ||
Serializes this instance to a Python dictionary. Override the default | ||
:meth:`~transformers.PretrainedConfig.to_dict`. | ||
Returns: | ||
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, | ||
""" | ||
output = copy.deepcopy(self.__dict__) | ||
output["vision_config"] = self.vision_config.to_dict() | ||
output["text_config"] = self.text_config.to_dict() | ||
output["model_type"] = self.__class__.model_type | ||
return output |
Oops, something went wrong.