In [20]:
import torch
import json
import argparse
from tqdm import tqdm
from datasets import load_dataset
from PIL import Image
import sys
sys.path.append('/home/hyang/llava_paso/LLaVA')

from llava.mm_utils import get_model_name_from_path
from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
    tokenizer_image_token,
    get_model_name_from_path,
    KeywordsStoppingCriteria,
)

In [14]:
def load_hades(mm_use_im_start_end=False, model_name="vicuna-7b-v1.5"):
    hades = load_dataset("Monosail/HADES")['test']
    hades = [item for item in hades if item['step'] == 5]
    malicious = []
    for item in hades:
        image = item['image'].convert('RGB')
        inst = item['instruction']

        # if item["category"] == "behavior":
        #     mask = "conduct this behavior on"
        # else:
        #     mask = f"this {item['category']}"
        # new_inst = f"The bottom of the image shows a {item['category']}. " + inst.replace(item["keywords"], mask)
        qs = inst
        if mm_use_im_start_end:
            qs = (
                DEFAULT_IM_START_TOKEN
                + DEFAULT_IMAGE_TOKEN
                + DEFAULT_IM_END_TOKEN
                + "\n"
                + qs
            )
        else:
            qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

        if "llama-2" in model_name.lower():
            conv_mode = "llava_llama_2"
        elif "v1" in model_name.lower():
            conv_mode = "llava_v1"
        elif "mpt" in model_name.lower():
            conv_mode = "mpt"
        else:
            conv_mode = "llava_v0"

        conv = conv_templates[conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        malicious.append({'id':item['id'], 'image':image,'prompt': prompt})
    return malicious

In [27]:
len(load_hades())

750

In [28]:
def load_normal(path, mm_use_im_start_end=False,model_name="vicuna-7b-v1.5"):
    with open(path, "r") as file:
        content = file.read()
    normal_dataset = json.loads(content)
    normal = []
    for item in tqdm(normal_dataset[:750]):
        image = Image.open("/home/hyang/llava_paso/coco/" + item['image']).convert('RGB').resize((1024, 1324))
        for conv_item in item['conversations']:
            if conv_item['from'] == 'human':
                qs = conv_item['value']
                if mm_use_im_start_end:
                    qs = (
                        DEFAULT_IM_START_TOKEN
                        + DEFAULT_IMAGE_TOKEN
                        + DEFAULT_IM_END_TOKEN
                        + "\n"
                        + qs
                    )
                else:
                    qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

                if "llama-2" in model_name.lower():
                    conv_mode = "llava_llama_2"
                elif "v1" in model_name.lower():
                    conv_mode = "llava_v1"
                elif "mpt" in model_name.lower():
                    conv_mode = "mpt"
                else:
                    conv_mode = "llava_v0"

                conv = conv_templates[conv_mode].copy()
                conv.append_message(conv.roles[0], qs)
                conv.append_message(conv.roles[1], None)
                prompt = conv.get_prompt()
                normal.append({'id':item['id'], 'image':image,'prompt': prompt})
    return normal


In [29]:
load_normal(path="/home/hyang/llava_paso/coco/conversation_58k.json")[0]

100%|██████████| 750/750 [00:12<00:00, 59.15it/s]


{'id': '000000033471',
 'image': <PIL.Image.Image image mode=RGB size=1024x1324>,
 'prompt': "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nWhat are the colors of the bus in the image?\n<image> ASSISTANT:"}

In [None]:
tokenizer, model, image_processor, context_len = load_pretrained_model(
        model_path="liuhaotian/llava-v1.5-7b", model_base=None, model_name="llava-v1.5-7b",cache_dir = "/home/hyang/llava_paso/.cache"
    )
model_name = get_model_name_from_path("liuhaotian/llava-v1.5-7b")

In [6]:
qs = new_inst
if model.config.mm_use_im_start_end:
    qs = (
        DEFAULT_IM_START_TOKEN
        + DEFAULT_IMAGE_TOKEN
        + DEFAULT_IM_END_TOKEN
        + "\n"
        + qs
    )
else:
    qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
if "llama-2" in model_name.lower():
        conv_mode = "llava_llama_2"
elif "v1" in model_name.lower():
    conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
    conv_mode = "mpt"
else:
    conv_mode = "llava_v0"

In [7]:
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

image_tensor = (
    image_processor.preprocess(image, return_tensors="pt")["pixel_values"]
    .half()
    .cuda()
)

input_ids = (
    tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
    .unsqueeze(0)
    .cuda()
)

stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

In [10]:
outputs = model(
            input_ids=input_ids,
            images=image_tensor,  # Pass image tensor to the model
            output_hidden_states=True
        )

In [12]:
tl_pair = []
lm_head = model.lm_head
decoding=True
k_indices=5
# Get normalization layer
if hasattr(model, "model") and hasattr(model.model, "norm"):
    norm = model.model.norm
elif hasattr(model, "transformer") and hasattr(model.transformer, "ln_f"):
    norm = model.transformer.ln_f
else:
    raise ValueError("Incorrect Model: Missing norm layer.")

# Process hidden states layer by layer
for i, hidden_state in enumerate(outputs.hidden_states):
    layer_logits = []
    layer_output = norm(hidden_state)
    logits = lm_head(layer_output)

    # Get logits for the next token
    next_token_logits = logits[:, -1, :]
    top_values, top_indices = torch.topk(next_token_logits, k_indices, dim=-1)

    # Decode or return raw token indices
    decoded_texts = [tokenizer.decode([idx], skip_special_tokens=False) for idx in top_indices.squeeze().tolist()]
    top_values = top_values.detach().cpu()
    if decoding:
        for value, token in zip(top_values.squeeze().tolist(), decoded_texts):
            layer_logits.append([token, value])
    else:
        for value, token in zip(top_values.squeeze().tolist(), top_indices.squeeze().tolist()):
            layer_logits.append([token, value])
    tl_pair.append(layer_logits)

# Extract hidden states for analysis
res_hidden_states = [hidden_state.detach().cpu().numpy() for hidden_state in outputs.hidden_states]

In [None]:
tl_pair

In [9]:
class Weak2StrongExplanation_VLM:
    def __init__(self, model, tokenizer, layer_nums=32, return_report=True, return_visual=True):
        self.model = model
        self.tokenizer = tokenizer
        self.layer_sums = layer_nums + 1
        self.forward_info = {}
        self.return_report = return_report
        self.return_visual = return_visual