### Method 1 VLM based image retrieval

In [None]:
import json
import re
from sentence_transformers import SentenceTransformer
import numpy as np
import csv
from sklearn.metrics.pairwise import cosine_similarity
import ollama
import os
from tqdm import tqdm

def image_captioning(data_path, prompt, output_csv, batch_size=10):
    count = 0
    batch = []  

    imglist = tqdm(os.listdir(data_path))
    for img in imglist:
        count += 1
        image_path = os.path.join(data_path, img)
        imglist.set_description(f"Processing {count} :{img}")

        with open(image_path, 'rb') as f:
            image_bytes = f.read()

        try:
            response = ollama.generate(
                'llava:34b',
                prompt,
                images=[image_bytes],
                options={"num_predict": 76}
            )['response']
            response = response.replace('\n', '').strip()
        except Exception as e:
            print(f"Error processing {image_path}: {e}")
            response = "Error"
        
        photo_id = os.path.basename(image_path).split(".")[0]
        batch.append([photo_id, response])

        if count % batch_size == 0:
            with open(output_csv, 'a', newline='', encoding='utf-8') as csvfile:
                writer = csv.writer(csvfile)
                if count <= batch_size:
                    writer.writerow(['photo_id', 'caption'])
                writer.writerows(batch)
            batch.clear()

    if batch:
        with open(output_csv, 'a', newline='', encoding='utf-8') as csvfile:
            writer = csv.writer(csvfile)
            if count <= batch_size:
                writer.writerow(['photo_id', 'caption'])
            writer.writerows(batch)

    print(f"Results saved to {output_csv}")

def basic_text_preprocessing(text):
    text = re.sub(r'[^\w\s]', ' ', text) 
    text = re.sub(r'\s+', ' ', text).strip() 
    text = text.lower()  
    words = text.split()
    unique_words = list(dict.fromkeys(words))  
    return ' '.join(unique_words)

def process_query(dialogue_data):
    dialogue = dialogue_data['dialogue']
    concatenated_messages = ''
    user_with_photo = next((entry['user_id'] for entry in dialogue if entry.get('share_photo')), -1)
    if user_with_photo != -1:
        for entry in dialogue:
            concatenated_messages += entry['message']
    return basic_text_preprocessing(concatenated_messages)


data_path = 'test_images/test_images'
obj_prompt = """
List out all objects for the photo without interpretations and repeating. Focus on objects or people's clothing. Skip auxiliary words and helping verbs.
"""
caption_prompt = """
Generate a brief caption describing the image, focusing on the most important elements.
"""
obj_output_csv = 'test_vlm_caption.csv'
caption_output_csv = 'test_vlm_caption_short.csv'

if not os.path.exists(obj_output_csv):
    print(f"{obj_output_csv} does not exist. Running the function...")
    image_captioning(data_path, obj_prompt, obj_output_csv)
else:
    print(f"{obj_output_csv} already exists. Skipping the function.")

if not os.path.exists(caption_output_csv):
    print(f"{caption_output_csv} does not exist. Running the function...")
    image_captioning(data_path, caption_prompt, caption_output_csv)
else:
    print(f"{caption_output_csv} already exists. Skipping the function.")

queries = []
with open('test.jsonl', 'r') as file:
    for line in file:
        queries.append(json.loads(line.strip()))

vlm_caption_data = []
with open('test_vlm_caption_short.csv', 'r', encoding='utf-8') as file:
    reader = csv.DictReader(file)
    for row in reader:
        vlm_caption_data.append({
            'photo_id': row['photo_id'],
            'caption': row['caption']
        })

vlm_object_data = []
with open('test_vlm_caption.csv', 'r', encoding='utf-8') as file:
    reader = csv.DictReader(file)
    for row in reader:
        vlm_object_data.append({
            'photo_id': row['photo_id'],
            'caption': row['caption']
        })

description_database = []
with open('test_images.jsonl', 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line.strip())
        description_database.append({
            'photo_id': data['photo_id'],
            'photo_description': data['photo_description']
        })

merged_database = []
for caption_entry in vlm_caption_data:
    photo_id = caption_entry['photo_id']
    caption = caption_entry['caption']

    object_entry = next((item for item in vlm_object_data if item['photo_id'] == photo_id), None)
    object_description = object_entry['caption'] if object_entry else ""

    description_entry = next((item for item in description_database if item['photo_id'] == photo_id), None)
    photo_obj_description = description_entry['photo_description'] if description_entry else ""

    combined_text = basic_text_preprocessing(f"{caption} {object_description} {photo_obj_description}")
    merged_database.append({
        'photo_id': photo_id,
        'photo_description': combined_text
    })


model = SentenceTransformer('all-MiniLM-L6-v2')

photo_descriptions = [item['photo_description'] for item in merged_database]
photo_ids = [item['photo_id'] for item in merged_database]
photo_embeddings = model.encode(photo_descriptions)
photo_embeddings = photo_embeddings / np.linalg.norm(photo_embeddings, axis=1, keepdims=True)  # 正規化

results = []

for i, query in enumerate(queries):
    processed_query = process_query(query)  
    query_embedding = model.encode([processed_query])
    query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True)  # 正規化

    similarities = cosine_similarity(photo_embeddings, query_embedding).flatten()

    top_30_indices = np.argsort(similarities)[::-1][:30]
    top_30_photo_ids = [str(photo_ids[idx]) for idx in top_30_indices]

    results.append({
        'dialogue_id': i + 1,
        'photo_id': " ".join(top_30_photo_ids)
    })
csv_filename = 'caption_obj_descript_similarity_result.csv'
with open(csv_filename, 'w', newline='', encoding='utf-8') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(["dialogue_id", "photo_id"])
    for result in results:
        writer.writerow([result['dialogue_id'], result['photo_id']])
print(f"Result is saved at {csv_filename}")


### Method 2 CLIP based image retrieval

In [None]:
import os
import json
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import re
import csv
from tqdm import tqdm

def process_text(dialogue_data):
    def basic_text_preprocessing(text, max_tokens=77):
        text = re.sub(r'[^\w\s]', '', text) 
        text = re.sub(r'\s+', ' ', text).strip()
        text = text.lower()
        tokens = text.split()
        text = " ".join(tokens[:max_tokens])
        return text
    dialogue = dialogue_data['dialogue']
    msg = ''
    user = next((d['user_id'] for d in dialogue if d.get('share_photo')), -1)
    if user != -1:
        for d in dialogue:
            if d['user_id'] == user:
                msg += d['message']
    return basic_text_preprocessing(msg)

device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "openai/clip-vit-large-patch14"
model = CLIPModel.from_pretrained(model_name).to(device)
processor = CLIPProcessor.from_pretrained(model_name)

data = []
with open('test.jsonl', 'r') as file:
    for line in file:
        data.append(json.loads(line.strip()))

image_folder = "test_images/test_images"
image_paths = [os.path.join(image_folder, img) for img in os.listdir(image_folder)]
images = [Image.open(img).convert("RGB") for img in image_paths]

csv_filename = "zeroshot_clip_results.csv"
with open(csv_filename, mode='w', newline='', encoding='utf-8') as file:
    writer = csv.writer(file)
    writer.writerow(["dialogue_id", "photo_id"])
    with tqdm(total=len(data), desc="Processing dialogues") as pbar:
        for idx, dialogue_data in enumerate(data):
            pbar.set_postfix({"Current index": idx})
            text_description = process_text(dialogue_data)
            inputs = processor(text=[text_description], images=images, return_tensors="pt", padding=True, truncation=True).to(device)

            with torch.no_grad():
                outputs = model(**inputs)
                logits_per_image = outputs.logits_per_image

                image_features = outputs.image_embeds
                text_features = outputs.text_embeds 

                image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)

                similarities = (text_features @ image_features.T).squeeze(0)

            sorted_indices = similarities.argsort(descending=True)
            top_30_indices = sorted_indices[:30].tolist()
            top_30_photo_ids = [os.path.basename(image_paths[idx]).split(".")[0] for idx in top_30_indices]

            writer.writerow([idx+1, " ".join(top_30_photo_ids)])
            pbar.update(1) 

print(f"Result is saved at {csv_filename}")

