In [2]:
import torch
import copy
import time
import requests
import io
import numpy as np
import re

import ipdb

from PIL import Image

from vilt.config import ex
from vilt.modules import ViLTransformerSS

from vilt.modules.objectives import cost_matrix_cosine, ipot
from vilt.transforms import pixelbert_transform
from vilt.datamodules.datamodule_base import get_pretrained_tokenizer

In [3]:
_config = {'exp_name': 'vilt', 'seed': 0, 'datasets': ['coco', 'vg', 'sbu', 'gcc'], 'loss_names': {'itm': 1, 'mlm': 1, 'mpp': 0, 'vqa': 0, 'nlvr2': 0, 'irtr': 0}, 'batch_size': 4096, 'train_transform_keys': ['pixelbert'], 'val_transform_keys': ['pixelbert'], 'image_size': 384, 'max_image_len': -1, 'patch_size': 32, 'draw_false_image': 1, 'image_only': False, 'vqav2_label_size': 3129, 'max_text_len': 40, 'tokenizer': 'bert-base-uncased', 'vocab_size': 30522, 'whole_word_masking': False, 'mlm_prob': 0.15, 'draw_false_text': 0, 'vit': 'vit_base_patch32_384', 'hidden_size': 768, 'num_heads': 12, 'num_layers': 12, 'mlp_ratio': 4, 'drop_rate': 0.1, 'optim_type': 'adamw', 'learning_rate': 0.0001, 'weight_decay': 0.01, 'decay_power': 1, 'max_epoch': 100, 'max_steps': 25000, 'warmup_steps': 2500, 'end_lr': 0, 'lr_mult': 1, 'get_recall_metric': False, 'resume_from': None, 'fast_dev_run': False, 'val_check_interval': 1.0, 'test_only': False, 'data_root': '', 'log_dir': 'result', 'per_gpu_batchsize': 0, 'num_gpus': 0, 'num_nodes': 1, 'load_path': 'weights/vilt_200k_mlm_itm.ckpt', 'num_workers': 8, 'precision': 16}
_config = copy.deepcopy(_config)

In [4]:

loss_names = {
    "itm": 0,
    "mlm": 0.5,
    "mpp": 0,
    "vqa": 0,
    "imgcls": 0,
    "nlvr2": 0,
    "irtr": 0,
    "arc": 0,
}
tokenizer = get_pretrained_tokenizer(_config["tokenizer"])


_config.update(
    {
        "loss_names": loss_names,
    }
)

model = ViLTransformerSS(_config)
model.setup("test")
#开启评估模式
model.eval()


device = "cuda:0" if _config["num_gpus"] > 0 else "cpu"
model.to(device)

ViLTransformerSS(
  (text_embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(40, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (token_type_embeddings): Embedding(2, 768)
  (transformer): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
    )
    (pos_drop): Dropout(p=0.1, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.1, inplace=False)
        )
        (drop_path): Identity()
        

In [5]:
def infer(url, mp_text, hidx):
    try:
        res = requests.get(url)
        image = Image.open(io.BytesIO(res.content)).convert("RGB")
        img = pixelbert_transform(size=384)(image)
        img = img.unsqueeze(0).to(device)
        # 完成img的预处理
    except:
        return False

    batch = {"text": [""], "image": [None]}
    tl = len(re.findall("\[MASK\]", mp_text))
    inferred_token = [mp_text]
    batch["image"][0] = img


    with torch.no_grad():
        for i in range(tl):
            batch["text"] = inferred_token
            encoded = tokenizer(inferred_token)
            batch["text_ids"] = torch.tensor(encoded["input_ids"]).to(device)
            batch["text_labels"] = torch.tensor(encoded["input_ids"]).to(device)
            batch["text_masks"] = torch.tensor(encoded["attention_mask"]).to(device)
            encoded = encoded["input_ids"][0][1:-1]
            infer = model(batch)
            print('111')
            print(infer["text_feats"])
            mlm_logits = model.mlm_score(infer["text_feats"])[0, 1:-1]
            mlm_values, mlm_ids = mlm_logits.softmax(dim=-1).max(dim=-1)
            mlm_values[torch.tensor(encoded) != 103] = 0
            select = mlm_values.argmax().item()
            encoded[select] = mlm_ids[select].item()
            inferred_token = [tokenizer.decode(encoded)]

    selected_token = ""
    encoded = tokenizer(inferred_token)
    # 完成text的预处理

    if hidx > 0 and hidx < len(encoded["input_ids"][0][:-1]):
        with torch.no_grad():
            batch["text"] = inferred_token
            batch["text_ids"] = torch.tensor(encoded["input_ids"]).to(device)
            batch["text_labels"] = torch.tensor(encoded["input_ids"]).to(device)
            batch["text_masks"] = torch.tensor(encoded["attention_mask"]).to(device)
            infer = model(batch)
            
            txt_emb, img_emb = infer["text_feats"], infer["image_feats"]

            print(batch["image"])
            print(batch["text"])

            print("text_feats", txt_emb)
            print("image_feats", img_emb)
            print(infer["cls_feats"].shape)

            
            txt_mask, img_mask = (
                infer["text_masks"].bool(),
                infer["image_masks"].bool(),
            )
            for i, _len in enumerate(txt_mask.sum(dim=1)):
                txt_mask[i, _len - 1] = False
            txt_mask[:, 0] = False
            img_mask[:, 0] = False
            txt_pad, img_pad = ~txt_mask, ~img_mask

            cost = cost_matrix_cosine(txt_emb.float(), img_emb.float())
           
            joint_pad = txt_pad.unsqueeze(-1) | img_pad.unsqueeze(-2)
            cost.masked_fill_(joint_pad, 0)

            txt_len = (txt_pad.size(1) - txt_pad.sum(dim=1, keepdim=False)).to(
                dtype=cost.dtype
            )
            img_len = (img_pad.size(1) - img_pad.sum(dim=1, keepdim=False)).to(
                dtype=cost.dtype
            )
            T = ipot(
                cost.detach(),
                txt_len,
                txt_pad,
                img_len,
                img_pad,
                joint_pad,
                0.1,
                1000,
                1,
            )

            plan = T[0]
            plan_single = plan * len(txt_emb)
            cost_ = plan_single.t()

            cost_ = cost_[hidx][1:].cpu()

            patch_index, (H, W) = infer["patch_index"]
            heatmap = torch.zeros(H, W)
            for i, pidx in enumerate(patch_index[0]):
                h, w = pidx[0].item(), pidx[1].item()
                heatmap[h, w] = cost_[i]

            heatmap = (heatmap - heatmap.mean()) / heatmap.std()
            heatmap = np.clip(heatmap, 1.0, 3.0)
            heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())

            _w, _h = image.size
            overlay = Image.fromarray(np.uint8(heatmap * 255), "L").resize(
                (_w, _h), resample=Image.NEAREST
            )
            image_rgba = image.copy()
            image_rgba.putalpha(overlay)
            image = image_rgba

            selected_token = tokenizer.convert_ids_to_tokens(
                encoded["input_ids"][0][hidx]
            )

    return [np.array(image), inferred_token[0], selected_token]

In [6]:
url = 'https://s3.geograph.org.uk/geophotos/06/21/24/6212487_1cca7f3f_1024x1024.jpg'
mp_text = 'a display of flowers growing out and over the retaining wall in front of cottages on a cloudy day."'
hidx = 4

ans = infer(url, mp_text, hidx)
# 将数组反序列化成图片
Image.fromarray(np.uint8(ans[0]))

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


AttributeError: 'list' object has no attribute 'shape'