In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

import argparse
import torch

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path

from PIL import Image
import torch.nn.functional as F

import pandas as pd
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
import numpy as np
import sys

IMAGE_PATH = "/code/yingqi/LLaVA_visiPruner/v1_68.png"

def load_image(image_file):
    if image_file.startswith('http://') or image_file.startswith('https://'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image

if "ipykernel_launcher" in sys.argv[0]:
    sys.argv = sys.argv[:1]

parser = argparse.ArgumentParser()

parser.add_argument("--model-path", type=str, default="/code/yingqi/models/liuhaotian/llava-v1.5-7b")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--image-file", type=str, default=IMAGE_PATH)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--conv-mode", type=str, default="vicuna_v1")
parser.add_argument("--answer-with-sentence", type=bool, default=False)
parser.add_argument("--num-chunks", type=int, default=1)
parser.add_argument("--chunk-idx", type=int, default=0)
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument("--num_beams", type=int, default=1)
parser.add_argument("--max_new_tokens", type=int, default=128)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()

disable_torch_init()

model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)

if "llama-2" in model_name.lower():
    conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
    conv_mode = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
    conv_mode = "chatml_direct"
elif "v1" in model_name.lower():
    conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
    conv_mode = "mpt"
else:
    conv_mode = "llava_v0"

  from .autonotebook import tqdm as notebook_tqdm
You are using a model of type llava to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.
  return self.fget.__get__(instance, owner)()
Loading checkpoint shards: 100%|██████████| 2/2 [02:47<00:00, 83.82s/it] 


In [100]:
if args.conv_mode is not None and conv_mode != args.conv_mode:
    print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
else:
    args.conv_mode = conv_mode

conv = conv_templates[args.conv_mode].copy()
if "mpt" in model_name.lower():
    roles = ('user', 'assistant')
else:
    roles = conv.roles

image = load_image("/code/yingqi/LLaVA_visiPruner/images/v1_73.jpg")
image_size = image.size
image_tensor = process_images([image], image_processor, model.config)
if type(image_tensor) is list:
    image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
else:
    image_tensor = image_tensor.to(model.device, dtype=torch.float16)

inp = "How many cars?\nAnswer using a single word or a phrase."

print(f"{roles[1]}: ", end="")

if image is not None:
    # first message
    if model.config.mm_use_im_start_end:
        inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
    else:
        inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
    image = None

conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

with torch.inference_mode():
    output = model.generate(
        input_ids,
        images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
        image_sizes=[image_size],
        do_sample=True if args.temperature > 0 else False,
        temperature=args.temperature,
        top_p=args.top_p,
        num_beams=args.num_beams,
        max_new_tokens=128,
        streamer=streamer,
        use_cache=True,
        output_attentions=True,
        output_hidden_states=True,
        return_dict_in_generate=True
    )
        

ASSISTANT: 4




In [101]:
num_heads = 32
head_idx = -1
values = [layer[1] for layer in output['past_key_values']]
values_states = [layer[0] for layer in values]
decoded = []
for layer_idx, value_states_layer in enumerate(values_states):
    output_weights = model.model.layers[layer_idx].self_attn.o_proj.weight.data.T.view(num_heads, 128, -1)
    value_output = torch.bmm(value_states_layer, output_weights)
    value_output = model.model.norm(value_output[:,-1,:])
    
    logits = torch.matmul(value_output, model.lm_head.weight.detach().T )
    damn = []

    next_token_logits = logits[head_idx, :]
    damn.append(torch.topk(next_token_logits, 5,dim=-1).indices)

    damn = torch.stack(damn, dim=0).view(-1)
    all_decoded = [tokenizer.decode([token]) for token in damn ]
    print(f"Layer {layer_idx} : {all_decoded} ")
    decoded.append(all_decoded)


Layer 0 : ['apt', 'sono', 'anha', 'summ', 'Generated'] 
Layer 1 : ['uba', 'beskre', '❯', 'rör', 'onymes'] 
Layer 2 : ['cgi', 'CLI', 'Hal', 'neg', 'hal'] 
Layer 3 : ['aturen', 'hein', 'aret', 'togg', 'bol'] 
Layer 4 : ['BS', 'nt', 'ССР', 'icode', 'Draw'] 
Layer 5 : ['cam', 'Loren', 'libs', 'Mens', 'FI'] 
Layer 6 : ['kel', 'proof', 'arrow', 'Silva', 'touch'] 
Layer 7 : ['Sdk', 'Dol', 'fid', 'kes', 'hom'] 
Layer 8 : ['bbe', 'comot', 'iewer', 'anja', 'pel'] 
Layer 9 : ['bum', 'inale', 'Govern', 'СП', 'ki'] 
Layer 10 : ['Ir', 'ivan', 'ь', 'surr', 'audi'] 
Layer 11 : ['stract', 'HT', 'цу', 'ali', 'avant'] 
Layer 12 : ['eredet', 'Elis', 'bon', 'cade', 'ente'] 
Layer 13 : ['ewnętrz', 'opere', '起', 'esterni', 'Attributes'] 
Layer 14 : ['ahren', 'ghan', 'SOUR', 'bě', 'ăr'] 
Layer 15 : ['wood', 'Jenkins', 'bos', 'cord', 'bos'] 
Layer 16 : ['хі', 'стову', 'Rena', 'utsch', 'conseil'] 
Layer 17 : ['Four', '四', 'fourth', 'four', 'four'] 
Layer 18 : ['four', 'five', 'three', 'six', 'four'] 
Layer 19 :