In [1]:
from sentence_transformers import SentenceTransformer, util
import json
import torch
import numpy as np
import random
from tqdm.notebook import tqdm
# import os
import inflect
%config Completer.use_jedi = False

In [2]:
def pluralize(word):
    p = inflect.engine()
    return p.plural(word)

In [3]:
model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1').cuda()

In [4]:
def strip_prefix(caption):
    
    if caption.startswith("An image of a "):
        caption = caption.split("An image of a ")[-1]
    elif caption.startswith("An image of "):
        caption = caption.split("An image of ")[-1]
    elif caption.startswith("An image of... "):
        caption = caption.split("An image of... ")[-1]
        
#     if caption.startswith("Sure"):
#         print(caption)
#         print("&"*100)

    return caption

def strip_classname(caption):
    
    if caption.startswith("Possible category: "):
        caption = caption.split("Possible category: ")[-1] 
    elif caption.startswith("Sure!"):
        caption = caption.rsplit("\n",1)[-1].strip()
        
    if "\n" in caption:
        caption = caption.rsplit("\n",1)[-1]
        
    return caption

def get_classdesc(classnames, json_dict):
    descriptions = []
    for c in classnames:
        name = c.replace("+", " ").replace("_", " ")
        desc = " or ".join(json_dict[c])
        name_desc = "{} with {}".format(name, desc)
        descriptions.append(name_desc)
    return descriptions

### Caption - label similarity

In [5]:
def get_accuracy(caption_embeddings, class_embeddings, topk=5):
    similarity_matrix = util.cos_sim(caption_embeddings, class_embeddings)
    pseudoLabel = similarity_matrix.topk(topk, 1).indices    
    gt_label = torch.Tensor([id_to_classid[fid] for fid in all_flickrids]).view(-1,1).repeat(1,topk)
    matched_labels = (torch.Tensor(pseudoLabel) == torch.Tensor(gt_label))
    top1_acc = matched_labels[:,0].sum()/len(gt_label)
    topk_acc = matched_labels.any(1).sum()/len(gt_label)
    return top1_acc, topk_acc, similarity_matrix.argmax(1)

In [15]:
## category
meta = json.load(open("../metadata/officehome.json"))

# cats = {int(c["category_id"]):c["category_name"] for c in meta["categories"]}
cats = [cats[idx].replace("_", " ").lower() for idx in range(65)]
# cats = [cats[idx].replace(",","") for idx in range(68)]
# cats = [cats[idx].replace("_indoor","").replace("_outdoor","").replace("_"," ") for idx in range(205)]
# cats = [cats[idx].replace("_"," ") for idx in range(205)]
# cats = [cats[idx].replace("+"," ").replace("_"," ").replace("-"," ") for idx in range(345)]

In [16]:
cats

['drill',
 'exit sign',
 'bottle',
 'glasses',
 'computer',
 'file cabinet',
 'shelf',
 'toys',
 'sink',
 'laptop',
 'kettle',
 'folder',
 'keyboard',
 'flipflops',
 'pencil',
 'bed',
 'hammer',
 'toothbrush',
 'couch',
 'bike',
 'postit notes',
 'mug',
 'webcam',
 'desk lamp',
 'telephone',
 'helmet',
 'mouse',
 'pen',
 'monitor',
 'mop',
 'sneakers',
 'notebook',
 'backpack',
 'alarm clock',
 'push pin',
 'paper clip',
 'batteries',
 'radio',
 'fan',
 'ruler',
 'pan',
 'screwdriver',
 'trash can',
 'printer',
 'speaker',
 'eraser',
 'bucket',
 'chair',
 'calendar',
 'calculator',
 'flowers',
 'lamp shade',
 'spoon',
 'candles',
 'clipboards',
 'scissors',
 'tv',
 'curtains',
 'fork',
 'soda',
 'table',
 'knives',
 'oven',
 'refrigerator',
 'marker']

In [17]:
class_embeddings = model.encode(cats)

In [18]:
class_embeddings.shape

(65, 384)

In [22]:
# for dom in ["clipart", "sketch", "painting"]:
for dom in ['art', 'product', 'real', 'clipart']:
    for tgt in ['art', 'product', 'real', 'clipart']:
        if dom == tgt:continue
    
        geo = meta['{}_train'.format(dom)]
    #     id_to_cname = {ann["image_id"]:ann["class_name"].lower() for ann in geo["annotations"]}
        id_to_classid = {ann["image_id"]:ann["category"] for ann in geo["annotations"]}
        id_to_blip = {e["image_id"]:e["blip2_cap"] for e in geo["metadata"]}

        all_flickrids = list(id_to_blip.keys())
        all_captions = [id_to_blip[v] for v in all_flickrids]
        caption_embeddings = model.encode(all_captions)

        top1, top5, pseudo = get_accuracy(caption_embeddings, class_embeddings)
        print("{}:{:.2f}/{:.2f}".format(dom, top1*100, top5*100))

        with open("../hard_labels/officeHome_{}_{}_embedMatchPL.txt".format(tgt, dom), "w") as fh:
            write_str = ""
            for fid, pl in zip(all_flickrids, pseudo):
                write_str += "{} {}\n".format(fid, pl)
            fh.write(write_str)

art:78.69/89.71
art:78.69/89.71
art:78.69/89.71
product:88.66/96.15
product:88.66/96.15
product:88.66/96.15
real:85.48/94.61
real:85.48/94.61
real:85.48/94.61
clipart:73.46/87.80
clipart:73.46/87.80
clipart:73.46/87.80


In [20]:
len(caption_embeddings)

3975

In [None]:
for idx, cname in id_to_cname.items():
    tag_set = set(id_to_blip[idx].split(" "))
    cls = set(cname.split(","))
    cls = set(map(lambda v:v.replace(" ",""), cls))
    cls_aug = set(map(lambda v:pluralize(v), cls))
    cls = cls.union(cls_aug)

    common = tag_set.intersection(cls)
    
    if len(common) == 0:
        print("Tags:{}".format(tag_set))
        print("Class:{}".format(cls))
        print(common)
        print()

In [None]:
cls

### Caption - Caption similarity

In [None]:
##
meta = json.load(open("/home/tarun/metadata/geoPlaces_metadata.json"))
cname_to_label = {c["category_name"]:c["category_id"] for c in meta["categories"]}

## usa
cap = json.load(open("/home/tarun/llama/extracted_captions_geoplaces_usa.json"))["extracted"]
id_to_cap = {e["flickr_id"]:e["extracted_class_name"] for e in cap}
id_to_cap_usa = {k:strip_prefix(v) for k,v in id_to_cap.items()};
id_to_label_usa = {e["flickr_id"]:cname_to_label[e["gt_category"]] for e in cap}

## asia
cap = json.load(open("/home/tarun/llama/extracted_captions_geoplaces_asia.json"))["extracted"]
id_to_cap = {e["flickr_id"]:e["extracted_class_name"] for e in cap}
id_to_cap_asia = {k:strip_prefix(v) for k,v in id_to_cap.items()};
id_to_label_asia = {e["flickr_id"]:cname_to_label[e["gt_category"]] for e in cap}

In [None]:
usa_flickrids = list(id_to_cap_usa.keys())
asia_flickrids = list(id_to_cap_asia.keys())

usa_captions = [id_to_cap_usa[v] for v in usa_flickrids]
usa_embeddings = model.encode(usa_captions)
usa_labels = torch.tensor([id_to_label_usa[v] for v in usa_flickrids])

asia_captions = [id_to_cap_asia[v] for v in asia_flickrids]
asia_embeddings = model.encode(asia_captions)
asia_labels = torch.Tensor([id_to_label_asia[v] for v in asia_flickrids])

In [None]:
def get_similarity_acc(source_embed, target_embed, source_label, target_label, within=False, topks=[1,5]):
    
    topk = max(topks)
        
    if within:
        similarity_matrix = util.cos_sim(source_embed, source_embed)
        mostSimilar = similarity_matrix.topk(topk+1, 1).indices
        mostSimilar = mostSimilar[:,1:]
    else:
        similarity_matrix = util.cos_sim(source_embed, target_embed)
        mostSimilar = similarity_matrix.topk(topk, 1).indices
    
    similarLabels = torch.Tensor(target_label)[mostSimilar.long().reshape(-1)].reshape(-1, topk)
    source_label = torch.Tensor(source_label).view(-1,1).repeat(1,topk)
    
    matched_labels = (torch.Tensor(similarLabels) == torch.Tensor(source_label))
                      
    top1_acc = matched_labels[:,0].sum()/len(source_label)
    topk_acc = matched_labels.any(1).sum()/len(source_label)
        
    return top1_acc, topk_acc

In [None]:
get_similarity_acc(asia_embeddings, usa_embeddings, asia_labels, usa_labels)

In [None]:
similarity_matrix = util.cos_sim(asia_embeddings, usa_embeddings)

In [None]:
ind = torch.sort(similarity_matrix, descending=True, dim=0).indices

In [None]:
k_orderedNeighbors = torch.Tensor(asia_labels)[ind[:5]]

In [None]:
k_orderedNeighbors.shape

In [None]:
assigned_target_labels = torch.mode(k_orderedNeighbors, dim=0).values

In [None]:
matched_labels = (torch.Tensor(assigned_target_labels) == torch.Tensor(usa_labels))
matched_labels.sum()/len(usa_labels)

## Nearest Captions

In [None]:
##
meta = json.load(open("..//metadata/geoplaces.json"))

# cap_src = "llm_cap_llama_13b"
cap_src = "caption"

## usa
id_to_cap_usa = {e["image_id"]:e[cap_src] for e in meta['usa_train']['metadata']}
id_to_label_usa = {e["image_id"]:e["category"] for e in meta['usa_train']['annotations']}
id_to_class_usa = {e["image_id"]:e["class_name"] for e in meta['usa_train']['annotations']}

## asia
id_to_cap_asia = {e["image_id"]:e[cap_src] for e in meta['asia_train']['metadata']}
id_to_label_asia = {e["image_id"]:e["category"] for e in meta['asia_train']['annotations']}
id_to_class_asia = {e["image_id"]:e["class_name"] for e in meta['asia_train']['annotations']}

In [None]:
usa_flickrids = list(id_to_cap_usa.keys())
asia_flickrids = list(id_to_cap_asia.keys())

usa_captions = [id_to_cap_usa[v] for v in usa_flickrids]
usa_embeddings = model.encode(usa_captions)

asia_captions = [id_to_cap_asia[v] for v in asia_flickrids]
asia_embeddings = model.encode(asia_captions)

In [None]:
usa_labels = torch.tensor([id_to_label_usa[v] for v in usa_flickrids])
asia_labels = torch.tensor([id_to_label_asia[v] for v in asia_flickrids])

In [None]:
usa_classes = [id_to_class_usa[v] for v in usa_flickrids]
asia_classes = [id_to_class_asia[v] for v in asia_flickrids]

In [None]:
similarity = util.cos_sim(asia_embeddings, usa_embeddings)

K = 5
mostSimilar = similarity.topk(K, dim=-1).indices

mostSimilar.shape

mostSimilarLabels = usa_labels[mostSimilar]
pseudoLabels = torch.mode(mostSimilarLabels, dim=-1).values

matched_labels = (torch.Tensor(pseudoLabels) == torch.Tensor(asia_labels))
matched_labels.sum()/len(asia_labels)

In [None]:
similarity = util.cos_sim(usa_embeddings, asia_embeddings)

K = 5
mostSimilar = similarity.topk(K, dim=-1).indices

mostSimilar.shape

mostSimilarLabels = asia_labels[mostSimilar]
pseudoLabels = torch.mode(mostSimilarLabels, dim=-1).values

matched_labels = (torch.Tensor(pseudoLabels) == torch.Tensor(usa_labels))
matched_labels.sum()/len(usa_labels)

In [None]:
print(get_similarity_acc(asia_embeddings, usa_embeddings, asia_labels, usa_labels))
print(get_similarity_acc(usa_embeddings, asia_embeddings, usa_labels, asia_labels))

In [None]:
matched_labels = (torch.Tensor(pseudoLabels) == torch.Tensor(asia_labels))
matched_labels.sum()/len(asia_labels)

In [None]:
for idx in np.random.choice(np.arange(1e5), 20):
    index = int(idx)
    otherIdx = mostSimilar[index][0]
    print("[{}]{}:{}::{}:{}\n".format(int(asia_classes[index] == usa_classes[otherIdx]),\
                                      asia_captions[index],\
                               asia_classes[index],\
                               usa_captions[otherIdx],
                               usa_classes[otherIdx]))
    