In [1]:
import json

def read_jsonl(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            json_object = json.loads(line.strip())
            data.append(json_object)
    return data


file_path = './experiment_res/interval_prompt_without_option_newprompt_13b.jsonl'
jsonl_data = read_jsonl(file_path)

In [None]:
jsonl_data[0]

In [3]:
import pandas as pd
import faiss
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
import random

from openai import OpenAI

import base64
import requests

import json
import re
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"


In [4]:

def show_image(image_paths, sentence=None):
    k=(len(image_paths)+4)//5
    fig, axs = plt.subplots(nrows=k, ncols=5, figsize=(20, 8))  
    axs = axs.flatten()  

    
    for ax, img_path in zip(axs, image_paths):
        try:
            img = Image.open(img_path)
            ax.imshow(img)
            ax.axis('off')  
            ax.set_title(img_path.split('/')[-1])  
        except FileNotFoundError:
            ax.imshow(np.zeros((10, 10, 3), dtype=int))  
            ax.axis('off')
            ax.set_title('File Not Found')
    if sentence:
        fig.suptitle(sentence, fontsize=16)
    plt.tight_layout()
    plt.show()

class ImageEmbedder:
    def __init__(self, model, preprocessor):
        """ model projects image to vector, processor load and prepare image to the model"""
        self.model = model
        self.processor = preprocessor

def BLIP_BASELINE():
    from torchvision import transforms
    from torchvision.transforms.functional import InterpolationMode

    import sys
    sys.path.insert(0, './BLIP')
    from BLIP.models.blip_itm import blip_itm
    # load model
    model = blip_itm(pretrained='./BLIP/chatir_weights.ckpt',  # Download from Google Drive, see README.md
                     med_config='BLIP/configs/med_config.json',
                     image_size=224,
                     vit='base'
                     )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device).eval()

    transform_test = transforms.Compose([
        transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    ])

    def blip_project_img(image):
        embeds = model.visual_encoder(image)
        projection = model.vision_proj(embeds[:, 0, :])
        return F.normalize(projection, dim=-1)

    def blip_prep_image(path):
        raw = Image.open(path).convert('RGB')
        return transform_test(raw)

    image_embedder = ImageEmbedder(blip_project_img, lambda path: blip_prep_image(path))

    # define dialog encoder (dialog --> img_feature)
    def dialog_encoder(dialog):
        text = model.tokenizer(dialog, padding='longest', truncation=True,
                               max_length=200,
                               return_tensors="pt"
                               ).to(device)

        text_output = model.text_encoder(text.input_ids, attention_mask=text.attention_mask,
                                         return_dict=True, mode='text')

        shift = model.text_proj(text_output.last_hidden_state[:, 0, :])
        return F.normalize(shift, dim=-1)

    return dialog_encoder, image_embedder

class TextDataset(torch.utils.data.Dataset):
    def __init__(self, texts: list):
        """
        Args:
            texts (list of str): List of text strings.
            processor (transformers processor): Processor to tokenize the text.
        """
        self.texts = texts
        # self.processor = processor

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]  # Get the text at the provided index
        return {'text': text}

def encode_text(dataset, model):
    """CLIP for encode text """
    # model.eval()
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=20,
                                              pin_memory=True,
                                              num_workers=1,
                                              prefetch_factor=2,
                                              shuffle=False,
                                              
                                              )
    all_features = []
    with torch.no_grad():
        for batch in data_loader:
            features = model(batch['text'])
            all_features.append(features.cpu())  
    return torch.cat(all_features)  

def retrieve_topk_images(query: list,
                         topk=10,
                         faiss_model=None,
                         blip_model=None,
                         id2image=None,
                         processor=None, ):
    text_dataset = TextDataset(query)
    query_vec = encode_text(text_dataset, blip_model)
    query_vec = query_vec.numpy()
    query_vec /= np.linalg.norm(query_vec, axis=1, keepdims=True)

    distance, indices = faiss_model.search(query_vec, topk)
    indices = np.array(indices)
    image_paths = [[id2image.get(idx, 'path/not/found') for idx in row] for row in indices]
    return image_paths, indices



def find_index_in_list(element, my_list):
    return my_list.index(element) if element in my_list else 50000



In [5]:

def answer(question,base64_target_image):
    client = OpenAI(api_key='')
    response = client.chat.completions.create(
      model="gpt-4o",
      messages=[
        {
          "role": "user",
          "content": [
            {
              "type": "text",
              "text": f"according to the image, answer the question:{question}，Your answer must be direct and simple",
            },
            {
              "type": "image_url",
              "image_url": {
                "url": f"data:image/jpeg;base64,{base64_target_image}",
              },
            }
          ],
        }
      ],
      max_tokens=100,
    )
    return response.choices[0].message.content


def summary(query,question,answer):
    client = OpenAI(api_key='')
    response = client.chat.completions.create(
    model="gpt-4o",
    messages=[
        {"role": "system", "content": f"""
        Your task is to summarize the information from the image's question and answer and add this information to the original image description.\
        Remember: the summarized information must be concise, and the original description should not be altered.

        <question>
        {question}
        <answer>
        {answer}
        <image description>
        {query}


The information extracted from the question and answer should be added to the original description as an attribute or a simple attributive clause.
        """
        },
        {"role": "user", "content": ""}
      ]
    )
    return response.choices[0].message.content

def encode_image(image_path):
  with open(image_path, "rb") as image_file:
    return base64.b64encode(image_file.read()).decode('utf-8')

In [None]:

faiss_model = faiss.read_index('./checkpoints/blip_faiss.index')
with open('./checkpoints/id2image.pickle', 'rb') as f:
    id2image = pickle.load(f)
    
with open('./checkpoints/blip_image_embedding.pickle', 'rb') as f:
    image_vector = pickle.load(f)
dialog_encoder, image_embedder = BLIP_BASELINE()

In [7]:
recall_res = pd.read_csv('recall_res.csv')


In [None]:
recall_res.head()

### LoadPllava

In [None]:
from peft import get_peft_model, LoraConfig, TaskType
from safetensors import safe_open
from pllava import PllavaProcessor, PllavaForConditionalGeneration, PllavaConfig
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils import get_balanced_memory
import os

def load_pllava(repo_id, num_frames,
                use_lora=False, weight_dir=None,
                lora_alpha=32, use_multi_gpus=False,
                pooling_shape=(16,12,12)):
    kwargs = {
        'num_frames': num_frames,
    }
    # print("===============>pooling_shape", pooling_shape)
    if num_frames == 0:
        kwargs.update(pooling_shape=(0,12,12)) # produce a bug if ever usen the pooling projector
    config = PllavaConfig.from_pretrained(
        repo_id if not use_lora else weight_dir,
        pooling_shape=pooling_shape,
        **kwargs,
    )



    model = PllavaForConditionalGeneration.from_pretrained(repo_id,
                                                               config=config,
                                                               torch_dtype=torch.bfloat16,
                                                               )

    try:
        processor = PllavaProcessor.from_pretrained(repo_id)
    except Exception as e:
        processor = PllavaProcessor.from_pretrained('llava-hf/llava-1.5-7b-hf')

    # config lora
    if use_lora and weight_dir is not None:
        print("Use lora")
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM, inference_mode=False,  target_modules=["q_proj", "v_proj"],
            r=128, lora_alpha=lora_alpha, lora_dropout=0.
        )
        print("Lora Scaling:", lora_alpha/128)
        model.language_model = get_peft_model(model.language_model, peft_config)
        assert weight_dir is not None, "pass a folder to your lora weight"
        print("Finish use lora")

    # load weights
    if weight_dir is not None:
        state_dict = {}
        save_fnames = os.listdir(weight_dir)
        if "model.safetensors" in save_fnames:
            use_full = False
            for fn in save_fnames:
                if fn.startswith('model-0'):
                    use_full=True
                    break
        else:
            use_full= True

        if not use_full:
            print("Loading weight from", weight_dir, "model.safetensors")
            with safe_open(f"{weight_dir}/model.safetensors", framework="pt", device="cpu") as f:
                for k in f.keys():
                    state_dict[k] = f.get_tensor(k)
        else:
            print("Loading weight from", weight_dir)
            for fn in save_fnames:
                if fn.startswith('model-0'):
                    with safe_open(f"{weight_dir}/{fn}", framework="pt", device="cpu") as f:
                        for k in f.keys():
                            state_dict[k] = f.get_tensor(k)

        if 'model' in state_dict.keys():
            with torch.device('meta'): # load large scaler model weight
                msg = model.load_state_dict(state_dict['model'], strict=False, assign=True)
        else:
            with torch.device('meta'):
                msg = model.load_state_dict(state_dict, strict=False, assign=True)
        print(msg)
    # dispatch model weight
    if use_multi_gpus:
        max_memory = get_balanced_memory(
            model,
            max_memory=None,
            no_split_module_classes=["LlamaDecoderLayer"],
            dtype='bfloat16',
            low_zero=False,
        )

        device_map = infer_auto_device_map(
            model,
            max_memory=max_memory,
            no_split_module_classes=["LlamaDecoderLayer"],
            dtype='bfloat16'
        )

        dispatch_model(model, device_map=device_map)
        print(model.hf_device_map)

    model = model.eval()

    return model, processor

# load model
llava_model, llava_processor = load_pllava(repo_id='MODELS/pllava-7b', #'llava-hf/llava-1.5-7b-hf',
            num_frames=5, # num_images = 5
            use_lora=True,
            weight_dir='MODELS/pllava-7b',
            lora_alpha=4,
            use_multi_gpus=False,
            pooling_shape=(5,12,12)
            )
llava_model.to('cuda')

In [10]:
import io

def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(io.BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image


def extract_json(text):
    pattern = r'{.*}'
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group()
    else:
        return 'parse incorrectly'

def question_attribute_llava(images_path, query, model=None,
                     processor=None):
    prompt = f"""You are Pllava, a large vision-language assistant. You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language.
                 Follow the instructions carefully and explain your answers in detail based on the provided video.
                  USER:<image>  USER: You need to find a common object that appears in all 5 pictures but has distinguishing features. Based on this object, ask a question to differentiate the pictures.

                Remember, you must ensure the question is specific, not abstract, and the answer should be directly obtainable by looking at the images.

                For example:
                Example 1: All 5 pictures have people, but the number of people differs. You can ask about the number of people.
                Example 2: All 5 pictures have cats, but the colors are different. You can ask about the color.
                Example 3: All 5 pictures have traffic lights, but their positions differ. You can ask about the position of the traffic lights.

                Ask a specific question based on the object that will help distinguish the pictures. The question is not overlapped with the description: {query}.
                Don't ask 2 questions each time. such as what is the attribute of a or b.

                Output as the following format
                {{
                "Question to differentiate the pictures":""
                }}
                ""
                ASSISTANT:"""
    image_tensor = [load_image(img_file) for img_file in images_path]
    inputs = processor(prompt, image_tensor, return_tensors="pt")
    inputs = {k:v.to("cuda") for k,v in inputs.items()}
    with torch.no_grad():
        output_token = model.generate(**inputs, media_type='video',
                                    do_sample=False,
                                    max_new_tokens=500,
                                      num_beams=1,
                                      min_length=1,
                                    top_p=0.9,
                                      repetition_penalty=1,
                                      length_penalty=1,
                                      temperature=1,
                                    ) # dont need to long for the choice.
    torch.cuda.empty_cache() # clear the history for this batch
    output_text = processor.batch_decode(output_token, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    # extract the question
    output_text = output_text.split('ASSISTANT:')[-1].strip()
    question_fewshot_json=json.loads(extract_json(output_text))
    question = question_fewshot_json["Question to differentiate the pictures"]
    return question

from collections import defaultdict
def downsample_retrieved_images(images_path:list, strategy='clustering', image_vector=None, retrieve_image_indices:list=None):
    if strategy == 'clustering':
        topk_embedding = image_vector[retrieve_image_indices] # [batch_size, topk, embed_dim]
        
        kmeans = faiss.Kmeans(d=topk_embedding.shape[1], k=5, niter=10, verbose=False) # k=num_clusters, niter=epoch
        # Train the KMeans object
        kmeans.train(topk_embedding)
        distance, clustering_label = kmeans.index.search(topk_embedding, 1)
        
        clustering_label = clustering_label.flatten()
        label_to_paths = defaultdict(list)
        for path, label in zip(images_path, clustering_label):
            label_to_paths[label].append(path)
        
        sampled_paths_list = [random.choice(paths) for paths in label_to_paths.values()]
    elif strategy == 'interval':
        sampled_paths_list = []
        for i in range(0, 50, 10):
            sampled_path = random.choice(images_path[i:i+10])
            sampled_paths_list.append(sampled_path)
    elif strategy == 'topk_cos_similiarity':
        sampled_paths_list = images_path
    return sampled_paths_list

In [None]:
from tqdm import tqdm

top10=0
better=0

save_data = []
for i in tqdm(range(700)):
    query=jsonl_data[i]["option"]
    # question=jsonl_data[i]['one_tune_questions']
    # Generate question using Pllava
    images_path, indices = retrieve_topk_images([query],
                                             topk=50,
                                             faiss_model=faiss_model,
                                             blip_model=dialog_encoder,
                                             id2image=id2image)
    images_path, indices = images_path[0], indices[0] 
    # Downsample retrieved image
    sampled_image_paths = downsample_retrieved_images(images_path,
                                                       strategy='interval',
                                                       image_vector=image_vector,
                                                       retrieve_image_indices=indices
                                                      )

    question = question_attribute_llava(images_path=sampled_image_paths,
                                 query=query,
                                 model=llava_model,
                                 processor=llava_processor)

    target_image_path='./playground/data/css_data/'+recall_res.loc[recall_res['option'] == query, 'target_image'].values[0]
    # sampled_image_paths=jsonl_data[i]['downsampled_images_paths']
    image_paths,indices=retrieve_topk_images([query],
                             topk=10000,
                             faiss_model=faiss_model,
                             blip_model=dialog_encoder,
                             id2image=id2image,
                             processor=None, )
    image_rank=find_index_in_list(target_image_path , image_paths[0])
    # show_image([target_image_path], sentence=None)
    # show_image(sampled_image_paths, sentence=None)
    
    
    base64_target_image=encode_image(target_image_path)
    answer_of_question=answer(question,base64_target_image)
    summary_of_question_and_option=summary(query,question,answer_of_question)
    
    image_paths_new,indices=retrieve_topk_images([summary_of_question_and_option],
                             topk=10000,
                             faiss_model=faiss_model,
                             blip_model=dialog_encoder,
                             id2image=id2image,
                             processor=None, )
    image_rank_new=find_index_in_list(target_image_path , image_paths_new[0])


    is_top10 = is_better = 0
    if image_rank_new<=10:
        top10+=1
        is_top10=1
    if image_rank_new<image_rank:
        better+=1
        is_better=1


    record_dict = {'downsampled_images_path': sampled_image_paths,
                   'one_tune_questions': question,
                   'target_image_path': target_image_path,
                   'image_rank': image_rank,
                   'option': query,
                   'answer_of_question':answer_of_question,
                   'summary_of_question_and_option': summary_of_question_and_option,
                   'image_rank_new': image_rank_new,
                    'is_top10': is_top10,
                    'is_better': is_better
                   }
    save_data.append(record_dict)
    with open('./rank_res/all_processes_pllava7b.jsonl', 'a') as f:
        f.write(json.dumps(record_dict) + '\n')


In [None]:
top10,better

In [None]:
i=25
query=jsonl_data[i]["option"]

target_image_path='./playground/data/css_data/'+recall_res.loc[recall_res['option'] == query, 'target_image'].values[0]
sampled_image_paths=jsonl_data[i]['downsampled_images_paths']
image_paths,indices=retrieve_topk_images([query],
                         topk=10000,
                         faiss_model=faiss_model,
                         blip_model=dialog_encoder,
                         id2image=id2image,
                         processor=None, )
image_rank=find_index_in_list(target_image_path , image_paths[0])
show_image([target_image_path], sentence=None)
show_image(sampled_image_paths, sentence=None)

question=jsonl_data[i]['one_tune_questions']
#

base64_target_image=encode_image(target_image_path)
answer_of_question=answer(question,base64_target_image)
summary_of_question_and_option=summary(query,question,answer_of_question)

image_paths_new,indices=retrieve_topk_images([summary_of_question_and_option],
                         topk=10000,
                         faiss_model=faiss_model,
                         blip_model=dialog_encoder,
                         id2image=id2image,
                         processor=None, )
image_rank_new=find_index_in_list(target_image_path , image_paths_new[0])

if image_rank_new<=10:
    top10+=1
if image_rank_new<image_rank:
    better+=1



In [None]:
k=(image_rank+4)//5
for i in range(k):
    show_image(image_paths[0][i*5:min((i+1)*5,image_rank+1)], sentence=None)

In [None]:
k=(image_rank_new+4)//5
for i in range(k):
    show_image(image_paths_new[0][i*5:min((i+1)*5,image_rank_new+1)], sentence=None)