In [None]:
import sys
sys.path.append("/mnt/lustre/lujinghui1/ofa_transformers_official/")
from ofa.modeling_ofa import OFAModel
from ofa.tokenization_ofa import OFATokenizer
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
import json, cv2
import numpy as np
device = 'cuda:0'

In [None]:
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
resolution = 384
patch_resize_transform = transforms.Compose([
        lambda image: image.convert("RGB"),
        transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
        transforms.ToTensor(), 
        transforms.Normalize(mean=mean, std=std)
])

def coord2bin(coords, w_resize_ratio, h_resize_ratio):
    coord_list = [float(coord) for coord in coords.strip().split()]
    bin_list = []
    
    bin_list += ["<bin_{}>".format(int(round(coord_list[0] * w_resize_ratio /512 * (1000 - 1))))]
    bin_list += ["<bin_{}>".format(int(round(coord_list[1] * h_resize_ratio /512 * (1000 - 1))))]
    bin_list += ["<bin_{}>".format(int(round(coord_list[2] * w_resize_ratio /512 * (1000 - 1))))]
    bin_list += ["<bin_{}>".format(int(round(coord_list[3] * h_resize_ratio /512 * (1000 - 1))))]
  
    return ' '.join(bin_list)

def bin2coord(bins, w_resize_ratio, h_resize_ratio):
    bin_list = [int(bin[5:-1]) for bin in bins.strip().split()]
    coord_list = []
    coord_list += [bin_list[0] / (1000 - 1) * 512 / w_resize_ratio]
    coord_list += [bin_list[1] / (1000 - 1) * 512 / h_resize_ratio]
    coord_list += [bin_list[2] / (1000 - 1) * 512 / w_resize_ratio]
    coord_list += [bin_list[3] / (1000 - 1) * 512 / h_resize_ratio]
    return coord_list

In [None]:
model_dir = '/mnt/lustre/lujinghui1/ofa_models/OFA_large'
# 加载预训练模型
tokenizer = OFATokenizer.from_pretrained(model_dir)
model = OFAModel.from_pretrained(model_dir, use_cache=False).to(device)
model.config.output_scores = True
model.config.return_dict_in_generate = True

In [None]:
gts = []
with open('/mnt/lustre/lujinghui1/events/anno/trash/trash_7_val.jsonl','r') as fin:
    for line in fin.readlines():
        line = json.loads(line)
        gts.append(line)

negatives, positives = [], []
for gt in gts:
    if len(gt['instances'])>0:
        positives.append(gt)
    else:
        negatives.append(gt)

In [None]:
len(positives), len(negatives)

In [None]:
img = Image.open('/mnt/lustre/lujinghui1/events/data/'+positives[1]['filename'])

w, h = img.size
print(f'w is {w}; h is {h}')
w_resize_ratio = resolution/ w
h_resize_ratio = resolution / h


ref = 'trash bin'
txt = f" which region does the text ' {ref} ' describe?"


inputs = tokenizer([txt], return_tensors="pt").input_ids.to(device)
patch_img = patch_resize_transform(img).unsqueeze(0).to(device)

In [None]:
gen = model.generate(inputs, patch_images=patch_img, num_beams=5, no_repeat_ngram_size=0,num_return_sequences=1) 

outputs = tokenizer.batch_decode(gen['sequences'], skip_special_tokens=True)
print(outputs)
print(gen['sequences_scores'])


In [None]:
# display result
img = np.array(img)
coord_list = bin2coord(outputs[0], w_resize_ratio, h_resize_ratio)
# coord_list = [float(coord) for coord in coords.split()]
cv2.rectangle(
    img,
    (int(coord_list[0]), int(coord_list[1])),
    (int(coord_list[2]), int(coord_list[3])),
    (0, 255, 0),
    3
)

coord_list

In [None]:
plt.figure(figsize = (25,50))
plt.imshow(img,interpolation='nearest')