Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Features] Support multi_modal training #628

Merged
merged 14 commits into from
Sep 6, 2023
Merged
30 changes: 18 additions & 12 deletions configs/ds_config_multimodal.json
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
{
"fp16": {
"enabled": false
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": false
"enabled": "auto"
},
"comms_logger": {
"enabled": false,
"verbose": false,
"prof_all": false,
"debug": false
},
"steps_per_print": 20000000000000000,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": false
}
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 2,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto"
}
}
17 changes: 17 additions & 0 deletions configs/ds_config_vis_chatbot.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"fp16": {
"enabled": false
},
"bf16": {
"enabled": false
},
"comms_logger": {
"enabled": false,
"verbose": false,
"prof_all": false,
"debug": false
},
"steps_per_print": 20000000000000000,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": false
}
44 changes: 44 additions & 0 deletions data/download.sh
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,50 @@ function main() {
tar zxvf ${filename}
rm ${filename}
fi

# multimodal
if [ "$1" = "coco2017" -o "$1" = "all" ]; then
echo "downloading coco 2017 dataset for multimodal finetuning"
mkdir coco2017
cd coco2017
wget "http://images.cocodataset.org/zips/train2017.zip"
wget "http://images.cocodataset.org/zips/val2017.zip"
wget "http://images.cocodataset.org/zips/test2017.zip"
unzip train2017.zip
unzip val2017.zip
unzip test2017.zip
rm train2017.zip
rm val2017.zip
rm test2017.zip
cd ../
fi

if [ "$1" = "llava_instruction_finetune_80k" -o "$1" = "all" ]; then
echo "downloading llava instruction finetune dataset with 80k conversation"
python ../utils/download_hf_file.py \
--repo_id liuhaotian/LLaVA-Instruct-150K \
--filename llava_instruct_80k.json
fi

if [ "$1" = "llava_cc3m_pretrain_595k" -o "$1" = "all" ]; then
echo "downloading llava pretrain images "
filepath="llava_cc3m_pretrain_595k"
python ../utils/download_hf_file.py \
--repo_id liuhaotian/LLaVA-CC3M-Pretrain-595K \
--filename images.zip \
--target_path ${filepath}

python ../utils/download_hf_file.py \
--repo_id liuhaotian/LLaVA-CC3M-Pretrain-595K \
--filename chat.json \
--llava_cc3m_pretrain_595k \
--target_path ${filepath}

cd ${filepath}
unzip images.zip
rm -rf images.zip
cd ../
fi
}
main "$@"

Expand Down
88 changes: 88 additions & 0 deletions examples/finetune_multi_modal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved.
# FIXME should merge with finetune.py
"""A one-line summary of the module or program, terminated by a period.

Leave one blank line. The rest of this docstring should contain an
overall description of the module or program. Optionally, it may also
contain a brief description of exported classes and functions and/or usage
examples.

Typical usage example:

foo = ClassFoo()
bar = foo.FunctionBar()
"""

import sys
import os
sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0])))
from transformers import HfArgumentParser

from lmflow.args import (
VisModelArguments,
MultiModalDatasetArguments,
AutoArguments,
)

from lmflow.datasets.dataset import Dataset
from lmflow.models.auto_model import AutoModel
from lmflow.pipeline.auto_pipeline import AutoPipeline

from lmflow.models.vision2seq_model import CustomAutoVision2SeqModel
from lmflow.models.vision_encoder import build_vision_tower
from lmflow.datasets.multi_modal_dataset import DataCollatorForSupervisedDataset
from torch.utils.data import DataLoader


def main():
# Parses arguments
pipeline_name = "finetuner"
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)

parser = HfArgumentParser((VisModelArguments, MultiModalDatasetArguments, PipelineArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses()

# Initialization
finetuner = AutoPipeline.get_pipeline(
pipeline_name=pipeline_name,
model_args=model_args,
data_args=data_args,
pipeline_args=pipeline_args,
)
# do not resiger deepspeed in the model.
# with_deepspeed flag may be removed
# by modifying the tune strategy in the future.
model = AutoModel.get_model(model_args, tune_strategy='none',
ds_config=pipeline_args.deepspeed,
custom_model=True,
with_deepspeed=False)
# FIXME check if need to move this part to hf_encoder_decoder.py
for param in model.backend_model.parameters():
param.requires_grad = False
if "language_projection" in pipeline_args.finetune_part:
for param in model.backend_model.language_projection.parameters():
param.requires_grad = True
if "language_model" in pipeline_args.finetune_part:
for param in model.backend_model.language_model.parameters():
param.requires_grad = True
if "vision_model" in pipeline_args.finetune_part:
for param in model.backend_model.vision_model.parameters():
param.requires_grad = True

dataset = Dataset(data_args, backend="custom_multi_modal")
data_collator = DataCollatorForSupervisedDataset(tokenizer=model.tokenizer)

# Finetuning
tuned_model = finetuner.tune(
model=model, dataset=dataset, data_collator=data_collator)


if __name__ == '__main__':
main()
50 changes: 37 additions & 13 deletions examples/vis_chatbot.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason to check whether to use deepspeed is that when loading the model with 8bit, using deepspeed would raise an error.
huggingface/transformers#24540

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem. Thanks!

Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ChatbotArguments:
)
}
)
prompt_format: Optional[str] = field(
chatbot_type: Optional[str] = field(
default="None",
metadata={
"help": (
Expand All @@ -80,7 +80,12 @@ class ChatbotArguments:
"help": "whether to do the stream inference"
}
)

with_deepspeed: Optional[bool] = field(
default=True,
metadata={
"help": "whether to use deepspeed"
}
)

def main():
pipeline_name = "inferencer"
Expand All @@ -104,10 +109,11 @@ def main():
ds_config=ds_config,
device=pipeline_args.device,
custom_model=model_args.custom_model,
with_deepspeed=chatbot_args.with_deepspeed,
)

data_args = DatasetArguments(dataset_path=None)
dataset = Dataset(data_args)
dataset = Dataset(data_args, backend="dict")

inferencer = AutoPipeline.get_pipeline(
pipeline_name=pipeline_name,
Expand Down Expand Up @@ -140,13 +146,21 @@ def main():
# " unconditionally."
# )

sep = "###"

end_string = chatbot_args.end_string
if chatbot_args.prompt_format == "mini_gpt":
if chatbot_args.chatbot_type == "mini_gpt":
context = "Give the following image: <Img>ImageContent</Img>. " + "You will be able to see the image once I provide it to you. Please answer my questions."
user_name = "Human"
sep = "###"

elif chatbot_args.chatbot_type == "llava":
context = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
user_name = "USER"
sep = " "
else:
context = ""
user_name = ""
sep = "###"
prompt_structure = chatbot_args.prompt_structure

# Load image and input text for reasoning
Expand All @@ -161,8 +175,10 @@ def main():
input_text = chatbot_args.input_text
if chatbot_args.task == "image_caption" and len(input_text) == 0:
input_text = "a photography of"
if chatbot_args.prompt_format == "mini_gpt":
context += sep + "Human: " + "<Img><ImageHere></Img> "
if chatbot_args.chatbot_type == "mini_gpt":
context += sep + user_name + ": " + "<Img><ImageHere></Img> "
elif chatbot_args.chatbot_type == "llava":
context += sep + user_name + ": " + "<image>\n"

# this flag is for determining if we need to add the ###Human: prompt
# if text after loading image, we add it when loading image
Expand All @@ -179,7 +195,7 @@ def main():
input_dataset = dataset.from_dict({
"type": "image_text",
"instances": [{"images": np.stack(image_list),
"text": input_text,}]
"text": input_text,}]
})
output = inferencer.inference(model, input_dataset)
print(output.backend_dataset['text'])
Expand All @@ -200,7 +216,12 @@ def main():
# batch of image with different shape
raw_image = raw_image.resize(base_size)
image_list.append(np.array(raw_image))
context += sep + "Human: " + "<Img><ImageHere></Img> "
if chatbot_args.chatbot_type == "mini_gpt":
context += sep + user_name + ": " + "<Img><ImageHere></Img> "
elif chatbot_args.chatbot_type == "llava":
context += sep + user_name + ": " + "<image>\n"
else:
raise NotImplementedError
text_after_loading_image = True
print("Finish loading image with path {}".format(image_path))
continue
Expand All @@ -213,8 +234,7 @@ def main():
continue

if text_after_loading_image is False:
if chatbot_args.prompt_format == "mini_gpt":
context += sep + "Human: "
context += sep + user_name + ": "
else:
text_after_loading_image = False

Expand All @@ -229,14 +249,18 @@ def main():
"instances": [{"images": np.stack(image_list),
"text": context,}]
})
remove_image_flag = chatbot_args.prompt_format=="mini_gpt"
if chatbot_args.chatbot_type in ["mini_gpt", "llava"]:
remove_image_flag = True
else:
remove_image_flag = False
begin_time = time.time()
if not chatbot_args.stream_inference:
# directly inference the results
output_dataset = inferencer.inference(
model,
input_dataset,
remove_image_flag=remove_image_flag)
remove_image_flag=remove_image_flag,
chatbot_type=chatbot_args.chatbot_type,)
response = output_dataset.backend_dataset['text']
print(response[0])
print("\n", end="")
Expand Down
8 changes: 4 additions & 4 deletions examples/vis_chatbot_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class ChatbotArguments:
"help": "task for reasoning",
}
)
prompt_format: Optional[str] = field(
chatbot_format: Optional[str] = field(
default="None",
metadata={
"help": "prompt format"
Expand All @@ -122,7 +122,7 @@ def upload_image(image_file, history, text_input, chat_state, image_list):
history = history + [((image_file.name,), None)]

if chat_state is None:
if chatbot_args.prompt_format == "mini_gpt":
if chatbot_args.chatbot_format == "mini_gpt":
chat_state = "Give the following image: <Img>ImageContent</Img>. " + "You will be able to see the image once I provide it to you. Please answer my questions."
else:
chat_state = ''
Expand All @@ -134,7 +134,7 @@ def upload_image(image_file, history, text_input, chat_state, image_list):
else:
image_list.append(image.resize(image_list[0].size))

if chatbot_args.prompt_format == "mini_gpt":
if chatbot_args.chatbot_format == "mini_gpt":
chat_state += "### Human: " + "<Img><ImageHere></Img>"
return (
gr.update(interactive=True, placeholder='Enter text and press enter, or upload an image'),
Expand Down Expand Up @@ -170,7 +170,7 @@ def gradio_answer(chatbot, chat_state, image_list, num_beams=1, temperature=1.0)
"instances": [{"images": np.stack([np.array(i) for i in image_list]),
"text": chat_state}]
})
remove_image_flag = chatbot_args.prompt_format=="mini_gpt"
remove_image_flag = chatbot_args.chatbot_format=="mini_gpt"

chatbot[-1][1] = ''

Expand Down
Loading