In [1]:
import models

In [2]:
from PIL import Image
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def load_one_image(file_path, image_size, device):
    raw_image = Image.open(file_path).convert('RGB')   

    w,h = raw_image.size
    
    transform = transforms.Compose([
        transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        ]) 
    image = transform(raw_image).unsqueeze(0).to(device)   
    return image

In [4]:
from models.blip import blip_feature_extractor

image_size = 224
model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth'   

model = blip_feature_extractor(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device=device)

load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth


In [5]:
def do_one_inference(caption, image):
    itm_output = model(image,caption,match_head='itm')
    itm_output = itm_output.to('cpu')
    itm_score = torch.nn.functional.softmax(itm_output,dim=1)[:,1]
    # print('The image and text is matched with a probability of %.4f'%itm_score)
    # itc_score = model(image,caption,match_head='itc')
    # itc_score = itc_score.to('cpu')
    # print('The image feature and text feature has a cosine similarity of %.4f'%itc_score)
    if itm_score > 0.5:
        if_match = 1
    else:
        if_match = 0
    return itm_score, 0, if_match

In [6]:
import pandas as pd
from PIL import UnidentifiedImageError
from fastprogress.fastprogress import progress_bar
svo_cleaned = pd.read_csv('svo-probs/svo_valid.csv')
svo_cleaned.head(2)
testset = svo_cleaned[["sentence", "pos_image_id", "neg_image_id"]]
image_path = "/data1/kenan/SVO_Probes/images/"

In [17]:
testset2 = svo_cleaned[35764:][["sentence", "pos_image_id", "neg_image_id"]]

In [18]:
matched_id = []
for idx, (sentence, pos, neg) in testset2.iterrows():
    if idx % 1000 == 0:
        print(idx)
    try:
        pos_url = image_path + str(pos) + ".jpg"
        image = load_one_image(pos_url, image_size=image_size,device=device)
        pos_image_feature = model(image, sentence, mode='image')[0,0]

        neg_url = image_path + str(neg) + ".jpg"
        image = load_one_image(neg_url, image_size=image_size,device=device)
        neg_image_feature = model(image, sentence, mode='image')[0,0]

        text_feat = model(image, sentence, mode='text')[0,0]

        pos_sim = pos_image_feature @ text_feat.t()
        neg_sim = neg_image_feature @ text_feat.t()

        if pos_sim > neg_sim:
            matched_id.append(0)
        elif pos_sim == neg_sim:
            matched_id.append(0.5)
        else:
            matched_id.append(1)
    except UnidentifiedImageError:
        matched_id.append("image_failed")

36000


In [19]:
len(matched_id)

366

In [20]:
svo_cleaned_35764 = svo_cleaned[35764:]

In [21]:
svo_cleaned_35764["matched_id"] = matched_id

In [22]:
svo_cleaned_35764.to_csv("svo-probs/svo_35764.csv", index=False)