In [12]:
import torch
from torchvision import transforms
from transformers import OFATokenizer, OFAModel
from transformers.models.ofa.generate import sequence_generator
import numpy as np
from torch.nn.utils.rnn import pad_sequence
import urllib.request

In [2]:
ckpt_dir = "/Users/iustingrigoras/Desktop/OFA/OFA-base"
tokenizer = OFATokenizer.from_pretrained(ckpt_dir)

/Users/iustingrigoras/Desktop/OFA/OFA-base
<super: <class 'OFATokenizer'>, <OFATokenizer object>>


In [None]:
model = OFAModel.from_pretrained(ckpt_dir, use_cache=True, output_atentions=True, output_hidden_states=True)
generator = sequence_generator.SequenceGenerator(
                    tokenizer=tokenizer,
                    beam_size=3,
                    max_len_b=10, 
                    min_len=0,
                    no_repeat_ngram_size=3,
                    temperature=0.5,
                )

In [3]:
if torch.backends.mps.is_available():
    device = torch.device("mps")  # Use Metal
else:
    device = torch.device("cpu")

In [5]:
ckpt_file = "/Users/iustingrigoras/Desktop/OFA/OFA-base/pytorch_model.bin"
ckpt = torch.load(ckpt_file)
model = OFAModel.from_pretrained(ckpt_dir, use_cache=True, output_atentions=True, output_hidden_states=True)
model.eval()
model = model.to(device)

In [9]:
import torchvision.transforms as transforms

def get_transform():
    # Normalize using the mean and standard deviation that were used for model training
    # These values are typically for ImageNet if not specified otherwise by the model
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform = transforms.Compose([
        transforms.ToPILImage(),            # Convert numpy array or tensor to PIL Image
        transforms.Resize((224, 224)),      # Resize to the input size that the model expects
        transforms.ToTensor(),              # Convert PIL Image to a tensor
        normalize,                          # Normalize the image
    ])
    return transform

In [13]:
def fetch_image(self, image_location: str):
        '''
        This helper function takes the path to an image (either an URL or a local path) and
        returns the image as an numpy array.
        '''
        if image_location.startswith('http'):
            urllib.request.urlretrieve(image_location, 'temp.jpg')
            image_location = 'temp.jpg'

        img = Image.open(image_location).convert('RGB')
        img = np.array(img)
        return img

In [24]:
def build_batch(self, input_text, image, answer=None, person_info=None):
        if not input_text:
            input_text = ''
        if answer is None:
            input_ids = self.tokenizer.encode(input_text, return_tensors='pt', padding=True, truncation=True)
            input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)  # all*input_ids
            txt_type_ids = torch.zeros_like(input_ids)
        else:
            input_ids_q = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(input_text))
            input_ids_c = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(answer))
            input_ids = [torch.tensor(self.tokenizer.build_inputs_with_special_tokens(input_ids_q, input_ids_c))]
            input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)  # all*input_ids

            txt_type_ids = torch.tensor((len(input_ids_q) + 2 )* [0] + (len(input_ids_c) + 1) * [2])

        position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0)
        num_sents = [input_ids.size(0)]       
        txt_lens = [i.size(0) for i in input_ids]

        if image is None:
            images_batch = None
        else:
            images_batch = torch.as_tensor(image.copy(), dtype=torch.float32)
            images_batch = get_transform(images_batch.permute(2, 0, 1))

        batch = {'input_ids': input_ids, 'txt_type_ids': txt_type_ids, 'position_ids': position_ids, 'images': images_batch,
                "txt_lens": txt_lens, "num_sents": num_sents, 'person_info': person_info}
        batch = self.move_to_device(batch)
        
        return batch

def data_setup(self, ex_id, image_location, input_text):
    image = self.fetch_image(image_location) if image_location else None

    batch = self.build_batch(input_text, image, answer=None, person_info=None)
    scores, hidden_states, attentions = self.model(batch,
                                                   compute_loss=False,
                                                    output_attentions=True,
                                                    output_hidden_states=True)

    attentions = torch.stack(attentions).transpose(1,0).detach().cpu()[0]

    if batch['images'] is None:
        img, img_coords = np.array([]), []
        len_img = 0
    else:
        image1, mask1 = self.model.preprocess_image(batch['images'].to(self.device))
        image1 = (image1 * self.model.pixel_std + self.model.pixel_mean) * mask1
        img = image1.cpu().numpy().astype(int).squeeze().transpose(1,2,0)

        h, w, _ = img.shape
        h0, w0 = h//64, w//64
        len_img = w0 * h0
        img_coords = np.fliplr(list(np.ndindex(h0, w0)))

    input_ids = batch['input_ids'].cpu()
    len_text = input_ids.size(1)
    txt_tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0, :len_text])

    len_tokens = len_text + len_img
    attentions = attentions[:, :, :len_tokens, :len_tokens]
    hidden_states = [hs[0].detach().cpu().numpy()[:len_tokens] for hs in hidden_states]

In [26]:
images = ["examples/family.jpeg"]
texts = ["what is the family doing?"]
data = [data_setup(i, img, txt) for i, (img, txt) in enumerate(zip(images, texts))]

TypeError: data_setup() missing 1 required positional argument: 'input_text'