In [9]:
import os
os.environ["HF_TOKEN"] = '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'

In [3]:
import tensorflow
from transformers import T5Tokenizer, TFT5Model
from transformers import ViTImageProcessor, TFViTModel

pretrained_t5_path = 'google-t5/t5-small'
pretrained_vit_path = 'google/vit-base-patch16-224-in21k'

tokenizer = T5Tokenizer.from_pretrained(pretrained_t5_path)
t5 = TFT5Model.from_pretrained(pretrained_t5_path)
t5_encoder = t5.encoder

vit = TFViTModel.from_pretrained(pretrained_vit_path)


Downloading spiece.model: 100%|██████████| 792k/792k [00:00<00:00, 2.52MB/s]
Downloading tokenizer_config.json: 100%|██████████| 2.32k/2.32k [00:00<00:00, 775kB/s]
Downloading config.json: 100%|██████████| 1.21k/1.21k [00:00<00:00, 402kB/s]
Downloading tf_model.h5: 100%|██████████| 242M/242M [00:18<00:00, 13.3MB/s] 
All model checkpoint layers were used when initializing TFT5Model.

All the layers of TFT5Model were initialized from the model checkpoint at google-t5/t5-small.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5Model for predictions without further training.
All model checkpoint layers were used when initializing TFViTModel.

All the layers of TFViTModel were initialized from the model checkpoint at google/vit-base-patch16-224-in21k.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFViTModel for predictions without further training.


In [13]:
from transformers import T5Tokenizer, TFT5Model, TFViTModel
from tensorflow.keras import layers, Model, Input
import tensorflow as tf

def VQAModel(t5):
    visual_embedding_shape = (197, 768)
    text_embedding_shape = (15, 512)
    
    visual_embedding = Input(visual_embedding_shape, name='visual_embedding')
    text_embedding = Input(text_embedding_shape, name='text_embedding')

    # 視覚埋め込みをテキスト埋め込みの次元に合わせて変換
    x_v = layers.Dense(
        units=text_embedding_shape[1], 
        activation='relu', 
        use_bias=True,
    )(visual_embedding)
    
    # 注意機構を適用して視覚とテキストの埋め込みを組み合わせる
    attention_output = layers.Attention()([x_v, text_embedding])
    
    # 出力を結合
    x = layers.Concatenate(axis=1)([attention_output, text_embedding])

    # 結合された埋め込みをT5デコーダに入力するためのエンコードを行う
    decoder_input_ids = layers.Input(shape=(None,), dtype=tf.int32, name='decoder_input_ids')
    t5_output = t5.decoder(input_ids=decoder_input_ids, encoder_hidden_states=x).last_hidden_state

    return Model(inputs=[visual_embedding, text_embedding, decoder_input_ids], outputs=t5_output)


In [16]:
# 入力例
text_input = tokenizer('I am a Ironman.', return_tensors='tf', padding='max_length', max_length=15).input_ids
visual_input = tf.random.uniform((1, 3, 224, 224), minval=-1, maxval=1)
visual_embedding = vit(visual_input).last_hidden_state

# テキスト入力をT5エンコーダーで埋め込みに変換
text_embedding = t5.encoder(input_ids=text_input).last_hidden_state

# デコーダの入力用IDを作成
decoder_input_ids = tokenizer('translate English to German: This is a test.', return_tensors='tf').input_ids

# VQAモデルを作成
model = VQAModel(t5)

# 予測
output = model([visual_embedding, text_embedding, decoder_input_ids])
print(output)

tf.Tensor(
[[[ 3.1208003e-02  9.1505021e-02 -1.1999548e-01 ...  7.0121162e-02
   -5.2835385e-05 -2.0129712e-01]
  [ 2.2196151e-02  1.0012906e-01 -3.8860645e-02 ...  3.2922305e-02
    2.8177467e-04 -1.2783377e-01]
  [ 2.7133652e-03 -3.7767246e-02  4.9384516e-02 ...  9.5747434e-02
    7.4668397e-04 -7.8224063e-02]
  ...
  [ 8.6837165e-02  1.2081293e-01 -1.0253500e-01 ... -6.0062591e-02
    1.3146683e-04 -3.0444263e-02]
  [ 1.2823616e-01  8.4139861e-02 -9.7969873e-03 ...  1.4082185e-02
    3.5240687e-04  3.7147999e-02]
  [ 8.5922129e-02  1.0110289e-01 -5.7305817e-02 ...  2.5469355e-02
    4.3694774e-05  5.9990045e-02]]], shape=(1, 12, 512), dtype=float32)
