In [1]:
import pandas as pd
import faiss
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import pickle
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import random
import json
import re
import requests, io


os.environ["TOKENIZERS_PARALLELISM"] = "false"


def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything(42)


In [2]:

torch.cuda.set_device(0)

In [3]:
class ImageEmbedder:
    def __init__(self, model, preprocessor):
        """ model projects image to vector, processor load and prepare image to the model"""
        self.model = model
        self.processor = preprocessor

def BLIP_BASELINE():
    from torchvision import transforms
    from torchvision.transforms.functional import InterpolationMode

    import sys
    sys.path.insert(0, './BLIP')
    from BLIP.models.blip_itm import blip_itm
    # load model
    model = blip_itm(pretrained='./BLIP/chatir_weights.ckpt',  # Download from Google Drive, see README.md
                     med_config='BLIP/configs/med_config.json',
                     image_size=224,
                     vit='base'
                     )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device).eval()

    # define Image Embedder (raw_image --> img_feature)
    transform_test = transforms.Compose([
        transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    ])

    def blip_project_img(image):
        embeds = model.visual_encoder(image)
        projection = model.vision_proj(embeds[:, 0, :])
        return F.normalize(projection, dim=-1)

    def blip_prep_image(path):
        raw = Image.open(path).convert('RGB')
        return transform_test(raw)

    image_embedder = ImageEmbedder(blip_project_img, lambda path: blip_prep_image(path))

    # define dialog encoder (dialog --> img_feature)
    def dialog_encoder(dialog):
        text = model.tokenizer(dialog, padding='longest', truncation=True,
                               max_length=200,
                               return_tensors="pt"
                               ).to(device)

        text_output = model.text_encoder(text.input_ids, attention_mask=text.attention_mask,
                                         return_dict=True, mode='text')

        shift = model.text_proj(text_output.last_hidden_state[:, 0, :])
        return F.normalize(shift, dim=-1)

    return dialog_encoder, image_embedder

class TextDataset(torch.utils.data.Dataset):
    def __init__(self, texts: list):
        """
        Args:
            texts (list of str): List of text strings.
            processor (transformers processor): Processor to tokenize the text.
        """
        self.texts = texts
        # self.processor = processor

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]  # Get the text at the provided index
        return {'text': text}



def encode_text(dataset, model):
    """CLIP for encode text """
    # model.eval()
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=100,
                                              pin_memory=True,
                                              num_workers=4,
                                              prefetch_factor=2,
                                              shuffle=False,
                                              
                                              )
    all_features = []
    with torch.no_grad():
        for batch in data_loader:
            features = model(batch['text'])
            all_features.append(features.cpu())  
    return torch.cat(all_features)  

def retrieve_topk_images(query: list,
                         topk=10,
                         faiss_model=None,
                         text_encoder=None,
                         id2image=None,
                         ):
    text_dataset = TextDataset(query)
    query_vec = encode_text(text_dataset, text_encoder)
    query_vec = query_vec.numpy()
    query_vec /= np.linalg.norm(query_vec, axis=1, keepdims=True)

    distance, indices = faiss_model.search(query_vec, topk)
    indices = np.array(indices)
    image_paths = [[id2image.get(idx, 'path/not/found') for idx in row] for row in indices]
    return image_paths, indices


In [None]:
# data = pd.read_csv('recall_res.csv')
faiss_model = faiss.read_index('./checkpoints/blip_faiss.index')
with open('./checkpoints/id2image.pickle', 'rb') as f:
    id2image = pickle.load(f)
    
with open('./checkpoints/blip_image_embedding.pickle', 'rb') as f:
    image_vector = pickle.load(f)
dialog_encoder, image_embedder = BLIP_BASELINE()

### 2. load llava question model

In [9]:
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
from safetensors import safe_open
from pllava import PllavaProcessor, PllavaForConditionalGeneration, PllavaConfig
from accelerate import init_empty_weights, dispatch_model, infer_auto_device_map, load_checkpoint_in_model
from accelerate.utils import get_balanced_memory


def load_pllava(repo_id, num_frames, use_lora=False, weight_dir=None, lora_alpha=32, use_multi_gpus=False, pooling_shape=(16,12,12)):
    kwargs = {
        'num_frames': num_frames,
    }
    # print("===============>pooling_shape", pooling_shape)
    if num_frames == 0:
        kwargs.update(pooling_shape=(0,12,12)) # produce a bug if ever usen the pooling projector
    config = PllavaConfig.from_pretrained(
        repo_id if not use_lora else weight_dir,
        pooling_shape=pooling_shape,
        **kwargs,
    )

    
    with torch.no_grad():
        model = PllavaForConditionalGeneration.from_pretrained(repo_id, config=config, torch_dtype=torch.bfloat16)
        
    try:
        processor = PllavaProcessor.from_pretrained(repo_id)
    except Exception as e:
        processor = PllavaProcessor.from_pretrained('llava-hf/llava-1.5-7b-hf')

    # config lora
    if use_lora and weight_dir is not None:
        print("Use lora")
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM, inference_mode=False,  target_modules=["q_proj", "v_proj"],
            r=128, lora_alpha=lora_alpha, lora_dropout=0.
        )
        print("Lora Scaling:", lora_alpha/128)
        model.language_model = get_peft_model(model.language_model, peft_config)
        assert weight_dir is not None, "pass a folder to your lora weight"
        print("Finish use lora")
    
    # load weights
    if weight_dir is not None:
        state_dict = {}
        save_fnames = os.listdir(weight_dir)
        if "model.safetensors" in save_fnames:
            use_full = False
            for fn in save_fnames:
                if fn.startswith('model-0'):
                    use_full=True        
                    break
        else:
            use_full= True

        if not use_full:
            print("Loading weight from", weight_dir, "model.safetensors")
            with safe_open(f"{weight_dir}/model.safetensors", framework="pt", device="cpu") as f:
                for k in f.keys():
                    state_dict[k] = f.get_tensor(k)
        else:
            print("Loading weight from", weight_dir)
            for fn in save_fnames:
                if fn.startswith('model-0'):
                    with safe_open(f"{weight_dir}/{fn}", framework="pt", device="cpu") as f:
                        for k in f.keys():
                            state_dict[k] = f.get_tensor(k)
            
        if 'model' in state_dict.keys():
            msg = model.load_state_dict(state_dict['model'], strict=False)
        else:
            msg = model.load_state_dict(state_dict, strict=False)
        print(msg)
    # dispatch model weight
    if use_multi_gpus:
        max_memory = get_balanced_memory(
            model,
            max_memory=None,
            no_split_module_classes=["LlamaDecoderLayer"],
            dtype='bfloat16',
            low_zero=False,
        )

        device_map = infer_auto_device_map(
            model,
            max_memory=max_memory,
            no_split_module_classes=["LlamaDecoderLayer"],
            dtype='bfloat16'
        )

        dispatch_model(model, device_map=device_map)
        print(model.hf_device_map)

    model = model.eval()

    return model, processor


In [None]:
# load model
llava_model, llava_processor = load_pllava(repo_id='/root/autodl-tmp/.autodl/HYF/questionIR/CSS/MODELS/pllava-7b', #'llava-hf/llava-1.5-7b-hf', 
            num_frames=5, # num_images = 5
            use_lora=True, 
            weight_dir='/root/autodl-tmp/.autodl/HYF/questionIR/CSS/MODELS/pllava-7b', 
            lora_alpha=4, 
            use_multi_gpus=False, 
            pooling_shape=(5,12,12)
            )
llava_model.to('cuda')

In [9]:
def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(io.BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image

def generate_question(option: str, 
                     images_path: list,
                     model=None,
                     processor=None):
    """
    Params:
       option: list: user's description 
       images_path: list: retrieve images based on the option via BLIP
       model: llave model
       processor: llave processor
       
    Return:
        return the question
    """

    prompt = f"""You are Pllava, a large vision-language assistant. 
    You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language.
    Follow the instructions carefully and explain your answers in detail based on the provided video.
    USER:<image>  USER: You need to find a common object that appears in all 5 pictures but has distinguishing features. Based on this object, ask a question to differentiate the pictures.

                Remember, you must ensure the question is specific, not abstract, and the answer should be directly obtainable by looking at the images.

                For example:
                Example 1: All 5 pictures have people, but the number of people differs. You can ask about the number of people.
                Example 2: All 5 pictures have cats, but the colors are different. You can ask about the color.
                Example 3: All 5 pictures have traffic lights, but their positions differ. You can ask about the position of the traffic lights.

                Ask a specific question based on the object that will help distinguish the pictures. The question is not overlapped with the description: {option}.
                Don't ask 2 questions each time. such as what is the attribute of a or b

                Output as the following format
                {{
                "What is the common object that appears in all five pictures":"",
                "What is he distinguishing feature that can help differentiate the picture":"",
                "Question to differentiate the pictures":""
                }}
                ""
                ASSISTANT:
                """
    
    
    image_tensor = [load_image(img_file) for img_file in images_path]
    inputs = processor(prompt, image_tensor, return_tensors="pt")
    inputs = {k:v.to("cuda") for k,v in inputs.items()}

    with torch.no_grad():
        output_token = model.generate(**inputs, media_type='video',
                                    do_sample=False,
                                    max_new_tokens=500, 
                                      num_beams=1, 
                                      min_length=1, 
                                    top_p=0.9, 
                                      repetition_penalty=1, 
                                      length_penalty=1, 
                                      temperature=1,
                                    ) # dont need to long for the choice.
    torch.cuda.empty_cache() # clear the history for this batch
    output_text = processor.batch_decode(output_token, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    # extract the question 
    output_text = output_text.split('ASSISTANT:')[-1].strip()
    return output_text


In [17]:
def extract_json(text):
    pattern = r'{.*}'
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group()
    else:
        return 'parse incorrectly'

def extract_output(text):
    question_fewshot_json=json.loads(extract_json(text))
    question=question_fewshot_json["Question to differentiate the pictures"]
    return question

### 3. Downsample retrieved image

In [13]:
from collections import defaultdict
def downsample_retrieved_images(images_path:list, strategy='clustering', image_vector=None, retrieve_image_indices:list=None):
    if strategy == 'clustering':
        topk_embedding = image_vector[retrieve_image_indices] # [batch_size, topk, embed_dim]
        
        kmeans = faiss.Kmeans(d=topk_embedding.shape[1], k=5, niter=10, verbose=False) # k=num_clusters, niter=epoch
        # Train the KMeans object
        kmeans.train(topk_embedding)
        distance, clustering_label = kmeans.index.search(topk_embedding, 1)
        
        clustering_label = clustering_label.flatten()
        label_to_paths = defaultdict(list)
        for path, label in zip(images_path, clustering_label):
            label_to_paths[label].append(path)
        
        sampled_paths_list = [random.choice(paths) for paths in label_to_paths.values()]
    elif strategy == 'interval':
        sampled_paths_list = []
        for i in range(0, 50, 10):
            sampled_path = random.choice(images_path[i:i+10])
            sampled_paths_list.append(sampled_path)
    elif strategy == 'topk_cos_similiarity':
        sampled_paths_list = images_path
    return sampled_paths_list

In [15]:
# data = pd.read_csv('recall_res.csv')
# option_list = data['option'].tolist()
# # option_list

# non-top10 dataset

recall_res_raw = pd.DataFrame()
for p in os.listdir('./top_interval_res'):
    if '0-9' not in p:
        tmp = pd.read_csv('./top_interval_res/' + p)
        recall_res_raw = pd.concat([recall_res_raw, tmp], axis=0)
recall_res_raw = recall_res_raw[['option', 'recall_images', 'target_image', 'target_image_position']]
recall_res = recall_res_raw.reset_index()
option_list = recall_res['option'].tolist()
# option_list

In [None]:

# topk = 100 and strategy='clustering' --> Kmeans labeled samples
topk = 50 # and strategy='clustering' --> interval sampled
# topk = 5 # topk sample


one_tune_questions = []
downsampled_images_paths = []
top100_images_path = []
for opt in tqdm(option_list):
    # the type of query hyperparameter in retrieve_topk_images is list
    images_path, indices = retrieve_topk_images([opt],
                                             topk=topk,
                                             faiss_model=faiss_model,
                                             text_encoder=dialog_encoder,
                                             id2image=id2image,
                             )
    images_path, indices = images_path[0], indices[0] 
    top100_images_path.append(images_path)
    # Downsample retrieved image
    downsampled_images_path = downsample_retrieved_images(images_path, 
                                                       strategy='interval', 
                                                       image_vector=image_vector, 
                                                       retrieve_image_indices=indices
                                                      )
    # print(downsampled_images_path)
    downsampled_images_paths.append(downsampled_images_path)
    # the type of option hyperparameter in generate_question is str
    output_ori = generate_question(opt, 
                                         downsampled_images_path,
                                         model=llava_model,
                                         processor=llava_processor
                                       )
    # extract output
    try:
        output_question = extract_output(output_ori)
    except:
        output_question = ''
        
    one_tune_questions.append(output_question)
    
    result = {
            "option": opt,
            "top100_images_path": images_path,
            "downsampled_images_paths": downsampled_images_path,
            "one_tune_questions": output_question,
            "output_ori": output_ori
        }
    with open('./experiment_res/interval_prompt_without_option_newprompt_13b.jsonl', 'a') as f:  # 'a' for append mode
        json.dump(result, f)
        f.write("\n")  # Add newline to separate entries

### 5. Batch Inference

In [19]:
# Simple version
class MultiModalDataset(torch.utils.data.Dataset):
    def __init__(self, option, images_path, processor):
        """
        Args:
            data (list of dict): List of data dictionaries, each containing images and conversations.
            processor: A processor for the model.
        """
        self.option = option
        self.images_path = images_path
        self.processor = processor


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_paths = self.images_path[idx] # 5 images per batch
        option = self.option[idx]

        # Load images
        image_tensor = [Image.open(image_path).convert("RGB") for image_path in image_paths]
        prompt = f"""You are Pllava, a large vision-language assistant. 
    You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language.
    Follow the instructions carefully and explain your answers in detail based on the provided video.
     USER:<image>  USER: our task is to identify the target image based on descriptions provided by users. Due to the ambiguity in user descriptions, you have already searched and found 5 images based on these descriptions. 
     Your task is to analyze the content of these 5 images, ask a question to clarify the user's needs, and your question is not overlapping with the descriptions. Thereby helping the user quickly find the target image. You only need to output one questions within 30 words.
                Complete the following tasks step by step:
                1. Combine the textual description, observe these 5 images, and analyze and summarize their common points and differences.
                2. To find the target image, ask a question based on these differences that can clarify the user’s needs and help them quickly find the target image.
                [Your Question]
                Based on the images and sentence description, {option}. One question you asked is:
     ASSISTANT:"""

  
        encode = self.processor(prompt, image_tensor, return_tensors="pt")

        return {'input_ids': encode['input_ids'].squeeze(0), # shape: [seq_len]
                'attention_mask': encode['attention_mask'].squeeze(0), # shape: [seq_len]
                'pixel_values': encode['pixel_values'], # shape: [num_images, 3, 224, 224]
                }

def collate_fn(batch):
    """
    Custom collate function to handle the merging of text and image data.
    Args:
        batch: A list of tuples (input_ids, pixel_values) from the dataset.
    Returns:
        A tuple (input_ids, pixel_values) where:
            input_ids: Tensor of shape [batch_size, seq_len]
            pixel_values: Tensor of shape [batch_size * num_images, 3, 224, 224]
    """
    # print(batch.keys())
    # input_ids, attention_mask, pixel_values = zip(*batch)
    input_ids = torch.nn.utils.rnn.pad_sequence([item['input_ids'] for item in batch],
                                                batch_first=True,
                                                padding_value=llava_processor.tokenizer.pad_token_id)
    attention_mask = torch.nn.utils.rnn.pad_sequence([item['attention_mask'] for item in batch], batch_first=True, padding_value=0)
    pixel_values = torch.cat([item['pixel_values'] for item in batch], dim=0)  # flatten images

    return {'input_ids': input_ids, # shape: [batch_size, seq_len]
            'attention_mask': attention_mask, # shape: [batch_size, seq_len]
            'pixel_values': pixel_values, # shape: [batch_size * num_images, 3, 224, 224]
            }

# load temp json file
def generate_question_batch(option: list, 
                     images_path: list[list],
                     model=None,
                     processor=None):
    dataset = MultiModalDataset(data, processor)
    dataloader = torch.utils.data.DataLoader(dataset, 
                                             batch_size=3,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=4,
                                            )
    output_res = []
    for batch in tqdm(dataloader):
        batch = {k:v.to("cuda") for k,v in batch.items()}
        with torch.no_grad():
            output_token = model.generate(**inputs, media_type='video',
                                    do_sample=False,
                                    max_new_tokens=500, 
                                      num_beams=1, 
                                      min_length=1, 
                                    top_p=0.9, 
                                      repetition_penalty=1, 
                                      length_penalty=1, 
                                      temperature=1,
                                    ) # dont need to long for the choice.
        output_text = processor.batch_decode(output_token, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        for text in output_text:
            processed_text = text.split('ASSISTANT:')[-1].strip()
            output_res.append(processed_text)
    return output_res



In [20]:
data = pd.read_csv('recall_res.csv')
option = data['option']
# Batch Inference 
retrieved_images_path, retrieved_indices = retrieve_topk_images(option,
                                             topk=100,
                                             faiss_model=faiss_model,
                                             text_encoder=dialog_encoder,
                                             id2image=id2image,
                             )

# Downsample retrieved image
downsampled_images_paths = []
for img_path, idx in zip(retrieved_images_path, retrieved_indices):
    downsampled_images_path = downsample_retrieved_images(img_path, 
                                                       strategy='interval', 
                                                       image_vector=image_vector, 
                                                       retrieve_image_indices=idx
                                                      )
    downsampled_images_paths.append(downsampled_images_path)
    


In [None]:
# GENERATE QUestion
one_tune_question = generate_question_batch(option,
                       downsampled_images_paths,
                        model=llava_model,
                         processor=llava_processor
                       )

In [None]:

query = [
    'a women dressed in white playing tennis on a clay court',
    'a white and gray cat perched on a door that is partially open',
    'a bird sitting on a branch of a tree',
    'a close up of a wine glass with a bartender in the background',
    'a group of giraffes is standing in a savannah'
]

option = 'rows of wooden chairs and benches in a classroom'

# the type of query hyperparameter in retrieve_topk_images is list
images_path, indices = retrieve_topk_images([option],
                         topk=100,
                         faiss_model=faiss_model,
                         text_encoder=dialog_encoder,
                         id2image=id2image,
                         )
images_path, indices = images_path[0], indices[0]
selected_images_path = downsample_retrieved_images(images_path, strategy='clustering', image_vector=image_vector, retrieve_image_indices=indices)
# the type of option hyperparameter in generate_question is str
output_question = generate_question(option, 
                     selected_images_path,
                     model=llava_model,
                     processor=llava_processor)


In [None]:
output_question

In [None]:
show_image(images_path, option)

In [None]:
# query = data['option'].tolist()
query = [
    'a women dressed in white playing tennis on a clay court',
    # 'a white and gray cat perched on a door that is partially open',
    # 'a bird sitting on a branch of a tree',
    # 'a close up of a wine glass with a bartender in the background',
    # 'a group of giraffes is standing in a savannah'
]
topk = 100

image_paths, indices = retrieve_topk_images(query,
                         topk=10,
                         faiss_model=faiss_model,
                         text_encoder=dialog_encoder,
                         id2image=id2image,
                         )


In [None]:
indices

In [None]:
topk_embedding = image_vector[indices]

In [None]:
topk_embedding.shape

In [None]:
n_clusters = 5
kmeans = faiss.Kmeans(d=image_vector.shape[1], k=n_clusters, niter=10, verbose=True)
# Train the KMeans object
kmeans.train(topk_embedding[0])

In [None]:
# Get the cluster centroids
centroids = kmeans.centroids
centroids

In [None]:
# Assign the data points to the nearest cluster
distance, clustering_label = kmeans.index.search(topk_embedding[0], 1)

In [None]:
image_paths, clustering_label

In [None]:
# distance

In [None]:
import numpy as np
import random


image_paths = [
    './playground/data/css_data/unlabeled2017/000000446161.jpg',
    './playground/data/css_data/unlabeled2017/000000396918.jpg',
    './playground/data/css_data/unlabeled2017/000000282138.jpg',
    './playground/data/css_data/unlabeled2017/000000296263.jpg',
    './playground/data/css_data/unlabeled2017/000000208485.jpg',
    './playground/data/css_data/unlabeled2017/000000087032.jpg',
    './playground/data/css_data/unlabeled2017/000000382609.jpg',
    './playground/data/css_data/unlabeled2017/000000201751.jpg',
    './playground/data/css_data/unlabeled2017/000000307338.jpg',
    './playground/data/css_data/unlabeled2017/000000239956.jpg'
]
clustering_label = np.array([[4],
                              [0],
                              [1],
                              [0],
                              [2],
                              [4],
                              [3],
                              [0],
                              [4],
                              [3]])


clustering_label = clustering_label.flatten()
clustering_label



In [None]:
def show_clusted_images(clustering_label, image_paths):
    
    images_by_label = {}

    for label, path in zip(clustering_label, image_paths):
        label = label[0]  
        if label not in images_by_label:
            images_by_label[label] = []
        images_by_label[label].append(path)


    for label, paths in images_by_label.items():
        num_images = len(paths)
        cols = 6  
        rows = int(np.ceil(num_images / cols) ) 
        fig, axs = plt.subplots(rows, cols, figsize=(20, 5 * rows))  
        fig.suptitle(f'Label {label}', fontsize=16)


        axs = axs.flatten()


        for ax, path in zip(axs, paths):
            img = Image.open(path)
            ax.imshow(img)
            ax.axis('off')  


        for ax in axs[len(paths):]:
            ax.axis('off')

        plt.tight_layout()
        plt.show()



In [None]:
n_clusters=5
for tk in range(topk_embedding.shape[0]):
    print('Query: ', query[tk])
    kmeans = faiss.Kmeans(d=image_vector.shape[1], k=n_clusters, niter=10, verbose=True)
    kmeans.train(topk_embedding[tk])
    distance, clustering_label = kmeans.index.search(topk_embedding[tk], 1)
    clustering_label_image_path = image_paths[tk]
    show_clusted_images(clustering_label, clustering_label_image_path)

In [12]:
import pandas as pd 
data = pd.read_csv('recall_res.csv')

In [None]:
data_repeat_5times = data.loc[data.index.repeat(5)].reset_index(drop=True)
data_repeat_5times

In [3]:
import re 
import json
def extract_json(text):
    pattern = r'{.*}'
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group()
    else:
        return 'parse incorrectly'

def extract_llm_ouptut(text):
    ext_json = extract_json(text)
    try:
        question = json.loads(ext_json)['Question to differentiate the pictures']
    except:
        question = ''
        pass
    return question

In [None]:
tpm= "{\n                  \"What is the common object that appears in all five pictures\": \"Book\",\n                  \"What is the distinguishing feature that can help differentiate the picture\": \"The color of the book cover\",\n                  \"Question to differentiate the pictures\": \"What color is the book cover in each of the five pictures?\"\n                  }"
extract_llm_ouptut(tpm)

In [None]:
def load_jsonl(filename):
    with open(filename, "r") as f:
        return [json.loads(l.strip("\n")) for l in f.readlines()]

load_jsonl('./experiment_res/one_option_five_images_five_response_13b.jsonl')