In [1]:
import json
import torch
import numpy as np
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer, util
import random
from tqdm.notebook import tqdm
from itertools import chain
# import os
%config Completer.use_jedi = False

from nltk.util import ngrams

In [2]:
model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1').cuda()

In [9]:
def generate_ngrams(text, n, sep=" "):
    """
    Function to generate all ngrams for a given 'n' from a string.

    :param text: The input string from which to generate ngrams.
    :param n: The size of the ngram.
    :return: A list of ngrams as strings.
    """
    if isinstance(n, list):
        ngram_list = []
        for ni in n:
            ngram_list += generate_ngrams(text, ni, sep)
        return ngram_list
    # Split the text into words
    words = text.split(sep)

    # Generate ngrams
    ngram_list = [' '.join(gram) for gram in ngrams(words, n)]

    return ngram_list

def preprocess(text):
    caption = text["caption"]
    ngram_caption = generate_ngrams(caption, [1,2,3,4])
    
    tags = text["tags"]#.split(",")
    ngram_tags = generate_ngrams(tags, [1,2,3,4], sep=",")
#     ngram_tags = tags
    
    all_ngrams = ngram_caption + ngram_tags
    return list(set(all_ngrams))

## Places

In [None]:
MAX_DESC_LEN=12
N_CLS=205

In [None]:
def strip_name(cname):
    return cname.replace("_indoor","").replace("_outdoor","").replace("_"," ").replace("-"," ")

def get_accuracy(similarity_matrix, topk=5):
    pseudoLabel = similarity_matrix.topk(topk, 1).indices  
    gt_label = torch.Tensor([id_to_classid[fid] for fid in all_flickrids]).view(-1,1).repeat(1,topk)
    matched_labels = (torch.Tensor(pseudoLabel) == torch.Tensor(gt_label))
    top1_acc = matched_labels[:,0].sum()/len(gt_label)
    topk_acc = matched_labels.any(1).sum()/len(gt_label)
    return top1_acc, topk_acc

In [None]:
places_desc = json.load(open("descriptors_geoplaces.json"))
places_desc = [[strip_name(k)] + v for k,v in places_desc.items()]
# places_desc = [[strip_name(k)] for k,v in places_desc.items()]

padded_desc = []
mask = []
for p in places_desc:
    pad_len = MAX_DESC_LEN-len(p)
    padded_desc.append(p + ['EOS']*pad_len)
    mask.append([1]*len(p) + [0]*pad_len)
mask = torch.Tensor(mask)
mask = mask[None]

all_desc = list(chain.from_iterable(padded_desc))

all_embeddings = model.encode(all_desc)

## load data and json files
domain = "asia"
meta = json.load(open("/home/tarun/metadata/geoPlaces_metadata.json"))
geo = meta[f'{domain}_train']
id_to_classid = {ann["image_id"]:ann["category"] for ann in geo["annotations"]}

cap = json.load(open(f"/home/tarun/llama/extracted_captions_geoplaces_{domain}.json"))["extracted"]
id_to_cap = {e["flickr_id"]:e["extracted_class_name"] for e in cap}

all_flickrids = list(id_to_cap.keys())
all_captions = [id_to_cap[v] for v in all_flickrids]
caption_embeddings = model.encode(all_captions)

similarity_matrix = util.cos_sim(caption_embeddings, all_embeddings)
similarity_matrix = similarity_matrix.reshape(-1, N_CLS, MAX_DESC_LEN)

In [None]:
similarity_matrix = similarity_matrix * mask

averaged_similarity = similarity_matrix.max(-1).values#/mask.sum(-1)

get_accuracy(averaged_similarity)

In [None]:
all_captions[0]

In [None]:
places_desc[1]

In [None]:
similarity_matrix[0][1]

In [None]:
mask.sum(-1)

## GeoImnet

In [None]:
MAX_DESC_LEN=12
N_CLS=600

In [None]:
def strip_name(cname):
    return cname.replace("+"," ").replace("_"," ")

In [None]:
imnet_desc = json.load(open("descriptors_geoimnet.json"))
imnet_desc = [[strip_name(k)] + v for k,v in imnet_desc.items()]
# imnet_desc = [[strip_name(k)] for k,v in imnet_desc.items()]

padded_desc = []
mask = []
for p in imnet_desc:
    pad_len = MAX_DESC_LEN-len(p)
    padded_desc.append(p + ['EOS']*pad_len)
    mask.append([1]*len(p) + [0]*pad_len)
mask = torch.Tensor(mask)
mask = mask[None]

In [None]:
all_desc = list(chain.from_iterable(padded_desc))

all_embeddings = model.encode(all_desc)

## load data and json files
domain = "asia"
meta = json.load(open("/home/tarun/metadata/geoImnet_metadata.json"))
geo = meta[f'{domain}_train']
id_to_classid = {ann["image_id"]:ann["category"] for ann in geo["annotations"]}

cap = json.load(open(f"/home/tarun/llama/extracted_captions_geoimnet_{domain}.json"))["extracted"]
id_to_cap = {e["flickr_id"]:e["extracted_class_name"] for e in cap}

all_flickrids = list(id_to_cap.keys())
all_captions = [id_to_cap[v] for v in all_flickrids]
caption_embeddings = model.encode(all_captions)

similarity_matrix = util.cos_sim(caption_embeddings, all_embeddings)
similarity_matrix = similarity_matrix.reshape(-1, N_CLS, MAX_DESC_LEN)

In [None]:
similarity_matrix = similarity_matrix * mask

averaged_similarity = similarity_matrix.sum(-1)/mask.sum(-1)

get_accuracy(averaged_similarity)

In [None]:
mask.sum(-1).shape

## Use tags to find the best possible labels: GeoImnet

In [4]:
def get_accuracy(similarity_matrix, gt_label, topk=5):
    pseudoLabel = similarity_matrix.topk(topk, 1).indices  
    gt_label = torch.Tensor(gt_label).view(-1,1).repeat(1,topk)
    matched_labels = (torch.Tensor(pseudoLabel) == torch.Tensor(gt_label))
    top1_acc = matched_labels[:,0].sum()/len(gt_label)
    topk_acc = matched_labels.any(1).sum()/len(gt_label)
    return top1_acc, topk_acc

In [10]:
# for name in ["GeoPlaces", "GeoImnet", "GeoYFCC"]:
for name in ["GeoYFCC"]:
    for source, target in zip(["asia", "usa"],["usa","asia"]):
        data = json.load(open("../metadata/{}.json".format(name.lower())))
        geo = data[f'{target}_train']
        id_to_classid = {ann["image_id"]:ann["category"] for ann in geo["annotations"]}
        
        if name == "GeoImnet":
            classnames = [v['category_name'].replace("+"," ").replace("_"," ") for v in data['categories']]
        elif name == "GeoPlaces":
            classnames = [v['category_name'].replace("_indoor","").replace("_outdoor","").replace("_"," ") for v in data['categories']]
        else:
            classnames = [v['category_name'].replace(",","") for v in data['categories']]
            
        class_embeddings = model.encode(classnames)

        id_to_tags = {m["image_id"]: preprocess(m) for m in geo["metadata"]}
        all_flickrids = list(id_to_tags.keys())
        all_tags = [id_to_tags[f] for f in all_flickrids]
        tag_lens = [len(t) for t in all_tags]
        MAX_TAG_LEN=90#max(tag_lens)

        padded_tags = []
        mask = []
        for p in all_tags:
            pad_len = max(0, MAX_TAG_LEN-len(p))
            padded_tags.append(p[:MAX_TAG_LEN] + ['EOS']*pad_len)
            mask.append([1]*min(MAX_TAG_LEN, len(p)) + [0]*pad_len)
        mask = torch.Tensor(mask)

        flattened_tags = list(chain.from_iterable(padded_tags))

        tag_embedding = model.encode(flattened_tags, batch_size=256)

        similarity = util.cos_sim(tag_embedding, class_embeddings)

        similarity_reshaped = similarity.reshape(*mask.shape, -1)

        mask = mask[...,None]

        masked_similarity = similarity_reshaped * mask

        pseudo_labels = masked_similarity.max(1).values

        gt_label = [id_to_classid[f] for f in all_flickrids]
        
        print("Name: {}, {}->{}".format(name.lower(), source, target))
        print(get_accuracy(pseudo_labels, gt_label))
        
        tag_label = pseudo_labels.argmax(1).cpu().numpy()
        write_str = ""
#         with open("../soft_labels/{}_{}_{}_tagMatchPL.txt".format(name.lower(), source, target), "w") as fh:
#             for pl, fid in tqdm(zip(tag_label, all_flickrids)):
#                 write_str += f"{fid} "
#                 write_str += f"{pl}"
#                 write_str += "\n"
#             fh.write(write_str)

Name: geoyfcc, asia->usa
(tensor(0.9276), tensor(0.9963))
Name: geoyfcc, usa->asia
(tensor(0.9197), tensor(0.9947))


In [None]:
len(id_to_tags)

In [None]:
all_flickrids = list(id_to_tags.keys())

In [None]:
flickrid = all_flickrids[random.sample(range(len(all_flickrids)) ,1)[0]]
print(flickrid)
print(id_to_tags[flickrid])

In [None]:
print(id_to_tags[3560730811])

In [None]:
id_to_metadata = {m["image_id"]: m for m in geo["metadata"]}

In [None]:
id_to_metadata[3560730811]

In [None]:
generate_ngrams(id_to_metadata[3560730811]["caption"], [1,2])

In [None]:
id_to_tags[14031943238]

In [None]:
soft_tag_label = torch.nn.functional.softmax(pseudo_labels/0.07, dim=-1).cpu().numpy()

In [None]:
write_str = ""
with open("../soft_labels/{}_{}_tagPL.txt".format(name, domain), "w") as fh:
    for pl, fid in tqdm(zip(soft_tag_label, all_flickrids)):
        write_str += f"{fid} "
        vals = list(map(str, pl.tolist()))
        write_str += " ".join(vals)
        write_str += "\n"
    fh.write(write_str)

In [None]:
name, domain

In [None]:
import matplotlib.pyplot as plt
plt.hist(tag_lens)