In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import yaml
import torch

from chatcaptioner.chat import set_openai_key, caption_images, get_instructions
from chatcaptioner.blip2 import Blip2
from chatcaptioner.utils import RandomSampledDataset, plot_img, print_info

## Set OpenAI

In [None]:
openai_key = os.environ["OPENAI_API_KEY"]
set_openai_key(openai_key)

## Load BLIP-2

In [None]:
blip2s = {
    'FlanT5 XXL': Blip2('FlanT5 XXL', device_id=0, bit8=True), # load BLIP-2 FlanT5 XXL to GPU0. Too large, need 8 bit. About 20GB GPU Memory
    # 'OPT2.7B COCO': Blip2('OPT2.7B COCO', device_id=1, bit8=False), # load BLIP-2 OPT2.7B COCO to GPU1. About 10GB GPU Memory
    # 'OPT6.7B COCO': Blip2('OPT6.7B COCO', device_id=2, bit8=True), # load BLIP-2 OPT6.7B COCO to GPU2. Too large, need 8 bit.
}
blip2s_q = {}

In [None]:
# blip2s_q = {
#     'FlanT5 XXL': Blip2('FlanT5 XXL', device_id=0, bit8=True), # load BLIP-2 FlanT5 XXL to GPU0. Too large, need 8 bit. About 20GB GPU Memory
#     # 'OPT2.7B': Blip2('OPT2.7B', device_id=1, bit8=False), # load BLIP-2 OPT2.7B COCO to GPU1. About 10GB GPU Memory
#     # 'OPT6.7B': Blip2('OPT6.7B', device_id=2, bit8=True), # load BLIP-2 OPT6.7B COCO to GPU2. Too large, need 8 bit.
# }
# blip2s = {'FlanT5 XXL': blip2s_q['FlanT5 XXL']}

## Test Setting. Change it Accordingly

In [None]:
# set the dataset to test
dataset_name = 'cc_val'  # current options: 'artemis', 'cc_val', 'coco_val'
# set the number of images you want to test
n_test_img = 3
# set the number of chat rounds between GPT3 and BLIP-2
n_rounds = 10
# set the number of visible chat rounds to BLIP-2. <0 means all the chat histories are visible.
n_blip2_context = 1
# if print the chat out in the testing
print_chat = True
# set the question model
question_model_tag = 'gpt-3.5-turbo'

## Load Dataset & Prepare Foloder to Save Results

In [None]:
# load the dataset
DATA_ROOT = 'datasets/'
dataset = RandomSampledDataset(DATA_ROOT, dataset_name)

# preparing the folder to save results
SAVE_PATH = 'experiments/0307_{}/{}'.format(question_model_tag, dataset_name)
if not os.path.exists(SAVE_PATH):
    os.makedirs(os.path.join(SAVE_PATH, 'caption_result'))
with open(os.path.join(SAVE_PATH, 'instruction.yaml'), 'w') as f:
    yaml.dump(get_instructions(), f)

## Start Caption

In [None]:
sample_img_ids = dataset.random_img_ids(n_test_img)

In [None]:
sample_img_ids = ['11627']
if question_model_tag in blip2s_q:
    question_model = blip2s_q[question_model_tag]
else:
    question_model = question_model_tag
caption_images(blip2s, 
               dataset, 
               sample_img_ids, 
               save_path=SAVE_PATH, 
               n_rounds=n_rounds, 
               n_blip2_context=n_blip2_context,
               model=question_model,
               print_mode='chat')