In [None]:
!pip install comet_ml

In [1]:
import clip
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image

from combiner import Combiner

device = "cuda" if torch.cuda.is_available() else "cpu"
combiner = Combiner(640, 2560, 5120).to(device)
saved_state_dict = torch.load("/Users/vothanhnhan/Desktop/gift/AIC2/model/cirr_comb_RN50x4_fullft.pt", map_location=device)
combiner.load_state_dict(saved_state_dict["Combiner"], strict=False)
combiner.eval()
combining_function = combiner.combine_features

clip_model, preprocess = clip.load("RN50x4", device=device)
print(device)
# * The difference with our previous model is that this clip_model uses 'RN50x4', not 'ViT-B32', and it's not as good.

cpu


In [2]:
pairs_id = []
predicted_features = torch.empty((0, clip_model.visual.output_dim)).to(device, non_blocking=True)
group_members = []
reference_names = []
# Load data
path = "../../datasets/100577935.jpg"
captions = "A man with a goatee in a black shirt and white latex gloves is using a tattoo gun to place a tattoo on someone 's back ."
# * Assume this is a caption generated by BLIP.
text_inputs = clip.tokenize(captions, context_length=77).to(device)

# Compute the predicted features
with torch.no_grad():
    text_features = clip_model.encode_text(text_inputs)
    image = preprocess(Image.open(path)).unsqueeze(0).to(device)
    reference_image_features = clip_model.encode_image(image)
    batch_predicted_features = combining_function(reference_image_features.reshape(1,640), text_features)

predicted_features = torch.vstack((predicted_features, F.normalize(batch_predicted_features, dim=-1)))

# print(predicted_features.shape)
# print(predicted_features)

# * cal cosine similarity between text_features (ground truth) and predicted features (combined features)
print(f"(orignal text with combined) Similarity: {round(torch.nn.functional.cosine_similarity(text_features,predicted_features, dim=1).item()*100,2)}%")


# * cal cosine similarity between query and predicted features (combined features)
query = "A man is putting a tattoo on a another 's man upper back ."
query_inputs = clip.tokenize(query, context_length=77).to(device)
query_features = clip_model.encode_text(query_inputs)
print(f"(query to combined) Similarity: {round(torch.nn.functional.cosine_similarity(query_features,predicted_features, dim=1).item()*100,2)}%")

# * cal orginal cosine similarity between query and image features (previous solution)
print(f"(previous solution) Similarity: {round(torch.nn.functional.cosine_similarity(query_features,reference_image_features, dim=1).item()*100,2)}%")


(orignal text with combined) Similarity: 96.49%
(query to combined) Similarity: 69.48%
(previous solution) Similarity: 40.72%
