In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from glob import glob
import yaml

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

from chatcaptioner.utils import print_info, plot_img, extractQA_chatgpt, RandomSampledDataset

In [None]:
def split_sentence(sentence, max_len=38):
    if len(sentence) < max_len:
        return sentence, 1
    words = sentence.split(' ')
    sub_sentence_list = []
    init = ''
    
    for word in words:
        tmp_init = init + ' ' + word
        if len(tmp_init) > max_len:
            sub_sentence_list.append(init)
            init = word
        else:
            init = tmp_init
    sub_sentence_list.append(init)
    
    return '\n'.join(sub_sentence_list), len(sub_sentence_list)
    

def plot_dialogue(lefts, rights, xs=[0.1, 0.7], init_y=1, y_gap=0.07, line_h=0.045):
    cdict = {'left': '#ecf5e6', 'right': '#e7f0fd'}
    
    def plot_text(x, y, s, pos):
        plt.text(
            x=x, y=y, s=s, 
            horizontalalignment=pos,
            multialignment='left',
            verticalalignment='top',
            bbox=dict(boxstyle='round', 
                      fc=cdict[pos], 
                      ec=cdict[pos], 
                      ))
    
    cur_y = init_y
    for l, r in zip(lefts, rights):
        l, n_lines = split_sentence(l)
        plot_text(x=xs[0], y=cur_y, s=l, pos='left')
        cur_y -= y_gap + line_h * (n_lines-1)
        
        r, n_lines = split_sentence(r)
        plot_text(x=xs[1], y=cur_y, s=r, pos='right')
        cur_y -= y_gap + line_h * (n_lines-1)
        
    return cur_y
        
def plot_summary(summary, x, y, max_len=43):
    summary, n_lines = split_sentence(summary, max_len)
    plt.text(
            x=x, y=y, s=summary, 
            horizontalalignment='center',
            multialignment='left',
            verticalalignment='top',
            bbox=dict(boxstyle='round', 
                      fc='#ffe5b5', 
                      ec='#ffe5b5', 
                      ))

        
def fancy_plot(img, questions, answers, summary, xs=[0, 1], init_y=1):
    ax = plt.gca()
    w, h = test_img.size
    img = img.resize([int(256/h*w), 256])
    # plt.xlim(*xs)
    
    imagebox = OffsetImage(img, zoom=0.5)
    ab = AnnotationBbox(imagebox, ((xs[1] + xs[0]) / 2, init_y), frameon=False, box_alignment=(0.5, 0))
    ax.add_artist(ab)
    
    y = init_y - 0.03
    y = plot_dialogue(questions, answers, xs=xs, init_y=y)
    
    y = y - 0.01
    plot_summary(summary, (xs[1] + xs[0]) / 2, y)
    
    plt.axis('off')
    
    
    

In [None]:
# specify SAVE_PATH to visualize the result you want
SAVE_PATH = 'experiments/test/'
DATA_ROOT = 'datasets/'

In [None]:
datasets_list = os.listdir(SAVE_PATH)
datasets_list = ['artemis', 'coco_val']
for dataset_name in datasets_list:
    print('============================')
    print('          {}          '.format(dataset_name))
    print('============================')
    fig_path = 'figs/testV4_chatgpt/{}'.format(dataset_name)
    os.makedirs(fig_path, exist_ok=True)
    
    dataset = RandomSampledDataset(DATA_ROOT, dataset_name)
    
    save_infos = glob(os.path.join(SAVE_PATH, dataset_name, 'caption_result', '*'))
    for info_file in save_infos:
        with open(info_file, 'r') as f:
            info = yaml.safe_load(f)
        
            
        img_id = info['id'] if 'id' in info else info['setting']['id']
        test_img, _ = dataset.fetch_img(img_id)
        
        questions, answers = extractQA_chatgpt(info['FlanT5 XXL']['ChatCaptioner']['chat'])
        summary = info['FlanT5 XXL']['ChatCaptioner']['caption']
        fancy_plot(test_img, questions, answers, summary)
        plt.gca().set_aspect(1.3)
        plt.savefig(os.path.join(fig_path, '{}.pdf'.format(img_id)), bbox_inches='tight')
        plt.close()