Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion BLOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ We released [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat-

<img width="650" alt="image" src="https://github.com/OpenGVLab/InternVL/assets/8529570/0e60912e-c52b-46fa-bd61-5f94a221d1fc">


## InternVL

> Date: 2023/12/12<br>
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ InternVL scales up the ViT to _**6B parameters**_ and aligns it with LLM.
<summary>Multimodal Dialogue (click to expand)</summary>

- Compared with SOTA VLLMs

| name | image size | MMMU<br>(val) | MMMU<br>(test) | MathVista<br>(testmini) | MMB<br>(test) | MMB−CN<br>(test) | MMVP | MME | ScienceQA<br>(image) | POPE | TextVQA | SEEDv1<br>(image) | VizWiz<br>(test) | GQA<br>(test) |
| ------------------ | ---------- | ------------- | -------------- | ----------------------- | ------------- | ---------------- | ---- | -------- | -------------------- | ---- | ------- | ----------------- | ---------------- | ------------- |
| GPT-4V\* | unknown | 56.8 | 55.7 | 49.9 | 77.0 | 74.4 | 38.7 | 1409/517 | - | - | 78.0 | 71.6 | - | - |
Expand All @@ -343,7 +343,7 @@ InternVL scales up the ViT to _**6B parameters**_ and aligns it with LLM.
| | | | | | | | | | | | | | | |
| LLaVA-NEXT-34B | 672x672 | 51.1 | 44.7 | 46.5 | 79.3 | 79.0 | - | 1631/397 | 81.8 | 87.7 | 69.5 | 75.9 | 63.8 | 67.1 |
| InternVL-Chat-V1.2 | 448x448 | 51.6 | 46.2 | 47.7 | 82.2 | 81.2 | 56.7 | 1672/509 | 83.3 | 88.0 | 69.7 | 75.6 | 60.0 | 64.0 |

\* denotes proprietary models. MMBench results are collected from the [leaderboard](https://mmbench.opencompass.org.cn/leaderboard). In most benchmarks, InternVL-Chat-V1.2 achieves better performance than LLaVA-NeXT-34B.

- Zero-Shot Image Captioning [\[see details\]](./internvl_g#zero-shot-image-captioning)
Expand Down
8 changes: 5 additions & 3 deletions internvl_chat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ The hyperparameters used for finetuning are listed in the following table.

## 📊 Evaluation

\* Training set observed.

**MultiModal Benchmark**

| model | MME | MMB<sub>dev/test</sub> | MMB-CN<sub>dev/test</sub> | POPE | MMVP | MathVista |
Expand All @@ -151,14 +153,14 @@ The hyperparameters used for finetuning are listed in the following table.
| model | MMMU<sub>val/test</sub> | CMMMU<sub>val/test</sub> | Tiny<sub>LVLM</sub> | LLaVA<sub>bench</sub> | MM-Vet |
| --------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- | ------------------------ | ------------------- | --------------------- | ------ |
| [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-1) | 39.1 / 35.3 | 34.8 / 34.0 | 344.5 | 76.3 | 45.0 |
| [InternVL-Chat-V1.2](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-2) | 51.6 / [46.2](https://eval.ai/web/challenges/challenge-page/2179/leaderboard/5377) | TODO | 350.3 | - | 48.9 |
| [InternVL-Chat-V1.2](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-2) | 51.6 / [46.2](https://eval.ai/web/challenges/challenge-page/2179/leaderboard/5377) | - | 350.3 | - | 48.9 |

**Visual Question Answering**

| model | VQAv2<sub>test</sub> | OKVQA<sub>val</sub> | TextVQA<sub>val</sub> | VizWiz<sub>val/test</sub> | AI2D<sub>test</sub> | GQA<sub>test</sub> | SQA<sub>test</sub> |
| --------------------------------------------------------------------------------- | -------------------- | ------------------- | --------------------- | ------------------------- | ------------------- | ------------------ | ------------------ |
| [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-1) | 80.9 | 64.2 | 65.8 | 58.3 / 57.3 | 70.2 | 62.4 | 91.2 |
| [InternVL-Chat-V1.2](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-2) | - | 62.5 | 69.7 | 61.9 / 60.0 | 71.6 | 64.0 | 83.3 |
| [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-1) | 80.9\* | 64.2\* | 65.8 | 58.3 / 57.3 | 70.2\* | 62.4\* | 91.2\* |
| [InternVL-Chat-V1.2](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-2) | - | 62.5\* | 69.7 | 61.9 / 60.0 | 71.6\* | 64.0\* | 83.3 |

**Image Captioning**

Expand Down
2 changes: 1 addition & 1 deletion internvl_chat/eval/scienceqa/evaluate_scienceqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def post_process(pred, option):
if v in pred:
return k

return random.choice(option_candidate)
return pred


def evaluate_chat_model():
Expand Down
6 changes: 5 additions & 1 deletion internvl_chat/internvl/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
from .llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
from .llama_rmsnorm_monkey_patch import \
replace_llama_rmsnorm_with_fused_rmsnorm
from .pad_data_collator import pad_data_collator
from .train_sampler_patch import replace_train_sampler

__all__ = ['replace_llama_attn_with_flash_attn',
'replace_llama_rmsnorm_with_fused_rmsnorm',
'replace_llama2_attn_with_flash_attn']
'replace_llama2_attn_with_flash_attn',
'replace_train_sampler',
'pad_data_collator']
49 changes: 49 additions & 0 deletions internvl_chat/internvl/patch/pad_data_collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
import torch

IGNORE_INDEX = -100


def pad_data_collator(features, pad_id=0):

first = features[0]
batch = {}

batch_lens = [feat['input_ids'].shape for feat in features]
max_item_length = max(batch_lens)[0]
for idx in range(len(features)):
feat = features[idx]
temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids']
feat['input_ids'] = temp_input_ids
temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
temp_labels[:feat['labels'].shape[0]] = feat['labels']
feat['labels'] = temp_labels
feat['attention_mask'] = feat['input_ids'].ne(pad_id)

# Special handling for labels.
# Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.)
if 'label' in first and first['label'] is not None:
label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label']
dtype = torch.long if isinstance(label, int) else torch.float
batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype)
elif 'label_ids' in first and first['label_ids'] is not None:
if isinstance(first['label_ids'], torch.Tensor):
batch['labels'] = torch.stack([f['label_ids'] for f in features])
else:
dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float
batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype)

# Handling of all other possible keys.
# Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in first.items():
if k not in ('label', 'label_ids') and v is not None and not isinstance(v, str):
if isinstance(v, torch.Tensor):
batch[k] = torch.stack([f[k] for f in features])
elif isinstance(v, np.ndarray):
batch[k] = torch.tensor(np.stack([f[k] for f in features]))
else:
batch[k] = torch.tensor([f[k] for f in features])

return batch
31 changes: 31 additions & 0 deletions internvl_chat/internvl/patch/train_sampler_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Optional

import torch
import transformers
from transformers.trainer import (LengthGroupedSampler, RandomSampler,
has_length)


# patch trainer
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
# Build the sampler.
if self.args.group_by_length:
lengths = []
for dataset in self.train_dataset.datasets:
lengths = lengths + dataset.length
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
return LengthGroupedSampler(
self.args.train_batch_size * self.args.gradient_accumulation_steps,
dataset=self.train_dataset,
lengths=lengths,
model_input_name=model_input_name,
)
else:
return RandomSampler(self.train_dataset)


def replace_train_sampler():
transformers.Trainer._get_train_sampler = _get_train_sampler
print('Replace train sampler!!')
45 changes: 34 additions & 11 deletions internvl_chat/internvl/train/internvl_chat_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
InternVisionModel,
InternVLChatConfig,
InternVLChatModel)
from internvl.patch import (replace_llama2_attn_with_flash_attn,
replace_llama_rmsnorm_with_fused_rmsnorm)
from internvl.patch import (pad_data_collator,
replace_llama2_attn_with_flash_attn,
replace_llama_rmsnorm_with_fused_rmsnorm,
replace_train_sampler)
from internvl.train.dataset import (TCSLoader, WeightedConcatDataset,
build_transform)
from PIL import Image, ImageFile, PngImagePlugin
Expand All @@ -39,6 +41,7 @@
# Upgrade transformers to v4.36.2, we don't need it anymore
# replace_llama2_attn_with_flash_attn()
replace_llama_rmsnorm_with_fused_rmsnorm()
replace_train_sampler()

try:
from petrel_client.client import Client
Expand Down Expand Up @@ -182,6 +185,7 @@ def preprocess(
tokenizer: transformers.PreTrainedTokenizer,
num_image_token: int,
text_only: bool = False,
group_by_length: bool = False,
) -> Dict:
conv = get_conv_template(template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
Expand Down Expand Up @@ -213,7 +217,7 @@ def preprocess(
input_ids = tokenizer(
conversations,
return_tensors='pt',
padding='max_length',
padding=False if group_by_length else 'max_length',
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
Expand Down Expand Up @@ -283,6 +287,7 @@ def preprocess_mpt(
tokenizer: transformers.PreTrainedTokenizer,
num_image_token: int,
text_only: bool = False,
group_by_length: bool = False,
) -> Dict:
conv = get_conv_template(template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
Expand Down Expand Up @@ -314,7 +319,7 @@ def preprocess_mpt(
input_ids = tokenizer(
conversations,
return_tensors='pt',
padding='max_length',
padding=False if group_by_length else 'max_length',
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
Expand Down Expand Up @@ -368,7 +373,7 @@ class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""

def __init__(self, template_name, meta, tokenizer, tcs_loader, num_image_token,
image_size=224, is_train=True, pad2square=False):
image_size=224, is_train=True, pad2square=False, group_by_length=False):
super(LazySupervisedDataset, self).__init__()
self.tokenizer = tokenizer
self.template_name = template_name
Expand All @@ -384,6 +389,21 @@ def __init__(self, template_name, meta, tokenizer, tcs_loader, num_image_token,
self.root = meta['root']
self.cached_data_dict = {}
self.tcs_loader = tcs_loader
self.group_by_length = group_by_length
if self.group_by_length:
self.conv2length = {}
self.length = []
for data_item in self.raw_data:
conversations = ''.join(data_item.split('conversations')[1:])
str_length = len(conversations)
if str_length not in self.conv2length:
token_length = tokenizer(
conversations, return_tensors='pt', padding=False, truncation=False,
).input_ids.size(1)
self.conv2length[str_length] = token_length
else:
token_length = self.conv2length[str_length]
self.length.append(token_length)

def __len__(self):
return len(self.raw_data)
Expand All @@ -405,7 +425,7 @@ def multi_modal_get_item(self, data_item):
else:
preprocess_function = preprocess
ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
self.tokenizer, self.num_image_token)
self.tokenizer, self.num_image_token, group_by_length=self.group_by_length)
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
Expand All @@ -425,7 +445,8 @@ def pure_text_get_item(self, data_item):
else:
preprocess_function = preprocess
ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
self.tokenizer, self.num_image_token, text_only=True)
self.tokenizer, self.num_image_token, text_only=True,
group_by_length=self.group_by_length)
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
Expand Down Expand Up @@ -455,7 +476,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return ret


def build_datasets(data_args, tokenizer, tcs_loader, model):
def build_datasets(data_args, tokenizer, tcs_loader, model, group_by_length=False):
datasets = []
lengths = []
ds_collections = json.loads(open(data_args.meta_path).read())
Expand All @@ -469,7 +490,8 @@ def build_datasets(data_args, tokenizer, tcs_loader, model):
num_image_token=model.num_image_token,
image_size=data_args.force_image_size,
is_train=ds_collections[ds_name]['data_augment'],
pad2square=data_args.pad2square
pad2square=data_args.pad2square,
group_by_length=group_by_length
)
except Exception:
logger.info(f'Error in loading dataset: {ds_name}')
Expand Down Expand Up @@ -623,7 +645,8 @@ def main():
if model_args.grad_checkpoint:
model.language_model._set_gradient_checkpointing()

train_dataset = build_datasets(data_args, tokenizer, tcs_loader, model)
train_dataset = build_datasets(data_args, tokenizer, tcs_loader, model,
group_by_length=training_args.group_by_length)

def _freeze_params(module):
for param in module.parameters():
Expand Down Expand Up @@ -672,7 +695,7 @@ def _freeze_params(module):
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=None,
tokenizer=tokenizer,
data_collator=default_data_collator,
data_collator=default_data_collator if not training_args.group_by_length else pad_data_collator,
)

# Training
Expand Down
2 changes: 1 addition & 1 deletion internvl_chat/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "internvl_chat"
version = "1.2.0"
version = "1.2.1"
description = "Scaling up Vision Foundation Models and Aligning for Generic Visual-Linguistic Tasks."
readme = "README.md"
requires-python = ">=3.8"
Expand Down
3 changes: 2 additions & 1 deletion internvl_chat/tools/json2jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@

data = json.load(open(args.path))
writer = open(args.path.replace('.json', '.jsonl'), 'w')
for item in data:
for idx, item in enumerate(data):
conversations = item['conversations']
if conversations[0]['from'] == 'system':
item['conversations'] = item['conversations'][1:]
item['id'] = idx
writer.write(json.dumps(item, ensure_ascii=False) + '\n')

writer.close()
25 changes: 25 additions & 0 deletions internvl_chat/tools/resize_pos_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import argparse

import torch
from internvl.model.internvl_chat import InternVLChatModel
from transformers import AutoTokenizer

argparse = argparse.ArgumentParser()
argparse.add_argument('model_path', type=str, default='')
argparse.add_argument('output_path', type=str, default='')
argparse.add_argument('force_image_size', type=int, default=448)

args = argparse.parse_args()

model = InternVLChatModel.from_pretrained(args.model_path, torch_dtype=torch.bfloat16)
model.vision_model.resize_pos_embeddings(old_size=model.config.vision_config.image_size,
new_size=args.force_image_size,
patch_size=14)
model.config.vision_config.image_size = args.force_image_size
model.config.force_image_size = args.force_image_size

model.save_pretrained(args.output_path)

tokenizer = AutoTokenizer.from_pretrained(args.model_path)
tokenizer.save_pretrained(args.output_path)
print('finished')
12 changes: 12 additions & 0 deletions internvl_chat/zero_stage3_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": [
0.9,
0.999
],
"eps": 1e-8,
"weight_decay": "auto"
}
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
Expand Down