In [1]:
from transformers import AutoModel
model = AutoModel.from_pretrained("OFA-Sys/ofa-large", output_attentions=True)

Downloading:   0%|          | 0.00/1.36k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.28G [00:00<?, ?B/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [2]:
from transformers import AutoModelForSequenceClassification, OFATokenizer, utils
from PIL import Image
from torchvision import transforms
from transformers import OFATokenizer, OFAModel
from transformers.models.ofa.generate import sequence_generator
import torch
from bertviz import model_view

In [3]:
tokenizer = OFATokenizer.from_pretrained("OFA-Sys/ofa-large")

OFA-Sys/ofa-large
<super: <class 'OFATokenizer'>, <OFATokenizer object>>


Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

In [4]:
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
resolution = 384
patch_resize_transform = transforms.Compose([
        lambda image: image.convert("RGB"),
        transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
        transforms.ToTensor(), 
        transforms.Normalize(mean=mean, std=std)
    ])

In [24]:
txt = " why is the woman pointing a gun at him?"
inputs = tokenizer([txt], return_tensors="pt").input_ids
img = Image.open("examples/vcr_3880.jpg")
patch_img = patch_resize_transform(img).unsqueeze(0)

In [25]:
generator = sequence_generator.SequenceGenerator(
                    tokenizer=tokenizer,
                    beam_size=2,
                    max_len_b=16, 
                    min_len=0,
                    no_repeat_ngram_size=3,
                )

In [26]:
data = {}
data["net_input"] = {"input_ids": inputs, 'patch_images': patch_img, 'patch_masks':torch.tensor([True])}
gen_output = generator.generate([model], data)
gen = [gen_output[i][0]["tokens"] for i in range(len(gen_output))]

Type of input_ids: <class 'torch.Tensor'>
Shape of input_ids: torch.Size([1, 12])


In [27]:
gen = model.generate(inputs, patch_images=patch_img, num_beams=5, no_repeat_ngram_size=3) 

print(tokenizer.batch_decode(gen, skip_special_tokens=True))

Type of input_ids: <class 'torch.Tensor'>
Shape of input_ids: torch.Size([1, 12])
[' to scare him']


## Visualize attention with BertViz library https://github.com/jessevig/bertviz

In [9]:
utils.logging.set_verbosity_error()

In [10]:
encoder_input_ids = tokenizer("The family is in the mountains", return_tensors="pt", add_special_tokens=True).input_ids
with tokenizer.as_target_tokenizer():
    decoder_input_ids = tokenizer("The snow is white", return_tensors="pt", add_special_tokens=True).input_ids

outputs = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)

encoder_text = tokenizer.convert_ids_to_tokens(encoder_input_ids[0])
decoder_text = tokenizer.convert_ids_to_tokens(decoder_input_ids[0])

Type of input_ids: <class 'torch.Tensor'>
Shape of input_ids: torch.Size([1, 8])


In [11]:
from bertviz import model_view
model_view(
    encoder_attention=outputs.encoder_attentions,
    decoder_attention=outputs.decoder_attentions,
    cross_attention=outputs.cross_attentions,
    encoder_tokens= encoder_text,
    decoder_tokens = decoder_text,
    display_mode="light"
)

<IPython.core.display.Javascript object>