Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ChatML template support for PackedSFT #9009

Closed
wants to merge 17 commits into from

Conversation

xingyaoww
Copy link

@xingyaoww xingyaoww commented Apr 23, 2024

What does this PR do ?

Support training of multi-turn interaction trajectories in ChatML template format.

Changelog

  • Add specific line by line info of high level changes in this PR.

Usage

The following example shows how to process a file with multi-turn chat trajectories in OpenAI Format to NeMo's packed SFT format:

Example script to download the data:

import os
import pathlib
from datasets import load_dataset

ds = load_dataset("xingyaoww/code-act")

if not os.path.exists("data/datasets"):
    pathlib.Path("data/datasets").mkdir(parents=True, exist_ok=True)
    print("Created data/datasets")

codeact_ds = ds["codeact"]
codeact_df = codeact_ds.to_pandas()
codeact_df.to_json("data/datasets/codeact.jsonl", orient="records", lines=True)
print(f"Saved {len(codeact_df)} examples to data/datasets/codeact.jsonl")

general_ds = ds["general"]
general_df = general_ds.to_pandas()
general_df.to_json("data/datasets/general.jsonl", orient="records", lines=True)
print(f"Saved {len(general_df)} examples to data/datasets/general.jsonl")
print("Done")

You will get two files data/datasets/codeact.jsonl,data/datasets/general.jsonl.

Then you run this example script to convert the data to codeact-mixture.nemo.jsonl:
python3 scripts/data/convert_openai_to_nemo_chat_format.py data/datasets/codeact.jsonl,data/datasets/general.jsonl --output_file data/datasets/codeact-mixture.nemo.jsonl

"""Convert OpenAI chat format to NeMo chat format.

Example usage:
python3 scripts/data/convert_openai_to_nemo_chat_format.py data/datasets/codeact.jsonl,data/datasets/general.jsonl --output_file data/datasets/codeact-mixture.nemo.jsonl

============
OpenAI chat format:
source = {
    'messages': [
        {'role': 'system', 'content': '{system message}'}, // this is optional
        {'role': 'user', 'content': '{turn 1 user message}'},
        {'role': 'assistant', 'content': '{turn 1 assistant message}'},
        {'role': 'user', 'content': '{turn 2 user message}'},
    ],
} 

Nemo chat format:
source = {
    'system': '{system message}',
    'conversations': [
        {'from': 'user', 'value': '{turn 1 user message}', 'label': None},
        {'from': 'assistant', 'value': '{turn 1 assistant message}', 'label': '{turn 1 assistant label}'},
        {'from': 'user', 'value': '{turn 2 user message}', 'label': None},
        {'from': 'assistant', 'value': '{turn 2 assistant message}', 'label': '{turn 2 assistant label}'},
    ],
    "mask": "user",
    "type": None,
}
"""

import json
import argparse
from tqdm import tqdm

parser = argparse.ArgumentParser(description='Convert OpenAI chat format to NeMo chat format')
parser.add_argument('input_file', type=str, help='Input file(s) in OpenAI chat format')
parser.add_argument('--output_file', type=str, default=None, help='Output file in NeMo chat format')
args = parser.parse_args()

def convert_openai_to_nemo_chat_format(openai_messages, mask_role='user'):
    nemo_instance = {}
    assert len(openai_messages) > 0, "OpenAI instance must have at least one conversation"
    if openai_messages[0]['role'] == 'system':
        nemo_instance['system'] = openai_messages[0]['content']
        openai_messages = openai_messages[1:]
    else:
        nemo_instance['system'] = None
    nemo_instance['conversations'] = []
    nemo_instance['mask'] = mask_role

    # https://github.com/xingyaoww/NeMo/blob/main/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py#L203-L206
    nemo_instance['type'] = None

    for message in openai_messages:
        if message['role'] == 'user':
            nemo_instance['conversations'].append({'from': 'user', 'value': message['content'], 'label': None})
        elif message['role'] == 'assistant':
            nemo_instance['conversations'].append({'from': 'assistant', 'value': message['content'], 'label': None})
        else:
            raise ValueError(f"Unknown role: {message['role']}")
    return nemo_instance

input_files = args.input_file.split(',')

n_total = 0
with open(args.output_file, 'w') as fout:
    for input_file in tqdm(input_files, desc='Converting files'):
        assert input_file.endswith('.jsonl'), "Input file must be in jsonl format"
        with open(input_file, 'r') as fin:
            for line in tqdm(fin, desc=f'Converting {input_file}'):
                openai_instance = json.loads(line)
                assert 'conversations' in openai_instance, "OpenAI instance must have 'conversations' key"
                nemo_instance = convert_openai_to_nemo_chat_format(openai_instance['conversations'])
                fout.write(json.dumps(nemo_instance) + '\n')
                n_total += 1
    print(f"Converted {input_file} to NeMo chat format and saved to {args.output_file}")

print(f"Converted {n_total} instances to NeMo chat format and saved to {args.output_file}")

Then you can script to convert such JSONL (data/datasets/codeact-mixture.nemo.jsonl) to NeMO format:

#!/bin/bash

FILEPATH=data/datasets/codeact-mixture.nemo.jsonl
NEMO_MODEL=data/models/nemo/mistral-7b-base.nemo

MAX_SEQ_LEN=16384
OUTPUT_DIR=data/datasets_packed/codeact_mixture_mistral_7b_16k
# final 4409 examples
mkdir -p $OUTPUT_DIR

# make all paths absolute
FILEPATH=$(realpath $FILEPATH)
NEMO_MODEL=$(realpath $NEMO_MODEL)
OUTPUT_DIR=$(realpath $OUTPUT_DIR)

export PYTHONPATH=$(pwd)/NeMo:$(pwd)/Megatron-LM:$PYTHONPATH

pushd NeMo/

python scripts/nlp_language_modeling/prepare_packed_ft_chat_dataset.py \
   model.data.train_ds.file_names=[$FILEPATH] \
   model.data.train_ds.max_seq_length=$MAX_SEQ_LEN \
   model.restore_from_path=$NEMO_MODEL \
   +output_dir=$OUTPUT_DIR \
   +pack_sizes=[$MAX_SEQ_LEN] \
   +seed=42 \
   +model.data.chat=True \
   '+model.data.chat_prompt_tokens.system_turn_start="<|im_start|>"' \
   '+model.data.chat_prompt_tokens.system_role="system"' \
   '+model.data.chat_prompt_tokens.turn_start="<|im_start|>"' \
   '+model.data.chat_prompt_tokens.label_start="<|im_start|>"' \
   '+model.data.chat_prompt_tokens.end_of_turn="<|im_end|>\n"' \
   '+model.data.chat_prompt_tokens.end_of_name="\n"' \
   '+model.data.chat_prompt_tokens.add_special_tokens=["<|im_start|>", "<|im_end|>"]'

Config File for Fine-tuning

# Modified from NeMo/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml
name: megatron_mistral_sft_codeact

trainer:
  devices: 4
  accelerator: gpu
  num_nodes: 2
  precision: bf16
  logger: False # logger provided by exp_manager
  enable_checkpointing: False
  use_distributed_sampler: False
  # max_epochs: 5  # will be override by max_steps
  max_steps: ?? # 645 steps per epoch, 5 epochs, 3225 examples, 32 bsz, 3225/32=100.78125
  # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
  log_every_n_steps: 1 # frequency with which training steps are logged 
  # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch
  # val_check_interval: 10
  limit_val_batches: 0
  gradient_clip_val: 1.0

exp_manager:
  explicit_log_dir: null
  exp_dir: ???
  name: ${name}
  create_wandb_logger: True
  wandb_logger_kwargs:
    project: nemo-sft
    name: nemo-mistral-sft-codeact
  resume_if_exists: True
  resume_ignore_no_checkpoint: True
  create_checkpoint_callback: True
  checkpoint_callback_params:
    save_nemo_on_train_end: True 
    every_n_train_steps: ??
    every_n_epochs: null
    filename: 'ckpt-step_{step}-{consumed_samples}'
    model_parallel_size: ${model.tensor_model_parallel_size}

model:
  seed: 42
  tensor_model_parallel_size: 8 # intra-layer model parallelism
  pipeline_model_parallel_size: 1 # inter-layer model parallelism
  global_batch_size: ??
  micro_batch_size: 1
  restore_from_path: ??? # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with
  resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
  save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. 
  sync_batch_comm: False

  ## Copied from Baichuan2
  mcore_gpt: True
  gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)
  # Megatron O2-style half-precision
  megatron_amp_O2: True # Enable O2-level automatic mixed precision using main parameters # False!!!
  # grad_allreduce_chunk_size_mb: 125
  overlap_p2p_comm: True # Overlap p2p communication with computes. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1
  batch_p2p_comm: True # Batch consecutive inter-peer send/recv operations. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1

  # Fusion
  grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce. Only used with O2 and no pipeline parallelism..
  gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism and O2.
  bias_activation_fusion: False # Use a kernel that fuses the bias addition from weight matrices with the subsequent activation function.
  bias_dropout_add_fusion: False # Use a kernel that fuses the bias addition, dropout and residual connection addition.
  masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask.
  get_attention_mask_from_fusion: True # When using fused softmax it will create the attention mask so we won't copy it to the pipeline stages.

  ## Sequence Parallelism
  # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially
  # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
  sequence_parallel: True

  ## Activation Checkpoint
  activations_checkpoint_granularity: full # null # 'selective' or 'full'
  activations_checkpoint_method: uniform # 'uniform', 'block', not used with 'selective'
  # 'uniform' divides the total number of transformer layers and checkpoints the input activation
  # of each chunk at the specified granularity
  # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
  activations_checkpoint_num_layers: 1 # not used with 'selective'
  activations_checkpoint_layers_per_pipeline: null
  answer_only_loss: True

  hidden_dropout: 0.0
  attention_dropout: 0.0
  ffn_dropout: 0.0

  # FSDP
  fsdp: False # Enable training with torch FSDP.
  fsdp_sharding_strategy: 'full' # Method to shard model states. Available options are 'full', 'hybrid', and 'grad'.
  fsdp_grad_reduce_dtype: 'fp32' # Gradient reduction data type.
  fsdp_sharded_checkpoint: False # Store and load FSDP shared checkpoint.
  fsdp_use_orig_params: False # Set to True to use FSDP for specific peft scheme.

  peft:
    peft_scheme: null # "adapter"  # can be either adapter,ia3, or ptuning
    restore_from_path: null

  data:
    train_ds:
      packed_sequence: True
      # Example of how to specify paths to multiple datasets
      # file_names: 
      #   - /path/to/squad.jsonl
      #   - /path/to/mnli.jsonl
      #   - /path/to/boolq.jsonl
      # Example of how each dataset is formatted
      # {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'}
      file_names: ??? # Path to a list of JSONL files corresponding to the source data.
      global_batch_size: ${model.global_batch_size}
      micro_batch_size: ${model.micro_batch_size}
      shuffle: True
      num_workers: 8
      memmap_workers: null
      pin_memory: True
      max_seq_length: ??
      min_seq_length: 1
      pad_to_max_length: True
      drop_last: True
      # Example of how to specify concat_sampling_probabilities
      # concat_sampling_probabilities:
      #   - 0.5
      #   - 0.25
      #   - 0.25
      concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random'

    validation_ds:
      file_names: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds.
      names: null # Names of the corresponding datasets used to log metrics.

      metric:
        name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss', 'rouge', 'token_f1']
        average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
        num_classes: null

  optim:
      name: distributed_fused_adam
      bucket_cap_mb: 60
      overlap_grad_sync: True
      overlap_param_sync: True
      contiguous_grad_buffer: True
      grad_sync_dtype: bf16
      
      lr: 1e-5
      weight_decay: 0.01 
      betas: 
      - 0.9
      - 0.98
      sched:
        name: CosineAnnealing
        # scheduler config override
        warmup_ratio: 0.1   # Warmup steps will be 10% of the training steps.
        min_lr: 1e-6

Example script for training:

#!/bin/bash

export NCCL_DEBUG=INFO
export PYTHONFAULTHANDLER=1

NEMO_MODEL=data/models/nemo/mistral-7b-base.nemo

MAX_SEQ_LEN=16384
PACKED_DATA_FILEPATH=data/datasets_packed/codeact_mixture_mistral_7b_16k/packed_16384_seed42.npy
NUM_EXAMPLES=4409

EXP_ID=Mistral-7B-16k-sft-codeact
CONFIG_PATH=scripts/train/configs/$EXP_ID.yaml
OUTPUT_DIR=data/ckpts/$EXP_ID
mkdir -p $OUTPUT_DIR

GLOBAL_BATCH_SIZE=32
N_EPOCHS=5
echo "GLOBAL_BATCH_SIZE: $GLOBAL_BATCH_SIZE"
STEPS_PER_EPOCH=$((NUM_EXAMPLES / GLOBAL_BATCH_SIZE))
echo "STEPS_PER_EPOCH: $STEPS_PER_EPOCH"
MAX_STEPS=$((NUM_EXAMPLES * N_EPOCHS / GLOBAL_BATCH_SIZE))
echo "MAX_STEPS: $MAX_STEPS"

# make all paths absolute
OUTPUT_DIR=$(realpath $OUTPUT_DIR)
NEMO_MODEL=$(realpath $NEMO_MODEL)
PACKED_DATA_FILEPATH=$(realpath $PACKED_DATA_FILEPATH)

CONFIG_PATH=$(realpath $CONFIG_PATH)
CONFIG_DIR=$(dirname $CONFIG_PATH)
CONFIG_NAME=$(basename $CONFIG_PATH .yaml)
echo "NEMO_MODEL: $NEMO_MODEL"
echo "OUTPUT_DIR: $OUTPUT_DIR"
echo "CONFIG_NAME: $CONFIG_NAME"
echo "CONFIG_DIR: $CONFIG_DIR"
export PYTHONPATH=$(pwd)/NeMo:$(pwd)/Megatron-LM:$PYTHONPATH

# export WANDB_DISABLED=True
export CUDA_DEVICE_MAX_CONNECTIONS=1

pushd NeMo/
# https://docs.nvidia.com/nemo-framework/user-guide/latest/modelalignment/sft.html#step-2-sft-training
python examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \
   --config-path $CONFIG_DIR \
   --config-name $CONFIG_NAME \
   trainer.max_steps=$MAX_STEPS \
   exp_manager.exp_dir=$OUTPUT_DIR \
   exp_manager.checkpoint_callback_params.every_n_train_steps=$STEPS_PER_EPOCH \
   model.restore_from_path=$NEMO_MODEL \
   model.global_batch_size=$GLOBAL_BATCH_SIZE \
   model.data.train_ds.max_seq_length=$MAX_SEQ_LEN \
   model.data.train_ds.file_names=[$PACKED_DATA_FILEPATH] \
   'model.data.train_ds.concat_sampling_probabilities=[1.0]' \

Jenkins CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

There's no need to comment jenkins on the PR to trigger Jenkins CI.
The GitHub Actions CI will run automatically when the PR is opened.
To run CI on an untrusted fork, a NeMo user with write access must click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

Copy link
Contributor

github-actions bot commented May 8, 2024

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

@github-actions github-actions bot added the stale label May 8, 2024
Copy link
Contributor

This PR was closed because it has been inactive for 7 days since being marked as stale.

@github-actions github-actions bot closed this May 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant