In [None]:
import os
from glob import glob
import yaml
from tqdm import tqdm
import nltk
from nltk.corpus import wordnet
from sentence_transformers import SentenceTransformer, util

from chatcaptioner.utils import RandomSampledDataset, print_info, plot_img

In [None]:
def map_word_to_hypernym(word):
    synsets = wordnet.synsets(word)
    if len(synsets) == 0:
        return word
    else:
        synset = synsets[0]  # Use first synset as default
        hypernyms = synset.hypernyms()
        if len(hypernyms) == 0:
            return word
        else:
            hypernym = hypernyms[0]  # Use first hypernym as default
            return hypernym.lemmas()[0].name()

In [None]:
def is_included(noun1, noun2):
    synsets1 = wordnet.synsets(noun1, pos=wordnet.NOUN)
    synsets2 = wordnet.synsets(noun2, pos=wordnet.NOUN)
    
    for synset1 in synsets1:
        for synset2 in synsets2:
            # Check for similarity score
            similarity_score = synset1.wup_similarity(synset2)
            if similarity_score is not None and similarity_score > 0.9:
                return True
            # Check for inclusion relationship
            if synset1 in synset2.closure(lambda s: s.hyponyms()) \
            or synset2 in synset1.closure(lambda s: s.hyponyms()):
                    return True
    return False

In [None]:
def extract_nouns(text):
    nouns = []
    sentences = nltk.sent_tokenize(text)
    for sentence in sentences:
        words = nltk.word_tokenize(sentence)
        tagged_words = nltk.pos_tag(words)
        for word, tag in tagged_words:
            if tag.startswith('N'):  # Nouns start with 'N' in POS tag
                nouns.append(word)
    return nouns

In [None]:
sentence_model = SentenceTransformer('all-mpnet-base-v2')

In [None]:
DATA_ROOT = 'datasets'
dataset = RandomSampledDataset(DATA_ROOT, 'pascal')

In [None]:
# specify SAVE_PATH to visualize the result you want
SAVE_PATH = 'experiments/test'

In [None]:
def check_cover(gt_objs, cap_objs):
    covered = []
    for gt_obj in gt_objs:
        for obj in cap_objs:
            if obj == 'people':
                obj = 'person'
            if is_included(gt_obj, obj):
                covered.append(gt_obj)
                break
    return len(covered), len(gt_objs)

In [None]:
results_blip2 = []
results_our = []

save_infos = glob(os.path.join(SAVE_PATH, 'pascal', 'caption_result', '*'))
for info_file in tqdm(save_infos):
    with open(info_file, 'r') as f:
        info = yaml.safe_load(f)
    img_id = info['id'] if 'id' in info else info['setting']['id']
    
    blip2 = info['FlanT5 XXL']['BLIP2+OurPrompt']['caption']
    blip2 = extract_nouns(blip2)
    
    our = info['FlanT5 XXL']['ChatCaptioner']['caption']
    our = extract_nouns(our)
    
    gt_objs = []
    gt_objs_tmp = info['setting']['GT']['caption'][0].split('_')
    
    for obj in gt_objs_tmp:
        if ' ' in obj: continue
        gt_objs.append(obj)
        
    results_blip2.append(check_cover(gt_objs, blip2))
    results_our.append(check_cover(gt_objs, our))
    


In [None]:
x, y = 0, 0
for a, b in results_our:
    x += a
    y += b
print(x, y)

In [None]:
x, y = 0, 0
for a, b in results_blip2:
    x += a
    y += b
print(x, y)

In [None]:
with open(info_file, 'r') as f:
    info = yaml.safe_load(f)
img_id = info['id'] if 'id' in info else info['setting']['id']

blip2 = info['FlanT5 XXL']['BLIP2+OurPrompt']['caption']
blip2 = extract_nouns(blip2)

our = info['FlanT5 XXL']['ChatCaptioner']['caption']
our = extract_nouns(our)

gt_objs = []
gt_objs_tmp = info['setting']['GT']['caption'][0].split('_')
    

In [None]:
blip2

In [None]:
our

In [None]:
gt_objs_tmp