In [1]:
import json
import torch
import numpy as np

from sentence_transformers import SentenceTransformer, util
import random
from tqdm.notebook import tqdm
from itertools import chain
# import os
%config Completer.use_jedi = False

from nltk.util import ngrams

In [2]:
model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1').cuda()

In [3]:
def generate_ngrams(text, n):
    """
    Function to generate all ngrams for a given 'n' from a string.

    :param text: The input string from which to generate ngrams.
    :param n: The size of the ngram.
    :return: A list of ngrams as strings.
    """
    if isinstance(n, list):
        ngram_list = []
        for ni in n:
            ngram_list += generate_ngrams(text, ni)
        return ngram_list
    # Split the text into words
    words = text.split()

    # Generate ngrams
    ngram_list = [' '.join(gram) for gram in ngrams(words, n)]

    return ngram_list

## DomainNet

In [4]:
def get_accuracy(similarity_matrix, gt_label, topk=5):
    pseudoLabel = similarity_matrix.topk(topk, 1).indices  
    gt_label = torch.Tensor(gt_label).view(-1,1).repeat(1,topk)
    matched_labels = (torch.Tensor(pseudoLabel) == torch.Tensor(gt_label))
    top1_acc = matched_labels[:,0].sum()/len(gt_label)
    topk_acc = matched_labels.any(1).sum()/len(gt_label)
    return top1_acc, topk_acc

In [5]:
name = "officeHome"

# for domain in ["clipart", "sketch", "painting"]:
for domain in ['art', 'product', 'real', 'clipart']:

    data = json.load(open("../metadata/{}.json".format(name.lower())))
    geo = data[f'{domain}_train']
    id_to_classid = {ann["image_id"]:ann["category"] for ann in geo["annotations"]}

    # classnames = [v['category_name'].replace("+"," ").replace("_"," ") for v in data['categories']]
    # classnames = [v['category_name'].replace("_indoor","").replace("_outdoor","").replace("_"," ") for v in data['categories']]
    classnames = [v['category_name'].replace("_"," ") for v in data['categories']]
    class_embeddings = model.encode(classnames)

    id_to_tags = {m["image_id"]: generate_ngrams(m["blip2_cap"], [1,2]) for m in geo["metadata"]}
    all_flickrids = list(id_to_tags.keys())
    all_tags = [id_to_tags[f] for f in all_flickrids]
    tag_lens = [len(t) for t in all_tags]
    MAX_TAG_LEN=40#max(tag_lens)

    padded_tags = []
    mask = []
    for p in all_tags:
        pad_len = max(0, MAX_TAG_LEN-len(p))
        padded_tags.append(p[:MAX_TAG_LEN] + ['EOS']*pad_len)
        mask.append([1]*min(MAX_TAG_LEN, len(p)) + [0]*pad_len)
    mask = torch.Tensor(mask)

    # mask.shape

    flattened_tags = list(chain.from_iterable(padded_tags))

    len(flattened_tags)

    tag_embedding = model.encode(flattened_tags, batch_size=256)

    # tag_embedding.shape

    similarity = util.cos_sim(tag_embedding, class_embeddings)

    # similarity.shape

    similarity_reshaped = similarity.reshape(*mask.shape, -1)

    mask = mask[...,None]

    masked_similarity = similarity_reshaped * mask

    pseudo_labels = masked_similarity.max(1).values

    gt_label = [id_to_classid[f] for f in all_flickrids]

    print(get_accuracy(pseudo_labels, gt_label))

    tag_label = pseudo_labels.argmax(1).cpu().numpy()

    write_str = ""
    with open("../hard_labels/{}_{}_{}_tagMatchPL.txt".format(name, domain, domain), "w") as fh:
        for pl, fid in tqdm(zip(tag_label, all_flickrids)):
            write_str += f"{fid} "
            write_str += f"{pl}"
            write_str += "\n"
        fh.write(write_str)

(tensor(0.7702), tensor(0.9054))


0it [00:00, ?it/s]

(tensor(0.8671), tensor(0.9592))


0it [00:00, ?it/s]

(tensor(0.8021), tensor(0.9496))


0it [00:00, ?it/s]

(tensor(0.7238), tensor(0.8838))


0it [00:00, ?it/s]

In [7]:
for src in ['art', 'product', 'real', 'clipart']:
    for tgt in ['art', 'product', 'real', 'clipart']:
        if src == tgt:
            continue
        print(f"cp officeHome_{tgt}_{tgt}_tagMatchPL.txt officeHome_{src}_{tgt}_tagMatchPL.txt")

cp officeHome_product_product_tagMatchPL.txt officeHome_art_product_tagMatchPL.txt
cp officeHome_real_real_tagMatchPL.txt officeHome_art_real_tagMatchPL.txt
cp officeHome_clipart_clipart_tagMatchPL.txt officeHome_art_clipart_tagMatchPL.txt
cp officeHome_art_art_tagMatchPL.txt officeHome_product_art_tagMatchPL.txt
cp officeHome_real_real_tagMatchPL.txt officeHome_product_real_tagMatchPL.txt
cp officeHome_clipart_clipart_tagMatchPL.txt officeHome_product_clipart_tagMatchPL.txt
cp officeHome_art_art_tagMatchPL.txt officeHome_real_art_tagMatchPL.txt
cp officeHome_product_product_tagMatchPL.txt officeHome_real_product_tagMatchPL.txt
cp officeHome_clipart_clipart_tagMatchPL.txt officeHome_real_clipart_tagMatchPL.txt
cp officeHome_art_art_tagMatchPL.txt officeHome_clipart_art_tagMatchPL.txt
cp officeHome_product_product_tagMatchPL.txt officeHome_clipart_product_tagMatchPL.txt
cp officeHome_real_real_tagMatchPL.txt officeHome_clipart_real_tagMatchPL.txt


In [None]:
soft_tag_label = torch.nn.functional.softmax(pseudo_labels/0.07, dim=-1).cpu().numpy()

In [None]:
write_str = ""
with open("../soft_labels/{}_{}_tagPL.txt".format(name, domain), "w") as fh:
    for pl, fid in tqdm(zip(soft_tag_label, all_flickrids)):
        write_str += f"{fid} "
        vals = list(map(str, pl.tolist()))
        write_str += " ".join(vals)
        write_str += "\n"
    fh.write(write_str)

In [None]:
name, domain

In [None]:
import matplotlib.pyplot as plt
plt.hist(tag_lens)