In [None]:
# Importa bibliotecas necessárias

import json
import numpy as np
import torch
import pickle
import warnings
import torchvision.transforms as transforms

from PIL import Image
from tqdm import tqdm

from lavis.models import load_model_and_preprocess
device = 'cuda'
model, vis_processors, _ = load_model_and_preprocess(name="blip2_feature_extractor", model_type="pretrain_vitL", is_eval=True, device=device)

warnings.filterwarnings('ignore')

In [None]:
def load_base(name_arq):

    f = open(name_arq, encoding="utf8")
    data = json.load(f)

    # Pega apenas as "respondiveis"
    #data = [d for d in data if d["answerable"] == 1]

    return data

In [None]:
# convert data to a normalized torch.FloatTensor
transform_patches = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.ToTensor()])

transform_img = transforms.Compose([
    transforms.Resize(size=(112,112)),
    transforms.ToTensor()])

transform_pil = transforms.ToPILImage()

In [None]:
def get_embedding_blip2(model, img_tensor, vis_processors, device):

    #emb_format = np.empty()
    i = 0
    for img in img_tensor:

        image = transform_pil(img)

        #preprocess image
        image_processed = vis_processors["eval"](image).unsqueeze(0).to(device)

        sample = {"image": image_processed}

        image_emb = model.extract_features(sample, mode="image").image_embeds[0,0,:] # size (768)

        if i == 0:
          emb_format = image_emb.cpu().numpy().reshape(1,768)

        else:
          emb_format = np.vstack((emb_format, image_emb.cpu().numpy().reshape(1,768)))

        i+=1


        #list_img.append(image_emb.cpu().numpy())

    return emb_format

In [None]:
# Visão
"""
{"imagem.jpg": [vetores],
"imagem2.jpg": [vetores],
...
}
"""

In [None]:
def format_embedding(embedding_img):

  for i in range(len(embedding_img)):

    if i==0:
      emb_format = np.concatenate((embedding_img[i].reshape(1,768), embedding_img[i+1].reshape(1,768)), axis=0)
    else:
      emb_format = np.concatenate((emb_format, embedding_img[i].reshape(1,768)), axis=0)

  return emb_format

In [None]:
def get_info_visao(name_arq, model, tam_base, vis_processors, device):

    # Realiza a leitura da base
    data = load_base(name_arq)

    #data = data[:2]

    # Define diretório onde se encontram as imagens
    dir_img = name_arq.split(".json")[0]

    # Irá carregar as informações visuais referentes a cada uma das imagens
    info_visao = {}

    for info in tqdm(data):

        base = []

        # Pega o nome da imagem
        name_img = info["image"]

        # Faz a leitura da imagem
        img = Image.open(dir_img+"/"+name_img)

        # Padroniza a imagem
        img_patches = transform_patches(img)
        img_tensor = transform_img(img)

        # Coloca a imagem e seus patches em uma lista para obter os seus embeddings
        base.append(img_tensor)
        base.append(img_patches[:, :112, :112])
        base.append(img_patches[:, :112, 112:])
        base.append(img_patches[:, 112:, :112])
        base.append(img_patches[:, 112:, 112:])

        # Pega os embeddings referentes a imagem
        embedding_img = get_embedding_blip2(model, base, vis_processors, device)

        # Atualiza as informações da imagem
        info_visao[name_img] = embedding_img

        del(embedding_img)

    return info_visao

In [None]:
def save_info_visao(name_arq_in, info_visao):

    name_arq_out = name_arq_in.split(".json")[0]+"_info_visao.pkl"

    file = open(name_arq_out, 'ab')
    pickle.dump(info_visao, file, pickle.HIGHEST_PROTOCOL)
    file.close()

    return

In [None]:
%%time

name_arq = "test.json"
info_visao = get_info_visao(name_arq, model, 1, vis_processors, device)
save_info_visao(name_arq, info_visao)

In [None]:
len(data)

In [None]:
#info_visao['VizWiz_train_00000000.jpg'].shape

#### Fontes

https://colab.research.google.com/drive/1jIflL9-gktbXq_2cEE_KM7yYq2PaOKRM?authuser=1#scrollTo=8cRyNhMQaAyh

https://khvmaths.medium.com/vision-transformer-understanding-the-underlying-concept-83d699d71180