In [None]:
###### FOR PACE ICE - replace GT username below ######
%cd /home/hice1/nbalakrishna3/scratch
!pwd

In [None]:
import os
import json
import base64
from openai import OpenAI
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
from dotenv import load_dotenv
from tqdm import tqdm

In [None]:
# load_dotenv()
# API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_API_KEY = ""
ANTHROPIC_API_KEY = ""

In [None]:
if not OPENAI_API_KEY:
    raise ValueError(" OPENAI_API_KEY not found in .env file")

In [None]:
IMAGE_FOLDER = "datasets/coco/images/train2017"          
LLAVA_OUTPUT_PATH = "llava_multi_exp2_responses.jsonl" # CHANGE LATER
GPT_MODEL = "gpt-4.1-mini"
CLAUDE_MODEL = "claude-3-5-sonnet-20241022"
MAX_OUTPUT = 200               

In [None]:
openai_client = OpenAI(api_key=OPENAI_API_KEY)
anthropic_client = Anthropic(api_key=ANTHROPIC_API_KEY)
print(anthropic_client.models.list())

In [None]:
def encode_image(image_path):
    with open(image_path, "rb") as img:
        return base64.b64encode(img.read()).decode("utf-8")

In [None]:
def generate_questions(base64_image):
    prompt = """
You are preparing controlled experimental materials for multimodal evaluation.

Given the IMAGE (provided separately), generate the following:

============================================================
1. Correct Caption
============================================================
‚Ä¢ Accurately describe the visible scene.
‚Ä¢ 9‚Äì15 words, objective, simple, and factual.
‚Ä¢ Should mention main objects; avoid inference beyond evidence.

============================================================
2. Visual Necessity Question Ladder (VNL): Levels L0 ‚Üí L4
============================================================

GENERAL RULES:
‚Ä¢ L1‚ÄìL4 MUST require looking at the image to answer.
‚Ä¢ All questions MUST be answerable using only the given image.
‚Ä¢ Do NOT include the answers.
‚Ä¢ No question should exceed 14 words.
‚Ä¢ Return concise, natural wording.

------------------------------------------------------------
L0 ‚Äì Baseline Question (Language-prior only)
------------------------------------------------------------
‚Ä¢ A question humans can answer **without seeing the image**.
‚Ä¢ May refer to the world generally (NOT the specific image).
‚Ä¢ Purpose: control for language-only biases.
‚Ä¢ 6‚Äì12 words.
Examples:
‚Äì ‚ÄúWhat season often has the coldest weather?‚Äù  
‚Äì ‚ÄúWhich animal is larger, a dog or an elephant?‚Äù  
‚Äì ‚ÄúWhat do people usually use to take photographs?‚Äù

------------------------------------------------------------
L1 ‚Äì Basic Visual Recognition
------------------------------------------------------------
‚Ä¢ Requires the image.
‚Ä¢ Ask about a **primary object** or its basic property.
‚Ä¢ No reasoning, no inference.
Examples:
‚Äì ‚ÄúWhat object is the person holding?‚Äù  
‚Äì ‚ÄúWhat color is the animal?‚Äù  
‚Äì ‚ÄúHow many people are visible?‚Äù

------------------------------------------------------------
L2 ‚Äì Intermediate Visual Detail
------------------------------------------------------------
‚Ä¢ Also requires the image.
‚Ä¢ Ask about a **secondary property** of a main object.
‚Ä¢ Slightly more specific than L1.
Examples:
‚Äì ‚ÄúWhat pattern is on the person‚Äôs shirt?‚Äù  
‚Äì ‚ÄúWhat type of hat is the man wearing?‚Äù  
‚Äì ‚ÄúWhat material is the table made of?‚Äù

------------------------------------------------------------
L3 ‚Äì Relational / Spatial Reasoning
------------------------------------------------------------
‚Ä¢ Requires image + spatial relations + relational understanding.
Examples:
‚Äì ‚ÄúWhere is the dog positioned relative to the child?‚Äù  
‚Äì ‚ÄúWhat object is behind the bicycle?‚Äù  
‚Äì ‚ÄúWhich person is closest to the camera?‚Äù

------------------------------------------------------------
L4 ‚Äì High-Level Visual Reasoning
------------------------------------------------------------
‚Ä¢ Hardest level; requires the entire scene.
‚Ä¢ Ask about interactions, goals, implied roles, or multi-object context.
‚Ä¢ Still must be answerable from the image alone (no external inference).
Examples:
‚Äì ‚ÄúWhat activity are the people engaged in?‚Äù  
‚Äì ‚ÄúWhy is the man extending his arm?‚Äù  
‚Äì ‚ÄúWhat is the group collectively doing?‚Äù

============================================================
Return EXACTLY this JSON structure:
{
  "correct_caption": "<string>",
  "L0": "<string>",
  "L1": "<string>",
  "L2": "<string>",
  "L3": "<string>",
  "L4": "<string>"
}
============================================================


"""
    response = openai_client.responses.create(
        model=GPT_MODEL,
        max_output_tokens=MAX_OUTPUT,
        input=[
            {
                "role": "user",
                "content": [
                    {"type": "input_text", "text": prompt},
                    {
                        "type": "input_image",
                        "image_url": f"data:image/jpeg;base64,{base64_image}"
                    }
                ]
            }
        ]
    )

    return json.loads(response.output_text)

In [None]:
import torch
import math

# MAY NEED TO FIX! Outputs NaN sometimes

def compute_attention_entropy(attentions):
    """
    Computes normalized entropy of the final_attention vector.
    Works with flattened attention tensors from LLaVA-1.5.
    Returns a single float or None.
    """

    if attentions is None:
        return None

    # flatten tuple-of-tuples into list of tensors
    flat_attns = []
    for layer in attentions:
        if isinstance(layer, torch.Tensor):
            flat_attns.append(layer)
        elif isinstance(layer, (tuple, list)):
            for x in layer:
                if isinstance(x, torch.Tensor):
                    flat_attns.append(x)

    if not flat_attns:
        return None

    entropies = []

    for layer_attn in flat_attns:
        if not isinstance(layer_attn, torch.Tensor) or layer_attn.ndim != 4:
            continue

        # avg over batch + heads ‚Üí [tgt_len, tgt_len]
        attn = layer_attn.mean(dim=(0, 1))
        final_attn = attn[-1]  # final token's attention distribution

        if final_attn.sum().item() == 0:
            continue

        p = final_attn / (final_attn.sum() + 1e-9)
        p = p.clamp(min=1e-9)

        entropy = -(p * p.log()).sum().item()
        entropies.append(entropy)

    if not entropies:
        return None

    # average across layers
    return sum(entropies) / len(entropies)


In [None]:
def extract_final_rows(attentions):
    """Helper: extract final-attention rows from flattened attention tensors."""
    rows = []

    flat = []
    for a in attentions:
        if isinstance(a, torch.Tensor):
            flat.append(a)
        elif isinstance(a, (tuple, list)):
            for b in a:
                if isinstance(b, torch.Tensor):
                    flat.append(b)

    for layer_attn in flat:
        if isinstance(layer_attn, torch.Tensor) and layer_attn.ndim == 4:
            attn = layer_attn.mean(dim=(0,1))   # [tgt_len, tgt_len]
            final_row = attn[-1]               # [tgt_len]
            rows.append(final_row)

    return rows

# def compute_attention_shift(prev_attn, curr_attn):
#     """
#     Computes the change in attention distribution between two turns.
#     prev_attn and curr_attn are the raw attention objects returned by model.generate().
#     """

#     if prev_attn is None or curr_attn is None:
#         return None

#     prev_rows = extract_final_rows(prev_attn)
#     curr_rows = extract_final_rows(curr_attn)

#     if len(prev_rows) == 0 or len(curr_rows) == 0:
#         return None

#     shifts = []

#     # align by min number of layers
#     for p, c in zip(prev_rows, curr_rows):
#         # normalize
#         p = p / (p.sum() + 1e-9)
#         c = c / (c.sum() + 1e-9)

#         # L1 distance
#         shift = (p - c).abs().sum().item()
#         shifts.append(shift)

#     return sum(shifts) / len(shifts)

def compute_attention_shift(prev_attn, curr_attn):
    if prev_attn is None or curr_attn is None:
        return None

    prev_rows = extract_final_rows(prev_attn)
    curr_rows = extract_final_rows(curr_attn)

    if len(prev_rows) == 0 or len(curr_rows) == 0:
        return None

    shifts = []

    for p, c in zip(prev_rows, curr_rows):
        # normalize
        p = p / (p.sum() + 1e-9)
        c = c / (c.sum() + 1e-9)

        # align lengths
        L = min(p.shape[0], c.shape[0])
        p = p[:L]
        c = c[:L]

        shift = (p - c).abs().sum().item()
        shifts.append(shift)

    return sum(shifts) / len(shifts)


In [None]:
# RAJ 

def compute_llava_mdi(attentions, inputs, image_token_id=32000, vision_token_count=576):
    """
    Computes MDI by detecting image token position and accounting for LLaVA's 
    internal token expansion (1 token -> 576 embeddings).
    """
    if not attentions or len(attentions) == 0:
        return None
    
    # 1. Find where the <image> token is in the input_ids
    input_ids = inputs.input_ids[0]  # [seq_len]
    image_indices = torch.where(input_ids == image_token_id)[0]
    
    if len(image_indices) == 0:
        return None
    
    # Position of <image> token in input_ids
    img_token_pos = image_indices[0].item()
    
    # 2. Calculate actual position in attention matrix
    # In input_ids: [token_0, token_1, ..., token_img_token_pos (=<image>), ..., token_n]
    # In attention:  [token_0, token_1, ..., [576 vision embeddings], ..., token_n]    
    vis_start = img_token_pos
    vis_end = img_token_pos + vision_token_count

    visual_scores = []
    textual_scores = []
    
    # Iterate over generated tokens (outer tuple) and layers (inner tuple)
    for token_step_attentions in attentions:
        for layer_attention in token_step_attentions:
            
            # layer_attention shape: [batch, heads, query_len, key_len]
            avg_attention = layer_attention.mean(dim=(0, 1)) 
            
            # Attention of the newly generated token looking back at context
            final_attn_row = avg_attention[-1, :]
            
            total_len = final_attn_row.shape[0]
            
            # Debug
            # print(f"Image starts at {vis_start}, ends at {vis_end}, Total Seq Len: {total_len}")
            
            if vis_end <= total_len:
                # Attention on visual tokens
                visual_val = final_attn_row[vis_start:vis_end].sum().item()
                
                # Attention on text tokens (before and after image)
                text_before = final_attn_row[:vis_start].sum().item()
                text_after = final_attn_row[vis_end:].sum().item()
                
                textual_val = text_before + text_after
            else:
                # Fallback if dimensions don't match
                visual_val = 0.0
                textual_val = final_attn_row.sum().item()

#             print(f"Total attention length: {total_len}, "
#                   f"Image tokens: [{vis_start}:{vis_end}], "
#                   f"Visual attention: {visual_val:.4f}, "
#                   f"Text attention: {textual_val:.4f}")
            visual_scores.append(visual_val)
            textual_scores.append(textual_val)
    
    if not visual_scores:
        return None
    
    # Average over all layers and generated tokens
    avg_vis = sum(visual_scores) / len(visual_scores)
    avg_text = sum(textual_scores) / len(textual_scores)
    
    # Compute MDI
    mdi = avg_vis / (avg_vis + avg_text + 1e-9)
    return mdi

In [None]:
# NEYA - V1

def compute_llava_mdi(attentions, inputs, image_token_id=32000):
    """
    MDI for a single-image LLaVA call.
    MDI = (attention_on_vision_tokens) / (attention_on_all_tokens)
    """

    if not attentions:
        return None

    # 1. Locate <image> token block (start/end)
    img_positions = torch.where(inputs.input_ids[0] == image_token_id)[0]
    if len(img_positions) == 0:
        return None

    img_start = img_positions[0].item()
    img_end   = img_positions[-1].item() + 1   # non-inclusive

    vision_scores = []
    text_scores = []

    # 2. Iterate over layers and generated tokens
    for layer_attn in attentions:
        for attn in layer_attn:
            # attn shape = (batch, heads, 1, key_len)
            attn = attn[0]                # ‚Üí (heads, 1, key_len)
            attn = attn.mean(0)[0]        # ‚Üí (key_len,)

            vis = attn[img_start:img_end].sum().item()
            text = (attn[:img_start].sum() +
                    attn[img_end:].sum()).item()

            vision_scores.append(vis)
            text_scores.append(text)

    vis_avg = sum(vision_scores) / len(vision_scores)
    text_avg = sum(text_scores) / len(text_scores)

    mdi = vis_avg / (vis_avg + text_avg + 1e-9)
    return mdi

In [None]:
def ask_llava(
    image_path,
    caption,
    question,
    history=None,
    max_new_tokens=50,
    return_metrics=True,
    last_turn_only=False
):
    """
    Runs LLaVA 1.5 with image + (caption + question) text prompt.
    Supports:
        - returning answer only
        - returning answer + MDI
        - returning answer + MDI + attention tensors
    """
    
    if history is None:
        history = []

    # ---- 1. Load Image ----
    image = Image.open(image_path).convert("RGB")

    # ---- 2. Build LLaVa-format prompt ----
    # Structure similar to chat format
#     text_prompt = (
#         "USER: <image>\n"
#         f"Context: {caption}\n"
#         f"Question: {question}\n"
#         "ASSISTANT:"
#     )

    # START MULTI-TURN IMPLEMENTATION
    prompt_parts = []
    
    if len(history) > 0:
        for q_prev, a_prev in history:
                prompt_parts.append(f"USER: {q_prev}\n")
                prompt_parts.append(f"ASSISTANT: {a_prev}\n")
    
    prompt_parts.append("USER: <image>\n")
    prompt_parts.append(f"Context: {caption}\n")
    prompt_parts.append(f"Question: {question}\n")
    prompt_parts.append("ASSISTANT:")
    
    text_prompt = "".join(prompt_parts)
    # END MULTI-TURN IMPLEMENTATION
        

    # ---- 3. Preprocess ---
    inputs = processor(
        text=text_prompt,
        images=image,
        return_tensors="pt").to(model.device)
    

    # ---- 4. Generate with attention ----
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,          # deterministic
            temperature=0.0,
            output_attentions=True,   # ENABLE attentions
            return_dict_in_generate=True,
            output_hidden_states=False
        )

    # ---- 5. Decode the answer ----
#     answer = processor.decode(outputs.sequences[0], skip_special_tokens=True)

    generated_ids = outputs.sequences[0][inputs.input_ids.shape[1]:] # Slice off input
    answer = processor.decode(generated_ids, skip_special_tokens=True).strip()

    # Clean prefix
    if "ASSISTANT:" in answer:
        answer = answer.split("ASSISTANT:")[-1].strip()

    # 7. Compute Metrics
    if return_metrics:
        # Compute MDI - pass the entire attentions tuple and inputs
        final_mdi = compute_llava_mdi(outputs.attentions, inputs) 
        if final_mdi is None:
            final_mdi = 0.0

        # Return tuple
        return answer, final_mdi, outputs.attentions # ADDED output.attentions here

    return answer

In [None]:
###### FOR PACE ICE ONLY - replace GT username below ######

# Tells HuggingFace to save all downloaded models + datasets in scratch directory instead of home directory
os.environ["HF_HOME"] = "/home/hice1/nbalakrishna3/scratch/huggingface"
os.environ["HF_DATASETS_CACHE"] = "/home/hice1/nbalakrishna3/scratch/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/home/hice1/nbalakrishna3/scratch/hf_cache"

In [None]:
from transformers import AutoProcessor, LlavaForConditionalGeneration
import torch
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
print("‚úÖ Using device:", device)

MODEL_NAME = "llava-hf/llava-1.5-7b-hf"

print("Loading LLaVA model...")
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = LlavaForConditionalGeneration.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    attn_implementation="eager",   # üëà FORCE EAGER ATTENTION
    device_map="auto"
).to(device)

print("‚úÖ LLaVA Model loaded!")

In [None]:
print(type(model.language_model.model.layers[0].self_attn))

In [None]:
def create_eval_prompt(caption, condition, question, model_answer):
    return f"""
You are an automated scoring module. You MUST output ONLY a single character: "0" or "1".
NEVER output words, sentences, explanations, punctuation, or reasoning.

YOUR TASK:
Judge whether the MODEL_ANSWER correctly matches what is visually true in the IMAGE.

IMPORTANT:
‚Ä¢ The CAPTION was only input context for another model. Do NOT treat it as truth.
‚Ä¢ Ignore any misleading or incorrect information in the CAPTION.
‚Ä¢ Use ONLY the IMAGE + the QUESTION + world knowledge to judge correctness.
‚Ä¢ If the MODEL_ANSWER matches the visible truth in the image, output "1".
‚Ä¢ If the MODEL_ANSWER is wrong, unsupported, or contradicts the image, output "0".
‚Ä¢ Your ENTIRE reply MUST be exactly one character: "0" or "1".

Now evaluate:

CAPTION (context only): {caption}
CONDITION: {condition}
QUESTION: {question}
MODEL_ANSWER: {model_answer}

Reply with ONLY "0" or "1".
"""

In [None]:
def eval_answer(image_path, caption, condition, question, model_answer):
    """
    Builds the judge prompt, encodes the image,
    calls Claude directly, and returns 0 or 1.
    """

    # ---- Build prompt ----
    prompt = create_eval_prompt(caption, condition, question, model_answer)

    # ---- Encode image ----
    with open(image_path, "rb") as f:
        img_bytes = f.read()
    b64img = base64.b64encode(img_bytes).decode("utf-8")

    # ---- Call Claude ----
    response = anthropic_client.messages.create(
        model="claude-sonnet-4-5-20250929",
        max_tokens=5,
        temperature=0,
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": "image/jpeg",
                            "data": b64img
                        }
                    },
                    {
                        "type": "text",
                        "text": prompt
                    }
                ]
            }
        ]
    )

    # ---- Parse output ----
    output = response.content[0].text.strip()

    if output not in ("0", "1"):
        raise ValueError(f"Unexpected Claude judge output: {output}")

    return int(output)

In [None]:
import random

from concurrent.futures import ThreadPoolExecutor, as_completed

def generate_llava_outputs(subset_size=None, last_turn_only=False):
    all_image_files = [
        f for f in os.listdir(IMAGE_FOLDER)
        if f.lower().endswith((".jpg", ".jpeg", ".png"))
    ]
    
    if subset_size is not None:
        image_files = random.sample(all_image_files, subset_size)
    else:
        image_files = all_image_files

    print(f"Found {len(image_files)} images.\n")

    with open(LLAVA_OUTPUT_PATH, "w", encoding="utf-8") as out:
        for img_file in tqdm(image_files, desc="Processing"):
            image_id = os.path.splitext(img_file)[0]
            path = os.path.join(IMAGE_FOLDER, img_file)

            try:
                # ---- 1) GPT captions + questions ----
                b64 = encode_image(path)
                q = generate_questions(b64)

                correct_caption = q["correct_caption"]
#                 incorrect_caption = q["incorrect_caption"]

                L0 = q["L0"]
                L1 = q["L1"]
                L2 = q["L2"]
                L3 = q["L3"]
                L4 = q["L4"]

                answers_correct = {}
                mdi_correct = {}
                entropy_correct = {}
                shift_correct = {}
                history_correct = []

                prev_attn = None

                for lvl, q in [("L0", L0), ("L1", L1), ("L2", L2), ("L3", L3), ("L4", L4)]:
#                     ans, mdi, attn = ask_llava(path, correct_caption, q, return_mdi=True, return_attn=True)
                    ans, mdi, attn = ask_llava(path, correct_caption, q, return_metrics=True, last_turn_only=last_turn_only)

                    if last_turn_only: 
                        history_correct = [(q, ans)]
                    else:
                        history_correct.append((q, ans))
                        
                    answers_correct[lvl] = ans
                    mdi_correct[lvl] = round(mdi, 3)

                    # entropy
                    ent = compute_attention_entropy(attn)
                    entropy_correct[lvl] = round(ent, 3) if ent is not None else None

                    # attention shift
                    if prev_attn is None:
                        shift_correct[lvl] = None
                    else:
                        shift = compute_attention_shift(prev_attn, attn)
                        shift_correct[lvl] = round(shift, 3) if shift is not None else None

                    prev_attn = attn
                

                # ---- 3) Base JSON structure ----
                output = {
                    "image_id": image_id,

                    "caption": correct_caption,

                    "questions": {
                        "L0": L0,
                        "L1": L1,
                        "L2": L2,
                        "L3": L3,
                        "L4": L4
                    },

                    "answers": answers_correct,
                    
                    "metrics": {},
                    
                    "eval_scores": {} 
                }
                
                levels = ["L0", "L1", "L2", "L3", "L4"]
                
                for lvl in levels:
                    output["metrics"][lvl] = {
                    "mdi": mdi_correct.get(lvl),
                    "entropy": entropy_correct.get(lvl),
                    "shift": shift_correct.get(lvl)}
                    
                # ---- 4) Parallel Claude evaluation ----
                jobs = []
                with ThreadPoolExecutor(max_workers=8) as ex:
                    for level, question in output["questions"].items():

                        # correct caption condition
                        jobs.append(ex.submit(
                            eval_answer,
                            path,
                            output["caption"],
                            "correct caption condition",
                            question,
                            output["answers"][level]
                        ))

                    # collect results
                    ordered_results = [j.result() for j in jobs]

                # ---- 5) Attach scores to JSON in correct structure ----
                idx = 0
                for level in ["L0", "L1", "L2", "L3", "L4"]:
                    score_c = ordered_results[idx]; idx += 1

                    output["eval_scores"][level] = score_c

                # ---- 6) Write one JSON line ----
                out.write(json.dumps(output, ensure_ascii=False) + "\n")

            except Exception as e:
                print(f"\nError with {image_id}: {e}")
                

    print(f"\nDone. JSONL saved to: {LLAVA_OUTPUT_PATH}\n")

In [None]:
if __name__ == "__main__":
    
    ######## LLAVA ########
    
    generate_llava_outputs(subset_size=3, last_turn_only=False) 