In [None]:
import os
import numpy as np
from tqdm import tqdm
from glob import glob
import yaml
import matplotlib
from matplotlib.colors import hsv_to_rgb
import matplotlib.pyplot as plt
import matplotlib.patheffects as PathEffects
import umap
from sentence_transformers import SentenceTransformer, util

from chatcaptioner.utils import extractQA_chatgpt

In [None]:
# specify SAVE_PATH to visualize the result you want
SAVE_PATH = 'experiments/test/'
DATA_ROOT = 'datasets/'
sentence_model = SentenceTransformer('all-mpnet-base-v2')

In [None]:
datasets_list = os.listdir(SAVE_PATH)
datasets_list = ['cc_val']
all_questions = []
effect_q = []
for dataset_name in datasets_list:
    print('============================')
    print('          {}          '.format(dataset_name))
    print('============================')
    
    
    save_infos = glob(os.path.join(SAVE_PATH, dataset_name, 'caption_result', '*'))
    for info_file in save_infos:
        with open(info_file, 'r') as f:
            info = yaml.safe_load(f)
        chat = info['FlanT5 XXL']['ChatCaptioner']['chat']
        if isinstance(chat, str):
            questions = []
            sentences = info['FlanT5 XXL']['ChatCaptioner']['chat'].split('\n')
            for sentence in sentences:
                if 'Question: Describe this image in details.' in sentence: continue
                if 'Question:' in sentence:
                    questions.append(sentence.split('Question:')[-1].strip())
            effect_q.append(len(set(questions)))
            all_questions += questions
        else:
            questions, answers = extractQA_chatgpt(chat)
            effect_q.append(len(set(questions[1:])))
            all_questions += questions[1:]

In [None]:
print('Unique Q/ Total Q: {}/{}'.format(len(set(all_questions)), len(all_questions)))
print('Average Unique Q Per Dialogue: {}'.format(sum(effect_q) / len(effect_q)))

In [None]:
all_embs = []
for question in tqdm(all_questions):
    all_embs.append(sentence_model.encode(question))
all_embs = np.stack(all_embs)

In [None]:
fit = umap.UMAP()
fit_color = umap.UMAP(n_components=1)
%time u = fit.fit_transform(all_embs)
%time c = fit_color.fit_transform(all_embs)
norm_c = (c - c.min())/ (c.max()-c.min())

In [None]:
cmap = matplotlib.colormaps['gnuplot2']

In [None]:
plt.scatter(u[:, 0], u[:, 1], s=8, alpha=0.5, c=norm_c, cmap='gnuplot2')
plt.xlim(6, 21)
plt.ylim(-1, 14)
plt.axis('off')
plt.show()

In [None]:
random_ids = random.sample(range(len(all_questions)), 5)
for q_id in random_ids:
    print('{}: {}'.format(q_id, all_questions[q_id]))

plt.scatter(u[:, 0], u[:, 1], s=1, c=norm_c, cmap='gnuplot2')
plt.xlim(6, 21)
plt.ylim(-1, 14)
for q_id in random_ids:
    plt.text(x=u[q_id, 0], y=u[q_id, 1], s=all_questions[q_id], 
             ha='center', wrap=True, 
             c=cmap(norm_c[q_id])
            )
    txt.set_bbox(dict(facecolor='white', alpha=0.8, edgecolor='white'))
plt.axis('off')
plt.show()