-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Question] Getting output likelihood scores from the model #108
Comments
Hi @vishaal27, thank you for the great question. Yes it is easy to do this with LLaVA. Here is a simple example that you may start with, by inserting this into generation_output = model.generate(
input_ids,
images=image_tensor.unsqueeze(0).half().cuda(),
do_sample=True,
temperature=0.2,
max_new_tokens=1024,
stopping_criteria=[stopping_criteria],
# add following two lines
return_dict_in_generate=True,
output_scores=True
)
input_token_len = input_ids.shape[1]
output_ids = generation_output.sequences[0, input_token_len:]
output_scores = generation_output.scores |
Thanks @haotian-liu, but as I understand it, this will return the log-likelihood of the generated output given some initial prompt right? I don't want to generate more tokens but rather evaluate the likelihood of a given token sequence under the model, for example if I want to do ImageNet classification with this model I would do something like: evaluate the log-likelihood of the sequence |
@vishaal27 I think this is also possible. Consider this following (pseudo) code: message = """Human: <image> what is the object in the photo?
GPT: This is a photo of a """
input_ids = tokenizer(message) The first output token should be the |
@vishaal27 Did you find the solution to your problem (I know its an old issue). I have a similar issue. I have a set of possible options and I want to computer log prob of those options as the output. When using prompt based method, tokens are generated. In case of single word output, I still get prob of one output not the distribution ob prob over all my possible options (which in your case were classes I think). How did you resolve it? |
This code should work: from llava.constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
IMAGE_PLACEHOLDER,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
process_images,
tokenizer_image_token,
get_model_name_from_path,
KeywordsStoppingCriteria,
)
from PIL import Image
import requests
from PIL import Image
from io import BytesIO
import re
import torch
import numpy as np
def image_parser(image_file):
out = image_file.split(',')
return out
def load_image(image_file):
if image_file.startswith("http") or image_file.startswith("https"):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_file).convert("RGB")
return image
def load_images(image_files):
out = []
for image_file in image_files:
image = load_image(image_file)
out.append(image)
return out
def count_all_parameters(model):
return sum(p.numel() for p in model.parameters())
def eval_model(model_path, image_file, query, options):
# Model
disable_torch_init()
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path, None, model_name
)
qs = query
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
if IMAGE_PLACEHOLDER in qs:
if model.config.mm_use_im_start_end:
qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
else:
qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
else:
if model.config.mm_use_im_start_end:
qs = image_token_se + "\n" + qs
else:
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
if "llama-2" in model_name.lower():
conv_mode = "llava_llama_2"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt"
else:
conv_mode = "llava_v0"
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
image_files = image_parser(image_file)
images = load_images(image_files)
images_tensor = process_images(
images,
image_processor,
model.config
).to(model.device, dtype=torch.float16)
log_lik_scores = []
for option in options:
target_prompt = prompt + ' ' + option
print(target_prompt)
input_ids = (
tokenizer_image_token(target_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
.unsqueeze(0)
.cuda()
)
attention_mask = torch.ones_like(input_ids)
with torch.inference_mode(), torch.cuda.amp.autocast():
outputs = model.forward(
input_ids=input_ids,
labels=input_ids,
attention_mask=attention_mask,
images=images_tensor,
)
log_lik_scores.append((option, -outputs.loss.item()))
pred_id = np.argmax(np.asarray([x[1] for x in log_lik_scores]))
print(log_lik_scores)
print('Prediction: {}'.format(log_lik_scores[pred_id]))
if __name__ == '__main__':
model_path = "liuhaotian/llava-v1.5-13b"
prompt = "Describe the image."
image_file = "https://llava-vl.github.io/static/images/view.jpg"
shared_prompt = 'This is an image of a '
options = [shared_prompt+x for x in ['horse', 'lion', 'tiger', 'elephant', 'eagle', 'dog']]
eval_model(model_path, image_file, prompt, options) |
Thanks, @vishaal27 It was helpful! |
@vishaal27 Though the answers are correct, I am surprised that probabilities of all options are so close to each other. I computed log likehood and probs Prompt was slightly different but same image and I had these options:
Did you get similar scores too? |
That could potentially be because your prompts are too long? One option would be to length-normalize your log-likelihood scores with the number of tokens in the prompt. In my experiments this did not make too much of a difference, but if you expect your prompts to be too long or of significantly different token lengths I would recommend to use length-normalized log-likelihoods. For reference, you can see here: https://blog.eleuther.ai/multiple-choice-normalization/ |
Not really. I can change the prompt
and outputs are still similar (very close/uniform(=):
|
Yes, however these look quite similar to the scores I was getting. One correction to my earlier comment: the scores are actually length-normalised since internally it uses nn.CrossEntropyLoss which by default has You could try checking the length-unnormalised scores by: log_lik_scores.append((option, -outputs.loss.item() * input_ids.shape[1])) instead of log_lik_scores.append((option, -outputs.loss.item())) However, I wouldn't expect to see too much of a difference. In general, these scores that you get seem similar to the scores I had from my experience. |
Maybe it is better to use the output_scores to calculate the softmax scores? def eval_relevance(args, tokenizer, model, image_processor):
disable_torch_init()
model_name = get_model_name_from_path(args.model_path)
if "llama-2" in model_name.lower():
conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
conv_mode = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
conv_mode = "chatml_direct"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt"
else:
conv_mode = "llava_v0"
if args.conv_mode is not None and conv_mode != args.conv_mode:
print(
"[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
conv_mode, args.conv_mode, args.conv_mode
)
)
else:
args.conv_mode = conv_mode
qs = args.query
if args.image_file != "":
image_token_se = (
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
)
if IMAGE_PLACEHOLDER in qs:
if model.config.mm_use_im_start_end:
qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
else:
qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
else:
if model.config.mm_use_im_start_end:
qs = image_token_se + "\n" + qs
else:
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
image_files = image_parser(args)
images = load_images(image_files)
image_sizes = [x.size for x in images]
images_tensor = process_images(images, image_processor, model.config)
if type(images_tensor) is list:
for i in range(len(images_tensor)):
images_tensor[i] = images_tensor[i].to(
model.device, dtype=torch.float16
)
else:
images_tensor = images_tensor.to(model.device, dtype=torch.float16)
else:
images_tensor = None
image_sizes = None
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = (
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
.unsqueeze(0)
.cuda()
)
with torch.inference_mode():
generation_output = model.generate(
input_ids,
images=images_tensor,
image_sizes=image_sizes,
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
top_p=args.top_p,
num_beams=args.num_beams,
max_new_tokens=args.max_new_tokens,
use_cache=True,
return_dict_in_generate=True,
output_scores=True,
)
logits = generation_output.scores[0][0]
probs = (
torch.nn.functional.softmax(
torch.tensor(
[
logits[tokenizer("Yes").input_ids[1]],
logits[tokenizer("No").input_ids[1]],
]
),
dim=0,
)
.detach()
.cpu()
.numpy()
)
return probs[0] Just replace the tokens "Yes" and "No" with your options. |
I'm using the same script to get the likelihood score, the output text are correct, but the scores contain 'inf': |
The scores that you get are of the shape of : [ num_tokens_in_generate.sequences - 1, batch, vocab_size] |
Question
Hi, is it possible to get the tokenwise log-likelihood scores of different outputs from the model?
The use-case would be something like:
Given an interleaved image/text input and a list of output text candidates, we should be able to get a score for each output candidate and then return their ranked list, rather than generating the outputs directly. This would be close to how LLMs are evaluated on MCQ tasks. An example from the T0 paper Page 6 (https://arxiv.org/pdf/2110.08207.pdf):
Is it straightforward to do this with LLaVA? I assume since the LM is built with transformers there should be a possibility to use output score functions already implemented (haven't dug into this yet)?
The text was updated successfully, but these errors were encountered: