# LLaVA Training Scripts for SageMaker

Create a SageMaker training script which is adapted from LLaVA/scripts/v1_5/finetune_task.sh.
According to LLaVA, per_device_train_batch_size * gradient_accumulation_steps * number of devices = 128
This setting is tested on ml.p4d.24xlarge (8 * A100[40G])

In [None]:
%%writefile LLaVA/finetune-full-piechart-QA.sh

#!/bin/bash
export WANDB_MODE=offline

cd /opt/ml/code
pip install -e . --no-deps

deepspeed llava/train/train_mem.py \
    --deepspeed ./scripts/zero3.json \
    --model_name_or_path liuhaotian/llava-v1.5-7b \
    --version v1 \
    --data_path /opt/ml/input/data/piechart/piechart-QA.jsonl \
    --image_folder /opt/ml/input/data//piechart/piechart-QA \
    --vision_tower openai/clip-vit-large-patch14-336 \
    --mm_projector_type mlp2x_gelu \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --image_aspect_ratio pad \
    --group_by_modality_length True \
    --bf16 True \
    --output_dir /opt/ml/checkpoints/$(job_id) \
    --num_train_epochs 1 \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 2 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 50000 \
    --save_total_limit 1 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --lazy_preprocess True \
    --report_to wandb




In [None]:
# Initialize sagemaker session and get the training data s3 uri
import json
import time
import boto3
import numpy as np
import sagemaker
import sagemaker.huggingface
import os

ROLE = sagemaker.get_execution_role()
sess = sagemaker.Session()
BUCKET = "YOUR_S3_BUCKET"
PREFIX = "data"
s3uri = os.path.join("s3://", BUCKET, PREFIX)
print(f"sagemaker role arn: {ROLE}")
print(f"sagemaker bucket: {BUCKET}")
print(f"sagemaker session region: {sess.boto_region_name}")
print(f"data uri: {s3uri}")

In [None]:
# Create a unique training job id
from time import gmtime, strftime
job_id = "llava-v15-7b-task-full-"+strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print(job_id)

In [None]:
environment = {
        'job_id': job_id
}

# Define metrics definitions, such metrics will be extracted from training script's printed logs and send to cloudwatch
metric_definitions=[
        {'Name': 'loss', 'Regex': "'loss': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'learning_rate', 'Regex': "'learning_rate': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'epoch', 'Regex': "'epoch': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'train_runtime', 'Regex': "'epoch': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'train_samples_per_second', 'Regex': "'epoch': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'train_steps_per_second', 'Regex': "'epoch': ([0-9]+(.|e\-)[0-9]+),?"},
        {'Name': 'train_loss', 'Regex': "'epoch': ([0-9]+(.|e\-)[0-9]+),?"}
]

In [None]:
# Point the training data to the s3 uri. Use FastFile to "mount" the s3 files directly instead of copying to local disk
from sagemaker.inputs import TrainingInput

training_input = TrainingInput(
    s3_data_type='S3Prefix', # Available Options: S3Prefix | ManifestFile | AugmentedManifestFile
    s3_data=s3uri,
    distribution='FullyReplicated', # Available Options: FullyReplicated | ShardedByS3Key 
    input_mode='FastFile'
)

In [None]:
from sagemaker.huggingface import HuggingFace

instance_type = 'ml.p4d.24xlarge'
use_spot_instances = False
max_run=6000
max_wait = 1200 if use_spot_instances else None
keep_alive_period_in_seconds = None

output_uri = os.path.join("s3://", BUCKET, job_id, "output")
checkpoint_uri = os.path.join("s3://", BUCKET, job_id, "checkpoints")

huggingface_estimator = HuggingFace(entry_point='finetune-full-piechart-QA.sh',
                                    source_dir='./LLaVA',
                                    instance_type=instance_type,
                                    instance_count=1,
                                    py_version='py310',
                                    image_uri='YOUR_TRAINING_IMAGE_URI',
                                    role=ROLE,
                                    metric_definitions = metric_definitions,
                                    environment=environment,
                                    use_spot_instances=use_spot_instances,
                                    max_run=max_run,
                                    max_wait=max_wait,
                                    output_path=output_uri,
                                    checkpoint_s3_uri=checkpoint_uri,
                                    keep_alive_period_in_seconds=keep_alive_period_in_seconds,
                                   )

huggingface_estimator.fit({'piechart': training_input}, job_name=job_id)
