# LLaVA Training Scripts for SageMaker

Create a SageMaker training script which is adapted from LLaVA/scripts/v1_5/finetune_task.sh.
According to LLaVA, per_device_train_batch_size * gradient_accumulation_steps * number of devices = 128
This setting is tested on ml.p4d.24xlarge (8 * A100[40G])

Upload the training data to S3

In [None]:
# !aws s3 sync ./data/ s3://YOUR_S3_BUCKET/data/

In [None]:
%%writefile LLaVA/finetune-llava-video.sh

#!/bin/bash
export WANDB_MODE=offline

WORKING_DIR=/opt/ml/code
SM_WORKING_DIR=/opt/ml/model

#The related information about multi-nodes cluster.
MASTER_HOST=$SM_MASTER
MASTER_ADDR=$SM_MASTER_ADDR
MASTER_PORT="23456"
NNODES="$NODE_NUMBER"
NODE_RANK="$NODE_INDEX"
GPUS_PER_NODE="$SM_NUM_GPUS"

echo "NNODES: ${NNODES}"
echo "NODE_RANK: ${NODE_RANK}"
echo "GPUS_PER_NODE: ${GPUS_PER_NODE}"
echo "job_id: ${job_id}"

LLM_VERSION="Qwen/Qwen2-7B-Instruct"
LLM_VERSION_CLEAN="Qwen2-7B-Instruct"
VISION_MODEL_VERSION="google/siglip-so400m-patch14-384"
VISION_MODEL_VERSION_CLEAN="siglip-so400m-patch14-384"

PROMPT_VERSION=plain
PRETRAIN_DATA_VERSION="blip558k"

BASE_RUN_NAME="LLaVA-Video-7B-Qwen2"
echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"

PROMPT_VERSION="qwen_1_5"
MID_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-ov_to_video_am9_aug17"
PREV_STAGE_CHECKPOINT="lmms-lab/LLaVA-Video-7B-Qwen2"
echo "PREV_STAGE_CHECKPOINT: ${PREV_STAGE_CHECKPOINT}"
echo "MID_RUN_NAME: ${MID_RUN_NAME}"

export AV_LOG_LEVEL=error  # Suppress FFmpeg info/warning messages
export PYTHONWARNINGS="ignore::UserWarning"  # Filter Python warnings

# --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model"

ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${GPUS_PER_NODE}" --nnodes="${NNODES}" --node_rank="${NODE_RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \
    llava/train/train_mem.py \
    --deepspeed scripts/zero3.json \
    --model_name_or_path $PREV_STAGE_CHECKPOINT \
    --version $PROMPT_VERSION \
    --data_path /opt/ml/input/data/training/train_formatted.json \
    --image_folder /opt/ml/input/data/training \
    --video_folder /opt/ml/input/data/training \
    --mm_tunable_parts="mm_language_model" \
    --mm_vision_tower_lr=2e-6 \
    --vision_tower ${VISION_MODEL_VERSION} \
    --mm_projector_type mlp2x_gelu \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --group_by_modality_length True \
    --image_aspect_ratio anyres_max_9 \
    --image_grid_pinpoints  "(1x1),...,(6x6)" \
    --mm_patch_merge_type spatial_unpad \
    --bf16 True \
    --run_name $MID_RUN_NAME \
    --output_dir /opt/ml/checkpoints/${job_id} \
    --num_train_epochs 20 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 2 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 25 \
    --save_total_limit 10 \
    --learning_rate 1e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 22768 \
    --gradient_checkpointing True \
    --dataloader_num_workers 2 \
    --lazy_preprocess True \
    --torch_compile True \
    --torch_compile_backend "inductor" \
    --dataloader_drop_last True \
    --frames_upbound 32 \
    --mm_newline_position grid \
    --add_time_instruction True \
    --force_sample True \
    --mm_spatial_pool_stride 2


In [None]:
# Initialize sagemaker session and get the training data s3 uri
import json
import time
import boto3
import numpy as np
import sagemaker
import sagemaker.huggingface
import os

ROLE = sagemaker.get_execution_role()
sess = sagemaker.Session()
BUCKET = "sagemaker-us-west-2-452145973879"
PREFIX = "datasets/hualai-video/hualai_sft_data/"
s3uri = os.path.join("s3://", BUCKET, PREFIX)
print(f"sagemaker role arn: {ROLE}")
print(f"sagemaker bucket: {BUCKET}")
print(f"sagemaker session region: {sess.boto_region_name}")
print(f"data uri: {s3uri}")

In [None]:
# Create a unique training job id
from time import gmtime, strftime
job_id = "llava-video-task-llm-"+strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print(job_id)

In [4]:
environment = {
        'job_id': job_id
}

# Define metrics definitions, such metrics will be extracted from training script's printed logs and send to cloudwatch
metric_definitions=[
        {'Name': 'loss', 'Regex': "'loss': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'learning_rate', 'Regex': "'learning_rate': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'epoch', 'Regex': "'epoch': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'train_runtime', 'Regex': "'train_runtime': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'train_samples_per_second', 'Regex': "'train_samples_per_second': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'train_steps_per_second', 'Regex': "'train_steps_per_second': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'train_loss', 'Regex': "'train_loss': ([0-9]+(.|e\-)[0-9]+),?"}
]

In [5]:
# Point the training data to the s3 uri. Use FastFile to "mount" the s3 files directly instead of copying to local disk
from sagemaker.inputs import TrainingInput

training_input = TrainingInput(
    s3_data_type='S3Prefix', # Available Options: S3Prefix | ManifestFile | AugmentedManifestFile
    s3_data=s3uri,
    distribution='FullyReplicated', # Available Options: FullyReplicated | ShardedByS3Key 
    input_mode='FastFile'
)

In [None]:
from sagemaker.huggingface import HuggingFace

image_uri = "452145973879.dkr.ecr.us-west-2.amazonaws.com/llava-video"
# image_uri = f"763104351884.dkr.ecr.{sess.boto_region_name}.amazonaws.com/pytorch-training:2.4.0-gpu-py311-cu124-ubuntu22.04-sagemaker"
instance_type = 'ml.g6e.48xlarge' # 'ml.g6e.12xlarge' # 'ml.p4d.24xlarge' 
use_spot_instances = False
max_run = 36000  # seconds, max 432,000 seconds (5 days)
max_wait = 40000 if use_spot_instances else None # seconds, max 3,600,000 seconds (1,000 hours)
keep_alive_period_in_seconds = None

output_uri = os.path.join("s3://", BUCKET, job_id, "output")
checkpoint_uri = os.path.join("s3://", BUCKET, job_id, "checkpoints")

huggingface_estimator = HuggingFace(entry_point='start.py',
                                    source_dir='./LLaVA',
                                    instance_type=instance_type,
                                    instance_count=1,
                                    py_version='py310',
                                    image_uri=image_uri,
                                    role=ROLE,
                                    metric_definitions=metric_definitions,
                                    environment=environment,
                                    use_spot_instances=use_spot_instances,
                                    max_run=max_run,
                                    max_wait=max_wait,
                                    output_path=output_uri,
                                    checkpoint_s3_uri=checkpoint_uri,
                                    keep_alive_period_in_seconds=keep_alive_period_in_seconds,
                                   )

huggingface_estimator.fit({'training': training_input}, job_name=job_id)
