In [27]:
import pandas as pd
# import faiss
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
import random

from openai import OpenAI

import base64
import requests

import json
import re

In [28]:

def show_image(image_paths, sentence=None):
    k=(len(image_paths)+4)//5
    fig, axs = plt.subplots(nrows=k, ncols=5, figsize=(20, 8))  
    axs = axs.flatten()  

    
    for ax, img_path in zip(axs, image_paths):
        try:
            img = Image.open(img_path)
            ax.imshow(img)
            ax.axis('off')  
            ax.set_title(img_path.split('/')[-1])  
        except FileNotFoundError:
            ax.imshow(np.zeros((10, 10, 3), dtype=int))  
            ax.axis('off')
            ax.set_title('File Not Found')
    if sentence:
        fig.suptitle(sentence, fontsize=16)
    plt.tight_layout()
    plt.show()



class ImageEmbedder:
    def __init__(self, model, preprocessor):
        """ model projects image to vector, processor load and prepare image to the model"""
        self.model = model
        self.processor = preprocessor

def BLIP_BASELINE():
    from torchvision import transforms
    from torchvision.transforms.functional import InterpolationMode

    import sys
    sys.path.insert(0, './BLIP')
    from BLIP.models.blip_itm import blip_itm
    # load model
    model = blip_itm(pretrained='./BLIP/chatir_weights.ckpt',  # Download from Google Drive, see README.md
                     med_config='BLIP/configs/med_config.json',
                     image_size=224,
                     vit='base'
                     )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device).eval()

    # define Image Embedder (raw_image --> img_feature)
    transform_test = transforms.Compose([
        transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    ])

    def blip_project_img(image):
        embeds = model.visual_encoder(image)
        projection = model.vision_proj(embeds[:, 0, :])
        return F.normalize(projection, dim=-1)

    def blip_prep_image(path):
        raw = Image.open(path).convert('RGB')
        return transform_test(raw)

    image_embedder = ImageEmbedder(blip_project_img, lambda path: blip_prep_image(path))

    def dialog_encoder(dialog):
        text = model.tokenizer(dialog, padding='longest', truncation=True,
                               max_length=200,
                               return_tensors="pt"
                               ).to(device)

        text_output = model.text_encoder(text.input_ids, attention_mask=text.attention_mask,
                                         return_dict=True, mode='text')

        shift = model.text_proj(text_output.last_hidden_state[:, 0, :])
        return F.normalize(shift, dim=-1)

    return dialog_encoder, image_embedder

class TextDataset(torch.utils.data.Dataset):
    def __init__(self, texts: list):
        """
        Args:
            texts (list of str): List of text strings.
            processor (transformers processor): Processor to tokenize the text.
        """
        self.texts = texts


    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]  
        return {'text': text}

def encode_text(dataset, model):
    """CLIP for encode text """
    # model.eval()
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=20,
                                              pin_memory=True,
                                              num_workers=1,
                                              prefetch_factor=2,
                                              shuffle=False,
                                              
                                              )
    all_features = []
    with torch.no_grad():
        for batch in data_loader:
            features = model(batch['text'])
            all_features.append(features.cpu())  
    return torch.cat(all_features)  

def retrieve_topk_images(query: list,
                         topk=10,
                         faiss_model=None,
                         blip_model=None,
                         id2image=None,
                         processor=None, ):
    text_dataset = TextDataset(query)
    query_vec = encode_text(text_dataset, blip_model)
    query_vec = query_vec.numpy()
    query_vec /= np.linalg.norm(query_vec, axis=1, keepdims=True)
    print('query_vec.shape------------', query_vec.shape)
    distance, indices = faiss_model.search(query_vec, topk)
    indices = np.array(indices)
    image_paths = [[id2image.get(idx, 'path/not/found') for idx in row] for row in indices]
    return image_paths, indices



def find_index_in_list(element, my_list):
    return my_list.index(element) 





In [29]:


def generate_valid_question(base64_image_top, query, max_retries=5):

    for attempt in range(max_retries):
        try:
            print(f"Attempt {attempt + 1} to generate a valid question.")
            
            question_fewshot = question_attribute(base64_image_top, query)

            question_fewshot_json = json.loads(extract_json(question_fewshot))
            
            if "Question to differentiate the pictures" in question_fewshot_json:
                return question_fewshot_json["Question to differentiate the pictures"]
            else:
                print(f"Key 'Question to differentiate the pictures' not found, retrying...")

        except json.JSONDecodeError:
            print("JSONDecodeError: The result is not valid JSON, retrying...")
        except KeyError:
            print("KeyError: Expected key not found in the generated JSON, retrying...")
        except Exception as e:
            print(f"An unexpected error occurred: {e}, retrying...")

    print("Failed to generate a valid question after multiple attempts.")
    return None

In [30]:
def extract_json(text):
    pattern = r'{.*}'
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group()
    else:
        return 'parse incorrectly'

In [31]:

def question_simple(base64_image_list,query):
    client = OpenAI(api_key='',base_url = "") 
    
    response = client.chat.completions.create(
      model="navigation_llava_4o",
      messages=[
        {
          "role": "user",
          "content": [
            {
              "type": "text",
              "text": f"""
              
            Your task is to identify the target image based on descriptions provided by users. Due to the ambiguity in user descriptions, you have already searched and found 5 images based on these descriptions. Your task is to analyze the content of these 5 images, ask a question to clarify the user's needs, and your question is not overlapping with the descriptions. Thereby helping the user quickly find the target image. You only need to output one questions within 30 words.
            Complete the following tasks step by step:
            1. Combine the textual description:, observe these 5 images, and analyze and summarize their common points and differences.
            2. To find the target image, ask a question based on these differences that can clarify the user’s needs and help them quickly find the target image.
            [Your Question]
            Based on the images and sentence description <{query}>, your questions are:
              """,
            },
            {
              "type": "image_url",
              "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image_list[0]}",
              },
            },
            {
              "type": "image_url",
              "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image_list[1]}",
              },
            },
            {
              "type": "image_url",
              "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image_list[2]}",
              },
            },
            {
              "type": "image_url",
              "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image_list[3]}",
              },
            },
            {
              "type": "image_url",
              "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image_list[4]}",
              },
            },
          ],
        }
      ],
      max_tokens=500,
    )
    return response.choices[0].message.content


In [32]:

def question_attribute(base64_image_list,query):
    client = OpenAI(api_key='',base_url = "") 
    
    response = client.chat.completions.create(
      model="navigation_llava_4o",
      messages=[
        {
          "role": "user",
          "content": [
            {
              "type": "text",
              "text": f"""
              
                You need to find a common object that appears in all 5 pictures but has distinguishing features. Based on this object, ask a question to differentiate the pictures.

                Remember, you must ensure the question is specific, not abstract, and the answer should be directly obtainable by looking at the images.
                
                For example:
                Example 1: All 5 pictures have people, but the number of people differs. You can ask about the number of people.
                Example 2: All 5 pictures have cats, but the colors are different. You can ask about the color.
                Example 3: All 5 pictures have traffic lights, but their positions differ. You can ask about the position of the traffic lights.
                
                Ask a specific question based on the object that will help distinguish the pictures.
                Don't ask 2 questions each time. such as what is the attribute of a or b

                Output must follow the format
                {{
                "What is the common object that appears in all five pictures":"",
                "What is he distinguishing feature that can help differentiate the picture":"",
                "Questin to differentiate the pictures":""
                }}
                ""
                
              """,
            },
            {
              "type": "image_url",
              "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image_list[0]}",
              },
            },
            {
              "type": "image_url",
              "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image_list[1]}",
              },
            },
            {
              "type": "image_url",
              "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image_list[2]}",
              },
            },
            {
              "type": "image_url",
              "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image_list[3]}",
              },
            },
            {
              "type": "image_url",
              "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image_list[4]}",
              },
            },
          ],
        }
      ],
      max_tokens=500,
    )
    return response.choices[0].message.content


In [33]:
def encode_image(image_path):
  with open(image_path, "rb") as image_file:
    return base64.b64encode(image_file.read()).decode('utf-8')

In [34]:

def answer(question,base64_target_image):
    client = OpenAI(api_key='',base_url = "") 
    
    response = client.chat.completions.create(
      model="navigation_llava_4o",
      messages=[
        {
          "role": "user",
          "content": [
            {
              "type": "text",
              "text": f"according to the image, answer the question:{question}，Your answer must be direct and simple",
            },
            {
              "type": "image_url",
              "image_url": {
                "url": f"data:image/jpeg;base64,{base64_target_image}",
              },
            }
          ],
        }
      ],
      max_tokens=100,
    )
    return response.choices[0].message.content

In [35]:

def summary(query,question,answer):
    client = OpenAI(api_key='',base_url = "") 
    
    response = client.chat.completions.create(
    model="navigation_llava_4o",
    messages=[
        {"role": "system", "content": f"""
        Your task is to summarize the information from the image's question and answer and add this information to the original image description.\
        Remember: the summarized information must be concise, and the original description should not be altered.

        <question>
        {question}
        <answer>
        {answer}
        <image description>
        {query}


The information extracted from the question and answer should be added to the original description as an attribute or a simple attributive clause.
        """
        },
        {"role": "user", "content": ""}
      ]
    )
    return response.choices[0].message.content

In [None]:
import faiss
import pickle

try:
    print("Load faiss Model...")
    faiss_model = faiss.read_index('./checkpoints/blip_faiss.index')
    print("faiss faiss Model Load successful")
except Exception as e:
    print(f"Error: {e}")

try:
    print("Load id2image Data...")
    with open('./checkpoints/id2image.pickle', 'rb') as f:
        id2image = pickle.load(f)
    print("id2image DataLoad Successful")
except Exception as e:
    print(f"Load id2image Error: {e}")

try:
    print("Load image_vector Data...")
    with open('./checkpoints/blip_image_embedding.pickle', 'rb') as f:
        image_vector = pickle.load(f)
    print("image_vector DataLoad Successful")
except Exception as e:
    print(f"Load image_vector Error: {e}")


In [37]:
from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('/root/autodl-tmp/.autodl/HYF/questionIR/CSS/MODELS/bert-base-uncased')
# model = BertModel.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('/root/autodl-tmp/.autodl/HYF/questionIR/CSS/MODELS/bert-base-uncased') 


In [None]:

faiss_model = faiss.read_index('./checkpoints/blip_faiss.index')
with open('./checkpoints/id2image.pickle', 'rb') as f:
    id2image = pickle.load(f)
    
with open('./checkpoints/blip_image_embedding.pickle', 'rb') as f:
    image_vector = pickle.load(f)
dialog_encoder, image_embedder = BLIP_BASELINE()

In [39]:
# recall_res = pd.read_csv('recall_res.csv')
#len(recall_res)


In [40]:
def read_jsonl(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            json_object = json.loads(line.strip())
            data.append(json_object)
    return data



In [None]:

def read_csv(file_path):

    data = pd.read_csv(file_path)
    return data

file_path = '/root/autodl-tmp/.autodl/HYF/questionIR/CSS/downloads/group_20-50.csv'
csv_data = read_csv(file_path)

print(csv_data.shape)  


In [None]:
csv_data.head()

In [None]:

i=30
# query=jsonl_data[i]["option"]

query = csv_data.iloc[i]["option"]

target_image_path='./playground/data/css_data/'+csv_data.iloc[i]["target_image"] 
show_image([target_image_path], sentence=None)


image_paths,indices=retrieve_topk_images([query],
                         topk=100,
                         faiss_model=faiss_model,
                         blip_model=dialog_encoder,
                         id2image=id2image,
                         processor=None, ) #topk=10000 太多了

print(f'target_image_path{target_image_path}, image_paths{image_paths[0]}')

image_rank=find_index_in_list(target_image_path , image_paths[0])


top_images_path=image_paths[0][:40]
random_selection_path = random.sample(top_images_path, 5)
base64_image_top=[]
for image_path in random_selection_path:
    base64_image_top.append(encode_image(image_path))
    

question = generate_valid_question(base64_image_top, query)


base64_target_image=encode_image(target_image_path)
answer_of_question=answer(question,base64_target_image)


summary_of_question_and_option=summary(query,question,answer_of_question)

image_paths_new,indices=retrieve_topk_images([summary_of_question_and_option],
                         topk=10000,
                         faiss_model=faiss_model,
                         blip_model=dialog_encoder,
                         id2image=id2image,
                         processor=None, )


image_rank_new=find_index_in_list(target_image_path , image_paths_new[0])

print(query)
print(question)
print(summary_of_question_and_option)


print("old_rank:",image_rank)
print("new_rank:",image_rank_new)

In [None]:
query,question,summary_of_question_and_option

In [None]:
show_image(image_paths[0][:image_rank+1], sentence=None)

In [None]:
image_paths_new,indices=retrieve_topk_images([summary_of_question_and_option],
                         topk=100, 
                         faiss_model=faiss_model,
                         blip_model=dialog_encoder,
                         id2image=id2image,
                         processor=None, )  
print(f'target_image_path:  {target_image_path}, image_paths:  {image_paths_new[0]}')
image_rank_new=find_index_in_list(target_image_path , image_paths_new[0])


In [None]:
image_rank_new

In [None]:
print("old_rank:",image_rank)
print("new_rank:",image_rank_new)

In [None]:
show_image(image_paths_new[0][:image_rank_new+1], sentence=None)

In [None]:
from tqdm import tqdm

query_question=[]
rank_storage=[]
rank_change=[]


top_10_hits_before = 0  
top_10_hits_after = 0  
total_queries = len(csv_data[:100])  


for i in tqdm(range(total_queries)):
    query = csv_data["option"][i]
    target_image_path = './playground/data/css_data/' + csv_data["target_image"][i]

    image_paths, indices = retrieve_topk_images([query],
                             topk=50000,
                             faiss_model=faiss_model,
                             blip_model=dialog_encoder,
                             id2image=id2image,
                             processor=None)

    
    image_rank = find_index_in_list(target_image_path, image_paths[0])


    if image_rank < 10:  
        top_10_hits_before += 1


    top_images_path = image_paths[0][:40]
    random_selection_path = random.sample(top_images_path, 5)
    

    base64_image_top = []
    for image_path in random_selection_path:
        base64_image_top.append(encode_image(image_path))
    
    question = generate_valid_question(base64_image_top, query)
    if question is None:
        print("Failed to generate a valid question. Skipping this sample.")
        continue
    

    base64_target_image = encode_image(target_image_path)
    answer_of_question = answer(question, base64_target_image)
    

    summary_of_question_and_option = summary(query, question, answer_of_question)
    
    
    image_paths_new, indices = retrieve_topk_images([summary_of_question_and_option],
                             topk=50000,
                             faiss_model=faiss_model,
                             blip_model=dialog_encoder,
                             id2image=id2image,
                             processor=None)
                             
    image_rank_new = find_index_in_list(target_image_path, image_paths_new[0])
    rank_storage.append([image_rank, image_rank_new, image_rank - image_rank_new])
    rank_change.append(image_rank - image_rank_new)
    

    if image_rank_new < 10:  
        top_10_hits_after += 1


top_10_recall_rate_before = top_10_hits_before / total_queries
top_10_recall_rate_after = top_10_hits_after / total_queries
print(f'top_10_hits: {top_10_hits_before}, total_queries: {total_queries}')
print(f"Top-10 Recall Rate before: {top_10_recall_rate_before * 100:.4f}%")
print(f'---------------------------------------')
print(f'top_10_hits: {top_10_hits_after}, total_queries: {total_queries}')
print(f"Top-10 Recall Rate after: {top_10_recall_rate_after * 100:.4f}%")



