If you are not using Linux, do NOT proceed, see instructions for macOS and Windows.
- Clone this repository and navigate to FIRE folder
git clone https://github.com/MM-FIRE/FIRE
cd FIRE
- Install Package
conda create -n llava python=3.10 -y
conda activate llava
pip install --upgrade pip # enable PEP 660 support
pip install -e .
- Install additional packages for training cases
pip install -e ".[train]"
pip install flash-attn --no-build-isolation
git pull
pip install -e .
# if you see some import errors when you upgrade,
# please try running the command below (without #)
# pip install flash-attn --no-build-isolation --no-cache-dir
The FIRE-100K, FIRE-1M, and FIRE-Bench datasets can be accessed at the following Dataset page.
The checkpoints of FIRE-LLaVA can be accessed at Model .
We used DeepSpeed Zero3 to train our models.
deepspeed --master_port 60000 llava/train/train_mem.py \
--lora_enable True --lora_r 64 --lora_alpha 256 \
--lora_modules q_proj,k_proj \
--deepspeed ./scripts/zero3.json \
--model_name_or_path Lin-Chen/open-llava-next-llama3-8b \
--version llama_v3_student \
--data_path data/path/to/FIRE-Dataset-Student \
--image_folder data/path/to/images \
--vision_tower openai/clip-vit-large-patch14-336 \
--mm_projector_type mlp2x_gelu \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--image_aspect_ratio anyres \
--mm_patch_merge_type spatial_unpad \
--group_by_modality_length True \
--bf16 True \
--output_dir ./checkpoints/llava-next-llama-3-8b-student-lora-merged \
--num_train_epochs 1 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 5000 \
--save_total_limit 1 \
--learning_rate 2e-4 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 3072 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb
deepspeed --master_port 60001 llava/train/train_mem.py \
--lora_enable True --lora_r 64 --lora_alpha 256 \
--lora_modules q_proj,k_proj \
--deepspeed ./scripts/zero3.json \
--model_name_or_path Lin-Chen/open-llava-next-llama3-8b \
--version llama_v3_teacher \
--data_path data/path/to/FIRE-Dataset-Teacher \
--image_folder data/path/to/images \
--vision_tower openai/clip-vit-large-patch14-336 \
--mm_projector_type mlp2x_gelu \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--image_aspect_ratio anyres \
--mm_patch_merge_type spatial_unpad \
--group_by_modality_length True \
--bf16 True \
--output_dir ./checkpoints/llava-next-llama-3-8b-teacher-lora-merged \
--num_train_epochs 1 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 5 \
--learning_rate 2e-4 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 3072 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--val_logging_steps 3000 \
--report_to wandb
Training for student and teacher models takes 16 hours on 8xA-100-80GB for every 1 million data points.
We follow the exactly same evaluation script LLaVA repo provided. Please refer to this Evaluation.md.
Coming soon
Coming soon
Thanks for their brilliant contributions to the community! Here are the codebases we built upon.