In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from glob import glob
import csv
import yaml
from chatcaptioner.chat import get_chat_log
from chatcaptioner.blip2 import Blip2
from chatcaptioner.utils import print_info, plot_img, extractQA_chatgpt, RandomSampledDataset

In [None]:
# specify SAVE_PATH to visualize the result you want
SAVE_PATH = 'experiments/testV4_chatgpt/'
DATA_ROOT = 'datasets/'

In [None]:
blip2 = Blip2('FlanT5 XXL', device_id=0, bit8=True)

In [None]:
datasets_list = os.listdir(SAVE_PATH)
datasets_list = ['cc_val']
uncertainty_list = []

for dataset_name in datasets_list:
    print('============================')
    print('          {}          '.format(dataset_name))
    print('============================')
    dataset = RandomSampledDataset(DATA_ROOT, dataset_name)
    
    save_infos = glob(os.path.join(SAVE_PATH, dataset_name, 'caption_result', '*'))
    for info_file in save_infos:
        with open(info_file, 'r') as f:
            info = yaml.safe_load(f)
        img_id = info['id'] if 'id' in info else info['setting']['id']
        test_img, _ = dataset.fetch_img(img_id)
        
        chat = info['FlanT5 XXL']['ChatCaptioner']['chat']
        questions, answers = extractQA_chatgpt(chat)
        not_sure = False
        for q, a in zip(questions, answers):
            if 'sure' in a or 'know' in a:
                not_sure = True
                print('Question: {}'.format(q))
                print('Answer: {}'.format(a))
                uncertainty_list.append((img_id, q, a))
        if not_sure:
            plot_img(test_img)
        

In [None]:
uncertainty_dict = {}
for img_id, q, a in uncertainty_list:
    if img_id not in uncertainty_dict:
        uncertainty_dict[img_id] = [q]
    else:
        uncertainty_dict[img_id].append(q)
with open(os.path.join('not_sure.yaml'), 'w') as f:
    yaml.dump(uncertainty_dict, f)

In [None]:
uncertainty_dict.keys()

In [None]:
uncertainty_dict['13778']

In [None]:
info_file = os.path.join(SAVE_PATH, dataset_name, 'caption_result', '13276.yaml')
with open(info_file, 'r') as f:
    info = yaml.safe_load(f)

In [None]:
questions, orig_answers = extractQA_chatgpt(info['FlanT5 XXL']['ChatCaptioner']['chat'])

In [None]:
questions

In [None]:
orig_answers

In [None]:
ANSWER_INSTRUCTION = 'Answer given questions. If you are not sure about the answer, say you don\'t know honestly. Don\'t imagine any contents that are not in the image.'
ANSWER_INSTRUCTION = 'Answer given questions. Don\'t imagine any contents that are not in the image.'
SUB_ANSWER_INSTRUCTION = 'Answer: '  # template following blip2 huggingface demo

In [None]:
answers = []
for i in range(len(questions)):
    print('Question: {}'.format(questions[i]))
    blip2_prompt = '\n'.join([ANSWER_INSTRUCTION, 
                              get_chat_log(questions[:i+1], answers, last_n=1), 
                              SUB_ANSWER_INSTRUCTION])    
    answer = blip2.ask(test_img, blip2_prompt)
    answer = answer.split('Question:')[0].replace('\n', ' ').strip()
    print('Answer: {}'.format(answer))
    answers.append(answer)

In [None]:
results = {}

# Open the CSV file for reading
with open('h_uncertain.csv', 'r') as csvfile:
    # Create a CSV reader object
    csvreader = csv.DictReader(csvfile)
    
    # Iterate over each row in the CSV file
    for row in csvreader:
        # Access the values in the row by index
        img_id = row['Input.image_id']
        question = row['Input.question']
        tag = img_id + '_' + question
        answer = row['Answer.summary']
        
        if tag not in results:
            results[tag] = [answer]
        else:
            results[tag].append(answer)

In [None]:
len(results)

In [None]:
uncertainQ = []
certainQ = []
for tag, answers in results.items():
    n_none = 0
    for answer in answers:
        if 'none' in answer.lower():
            n_none += 1
    if n_none >= 2:
        uncertainQ.append(tag)
    else:
        certainQ.append([tag, answers])

In [None]:
certain_img = {}
for tag, h_answers in certainQ:
    img_id, question = tag.split('_')
    if img_id in certain_img:
        certain_img[img_id][question] = h_answers
    else:
        certain_img[img_id] = {question: h_answers}
        

In [None]:
len(certainQ)

In [None]:
ANSWER_INSTRUCTION = 'Answer given questions. Don\'t imagine any contents that are not in the image.'
SUB_ANSWER_INSTRUCTION = 'Answer: '  # template following blip2 huggingface demo

for img_id, c_questions in certain_img.items():
    with open(os.path.join(SAVE_PATH, dataset_name, 'caption_result', '{}.yaml'.format(img_id)), 'r') as f:
        info = yaml.safe_load(f)
    test_img, _ = dataset.fetch_img(img_id)

    chat = info['FlanT5 XXL']['ChatCaptioner']['chat']
    questions, _ = extractQA_chatgpt(chat)

    answers = []
    for i in range(len(questions)):
        if questions[i] in c_questions:
            print('?????????????????')
        print('Question: {}'.format(questions[i]))
        blip2_prompt = '\n'.join([ANSWER_INSTRUCTION, 
                                  get_chat_log(questions[:i+1], answers, last_n=1), 
                                  SUB_ANSWER_INSTRUCTION])    
        answer = blip2.ask(test_img, blip2_prompt)
        answer = answer.split('Question:')[0].replace('\n', ' ').strip()
        answers.append(answer)
        print('Answer: {}'.format(answer))
        if questions[i] in c_questions:
            for h_answer in c_questions[questions[i]]:
                print('Human: {}'.format(h_answer))
            print('!!!!!!!!!!!!!!!!!!!!')
    plot_img(test_img)