In [None]:
device = 'cuda:0'
image_num = 100

In [None]:
# llava

from transformers import AutoProcessor, LlavaForConditionalGeneration
import torch

model_id = '/mnt/sdb1/niesen/model/llava-hf/llava-1.5-7b-hf'

llava_model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(device)
llava_processor = AutoProcessor.from_pretrained(model_id)

In [None]:
import os
import json
import torch
from glob import glob
from PIL import Image

# Initialize conversation template for the description
description_conversation = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "Describe the image briefly in one sentence."},
            {"type": "image"},
        ],
    },
]

# Set paths and parameters
adv_image_dir = '...'
vqa_questions_path = 'coco300_vqa_main.json'
output_json = "llava_eval_res.json"

# Load VQA questions
with open(vqa_questions_path, 'r') as f:
    vqa_data = json.load(f)

# Create a mapping from filename to questions for quick lookup
questions_map = {item['image']: item['vqa'] for item in vqa_data}

# Function to create conversation for a specific question
def create_question_conversation(question):
    return [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": question},
                {"type": "image"},
            ],
        },
    ]

# Load adversarial images
adv_images = []
adv_files = []
for ext in ['*.jpg', '*.png', '*.jpeg']:
    adv_files.extend(glob(os.path.join(adv_image_dir, ext)))
adv_files.sort()

# Store results
results = []

with torch.no_grad():
    for i, adv_file in enumerate(adv_files):
        if i == 100:
            break
        filename = os.path.basename(adv_file).replace(".png", ".jpg")
        print(f'Processing: {filename}')
        
        # Skip if we don't have questions for this image
        if filename not in questions_map:
            print(f"Skipping {filename} - no questions found")
            continue
            
        # Get the questions for this image
        questions = questions_map[filename]
        if len(questions) < 3:
            print(f"Warning: Only {len(questions)} questions found for {filename}")
            questions.extend([""] * (3 - len(questions)))  # Pad with empty questions if needed
            
        # Load the adversarial image
        adv_image = Image.open(adv_file).convert('RGB')
        
        # Process description
        text = llava_processor.apply_chat_template(description_conversation, add_generation_prompt=True)
        inputs = llava_processor(images=adv_image, text=text, return_tensors='pt').to(device, torch.float16)
        output = llava_model.generate(**inputs, max_new_tokens=200, do_sample=False)
        begin_token = inputs['input_ids'].shape[1]
        description = llava_processor.decode(output[0][begin_token:], skip_special_tokens=True)
        
        # Process each question
        question_responses = []
        for question in questions[:3]:  # Only take first 3 questions
            if not question.strip():  # Skip empty questions
                question_responses.append("")
                continue
                
            q_conversation = create_question_conversation(question)
            q_text = llava_processor.apply_chat_template(q_conversation, add_generation_prompt=True)
            q_inputs = llava_processor(images=adv_image, text=q_text, return_tensors='pt').to(device, torch.float16)
            q_output = llava_model.generate(**q_inputs, max_new_tokens=200, do_sample=False)
            q_begin_token = q_inputs['input_ids'].shape[1]
            q_response = llava_processor.decode(q_output[0][q_begin_token:], skip_special_tokens=True)
            question_responses.append(q_response)
        
        # Save results
        results.append({
            "filename": filename,
            "adversarial_response_1": description,
            "adversarial_response_2": question_responses[0],
            "adversarial_response_3": question_responses[1],
            "adversarial_response_4": question_responses[2],
        })

# Save as JSON file
with open(output_json, 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=4)

print(f"Results saved to {output_json}")