Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Features] Support multi_modal training #628

Merged
merged 14 commits into from
Sep 6, 2023
Merged
30 changes: 18 additions & 12 deletions configs/ds_config_multimodal.json
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
{
"fp16": {
"enabled": false
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": false
"enabled": "auto"
},
"comms_logger": {
"enabled": false,
"verbose": false,
"prof_all": false,
"debug": false
},
"steps_per_print": 20000000000000000,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": false
}
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 2,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto"
}
}
17 changes: 17 additions & 0 deletions configs/ds_config_vis_chatbot.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"fp16": {
"enabled": false
},
"bf16": {
"enabled": false
},
"comms_logger": {
"enabled": false,
"verbose": false,
"prof_all": false,
"debug": false
},
"steps_per_print": 20000000000000000,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": false
}
88 changes: 88 additions & 0 deletions examples/finetune_multi_modal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved.
# FIXME should merge with finetune.py
"""A one-line summary of the module or program, terminated by a period.

Leave one blank line. The rest of this docstring should contain an
overall description of the module or program. Optionally, it may also
contain a brief description of exported classes and functions and/or usage
examples.

Typical usage example:

foo = ClassFoo()
bar = foo.FunctionBar()
"""

import sys
import os
sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0])))
from transformers import HfArgumentParser

from lmflow.args import (
VisModelArguments,
MultiModalDatasetArguments,
AutoArguments,
)

from lmflow.datasets.dataset import Dataset
from lmflow.models.auto_model import AutoModel
from lmflow.pipeline.auto_pipeline import AutoPipeline

from lmflow.models.vision2seq_model import CustomAutoVision2SeqModel
from lmflow.models.vision_encoder import build_vision_tower
from lmflow.datasets.multi_modal_dataset import DataCollatorForSupervisedDataset
from torch.utils.data import DataLoader


def main():
# Parses arguments
pipeline_name = "finetuner"
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)

parser = HfArgumentParser((VisModelArguments, MultiModalDatasetArguments, PipelineArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses()

# Initialization
finetuner = AutoPipeline.get_pipeline(
pipeline_name=pipeline_name,
model_args=model_args,
data_args=data_args,
pipeline_args=pipeline_args,
)
# do not resiger deepspeed in the model.
# with_deepspeed flag may be removed
# by modifying the tune strategy in the future.
model = AutoModel.get_model(model_args, tune_strategy='none',
ds_config=pipeline_args.deepspeed,
custom_model=True,
with_deepspeed=False)
# FIXME check if need to move this part to hf_encoder_decoder.py
for param in model.backend_model.parameters():
param.requires_grad = False
if "language_projection" in pipeline_args.finetune_part:
for param in model.backend_model.language_projection.parameters():
param.requires_grad = True
if "language_model" in pipeline_args.finetune_part:
for param in model.backend_model.language_model.parameters():
param.requires_grad = True
if "vision_model" in pipeline_args.finetune_part:
for param in model.backend_model.vision_model.parameters():
param.requires_grad = True

dataset = Dataset(data_args, backend="custom_multi_modal")
data_collator = DataCollatorForSupervisedDataset(tokenizer=model.tokenizer)

# Finetuning
tuned_model = finetuner.tune(
model=model, dataset=dataset, data_collator=data_collator)


if __name__ == '__main__':
main()
50 changes: 37 additions & 13 deletions examples/vis_chatbot.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason to check whether to use deepspeed is that when loading the model with 8bit, using deepspeed would raise an error.
huggingface/transformers#24540

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem. Thanks!

Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ChatbotArguments:
)
}
)
prompt_format: Optional[str] = field(
chatbot_type: Optional[str] = field(
default="None",
metadata={
"help": (
Expand All @@ -80,7 +80,12 @@ class ChatbotArguments:
"help": "whether to do the stream inference"
}
)

with_deepspeed: Optional[bool] = field(
default=True,
metadata={
"help": "whether to use deepspeed"
}
)

def main():
pipeline_name = "inferencer"
Expand All @@ -104,10 +109,11 @@ def main():
ds_config=ds_config,
device=pipeline_args.device,
custom_model=model_args.custom_model,
with_deepspeed=chatbot_args.with_deepspeed,
)

data_args = DatasetArguments(dataset_path=None)
dataset = Dataset(data_args)
dataset = Dataset(data_args, backend="dict")

inferencer = AutoPipeline.get_pipeline(
pipeline_name=pipeline_name,
Expand Down Expand Up @@ -140,13 +146,21 @@ def main():
# " unconditionally."
# )

sep = "###"

end_string = chatbot_args.end_string
if chatbot_args.prompt_format == "mini_gpt":
if chatbot_args.chatbot_type == "mini_gpt":
context = "Give the following image: <Img>ImageContent</Img>. " + "You will be able to see the image once I provide it to you. Please answer my questions."
user_name = "Human"
sep = "###"

elif chatbot_args.chatbot_type == "llava":
context = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
user_name = "USER"
sep = " "
else:
context = ""
user_name = ""
sep = "###"
prompt_structure = chatbot_args.prompt_structure

# Load image and input text for reasoning
Expand All @@ -161,8 +175,10 @@ def main():
input_text = chatbot_args.input_text
if chatbot_args.task == "image_caption" and len(input_text) == 0:
input_text = "a photography of"
if chatbot_args.prompt_format == "mini_gpt":
context += sep + "Human: " + "<Img><ImageHere></Img> "
if chatbot_args.chatbot_type == "mini_gpt":
context += sep + user_name + ": " + "<Img><ImageHere></Img> "
elif chatbot_args.chatbot_type == "llava":
context += sep + user_name + ": " + "<image>\n"

# this flag is for determining if we need to add the ###Human: prompt
# if text after loading image, we add it when loading image
Expand All @@ -179,7 +195,7 @@ def main():
input_dataset = dataset.from_dict({
"type": "image_text",
"instances": [{"images": np.stack(image_list),
"text": input_text,}]
"text": input_text,}]
})
output = inferencer.inference(model, input_dataset)
print(output.backend_dataset['text'])
Expand All @@ -200,7 +216,12 @@ def main():
# batch of image with different shape
raw_image = raw_image.resize(base_size)
image_list.append(np.array(raw_image))
context += sep + "Human: " + "<Img><ImageHere></Img> "
if chatbot_args.chatbot_type == "mini_gpt":
context += sep + user_name + ": " + "<Img><ImageHere></Img> "
elif chatbot_args.chatbot_type == "llava":
context += sep + user_name + ": " + "<image>\n"
else:
raise NotImplementedError
text_after_loading_image = True
print("Finish loading image with path {}".format(image_path))
continue
Expand All @@ -213,8 +234,7 @@ def main():
continue

if text_after_loading_image is False:
if chatbot_args.prompt_format == "mini_gpt":
context += sep + "Human: "
context += sep + user_name + ": "
else:
text_after_loading_image = False

Expand All @@ -229,14 +249,18 @@ def main():
"instances": [{"images": np.stack(image_list),
"text": context,}]
})
remove_image_flag = chatbot_args.prompt_format=="mini_gpt"
if chatbot_args.chatbot_type in ["mini_gpt", "llava"]:
remove_image_flag = True
else:
remove_image_flag = False
begin_time = time.time()
if not chatbot_args.stream_inference:
# directly inference the results
output_dataset = inferencer.inference(
model,
input_dataset,
remove_image_flag=remove_image_flag)
remove_image_flag=remove_image_flag,
chatbot_type=chatbot_args.chatbot_type,)
response = output_dataset.backend_dataset['text']
print(response[0])
print("\n", end="")
Expand Down
8 changes: 4 additions & 4 deletions examples/vis_chatbot_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class ChatbotArguments:
"help": "task for reasoning",
}
)
prompt_format: Optional[str] = field(
chatbot_format: Optional[str] = field(
default="None",
metadata={
"help": "prompt format"
Expand All @@ -122,7 +122,7 @@ def upload_image(image_file, history, text_input, chat_state, image_list):
history = history + [((image_file.name,), None)]

if chat_state is None:
if chatbot_args.prompt_format == "mini_gpt":
if chatbot_args.chatbot_format == "mini_gpt":
chat_state = "Give the following image: <Img>ImageContent</Img>. " + "You will be able to see the image once I provide it to you. Please answer my questions."
else:
chat_state = ''
Expand All @@ -134,7 +134,7 @@ def upload_image(image_file, history, text_input, chat_state, image_list):
else:
image_list.append(image.resize(image_list[0].size))

if chatbot_args.prompt_format == "mini_gpt":
if chatbot_args.chatbot_format == "mini_gpt":
chat_state += "### Human: " + "<Img><ImageHere></Img>"
return (
gr.update(interactive=True, placeholder='Enter text and press enter, or upload an image'),
Expand Down Expand Up @@ -170,7 +170,7 @@ def gradio_answer(chatbot, chat_state, image_list, num_beams=1, temperature=1.0)
"instances": [{"images": np.stack([np.array(i) for i in image_list]),
"text": chat_state}]
})
remove_image_flag = chatbot_args.prompt_format=="mini_gpt"
remove_image_flag = chatbot_args.chatbot_format=="mini_gpt"

chatbot[-1][1] = ''

Expand Down
14 changes: 0 additions & 14 deletions scripts/.nfs0000000094418362000004c4

This file was deleted.

75 changes: 75 additions & 0 deletions scripts/run_finetune_multi_modal_stage1.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/bin/bash
# Please run this script under ${project_id} in project directory of
# https://github.com/shizhediao/llm-ft
# COMMIT: d5fecf30ba8011067b10cf51fede53a5ab6574e4

# Parses argumen
model_name_or_path=Salesforce/blip2-flan-t5-xxl
dataset_path=/path/to/cc3m_595k.json
image_folder=/path/to/images
output_dir=output_models/finetune
deepspeed_args="--master_port=12000"

while [[ $# -ge 1 ]]; do
key="$1"
case ${key} in
-m|--model_name_or_path)
model_name_or_path="$2"
shift
;;
-d|--dataset_path)
dataset_path="$2"
shift
;;
-o|--output_model_path)
output_dir="$2"
shift
;;
--deepspeed_args)
deepspeed_args="$2"
shift
;;
*)
echo "error: unknown option \"${key}\"" 1>&2
exit 1
esac
shift
done

# Finetune
exp_id=finetune
project_dir=$(cd "$(dirname $0)"/..; pwd)
log_dir=${project_dir}/log/${exp_id}
mkdir -p ${output_dir} ${log_dir}

deepspeed ${deepspeed_args} \
examples/finetune_multi_modal.py \
--deepspeed configs/ds_config_multimodal.json \
--arch_type vision_encoder_decoder \
--llava_loading True \
--model_name_or_path ${model_name_or_path} \
--image_encoder_name_or_path openai/clip-vit-large-patch14 \
--dataset_path ${dataset_path} \
--output_dir ${output_dir} --overwrite_output_dir \
--image_folder ${image_folder} \
--custom_vision_model True \
--llm_model_name_or_path lmsys/vicuna-7b-v1.5 \
--image_aspect_ratio None \
--fp16 True \
--learning_rate 2e-5 \
--gradient_accumulation_steps 1 \
--per_device_train_batch_size 2 \
--learning_rate 2e-3 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--run_name finetune \
--validation_split_percentage 0 \
--logging_steps 20 \
--do_train \
--ddp_timeout 72000 \
--save_steps 5000 \
--dataloader_num_workers 1 \
--num_train_epochs 1 \
| tee ${log_dir}/train.log \
2> ${log_dir}/train.err
Loading