Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch Inference on a single image #143

Closed
boofarboofar opened this issue Nov 30, 2023 · 8 comments
Closed

Batch Inference on a single image #143

boofarboofar opened this issue Nov 30, 2023 · 8 comments

Comments

@boofarboofar
Copy link

Hello,

Love the model. Is it possible to do batch inference with multiple queries per a single image? I've tried myself but couldn't get things quite right.

@1049451037
Copy link
Member

1049451037 commented Dec 1, 2023

It's indeed a bit complicated to do this in sat. Is there any solution in huggingface? @kq-chen

@kq-chen
Copy link
Collaborator

kq-chen commented Dec 1, 2023

for huggingface version, yes, try sth like this.

import torch
import requests
from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer

tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
model = AutoModelForCausalLM.from_pretrained(
    'THUDM/cogvlm-chat-hf',
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True
).to('cuda').eval()

input_sample1 = model.build_conversation_input_ids(
    tokenizer,
    images=[Image.open(requests.get('https://github.com/THUDM/CogVLM/blob/main/openai_demo/demo.jpg?raw=true', stream=True).raw).convert('RGB'),],
    query='Do you think this is a spring or winter photo?',  # Q2
    history=[
        (
            "What's in this image?",   # Q1
            'The image displays a wooden boardwalk extending through a vibrant green grassy wetland.'  # A1
         )
        ], 
    )
input_sample2 = model.build_conversation_input_ids(
    tokenizer,
    images=[Image.open(requests.get('https://github.com/THUDM/CogVLM/blob/main/examples/1.png?raw=true', stream=True).raw).convert('RGB'),],
    query='Describe this image',  # Q1
    history=[], 
    )

def recur_move_to(item, tgt, criterion_func):
    if criterion_func(item):
        device_copy = item.to(tgt)
        return device_copy
    elif isinstance(item, list):
        return [recur_move_to(v, tgt, criterion_func) for v in item]
    elif isinstance(item, tuple):
        return tuple([recur_move_to(v, tgt, criterion_func) for v in item])
    elif isinstance(item, dict):
        return {k: recur_move_to(v, tgt, criterion_func) for k, v in item.items()}
    else:
        return item

def collate_fn(features, tokenizer) -> dict:
    images = [feature.pop('images') for feature in features]
    tokenizer.padding_side = 'left'
    padded_features = tokenizer.pad(features)
    inputs = {**padded_features, 'images': images}
    return inputs

input_batch = collate_fn([input_sample1, input_sample2], tokenizer)
input_batch = recur_move_to(input_batch, 'cuda', lambda x: isinstance(x, torch.Tensor))
input_batch = recur_move_to(input_batch, torch.bfloat16, lambda x: isinstance(x, torch.Tensor) and torch.is_floating_point(x))

gen_kwargs = {"max_length": 2048, "do_sample": False}

with torch.no_grad():
    outputs = model.generate(**input_batch, **gen_kwargs)
    outputs = outputs[:, input_batch['input_ids'].shape[1]:]
    print(tokenizer.batch_decode(outputs))

@boofarboofar
Copy link
Author

for huggingface version, yes, try sth like this.

import torch
import requests
from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer

tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
model = AutoModelForCausalLM.from_pretrained(
    'THUDM/cogvlm-chat-hf',
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True
).to('cuda').eval()

input_sample1 = model.build_conversation_input_ids(
    tokenizer,
    images=[Image.open(requests.get('https://github.com/THUDM/CogVLM/blob/main/openai_demo/demo.jpg?raw=true', stream=True).raw).convert('RGB'),],
    query='Do you think this is a spring or winter photo?',  # Q2
    history=[
        (
            "What's in this image?",   # Q1
            'The image displays a wooden boardwalk extending through a vibrant green grassy wetland.'  # A1
         )
        ], 
    )
input_sample2 = model.build_conversation_input_ids(
    tokenizer,
    images=[Image.open(requests.get('https://github.com/THUDM/CogVLM/blob/main/examples/1.png?raw=true', stream=True).raw).convert('RGB'),],
    query='Describe this image',  # Q1
    history=[], 
    )

def recur_move_to(item, tgt, criterion_func):
    if criterion_func(item):
        device_copy = item.to(tgt)
        return device_copy
    elif isinstance(item, list):
        return [recur_move_to(v, tgt, criterion_func) for v in item]
    elif isinstance(item, tuple):
        return tuple([recur_move_to(v, tgt, criterion_func) for v in item])
    elif isinstance(item, dict):
        return {k: recur_move_to(v, tgt, criterion_func) for k, v in item.items()}
    else:
        return item

def collate_fn(features, tokenizer) -> dict:
    images = [feature.pop('images') for feature in features]
    tokenizer.padding_side = 'left'
    padded_features = tokenizer.pad(features)
    inputs = {**padded_features, 'images': images}
    return inputs

input_batch = collate_fn([input_sample1, input_sample2], tokenizer)
input_batch = recur_move_to(input_batch, 'cuda', lambda x: isinstance(x, torch.Tensor))
input_batch = recur_move_to(input_batch, torch.bfloat16, lambda x: isinstance(x, torch.Tensor) and torch.is_floating_point(x))

gen_kwargs = {"max_length": 2048, "do_sample": False}

with torch.no_grad():
    outputs = model.generate(**input_batch, **gen_kwargs)
    outputs = outputs[:, input_batch['input_ids'].shape[1]:]
    print(tokenizer.batch_decode(outputs))

@kq-chen @1049451037

This is very helpful and works great, thank you for the quick response. I have one last question, one of the issues I'm trying to solve is keeping the GPU saturated. Because of the code structure, I can't really use more than one thread to feed the GPU (I can't figure out how to use torchloader since I have to have the model accessible to generate the input tensors).

I tried doing things like loading the model on the CPU on the loading processes, but that requires too much ram. I tried poking around in the code, but couldn't figure out how to generate the input tensors independently of the model itself.

@kq-chen
Copy link
Collaborator

kq-chen commented Dec 4, 2023

@boofarboofar the func build_conversation_input_ids actually only uses config. so you can copy the relevant part, and make litte modification, like this.

from typing import Optional, Tuple, List, Literal, TYPE_CHECKING
import torch

from PIL import Image
from torchvision import transforms

if TYPE_CHECKING:
    import PIL
    from transformers import PreTrainedTokenizer

LANGUAGE_TOKEN_TYPE = 0
VISION_TOKEN_TYPE = 1


def _history_to_prompt(signal_type, history, query):
    if signal_type == 'base':
        return query
    elif signal_type == 'vqa':
        answer_format = 'Short answer:'
    elif signal_type == 'chat':
        answer_format = 'Answer:'
    else:
        assert False, f"Unknown signal type {signal_type}"

    prompt = ''
    for i, (old_query, response) in enumerate(history):
        prompt += 'Question: ' + old_query + " {} ".format(answer_format) + response + "\n"
    prompt += 'Question: {} {}'.format(query, answer_format)
    return prompt


def build_conversation_input_ids(
        config,
        tokenizer: "PreTrainedTokenizer",
        *,
        query: str,
        history: Optional[List[Tuple[str, str]]] = None,
        images: Optional[List["PIL.Image"]] = None,
        template_version: Optional[Literal["base", "chat", "vqa"]] = None,
):
    image_size: int = config.vision_config['image_size']
    patch_size: int = config.vision_config['patch_size']
    template_version = template_version or config.template_version
    assert images is None or len(images) <= 1, f"not support multi images by now."
    history = history or []
    text = _history_to_prompt(template_version, history, query)

    input_ids = [tokenizer.bos_token_id]
    token_type_ids = [LANGUAGE_TOKEN_TYPE]
    if images is not None and len(images) == 1:
        # vision
        transform = transforms.Compose(
            [
                transforms.Resize(
                    (image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC
                ),
                transforms.ToTensor(),
                transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
            ]
        )
        images = [transform(images[0])]
        # language
        vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
        input_ids += [tokenizer.pad_token_id] * vision_token_num
        token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
    text_ids = tokenizer.encode(text, add_special_tokens=False)

    input_ids += text_ids
    token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
    attention_mask = [1] * len(input_ids)

    return {
        'input_ids': torch.tensor(input_ids, dtype=torch.long),
        'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
        'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
        'images': images,
    }


if __name__ == '__main__':
    import requests
    from transformers import LlamaTokenizer, AutoConfig

    tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
    config = AutoConfig.from_pretrained('THUDM/cogvlm-chat-hf', trust_remote_code=True)
    input_sample1 = build_conversation_input_ids(
        config=config,
        tokenizer=tokenizer,
        images=[
            Image.open(
                requests.get('https://github.com/THUDM/CogVLM/blob/main/openai_demo/demo.jpg?raw=true', stream=True).raw
            ).convert('RGB'), ],
        query='Do you think this is a spring or winter photo?',  # Q2
        history=[
            (
                "What's in this image?",  # Q1
                'The image displays a wooden boardwalk extending through a vibrant green grassy wetland.'  # A1
            )
        ],
    )
    print(input_sample1)

@boofarboofar
Copy link
Author

@boofarboofar the func build_conversation_input_ids actually only uses config. so you can copy the relevant part, and make litte modification, like this.

from typing import Optional, Tuple, List, Literal, TYPE_CHECKING
import torch

from PIL import Image
from torchvision import transforms

if TYPE_CHECKING:
    import PIL
    from transformers import PreTrainedTokenizer

LANGUAGE_TOKEN_TYPE = 0
VISION_TOKEN_TYPE = 1


def _history_to_prompt(signal_type, history, query):
    if signal_type == 'base':
        return query
    elif signal_type == 'vqa':
        answer_format = 'Short answer:'
    elif signal_type == 'chat':
        answer_format = 'Answer:'
    else:
        assert False, f"Unknown signal type {signal_type}"

    prompt = ''
    for i, (old_query, response) in enumerate(history):
        prompt += 'Question: ' + old_query + " {} ".format(answer_format) + response + "\n"
    prompt += 'Question: {} {}'.format(query, answer_format)
    return prompt


def build_conversation_input_ids(
        config,
        tokenizer: "PreTrainedTokenizer",
        *,
        query: str,
        history: Optional[List[Tuple[str, str]]] = None,
        images: Optional[List["PIL.Image"]] = None,
        template_version: Optional[Literal["base", "chat", "vqa"]] = None,
):
    image_size: int = config.vision_config['image_size']
    patch_size: int = config.vision_config['patch_size']
    template_version = template_version or config.template_version
    assert images is None or len(images) <= 1, f"not support multi images by now."
    history = history or []
    text = _history_to_prompt(template_version, history, query)

    input_ids = [tokenizer.bos_token_id]
    token_type_ids = [LANGUAGE_TOKEN_TYPE]
    if images is not None and len(images) == 1:
        # vision
        transform = transforms.Compose(
            [
                transforms.Resize(
                    (image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC
                ),
                transforms.ToTensor(),
                transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
            ]
        )
        images = [transform(images[0])]
        # language
        vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
        input_ids += [tokenizer.pad_token_id] * vision_token_num
        token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
    text_ids = tokenizer.encode(text, add_special_tokens=False)

    input_ids += text_ids
    token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
    attention_mask = [1] * len(input_ids)

    return {
        'input_ids': torch.tensor(input_ids, dtype=torch.long),
        'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
        'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
        'images': images,
    }


if __name__ == '__main__':
    import requests
    from transformers import LlamaTokenizer, AutoConfig

    tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
    config = AutoConfig.from_pretrained('THUDM/cogvlm-chat-hf', trust_remote_code=True)
    input_sample1 = build_conversation_input_ids(
        config=config,
        tokenizer=tokenizer,
        images=[
            Image.open(
                requests.get('https://github.com/THUDM/CogVLM/blob/main/openai_demo/demo.jpg?raw=true', stream=True).raw
            ).convert('RGB'), ],
        query='Do you think this is a spring or winter photo?',  # Q2
        history=[
            (
                "What's in this image?",  # Q1
                'The image displays a wooden boardwalk extending through a vibrant green grassy wetland.'  # A1
            )
        ],
    )
    print(input_sample1)

Thanks again, this works well. My primary bottleneck now is just in generate on a single python thread, but that's out of scope here, I suspect I'll need to convert the model to: https://github.com/huggingface/candle.

@heyalexchoi
Copy link

heyalexchoi commented Jan 6, 2024

for huggingface version, yes, try sth like this.

import torch
import requests
from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer

tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
model = AutoModelForCausalLM.from_pretrained(
    'THUDM/cogvlm-chat-hf',
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True
).to('cuda').eval()

input_sample1 = model.build_conversation_input_ids(
    tokenizer,
    images=[Image.open(requests.get('https://github.com/THUDM/CogVLM/blob/main/openai_demo/demo.jpg?raw=true', stream=True).raw).convert('RGB'),],
    query='Do you think this is a spring or winter photo?',  # Q2
    history=[
        (
            "What's in this image?",   # Q1
            'The image displays a wooden boardwalk extending through a vibrant green grassy wetland.'  # A1
         )
        ], 
    )
input_sample2 = model.build_conversation_input_ids(
    tokenizer,
    images=[Image.open(requests.get('https://github.com/THUDM/CogVLM/blob/main/examples/1.png?raw=true', stream=True).raw).convert('RGB'),],
    query='Describe this image',  # Q1
    history=[], 
    )

def recur_move_to(item, tgt, criterion_func):
    if criterion_func(item):
        device_copy = item.to(tgt)
        return device_copy
    elif isinstance(item, list):
        return [recur_move_to(v, tgt, criterion_func) for v in item]
    elif isinstance(item, tuple):
        return tuple([recur_move_to(v, tgt, criterion_func) for v in item])
    elif isinstance(item, dict):
        return {k: recur_move_to(v, tgt, criterion_func) for k, v in item.items()}
    else:
        return item

def collate_fn(features, tokenizer) -> dict:
    images = [feature.pop('images') for feature in features]
    tokenizer.padding_side = 'left'
    padded_features = tokenizer.pad(features)
    inputs = {**padded_features, 'images': images}
    return inputs

input_batch = collate_fn([input_sample1, input_sample2], tokenizer)
input_batch = recur_move_to(input_batch, 'cuda', lambda x: isinstance(x, torch.Tensor))
input_batch = recur_move_to(input_batch, torch.bfloat16, lambda x: isinstance(x, torch.Tensor) and torch.is_floating_point(x))

gen_kwargs = {"max_length": 2048, "do_sample": False}

with torch.no_grad():
    outputs = model.generate(**input_batch, **gen_kwargs)
    outputs = outputs[:, input_batch['input_ids'].shape[1]:]
    print(tokenizer.batch_decode(outputs))

@kq-chen I am trying to set up batch captioning (1 image and 1 query each) using this example but am having an issue:

   Traceback (most recent call last):
  File "<stdin>", line 4, in <module>
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1718, in generate
    return self.greedy_search(
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2579, in greedy_search
    outputs = self(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/THUDM/cogagent-vqa-hf/9c742a5c237ff3911077a1847cf5db72bf8092f0/modeling_cogagent.py", line 786, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/THUDM/cogagent-vqa-hf/9c742a5c237ff3911077a1847cf5db72bf8092f0/modeling_cogagent.py", line 539, in forward
    return self.llm_forward(
  File "/root/.cache/huggingface/modules/transformers_modules/THUDM/cogagent-vqa-hf/9c742a5c237ff3911077a1847cf5db72bf8092f0/modeling_cogagent.py", line 631, in llm_forward
    layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/THUDM/cogagent-vqa-hf/9c742a5c237ff3911077a1847cf5db72bf8092f0/modeling_cogagent.py", line 382, in forward
    attention_output, self_cross_attn_weights, present_cross_key_value = self.cross_attn(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/THUDM/cogagent-vqa-hf/9c742a5c237ff3911077a1847cf5db72bf8092f0/modeling_cogagent.py", line 324, in forward
    context_layer = attention_fn(
  File "/root/.cache/huggingface/modules/transformers_modules/THUDM/cogagent-vqa-hf/9c742a5c237ff3911077a1847cf5db72bf8092f0/modeling_cogagent.py", line 141, in attention_fn
    attention_scores = attention_scores + attention_mask
RuntimeError: The size of tensor a (313) must match the size of tensor b (2) at non-singleton dimension 2

where it seems the shape of attention_mask ([2, 313] on batch size 2) does not match the shape of attention_scores (https://huggingface.co/THUDM/cogagent-vqa-hf/blob/9c742a5c237ff3911077a1847cf5db72bf8092f0/modeling_cogagent.py#L141 I believe)

I was wondering if someone could help me figure this out.

EDIT: this issue is happening with 'THUDM/cogagent-vqa-hf', not with 'THUDM/cogvlm-chat-hf'

@zhaoyucs
Copy link

zhaoyucs commented Jan 6, 2024

for huggingface version, yes, try sth like this.

import torch
import requests
from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer

tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
model = AutoModelForCausalLM.from_pretrained(
    'THUDM/cogvlm-chat-hf',
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True
).to('cuda').eval()

input_sample1 = model.build_conversation_input_ids(
    tokenizer,
    images=[Image.open(requests.get('https://github.com/THUDM/CogVLM/blob/main/openai_demo/demo.jpg?raw=true', stream=True).raw).convert('RGB'),],
    query='Do you think this is a spring or winter photo?',  # Q2
    history=[
        (
            "What's in this image?",   # Q1
            'The image displays a wooden boardwalk extending through a vibrant green grassy wetland.'  # A1
         )
        ], 
    )
input_sample2 = model.build_conversation_input_ids(
    tokenizer,
    images=[Image.open(requests.get('https://github.com/THUDM/CogVLM/blob/main/examples/1.png?raw=true', stream=True).raw).convert('RGB'),],
    query='Describe this image',  # Q1
    history=[], 
    )

def recur_move_to(item, tgt, criterion_func):
    if criterion_func(item):
        device_copy = item.to(tgt)
        return device_copy
    elif isinstance(item, list):
        return [recur_move_to(v, tgt, criterion_func) for v in item]
    elif isinstance(item, tuple):
        return tuple([recur_move_to(v, tgt, criterion_func) for v in item])
    elif isinstance(item, dict):
        return {k: recur_move_to(v, tgt, criterion_func) for k, v in item.items()}
    else:
        return item

def collate_fn(features, tokenizer) -> dict:
    images = [feature.pop('images') for feature in features]
    tokenizer.padding_side = 'left'
    padded_features = tokenizer.pad(features)
    inputs = {**padded_features, 'images': images}
    return inputs

input_batch = collate_fn([input_sample1, input_sample2], tokenizer)
input_batch = recur_move_to(input_batch, 'cuda', lambda x: isinstance(x, torch.Tensor))
input_batch = recur_move_to(input_batch, torch.bfloat16, lambda x: isinstance(x, torch.Tensor) and torch.is_floating_point(x))

gen_kwargs = {"max_length": 2048, "do_sample": False}

with torch.no_grad():
    outputs = model.generate(**input_batch, **gen_kwargs)
    outputs = outputs[:, input_batch['input_ids'].shape[1]:]
    print(tokenizer.batch_decode(outputs))

@kq-chen I am trying to set up batch captioning (1 image and 1 query each) using this example but am having an issue:

   Traceback (most recent call last):
  File "<stdin>", line 4, in <module>
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1718, in generate
    return self.greedy_search(
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2579, in greedy_search
    outputs = self(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/THUDM/cogagent-vqa-hf/9c742a5c237ff3911077a1847cf5db72bf8092f0/modeling_cogagent.py", line 786, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/THUDM/cogagent-vqa-hf/9c742a5c237ff3911077a1847cf5db72bf8092f0/modeling_cogagent.py", line 539, in forward
    return self.llm_forward(
  File "/root/.cache/huggingface/modules/transformers_modules/THUDM/cogagent-vqa-hf/9c742a5c237ff3911077a1847cf5db72bf8092f0/modeling_cogagent.py", line 631, in llm_forward
    layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/THUDM/cogagent-vqa-hf/9c742a5c237ff3911077a1847cf5db72bf8092f0/modeling_cogagent.py", line 382, in forward
    attention_output, self_cross_attn_weights, present_cross_key_value = self.cross_attn(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/THUDM/cogagent-vqa-hf/9c742a5c237ff3911077a1847cf5db72bf8092f0/modeling_cogagent.py", line 324, in forward
    context_layer = attention_fn(
  File "/root/.cache/huggingface/modules/transformers_modules/THUDM/cogagent-vqa-hf/9c742a5c237ff3911077a1847cf5db72bf8092f0/modeling_cogagent.py", line 141, in attention_fn
    attention_scores = attention_scores + attention_mask
RuntimeError: The size of tensor a (313) must match the size of tensor b (2) at non-singleton dimension 2

where it seems the shape of attention_mask ([2, 313] on batch size 2) does not match the shape of attention_scores (https://huggingface.co/THUDM/cogagent-vqa-hf/blob/9c742a5c237ff3911077a1847cf5db72bf8092f0/modeling_cogagent.py#L141 I believe)

I was wondering if someone could help me figure this out.

EDIT: this issue is happening with 'THUDM/cogagent-vqa-hf', not with 'THUDM/cogvlm-chat-hf'

@heyalexchoi I have the same issue and I replace line 612 in modeling_cogagent.py with:

 cross_attention_mask = torch.ones(
                (batch_size, 1, 1, 1), dtype=torch.bool, device=inputs_embeds.device
)

@zzzzzero
Copy link

zzzzzero commented Apr 1, 2024

I’d like to know what your GPU utilization is. I changed the code to batch inference and found that the GPU utilization is still not 100%, only about 75%. Is this normal?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants