Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Megatron-VLM training #806

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 45 additions & 0 deletions examples/llava/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Minimal VLM training

This is a minimal visual language model training code in pure Megatron style.

## Step 1

Download LLaMA weights: https://huggingface.co/meta-llama/Llama-2-70b-chat-hf。

The EvaViT model weights will be automatically downloaded by [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer).

You can install SAT with:

```
git clone https://github.com/THUDM/SwissArmyTransformer
cd SwissArmyTransformer
pip install . --no-deps
```

Then, transform LLaMA and EvaViT model into megatron format:

```
python tools/checkpoint/convert.py --model-type GPT --loader llama2 --saver mcore --load-dir /path/to/Llama-2-70b-chat-hf --save-dir /path/to/save/llama2-70b-chat-megatron-mcore-tp8-pp2-first36 --tokenizer-model /path/to/Llama-2-70b-chat-hf/tokenizer.model --target-tensor-parallel-size 8 --target-pipeline-parallel-size 2 --model-size 70Bf --checkpoint-type hf --first-pipeline-num-layers 36
python tools/checkpoint/convert.py --model-type EVA --loader eva_sat --saver eva_mcore --load-dir eva-clip-4b-14-x-drop-last-layer --save-dir /path/to/save/eva2-clip-224-mcore-tp8-pp1 --target-tensor-parallel-size 8 --target-pipeline-parallel-size 1
```

## Step 2

Download LLaVA dataset [`metadata.json`](https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md) and `images.zip`.

Run training: (You should change th /path/to into the real path of your file.)

```
bash pretrain_llama2_70b_tp8_pp2.sh
```

If you want to use sequence parallel and context parallel, you need to change --image-seq-length from 257 to 256, which will omit the [CLS] token of ViT.

## Step 3

Run inference: (You should change th /path/to into the real path of your file.)

```
bash run_inference_70b_tp8_pp2.sh # server
python llava_inference_cli.py 127.0.0.1:5000 # client
```
178 changes: 178 additions & 0 deletions examples/llava/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import torch
from torch.utils.data import Dataset
from PIL import Image
import json
from copy import copy
import os
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import numpy as np
from functools import partial

class BlipImageBaseProcessor():
def __init__(self, mean=None, std=None):
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)

self.normalize = transforms.Normalize(mean, std)

class BlipImageEvalProcessor(BlipImageBaseProcessor):
def __init__(self, image_size=384, mean=None, std=None):
super().__init__(mean=mean, std=std)

self.transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
self.normalize,
]
)

def __call__(self, item):
return self.transform(item)

def blip2_image_processor_func_megatron(image_seq_length, image_processor, image):
return {'external_images': image_processor(image).unsqueeze(0), 'external_input_ids': torch.zeros(1, image_seq_length, dtype=torch.long), 'external_position_ids': torch.arange(image_seq_length, dtype=torch.long).unsqueeze(0)}

def _history_to_prompt(self, history, query, add_eoi_first=False):
ret = []
for i, (old_query, response) in enumerate(history):
ret.append({"user": old_query, "assistant": response})
ret.append({"user": query})
return ret

def format_conversation(conversations, tokenizer, image_length, is_inference=False, is_text_only=False):
# Note: `loss_mask` here means whether *the prediction* of the token should take loss
tokens = (([0]*image_length) if not is_text_only else []) + [tokenizer.bos_token_id] # For simplicify, just insert image at the begining for now.
loss_masks = [0] * len(tokens)
def _update(_tokens, value):
value = int(value)
tokens.extend(_tokens)
loss_masks.extend([value] * len(_tokens))
context_length = len(tokens)
for idx, conv in enumerate(conversations):
no_training_tokens = []
# prompt
if idx == 0:
no_training_tokens.extend(tokenizer.encode("[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\n", add_special_tokens=False))
no_training_tokens.extend(tokenizer.encode("{} [/INST]\n".format(conv["user"]), add_special_tokens=False))
_update(no_training_tokens, 0)
# context_length
if idx == len(conversations) - 1:
context_length = len(tokens)
# answer
if not (is_inference and idx == len(conversations) - 1):
# update answer
ans_tokens = tokenizer.encode(conv["assistant"], add_special_tokens=False)
_update(ans_tokens, 1)
_update([tokenizer.eos_token_id], 1)
suffix_tokens = tokenizer.encode("\n", add_special_tokens=False)
_update(suffix_tokens, 0)
assert len(tokens) == len(loss_masks), f"length mismatch: {len(tokens)} vs {len(loss_masks)}"
return tokens, loss_masks, context_length

from transformers import AutoTokenizer

def llama2_tokenizer(tokenizer_path):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
return tokenizer

import re
import numpy as np

class llama2_text_processor:
def __init__(self, tokenizer, max_target_length=1024, image_length=256, model=None):
self.tokenizer = tokenizer
self.max_target_length = max_target_length
self.image_length = image_length
self.model = model

def __call__(self, caption, prompt="", history=[], is_text_only=False):
if is_text_only:
cut = 5
else:
cut = self.image_length + 5
if len(prompt) > self.max_target_length - cut:
prompt = prompt[:self.max_target_length - cut]
ret = self.history_to_prompt(history, prompt)
ret[-1].update({"assistant": caption})
input_ids, loss_masks, context_length = format_conversation(ret, self.tokenizer, self.image_length, is_text_only=is_text_only)

if context_length >= self.max_target_length - 5:
return None
elif len(input_ids) > self.max_target_length:
input_ids = input_ids[:self.max_target_length]
loss_masks = loss_masks[:self.max_target_length]

attention_mask = [1] * len(input_ids)
pad_len = self.max_target_length - len(input_ids)
input_ids = input_ids + [0] * pad_len
loss_masks = loss_masks + [0] * pad_len
labels = input_ids[1:] + [0]
loss_masks = loss_masks[1:] + [0] # !!!!!
attention_mask = attention_mask + [1] * pad_len # no need to pad for mask
np_mask = np.tril(np.expand_dims(np.array(attention_mask), 0).repeat(len(attention_mask), 0))
input_ids = torch.tensor(input_ids)
loss_masks = torch.tensor(loss_masks)
labels = torch.tensor(labels)
attention_mask = torch.from_numpy(np_mask).unsqueeze(0)
position_ids = torch.arange(input_ids.shape[-1])

return {'tokens': input_ids, 'labels': labels, 'loss_mask': loss_masks.float(), 'attention_mask': attention_mask < 0.5, 'position_ids': position_ids}

def history_to_prompt(self, history, query):
return _history_to_prompt(self, history, query)

class ImageJsonDataset(Dataset):
def __init__(self, json_path, single_process_fn=None):
"""
Initializes the ImageJsonlDataset.

Args:
jsonl_path (str): Path to the JSONL file where each line is a JSON object.
single_process_fn (callable, optional): A callable function to process data items in __getitem__.
"""
self.json_path = json_path
self.image_path = '/'.join(json_path.split('/')[:-1])
self.single_process_fn = single_process_fn

# Load JSONL file and store each line
self.data = []
with open(json_path, 'r') as file:
self.data = json.load(file)

def __len__(self):
"""
Returns the number of entries in the dataset.
"""
return len(self.data)

def __getitem__(self, idx):
"""
Retrieves an item from the dataset by index.

Args:
idx (int): Index of the data item in the dataset.

Returns:
Processed data item, which includes the image as a PIL object and other key-value pairs.
"""
item = copy(self.data[idx])
# Load the image
image_path = item.get('image')
try:
image = Image.open(os.path.join(self.image_path, "images", image_path))
item['image'] = image
except IOError:
print(f"Warning: Failed to open image at {image_path}. Returning None for 'image'.")
item['image'] = None # Indicate failure to load image

# Apply the processing function if specified
if self.single_process_fn:
item = self.single_process_fn(item)

return item
26 changes: 26 additions & 0 deletions examples/llava/llava_inference_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import sys
import json
import requests


if __name__ == "__main__":
url = sys.argv[1]
url = 'http://' + url + '/api'
headers = {'Content-Type': 'application/json'}

while True:
sentence = input("Enter prompt: ")
image_path = input("Enter image_path (leave blank if no image): ")
tokens_to_generate = int(eval(input("Enter number of tokens to generate: ")))

image_dict = {} if not image_path else {"image_path": image_path}

data = {**image_dict, "prompts": ["[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\n{} [/INST]\n".format(sentence)], "tokens_to_generate": tokens_to_generate, "temperature": 0.8, "top_k": 1}
response = requests.put(url, data=json.dumps(data), headers=headers)

if response.status_code != 200:
print(f"Error {response.status_code}: {response.json()['message']}")
else:
print("Megatron Response: ")
print(response.json()['text'][0])
111 changes: 111 additions & 0 deletions examples/llava/pretrain_llama2_70b_tp8_pp2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#! /bin/bash
if [ -z "$MLP_WORKER_NUM" ]; then
GPUS_PER_NODE=8
WORLD_SIZE=8
NNODES=1
NODE_RANK=0
else
GPUS_PER_NODE=$MLP_GPU
WORLD_SIZE=$(($MLP_WORKER_NUM * $MLP_GPU))
NNODES=$MLP_WORKER_NUM
NODE_RANK=$MLP_ROLE_INDEX
fi

# MASTER_ADDR is the first in SLURM_NODELIST
if [ -z "$MLP_WORKER_0_HOST" ]; then
MASTER_ADDR=localhost
MASTER_PORT=27878
else
MASTER_ADDR=$MLP_WORKER_0_HOST
MASTER_PORT=$MLP_WORKER_0_PORT
fi

NCCL_ALGO=RING
NCCL_IB_GID_INDEX=3
NCCL_IB_RETRY_CNT=7
NCCL_IB_TIME_OUT=32
NCCL_DEBUG=INFO
GLOO_SOCKET_IFNAME=eth1
NCCL_SOCKET_IFNAME=eth1
CUDA_DEVICE_MAX_CONNECTIONS=1
OPTIONS_NCCL="CUDA_DEVICE_MAX_CONNECTIONS=1 NCCL_ALGO=RING NCCL_IB_GID_INDEX=3 NCCL_IB_RETRY_CNT=7 NCCL_IB_TIME_OUT=32 NCCL_DEBUG=INFO GLOO_SOCKET_IFNAME=eth1 NCCL_SOCKET_IFNAME=eth1"

CHECKPOINT_PATH=/path/to/save/llama2-70b-chat-megatron-mcore-tp8-pp2-first36

DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"

GPT_ARGS="
--tensor-model-parallel-size 8 \
--pipeline-model-parallel-size 2 \
--first-pipeline-num-layers 36 \
--num-layers 80 \
--hidden-size 8192 \
--num-attention-heads 64 \
--seq-length 1024 \
--image-seq-length 257 \
--max-position-embeddings 1024 \
--micro-batch-size 2 \
--global-batch-size 64 \
--swiglu \
--transformer-impl transformer_engine \
--normalization RMSNorm \
--untie-embeddings-and-output-weights \
--position-embedding-type rope \
--no-position-embedding \
--disable-bias-linear \
--group-query-attention \
--num-query-groups 8 \
--vocab-size 32000 \
--tokenizer-type Llama2Tokenizer \
--tokenizer-model /path/to/Llama-2-70b-chat-hf/tokenizer.model \
--ffn-hidden-size 28672 \
--no-load-optim \
--no-load-rng \
--use-checkpoint-args \
--use-distributed-optimizer \
--no-save-optim \
--lr 0.00001 \
--train-iters 5000 \
--lr-decay-iters 320000 \
--lr-decay-style cosine \
--min-lr 1.0e-5 \
--weight-decay 1e-2 \
--lr-warmup-fraction .01 \
--clip-grad 1.0 \
--bf16 \
--use-mcore-models
"

DATA_ARGS="
--train-data-path /path/to/llava_instruct_150k_datas/metadata.json \
"

OUTPUT_ARGS="
--log-interval 50 \
--save-interval 500 \
--eval-interval 500 \
--eval-iters 10
"

LOG_DIR="./output"
mkdir -p ${LOG_DIR}/${MLP_TASK_ID}_${JOB_NAME}
run_cmd="${OPTIONS_NCCL} torchrun $DISTRIBUTED_ARGS pretrain_llava.py \
$GPT_ARGS \
$DATA_ARGS \
$OUTPUT_ARGS \
--distributed-backend nccl \
--save ./checkpoints/llama2_megatron_training_first36/ \
--load $CHECKPOINT_PATH \
--vit-load /path/to/save/eva2-clip-224-mcore-tp8-pp1"

echo ${run_cmd}
eval ${run_cmd} >${LOG_DIR}/${MLP_TASK_ID}_${JOB_NAME}/${MLP_ROLE_INDEX}.out 2>${LOG_DIR}/${MLP_TASK_ID}_${JOB_NAME}/${MLP_ROLE_INDEX}.err

set +x
echo "DONE with job $MLP_TASK_ID, index $MLP_ROLE_INDEX on `hostname`"