In [1]:
import random
import numpy as np
import os
from os.path import exists
import json
from pprint import pprint as pprint
from typing import List, Optional
import copy
import pickle
from tqdm import tqdm

In [2]:
import torch
import clip
from nltk import data
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = 1000000000

In [3]:
from collections import Counter

In [4]:
# Be modified from https://doi.org/10.18653/v1/2023.acl-long.88 
from nltk.corpus import wordnet as wn
from wiktionaryparser import WiktionaryParser

In [5]:
CLIP_MODEL = "ViT-B/32"  # (ViT-B/32, ViT-L/14 /mnt/model/ViT-B-32.pt)
dictionary_type = 'compensate' # GPT_gen (DG or CADG), compensate (WN+DG or WN+CADG), wordnet (WN)
d_split = 'train'
GPT_def_path = 'text/GPT_Context_Definitions.json' # definition path

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
CLIP_model, preprocess = clip.load(CLIP_MODEL, device=device)

In [19]:
def image_loader(path):
    img_files = os.listdir(path)
    imgs = {}
    for file in tqdm(img_files):
        file_path = os.path.join(path, file)
        img = preprocess(Image.open(file_path)).unsqueeze(0)
        imgs[file] = img
    return imgs

if d_split == 'train':
    image_path = "data/train_v1/train_images_v1"
    data_file_path = "data/train_v1/train.data.v1.txt"
    gold_file_path = "data/train_v1/train.gold.v1.txt"
    image_dict_path = 'temp/img_dict_train.pkl'

if os.path.isfile(image_dict_path):
    img_dict = pickle.load(open(image_dict_path,'rb'))
else:
    img_dict = image_loader(image_path,preprocess)
    pickle.dump(img_dict, open(image_dict_path,'wb'))


In [22]:
class GPT_definitions(object):
    def __init__(self, GPT_def_path):
        temp_dict = json.load(open(GPT_def_path))
        
        GPT_dict = {}
        for key in temp_dict.keys():
            for k in temp_dict[key]:
                 GPT_dict[k] = []
        for key in temp_dict.keys():
            for k in temp_dict[key]:
                 GPT_dict[k].append(temp_dict[key][k])
        self.GPT_dict = GPT_dict
        
    def get_senses(self, target_word):

        return self.GPT_dict[target_word]

In [23]:
class Dictionary_wrapper(object):
    
    def __init__(self):
        self.wn = wn
        self.wiktionary_parser = WiktionaryParser()
        self.GPT_definitions = GPT_definitions(GPT_def_path)
        
    def get_wn_definitions(self, target_word):
        sense_definitions = []
        target_senses = self.wn.synsets(target_word)
        for synset in target_senses:
            sense_definition = synset.definition().split(';')[0]
            sense_definitions.append(sense_definition)
        sense_definitions = list(set(sense_definitions))
        
        return sense_definitions
        
    def get_wiktionary_definitions(self, target_word, lang):
        parser = self.wiktionary_parser
        sense_definitions = []
        
        target_senses = parser.fetch(target_word, lang)
        for synset in target_senses:
            for polysemy in synset['definitions']:
                for sense in polysemy['text'][1:]:
                    sense_definition = sense.split(';')[0]
                sense_definitions.append(sense_definition)
        sense_definitions = list(set(sense_definitions))
        
        return sense_definitions
    
    def get_GPT_definitions(self, target_word):
        return self.GPT_definitions.get_senses(target_word)
    
    def get_definitions(self, target_word, dictionary_type = "wordnet", lang='english'):
        if dictionary_type == 'wordnet':
            sense_definitions = self.get_wn_definitions(target_word)
        elif dictionary_type == 'GPT_gen':
            sense_definitions = self.get_GPT_definitions(target_word)
        elif dictionary_type == 'compensate':
            sense_definitions = self.get_wn_definitions(target_word)
            if len(sense_definitions) == 0:
                sense_definitions += self.get_GPT_definitions(target_word)
        return sense_definitions

In [24]:
dictionary = Dictionary_wrapper()

In [25]:
def data_loader(data_file_path, dictionary, dictionary_type="wordnet", gold_file_path = None):

    def target_word_preprocessing(target_word):
        return target_word

    text_data = {}
    fin_data = open(data_file_path,encoding='utf-8')
    candidate_lens = []
    for data_index, line in tqdm(enumerate(fin_data)):
        line = line.strip()
        if not line: continue

        cols = line.split('\t')
        target_word = cols[0]; target_word = target_word_preprocessing(target_word)
        context = cols[1]
        candidates = cols[2:]


        sense_definitions = dictionary.get_definitions(target_word, dictionary_type)
        wordnet_definitions = dictionary.get_definitions(target_word, 'wordnet')

        text_data[data_index] = {'target_word': target_word,
                                 'sense_definitions': sense_definitions,
                                 'wordnet_definitions': wordnet_definitions,
                                 'context': context,
                                 'candidates': candidates}

        candidate_lens.append(len(candidates))
    fin_data.close()


    if gold_file_path:
        fin_gold = open(gold_file_path)
        for gold_index, line in enumerate(fin_gold):
            line = line.strip()
            if not line: continue

            gold = line
            text_data[gold_index]['gold'] = gold
    print(np.mean(candidate_lens))
    return text_data

In [29]:
def data_loader_gpt3_5(data_file_path, dictionary, dictionary_type="wordnet", gold_file_path = None):
    
    def target_word_preprocessing(target_word):
        return target_word

    text_data = {}
    de = json.load(open('text/gpt3_5_train.json'))
    num = 0

    for i in de:
        if 'definition'  in de[i][0][0]:
            de[i] = [[' ']]
        if len(de[i])!=5:
            for j in range(len(de[i]),5):
                de[i].append([' '])
        for j in range(5):
            if len(de[i][j])==0:
                de[i][j] = [' ']
        num +=  len(de[i])

    for i in de:
        [sense_definition[0] for sense_definition in de[i]]

    print(num/len(de))
    fin_data = open(data_file_path,encoding='utf-8')
    candidate_lens = []

    for data_index, line in tqdm(enumerate(fin_data)):
        line = line.strip()
        if not line: continue

        cols = line.split('\t')
        target_word = cols[0]; target_word = target_word_preprocessing(target_word)
        context = cols[1]
        candidates = cols[2:]

        sense_definitions = de[target_word]
        wordnet_definitions = dictionary.get_definitions(target_word, 'wordnet')
        text_data[data_index] = {'target_word': target_word,
                                 'sense_definitions': sense_definitions,
                                 'wordnet_definitions': wordnet_definitions,
                                 'context': context,
                                 'candidates': candidates}

        candidate_lens.append(len(candidates))

    fin_data.close()
    if gold_file_path:
        fin_gold = open(gold_file_path)
        for gold_index, line in enumerate(fin_gold):
            line = line.strip()
            if not line: continue
            gold = line
            text_data[gold_index]['gold'] = gold
    print(np.mean(candidate_lens))
    return text_data

In [30]:
def data_loader_gpt4(data_file_path, dictionary, dictionary_type="wordnet", gold_file_path = None):

    def target_word_preprocessing(target_word):
        return target_word


    text_data = {}
    de = json.load(open('text/gpt4_train.json'))
    num = 0


    for i in de:
        if len(de[i])!=5:
            for j in range(len(de[i]),5):
                de[i].append([' '])
        for j in range(5):
            if len(de[i][j])==0:
                de[i][j] = [' ']
        if 'definition'  in de[i][0][0] or 'sorry'  in de[i][0][0]:
            de[i] = [[' ']]
        num +=  len(de[i])

    for i in de:
        [sense_definition[0] for sense_definition in de[i]]

    print(num/len(de))
    fin_data = open(data_file_path,encoding='utf-8')
    candidate_lens = []
    for data_index, line in tqdm(enumerate(fin_data)):
        line = line.strip()
        if not line: continue

        cols = line.split('\t')
        target_word = cols[0]; target_word = target_word_preprocessing(target_word)
        context = cols[1]
        candidates = cols[2:]


        if target_word not in de.keys():
            sense_definitions = dictionary.get_definitions(target_word, dictionary_type)
        else:
            sense_definitions = de[target_word]

        wordnet_definitions = dictionary.get_definitions(target_word, 'wordnet')
        text_data[data_index] = {'target_word': target_word,
                                 'sense_definitions': sense_definitions,
                                 'wordnet_definitions': wordnet_definitions,
                                 'context': context,
                                 'candidates': candidates}

        candidate_lens.append(len(candidates))
    fin_data.close()

    if gold_file_path:
        fin_gold = open(gold_file_path)
        for gold_index, line in enumerate(fin_gold):
            line = line.strip()
            if not line: continue
            gold = line
            text_data[gold_index]['gold'] = gold
    print(np.mean(candidate_lens))
    return text_data

In [None]:
import torch.nn.functional as F
class Q_VWSD_QC(object):
    def __init__(self, CLIP_model, CLIP_preprocess):
        self.CLIP_model = CLIP_model;
        self.CLIP_preprocess = CLIP_preprocess

    def code(self, text_data, img_dict, llm= False):

        CLIP_model = self.CLIP_model
        states1 = []
        states2 = []

        for data_index in tqdm(range(len(text_data.keys()))):
            data = text_data[data_index]
            context = data['context']; candidates = data['candidates']
            target_word = data['target_word']
            context = context.replace(target_word, '\"'+target_word+'\"')

            sense_definitions = data['sense_definitions']
            if llm:
                sense_definitions = [context + ' : ' + sense_definition[0] for sense_definition in sense_definitions]
            else:
                sense_definitions = [context + ' : ' + sense_definition for sense_definition in sense_definitions]

            if not len(sense_definitions):
                sense_definitions += [context]


            gold = data['gold']; gold_index = data['candidates'].index(gold)

            with torch.no_grad():
                    context_text = clip.tokenize([context], truncate = True).to(device)
                    definition_text = clip.tokenize(sense_definitions, truncate = True).to(device)

                    images = [img_dict[candidate] for candidate in candidates]
                    images = torch.stack(images).squeeze().to(device)

                    image_features = CLIP_model.encode_image(images)

                    text_features = CLIP_model.encode_text(context_text)
                    def_features = CLIP_model.encode_text(definition_text)

                    text_features = torch.nn.functional.normalize(text_features,p=2,dim=1)
                    def_features = torch.nn.functional.normalize(def_features,p=2,dim=1)

                    cosine_similarity = F.cosine_similarity(text_features, def_features, dim=1).unsqueeze(dim=0)

                    state = torch.matmul(cosine_similarity, def_features)
                    image_features = torch.nn.functional.normalize(image_features,p=2,dim=1)
                    states1.append(state)
                    states2.append(image_features)
        states1 = torch.stack(states1, dim=0)
        states2 = torch.stack(states2, dim=0)
        torch.save(states1, '/mnt/state_gpt4.0_large.pt')
        torch.save(states2, '/mnt/project_large.pt')
VWSD_CLIP = Q_VWSD_QC(CLIP_model, preprocess)
VWSD_CLIP.code(text_data, img_dict,True)


In [14]:
class Q_VWSD_QI(object):
    def __init__(self, CLIP_model, CLIP_preprocess):
        self.CLIP_model = CLIP_model; 
        self.CLIP_preprocess = CLIP_preprocess

    def calculate_lower_triangle_matrix_2spp(self,tensor):

        result_tensor = torch.zeros(int(tensor.shape[0]*(tensor.shape[0]-1)/2), 10)
        index = 0
        for i in range(tensor.shape[0]-1):
            for j in range(i, tensor.shape[0]-1):
                result_tensor[index] = tensor[i,:] * tensor[j,:]
                index += 1

        return result_tensor

    def calculate_lower_triangle_matrix_cos(self,vector_tensor):

        cosine_similarity_matrix = torch.nn.functional.cosine_similarity(vector_tensor.unsqueeze(1), vector_tensor.unsqueeze(0), dim=2)
        mask = torch.tril(torch.ones_like(cosine_similarity_matrix), diagonal=-1)
        matrix_no_diagonal = cosine_similarity_matrix * mask
        lower_triangle_tensor = matrix_no_diagonal[mask == 1]
        return lower_triangle_tensor

    def evaluate_posterior(self, text_data, img_dict):
        CLIP_model = self.CLIP_model
        preds = []
        golds = []
        answers = []
        partial_answers = []
        for data_index in tqdm(text_data.keys()):
            data = text_data[data_index]
            context = data['context']; candidates = data['candidates']
            target_word = data['target_word']
            context = context.replace(target_word, '\"'+target_word+'\"')
            gold = data['gold']; gold_index = data['candidates'].index(gold)
            text = clip.tokenize([context]).to(device)
            with torch.no_grad():
                images = [img_dict[candidate] for candidate in candidates]
                images = torch.stack(images).squeeze().to(device)

                logits_per_image, logits_per_text = CLIP_model(images, text)
                probs = logits_per_text.softmax(dim=-1).cpu().numpy()
                pred = np.argmax(probs[0])

                preds.append(data['candidates'][pred])
                golds.append(gold)
                if pred == gold_index:
                    answers.append(1)
                else:
                    answers.append(0)

                sorted_indexes = reversed(np.argsort(probs[0]))

                i = 1
                for index in sorted_indexes:
                    if index == gold_index:
                        partial_answers.append(1/i)
                        break
                    i+=1
        return preds, golds, answers, partial_answers


    def evaluate_bayesian_posterior(self, text_data, img_dict, llm=False):
        CLIP_model = self.CLIP_model
        preds = []
        golds = []
        answers = []
        partial_answers = []
        for data_index in tqdm(range(len(text_data.keys()))):
            data = text_data[data_index]
            context = data['context']; candidates = data['candidates']
            target_word = data['target_word']
            context = context.replace(target_word, '\"'+target_word+'\"')

            sense_definitions = data['sense_definitions']
            if llm:
                sense_definitions = [context + ' : ' + sense_definition[0] for sense_definition in sense_definitions]
            else:
                sense_definitions = [context + ' : ' + sense_definition for sense_definition in sense_definitions]

            if not len(sense_definitions):
                sense_definitions += [context]
            gold = data['gold']; gold_index = data['candidates'].index(gold)
            with torch.no_grad():
                context_text = clip.tokenize([context], truncate = True).to(device)
                definition_text = clip.tokenize(sense_definitions, truncate = True).to(device)

                images = [img_dict[candidate] for candidate in candidates]
                images = torch.stack(images).squeeze().to(device)

                text_features = CLIP_model.encode_text(context_text)
                def_features = CLIP_model.encode_text(definition_text)
                logits_per_definition = torch.matmul(text_features, def_features.T)
                prob_dist_definitions = logits_per_definition.softmax(dim=-1)

                logits_per_image, logits_per_text = CLIP_model(images, definition_text)
                probs_per_text = logits_per_text.softmax(dim=-1)
                bayesian_probs = torch.matmul(prob_dist_definitions, probs_per_text).cpu().numpy()
                pred = np.argmax(bayesian_probs)
                sorted_indexes = reversed(np.argsort(bayesian_probs[0]))

                i = 1
                for index in sorted_indexes:
                    if index == gold_index:
                        partial_answers.append(1/i)
                        break
                    i+=1

                preds.append(data['candidates'][pred])
                golds.append(gold)
                if pred == gold_index:
                    answers.append(1)
                else:
                    answers.append(0)
        return preds, golds, answers, partial_answers

    def evaluate_QI_posterior(self, text_data, img_dict, llm= False):

        CLIP_model = self.CLIP_model
        preds = []
        golds = []
        answers = []
        partial_answers = []
        for data_index in tqdm(range(len(text_data.keys()))):
            data = text_data[data_index]
            context = data['context']; candidates = data['candidates']
            target_word = data['target_word']
            context = context.replace(target_word, '\"'+target_word+'\"')

            sense_definitions = data['sense_definitions']
            if llm:
                sense_definitions = [context + ' : ' + sense_definition[0] for sense_definition in sense_definitions]
            else:
                sense_definitions = [context + ' : ' + sense_definition for sense_definition in sense_definitions]

            if not len(sense_definitions):
                sense_definitions += [context]


            gold = data['gold']; gold_index = data['candidates'].index(gold)

            with torch.no_grad():
                context_text = clip.tokenize([context], truncate = True).to(device)
                definition_text = clip.tokenize(sense_definitions, truncate = True).to(device)

                images = [img_dict[candidate] for candidate in candidates]
                images = torch.stack(images).squeeze().to(device)

                image_features = CLIP_model.encode_image(images)

                text_features = CLIP_model.encode_text(context_text)
                def_features = CLIP_model.encode_text(definition_text)

                text_features = torch.nn.functional.normalize(text_features,p=2,dim=1)
                def_features = torch.nn.functional.normalize(def_features,p=2,dim=1)

                logits_per_definition = torch.matmul(text_features, def_features.T)
                prob_dist_definitions =  logits_per_definition**2
                sum_prob_dist_definitions = torch.sum(prob_dist_definitions,dim=1)
                prob_dist_definitions = prob_dist_definitions / sum_prob_dist_definitions

                image_features = torch.nn.functional.normalize(image_features,p=2,dim=1)
                logits_per_text = torch.matmul(def_features, image_features.T)
                probs_per_text = logits_per_text**2

                cos = self.calculate_lower_triangle_matrix_cos(def_features).unsqueeze(1).cuda()
                sp1p2 = torch.sqrt(prob_dist_definitions.T) * torch.sqrt(probs_per_text).cuda()
                sp1p2p3p4 = torch.sum(2*self.calculate_lower_triangle_matrix_2spp(sp1p2).cuda()*cos,dim=0)
                qbayesian_probs = (torch.matmul(prob_dist_definitions, probs_per_text)+sp1p2p3p4).cpu().numpy()

                pred = np.argmax(qbayesian_probs)

                sorted_indexes = reversed(np.argsort(qbayesian_probs[0]))
                i = 1
                for index in sorted_indexes:
                    if index == gold_index:
                        partial_answers.append(1/i)
                        break
                    i+=1

                preds.append(data['candidates'][pred])
                golds.append(gold)
                if pred == gold_index:
                    answers.append(1)
                else:
                    answers.append(0)


        return preds, golds, answers, partial_answers

In [None]:
text_data = data_loader_gpt3_5(data_file_path,
                        dictionary,
                        dictionary_type,
                        gold_file_path=gold_file_path)
VWSD_CLIP = Q_VWSD_QI(CLIP_model, preprocess)
p_preds, p_golds, p_answers, p_partial_answers = VWSD_CLIP.evaluate_posterior(text_data, img_dict)
print("Accuracy:", "%.2f" % (np.mean(p_answers) * 100))
print("MRR:", "%.2f" % (np.mean(p_partial_answers) * 100))

In [None]:
index = 0
pb_sense_nums_w = []
pb_sense_nums_r = []
for t, p, g in zip(text_data, p_preds, p_golds):
    if p != g:
        #print(t, text_data[t]['context'], p, g, len(text_data[t]['sense_definitions']))
        pb_sense_nums_w.append(len(text_data[t]['wordnet_definitions']))
    else:
        pb_sense_nums_r.append(len(text_data[t]['wordnet_definitions']))
    index += 1
print(sorted(Counter(pb_sense_nums_w).items()))
print(sorted(Counter(pb_sense_nums_r).items()))

right_when_zero = sorted(Counter(pb_sense_nums_r).items())[0][1]
wrong_when_zero = sorted(Counter(pb_sense_nums_w).items())[0][1]

right_when_one = sorted(Counter(pb_sense_nums_r).items())[1][1]
wrong_when_one = sorted(Counter(pb_sense_nums_w).items())[1][1]

right_when_over_one = 0
wrong_when_over_one = 0

for s, c in sorted(Counter(pb_sense_nums_w).items()):
    if s > 1: wrong_when_over_one += c
for s, c in sorted(Counter(pb_sense_nums_r).items()):
    if s > 1: right_when_over_one += c

print('Hits@1 |D^t|==0: %.2f' % (right_when_zero / (right_when_zero + wrong_when_zero) * 100))
print('Hits@1 |D^t|==1: %.2f' % (right_when_one / (right_when_one + wrong_when_one) * 100))
print('Hits@1 |D^t|>1: %.2f' % (right_when_over_one / (right_when_over_one + wrong_when_over_one) * 100))

In [None]:
text_data = data_loader_gpt4(data_file_path,
                        dictionary,
                        dictionary_type,
                        gold_file_path = gold_file_path)
VWSD_CLIP = Q_VWSD_QI(CLIP_model, preprocess)
p_preds, p_golds, p_answers, p_partial_answers =  VWSD_CLIP.evaluate_QI_posterior(text_data, img_dict,True)
print("Accuracy:", "%.2f"%(np.mean(p_answers)*100))
print("MRR:", "%.2f"%(np.mean(p_partial_answers)*100))

In [None]:
index = 0
pb_sense_nums_w = []
pb_sense_nums_r = []
for t, p, g in zip(text_data, p_preds, p_golds):
    if p != g:
        #print(t, text_data[t]['context'], p, g, len(text_data[t]['sense_definitions']))
        pb_sense_nums_w.append(len(text_data[t]['wordnet_definitions']))
    else:
        pb_sense_nums_r.append(len(text_data[t]['wordnet_definitions']))
    index+=1
print(sorted(Counter(pb_sense_nums_w).items()))
print(sorted(Counter(pb_sense_nums_r).items()))

right_when_zero = sorted(Counter(pb_sense_nums_r).items())[0][1]
wrong_when_zero = sorted(Counter(pb_sense_nums_w).items())[0][1]

right_when_one = sorted(Counter(pb_sense_nums_r).items())[1][1]
wrong_when_one = sorted(Counter(pb_sense_nums_w).items())[1][1]

right_when_over_one = 0
wrong_when_over_one = 0

for s, c in sorted(Counter(pb_sense_nums_w).items()):
     if s > 1: wrong_when_over_one += c
for s, c in sorted(Counter(pb_sense_nums_r).items()):
     if s > 1: right_when_over_one += c

print('Hits@1 |D^t|==0: %.2f'%(right_when_zero/(right_when_zero + wrong_when_zero)*100))
print('Hits@1 |D^t|==1: %.2f'%(right_when_one/(right_when_one + wrong_when_one)*100))
print('Hits@1 |D^t|>1: %.2f'%(right_when_over_one/(right_when_over_one + wrong_when_over_one)*100))