-
Notifications
You must be signed in to change notification settings - Fork 819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Features] Support multi_modal training #628
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
4dc5fb8
support multi_modal training
lianqing11 588679d
support llava inference
lianqing11 68f8639
update second stage finetune
lianqing11 5fd7bbd
polish the code style and modify the dataset and model path
lianqing11 ab9598c
update script and the path to download the model
lianqing11 ddfa884
fix link in download llava dataset
lianqing11 76b01aa
add the script for downloading the dataset
lianqing11 a518629
update the num gpu-1 command for inference
lianqing11 10ac23e
update downloading script
lianqing11 7e93a25
update downloading dataset for multimodal
lianqing11 0ffa432
update the link to constans
lianqing11 a91bd3e
remove cd .., modify llava inference script
lianqing11 1bef01d
modify input arg
lianqing11 c529723
modify vicuna model
lianqing11 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,23 @@ | ||
{ | ||
"fp16": { | ||
"enabled": false | ||
"enabled": "auto", | ||
"loss_scale": 0, | ||
"loss_scale_window": 1000, | ||
"initial_scale_power": 16, | ||
"hysteresis": 2, | ||
"min_loss_scale": 1 | ||
}, | ||
"bf16": { | ||
"enabled": false | ||
"enabled": "auto" | ||
}, | ||
"comms_logger": { | ||
"enabled": false, | ||
"verbose": false, | ||
"prof_all": false, | ||
"debug": false | ||
}, | ||
"steps_per_print": 20000000000000000, | ||
"train_micro_batch_size_per_gpu": 1, | ||
"wall_clock_breakdown": false | ||
} | ||
"train_micro_batch_size_per_gpu": "auto", | ||
"train_batch_size": "auto", | ||
"gradient_accumulation_steps": "auto", | ||
"zero_optimization": { | ||
"stage": 2, | ||
"overlap_comm": true, | ||
"contiguous_gradients": true, | ||
"sub_group_size": 1e9, | ||
"reduce_bucket_size": "auto" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
{ | ||
"fp16": { | ||
"enabled": false | ||
}, | ||
"bf16": { | ||
"enabled": false | ||
}, | ||
"comms_logger": { | ||
"enabled": false, | ||
"verbose": false, | ||
"prof_all": false, | ||
"debug": false | ||
}, | ||
"steps_per_print": 20000000000000000, | ||
"train_micro_batch_size_per_gpu": 1, | ||
"wall_clock_breakdown": false | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved. | ||
# FIXME should merge with finetune.py | ||
"""A one-line summary of the module or program, terminated by a period. | ||
|
||
Leave one blank line. The rest of this docstring should contain an | ||
overall description of the module or program. Optionally, it may also | ||
contain a brief description of exported classes and functions and/or usage | ||
examples. | ||
|
||
Typical usage example: | ||
|
||
foo = ClassFoo() | ||
bar = foo.FunctionBar() | ||
""" | ||
|
||
import sys | ||
import os | ||
sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0]))) | ||
from transformers import HfArgumentParser | ||
|
||
from lmflow.args import ( | ||
VisModelArguments, | ||
MultiModalDatasetArguments, | ||
AutoArguments, | ||
) | ||
|
||
from lmflow.datasets.dataset import Dataset | ||
from lmflow.models.auto_model import AutoModel | ||
from lmflow.pipeline.auto_pipeline import AutoPipeline | ||
|
||
from lmflow.models.vision2seq_model import CustomAutoVision2SeqModel | ||
from lmflow.models.vision_encoder import build_vision_tower | ||
from lmflow.datasets.multi_modal_dataset import DataCollatorForSupervisedDataset | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
def main(): | ||
# Parses arguments | ||
pipeline_name = "finetuner" | ||
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) | ||
|
||
parser = HfArgumentParser((VisModelArguments, MultiModalDatasetArguments, PipelineArguments)) | ||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): | ||
# If we pass only one argument to the script and it's the path to a json file, | ||
# let's parse it to get our arguments. | ||
model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) | ||
else: | ||
model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses() | ||
|
||
# Initialization | ||
finetuner = AutoPipeline.get_pipeline( | ||
pipeline_name=pipeline_name, | ||
model_args=model_args, | ||
data_args=data_args, | ||
pipeline_args=pipeline_args, | ||
) | ||
# do not resiger deepspeed in the model. | ||
# with_deepspeed flag may be removed | ||
# by modifying the tune strategy in the future. | ||
model = AutoModel.get_model(model_args, tune_strategy='none', | ||
ds_config=pipeline_args.deepspeed, | ||
custom_model=True, | ||
with_deepspeed=False) | ||
# FIXME check if need to move this part to hf_encoder_decoder.py | ||
for param in model.backend_model.parameters(): | ||
param.requires_grad = False | ||
if "language_projection" in pipeline_args.finetune_part: | ||
for param in model.backend_model.language_projection.parameters(): | ||
param.requires_grad = True | ||
if "language_model" in pipeline_args.finetune_part: | ||
for param in model.backend_model.language_model.parameters(): | ||
param.requires_grad = True | ||
if "vision_model" in pipeline_args.finetune_part: | ||
for param in model.backend_model.vision_model.parameters(): | ||
param.requires_grad = True | ||
|
||
dataset = Dataset(data_args, backend="custom_multi_modal") | ||
data_collator = DataCollatorForSupervisedDataset(tokenizer=model.tokenizer) | ||
|
||
# Finetuning | ||
tuned_model = finetuner.tune( | ||
model=model, dataset=dataset, data_collator=data_collator) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason to check whether to use deepspeed is that when loading the model with 8bit, using deepspeed would raise an error.
huggingface/transformers#24540
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem. Thanks!