Skip to content

Commit

Permalink
Merge pull request #628 from OptimalScale/lianqing/multi_modal_training
Browse files Browse the repository at this point in the history
[Features] Support multi_modal training
  • Loading branch information
research4pan committed Sep 6, 2023
2 parents 0833606 + c529723 commit bc569db
Show file tree
Hide file tree
Showing 30 changed files with 2,120 additions and 215 deletions.
30 changes: 18 additions & 12 deletions configs/ds_config_multimodal.json
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
{
"fp16": {
"enabled": false
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": false
"enabled": "auto"
},
"comms_logger": {
"enabled": false,
"verbose": false,
"prof_all": false,
"debug": false
},
"steps_per_print": 20000000000000000,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": false
}
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 2,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto"
}
}
17 changes: 17 additions & 0 deletions configs/ds_config_vis_chatbot.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"fp16": {
"enabled": false
},
"bf16": {
"enabled": false
},
"comms_logger": {
"enabled": false,
"verbose": false,
"prof_all": false,
"debug": false
},
"steps_per_print": 20000000000000000,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": false
}
44 changes: 44 additions & 0 deletions data/download.sh
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,50 @@ function main() {
tar zxvf ${filename}
rm ${filename}
fi

# multimodal
if [ "$1" = "coco2017" -o "$1" = "all" ]; then
echo "downloading coco 2017 dataset for multimodal finetuning"
mkdir coco2017
cd coco2017
wget "http://images.cocodataset.org/zips/train2017.zip"
wget "http://images.cocodataset.org/zips/val2017.zip"
wget "http://images.cocodataset.org/zips/test2017.zip"
unzip train2017.zip
unzip val2017.zip
unzip test2017.zip
rm train2017.zip
rm val2017.zip
rm test2017.zip
cd ../
fi

if [ "$1" = "llava_instruction_finetune_80k" -o "$1" = "all" ]; then
echo "downloading llava instruction finetune dataset with 80k conversation"
python ../utils/download_hf_file.py \
--repo_id liuhaotian/LLaVA-Instruct-150K \
--filename llava_instruct_80k.json
fi

if [ "$1" = "llava_cc3m_pretrain_595k" -o "$1" = "all" ]; then
echo "downloading llava pretrain images "
filepath="llava_cc3m_pretrain_595k"
python ../utils/download_hf_file.py \
--repo_id liuhaotian/LLaVA-CC3M-Pretrain-595K \
--filename images.zip \
--target_path ${filepath}

python ../utils/download_hf_file.py \
--repo_id liuhaotian/LLaVA-CC3M-Pretrain-595K \
--filename chat.json \
--llava_cc3m_pretrain_595k \
--target_path ${filepath}

cd ${filepath}
unzip images.zip
rm -rf images.zip
cd ../
fi
}
main "$@"

Expand Down
88 changes: 88 additions & 0 deletions examples/finetune_multi_modal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved.
# FIXME should merge with finetune.py
"""A one-line summary of the module or program, terminated by a period.
Leave one blank line. The rest of this docstring should contain an
overall description of the module or program. Optionally, it may also
contain a brief description of exported classes and functions and/or usage
examples.
Typical usage example:
foo = ClassFoo()
bar = foo.FunctionBar()
"""

import sys
import os
sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0])))
from transformers import HfArgumentParser

from lmflow.args import (
VisModelArguments,
MultiModalDatasetArguments,
AutoArguments,
)

from lmflow.datasets.dataset import Dataset
from lmflow.models.auto_model import AutoModel
from lmflow.pipeline.auto_pipeline import AutoPipeline

from lmflow.models.vision2seq_model import CustomAutoVision2SeqModel
from lmflow.models.vision_encoder import build_vision_tower
from lmflow.datasets.multi_modal_dataset import DataCollatorForSupervisedDataset
from torch.utils.data import DataLoader


def main():
# Parses arguments
pipeline_name = "finetuner"
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)

parser = HfArgumentParser((VisModelArguments, MultiModalDatasetArguments, PipelineArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses()

# Initialization
finetuner = AutoPipeline.get_pipeline(
pipeline_name=pipeline_name,
model_args=model_args,
data_args=data_args,
pipeline_args=pipeline_args,
)
# do not resiger deepspeed in the model.
# with_deepspeed flag may be removed
# by modifying the tune strategy in the future.
model = AutoModel.get_model(model_args, tune_strategy='none',
ds_config=pipeline_args.deepspeed,
custom_model=True,
with_deepspeed=False)
# FIXME check if need to move this part to hf_encoder_decoder.py
for param in model.backend_model.parameters():
param.requires_grad = False
if "language_projection" in pipeline_args.finetune_part:
for param in model.backend_model.language_projection.parameters():
param.requires_grad = True
if "language_model" in pipeline_args.finetune_part:
for param in model.backend_model.language_model.parameters():
param.requires_grad = True
if "vision_model" in pipeline_args.finetune_part:
for param in model.backend_model.vision_model.parameters():
param.requires_grad = True

dataset = Dataset(data_args, backend="custom_multi_modal")
data_collator = DataCollatorForSupervisedDataset(tokenizer=model.tokenizer)

# Finetuning
tuned_model = finetuner.tune(
model=model, dataset=dataset, data_collator=data_collator)


if __name__ == '__main__':
main()
50 changes: 37 additions & 13 deletions examples/vis_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ChatbotArguments:
)
}
)
prompt_format: Optional[str] = field(
chatbot_type: Optional[str] = field(
default="None",
metadata={
"help": (
Expand All @@ -80,7 +80,12 @@ class ChatbotArguments:
"help": "whether to do the stream inference"
}
)

with_deepspeed: Optional[bool] = field(
default=True,
metadata={
"help": "whether to use deepspeed"
}
)

def main():
pipeline_name = "inferencer"
Expand All @@ -104,10 +109,11 @@ def main():
ds_config=ds_config,
device=pipeline_args.device,
custom_model=model_args.custom_model,
with_deepspeed=chatbot_args.with_deepspeed,
)

data_args = DatasetArguments(dataset_path=None)
dataset = Dataset(data_args)
dataset = Dataset(data_args, backend="dict")

inferencer = AutoPipeline.get_pipeline(
pipeline_name=pipeline_name,
Expand Down Expand Up @@ -140,13 +146,21 @@ def main():
# " unconditionally."
# )

sep = "###"

end_string = chatbot_args.end_string
if chatbot_args.prompt_format == "mini_gpt":
if chatbot_args.chatbot_type == "mini_gpt":
context = "Give the following image: <Img>ImageContent</Img>. " + "You will be able to see the image once I provide it to you. Please answer my questions."
user_name = "Human"
sep = "###"

elif chatbot_args.chatbot_type == "llava":
context = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
user_name = "USER"
sep = " "
else:
context = ""
user_name = ""
sep = "###"
prompt_structure = chatbot_args.prompt_structure

# Load image and input text for reasoning
Expand All @@ -161,8 +175,10 @@ def main():
input_text = chatbot_args.input_text
if chatbot_args.task == "image_caption" and len(input_text) == 0:
input_text = "a photography of"
if chatbot_args.prompt_format == "mini_gpt":
context += sep + "Human: " + "<Img><ImageHere></Img> "
if chatbot_args.chatbot_type == "mini_gpt":
context += sep + user_name + ": " + "<Img><ImageHere></Img> "
elif chatbot_args.chatbot_type == "llava":
context += sep + user_name + ": " + "<image>\n"

# this flag is for determining if we need to add the ###Human: prompt
# if text after loading image, we add it when loading image
Expand All @@ -179,7 +195,7 @@ def main():
input_dataset = dataset.from_dict({
"type": "image_text",
"instances": [{"images": np.stack(image_list),
"text": input_text,}]
"text": input_text,}]
})
output = inferencer.inference(model, input_dataset)
print(output.backend_dataset['text'])
Expand All @@ -200,7 +216,12 @@ def main():
# batch of image with different shape
raw_image = raw_image.resize(base_size)
image_list.append(np.array(raw_image))
context += sep + "Human: " + "<Img><ImageHere></Img> "
if chatbot_args.chatbot_type == "mini_gpt":
context += sep + user_name + ": " + "<Img><ImageHere></Img> "
elif chatbot_args.chatbot_type == "llava":
context += sep + user_name + ": " + "<image>\n"
else:
raise NotImplementedError
text_after_loading_image = True
print("Finish loading image with path {}".format(image_path))
continue
Expand All @@ -213,8 +234,7 @@ def main():
continue

if text_after_loading_image is False:
if chatbot_args.prompt_format == "mini_gpt":
context += sep + "Human: "
context += sep + user_name + ": "
else:
text_after_loading_image = False

Expand All @@ -229,14 +249,18 @@ def main():
"instances": [{"images": np.stack(image_list),
"text": context,}]
})
remove_image_flag = chatbot_args.prompt_format=="mini_gpt"
if chatbot_args.chatbot_type in ["mini_gpt", "llava"]:
remove_image_flag = True
else:
remove_image_flag = False
begin_time = time.time()
if not chatbot_args.stream_inference:
# directly inference the results
output_dataset = inferencer.inference(
model,
input_dataset,
remove_image_flag=remove_image_flag)
remove_image_flag=remove_image_flag,
chatbot_type=chatbot_args.chatbot_type,)
response = output_dataset.backend_dataset['text']
print(response[0])
print("\n", end="")
Expand Down
8 changes: 4 additions & 4 deletions examples/vis_chatbot_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class ChatbotArguments:
"help": "task for reasoning",
}
)
prompt_format: Optional[str] = field(
chatbot_format: Optional[str] = field(
default="None",
metadata={
"help": "prompt format"
Expand All @@ -122,7 +122,7 @@ def upload_image(image_file, history, text_input, chat_state, image_list):
history = history + [((image_file.name,), None)]

if chat_state is None:
if chatbot_args.prompt_format == "mini_gpt":
if chatbot_args.chatbot_format == "mini_gpt":
chat_state = "Give the following image: <Img>ImageContent</Img>. " + "You will be able to see the image once I provide it to you. Please answer my questions."
else:
chat_state = ''
Expand All @@ -134,7 +134,7 @@ def upload_image(image_file, history, text_input, chat_state, image_list):
else:
image_list.append(image.resize(image_list[0].size))

if chatbot_args.prompt_format == "mini_gpt":
if chatbot_args.chatbot_format == "mini_gpt":
chat_state += "### Human: " + "<Img><ImageHere></Img>"
return (
gr.update(interactive=True, placeholder='Enter text and press enter, or upload an image'),
Expand Down Expand Up @@ -170,7 +170,7 @@ def gradio_answer(chatbot, chat_state, image_list, num_beams=1, temperature=1.0)
"instances": [{"images": np.stack([np.array(i) for i in image_list]),
"text": chat_state}]
})
remove_image_flag = chatbot_args.prompt_format=="mini_gpt"
remove_image_flag = chatbot_args.chatbot_format=="mini_gpt"

chatbot[-1][1] = ''

Expand Down
Loading

0 comments on commit bc569db

Please sign in to comment.