diff --git a/.gitignore b/.gitignore index 8702dda..0ce48f3 100644 --- a/.gitignore +++ b/.gitignore @@ -134,6 +134,7 @@ dmypy.json build .vscode/ +*.iml .attach_pid* src/neuronx_distributed.egg-info/ *.whl diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..78d8d1f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,25 @@ +default_language_version: + # force all unspecified python hooks to run python3 + python: python3 +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: end-of-file-fixer + - id: trailing-whitespace + - id: detect-aws-credentials +- repo: https://github.com/pocc/pre-commit-hooks + rev: v1.1.1 + hooks: + - id: clang-format + args: [--style=file, -i] +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.0 + hooks: + - id: ruff + name: ruff + entry: ruff + args: [check, --fix, "--line-length=120", "--ignore=F401,E203"] + types: [python] + language: system + exclude: cases_update diff --git a/README.md b/README.md index 24cb860..1b8cd73 100644 --- a/README.md +++ b/README.md @@ -9,10 +9,10 @@ To install the library, please follow the instructions mentioned here: https://a To build from source, run the following command: ``` -python3 setup.py bdist_wheel +bash ./build.sh ``` - -It should place the wheel at `dist/` + +It should place the wheel at `build/` ## API Reference Guide diff --git a/build-tools/bin/custom-build b/build-tools/bin/custom-build index 09e9de0..a8b39bb 100755 --- a/build-tools/bin/custom-build +++ b/build-tools/bin/custom-build @@ -8,43 +8,12 @@ LICENSE_TXT_PATH=${BUILD_PATH}/private/LICENSE.txt BUILD_PATH_NEURONX_DISTRIBUTED=${BUILD_PATH}/public/NeuronxDistributed mkdir -p ${BUILD_PATH_NEURONX_DISTRIBUTED} -# check against flake8 linter -# Options used: -# --max-line-length=120 is used since a lot of docstrings -# contain lines longer than 120 that wouldn't make sense -# to split (ex. code snippets) -# -# Warnings that are ignored -# F401: unused import -# - Reason to ignore: Side effects might occur on import. -# Also, neuronx-cc check would trip this. -# W503/504: newline before/after binary operator. -# - Reason to Ignore: conditionals are often split into -# multiple lines for readability). -# -# More info in the following links: -# 1) https://flake8.pycqa.org/en/latest/user/error-codes.html -# 2) https://pycodestyle.pycqa.org/en/latest/intro.html#error-codes - -FLAKE8_MSG=$(flake8 --max-line-length=120 --ignore=F401,W503,W504,E203 ${SRC_PATH}/src/neuronx_distributed || true) - -python3.8 -m pip install flake8==3.7 -if [[ ! -z $FLAKE8_MSG ]] -then - echo "FLAKE8 LINTING HAS DETECTED FORMATTING AND POTENTIALLY SOME SYNTAX ERRORS, PLEASE CHECK ABOVE OUTPUT!" - exit 1 -fi - -if [[ "$1" == "flake8" ]] -then - exit 0 -fi - -# # Copy Python source files +# Copy Python source files cp setup.py ${BUILD_PATH_NEURONX_DISTRIBUTED}/ cp -r src ${BUILD_PATH_NEURONX_DISTRIBUTED}/ cp $LICENSE_TXT_PATH ${BUILD_PATH_NEURONX_DISTRIBUTED}/ -## Build wheel -DIST_DIR=${BUILD_PATH}/pip/public/neuronx-distributed -python3.8 setup.py bdist_wheel --dist-dir ${DIST_DIR} + +export DIST_DIR=${BUILD_PATH}/pip/public/neuronx-distributed + +bash build.sh diff --git a/build.sh b/build.sh new file mode 100644 index 0000000..2ddcef2 --- /dev/null +++ b/build.sh @@ -0,0 +1,30 @@ +#! /bin/bash +set -e + +: ${DIST_DIR:=build} + +python3.8 -m pip install ruff +# removing cache fails in ToD +python3.8 -m ruff check --no-cache --line-length=120 --ignore=F401,E203 +# exit when asked to run `ruff` only +if [[ "$1" == "ruff" ]] +then + exit 0 +fi + +# Run static code analysis +python3.8 -m pip install mypy +# Install type bindings +python3.8 -m pip install types-requests boto3-stubs[s3] +# removing cache fails in ToD +python3.8 -m mypy --no-incremental || true +# exit when asked to run `mypy` only +if [[ "$1" == "mypy" ]] +then + exit 0 +fi + + + +# Build wheel +python3.8 setup.py bdist_wheel --dist-dir ${DIST_DIR} diff --git a/examples/inference/dbrx/dbrx_runner.py b/examples/inference/dbrx/dbrx_runner.py new file mode 100644 index 0000000..2622a04 --- /dev/null +++ b/examples/inference/dbrx/dbrx_runner.py @@ -0,0 +1,79 @@ +import torch +from dbrx.neuron_modeling_dbrx import ( + NeuronDbrxConfig, + NeuronDbrxForCausalLM, + NeuronDbrxModel, +) +from runner import InferenceRunner +from transformers import AutoTokenizer + +from neuronx_distributed.parallel_layers.checkpointing import _invoke_preshard_hook + + +class DbrxRunner(InferenceRunner): + def load_hf_model(self): + config = NeuronDbrxConfig.from_pretrained(self.model_path) + return NeuronDbrxForCausalLM.load_hf_model(self.model_path, config) + + def load_neuron_model_on_cpu(self, max_prompt_length, sequence_length, batch_size, **kwargs): + # On CPU we can only run tensor parallelism with degree 1 + config = self.get_config_for_nxd( + batch_size, + 1, + max_prompt_length=max_prompt_length, + sequence_length=sequence_length, + enable_bucketing=False, + **kwargs) + config.torch_dtype = torch.float32 + + self.init_ditributed_env() + neuron_model = NeuronDbrxModel(config) + + state_dict = NeuronDbrxForCausalLM.get_state_dict(self.model_path, config) + + _invoke_preshard_hook(neuron_model, state_dict) + + neuron_model.load_state_dict(state_dict, strict=False) + + if config.torch_dtype == torch.bfloat16: + neuron_model.bfloat16() + + model = NeuronDbrxForCausalLM(None, config) + model.context_encoding_model.model = neuron_model + model.token_generation_model.model = neuron_model + return model + + def load_neuron_model(self, traced_model_path): + config = NeuronDbrxConfig.from_pretrained(traced_model_path) + model = NeuronDbrxForCausalLM.from_pretrained("", config) + + model.load(traced_model_path) + if config.torch_dtype == torch.bfloat16: + model.bfloat16() + + return model + + def load_tokenizer(self, padding_side=None): + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path) + tokenizer.pad_token = tokenizer.unk_token + tokenizer.padding_side = padding_side if padding_side else self.get_padding_side() + return tokenizer + + def get_config_cls(self): + return NeuronDbrxConfig + + def get_model_cls(self): + return NeuronDbrxForCausalLM + + def get_padding_side(self): + return "right" + + def get_default_hf_generation_config_kwargs(self): + config = super().get_default_hf_generation_config_kwargs() + config['pad_token_id'] = 0 + + return config + + +if __name__ == "__main__": + DbrxRunner.cmd_execute() diff --git a/examples/inference/dbrx/neuron_modeling_dbrx.py b/examples/inference/dbrx/neuron_modeling_dbrx.py new file mode 100644 index 0000000..61b33ae --- /dev/null +++ b/examples/inference/dbrx/neuron_modeling_dbrx.py @@ -0,0 +1,337 @@ +# coding=utf-8 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Dbrx model for NXD inference.""" +import logging +import warnings +import gc +from typing import Optional, Tuple, Union + +import torch +from modules.gqa import ( + GQA, + BaseGroupQueryAttention, +) +from modules.model_base import NeuronBaseModel, NeuronBaseForCausalLM +from torch import nn + +try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel +except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel +from torch_neuronx.xla_impl.ops import nki_jit +from transformers import DbrxForCausalLM, DbrxPreTrainedModel +from transformers.generation import SampleDecoderOnlyOutput, SampleEncoderDecoderOutput +from modules.attention.attention_base import NeuronAttentionBase +from modules.attention.utils import RotaryEmbedding +from modules.config import NeuronInferenceConfig +from transformers.models.dbrx.configuration_dbrx import DbrxConfig + + +from neuronx_distributed.modules.moe.expert_mlps import ExpertMLPs +from neuronx_distributed.modules.moe.model import MoE +from neuronx_distributed.modules.moe.routing import RouterTopK +from neuronx_distributed.parallel_layers import parallel_state, utils +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, +) +from neuronx_distributed.utils.sampling import Sampler + +_flash_fwd_call = nki_jit()(attention_isa_kernel) + +SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] + +GQA_SHARDING_STRATEGY = GQA.REPLICATE_TO_TP_DEGREE + +logger = logging.getLogger(__name__) + + +def convert_dbrx_to_neuron_state_dict(dbrx_state_dict, cfg): + """ + Helper function which returns the model weights from the dbrx model in a state dictionary compatible with the stucture of the neuron MoE model. + """ + + assert cfg.glu_mlp is True, "Only GLU MLP is supported for Dbrx Top-K model" + neuron_state_dict = {} + neuron_state_dict["embed_tokens.weight"] = dbrx_state_dict["wte.weight"].clone().detach() + neuron_state_dict["norm.weight"] = dbrx_state_dict["norm_f.weight"].clone().detach() + neuron_state_dict["lm_head.weight"] = dbrx_state_dict["lm_head.weight"].clone().detach() + + for l in range(cfg.n_layers): # noqa: E741 + # Copy router weights + neuron_state_dict[f"layers.{l}.ffn.router.linear_router.weight"] = ( + dbrx_state_dict[f"blocks.{l}.ffn.router.layer.weight"].clone().detach() + ) + + num_experts = cfg.ffn_config.moe_num_experts + intermediate_size, hidden_size = cfg.ffn_config.ffn_hidden_size, cfg.d_model + + # Copy gate_proj and up_proj after concatenation + # [num_experts, hidden_size, 2 * intermediate_size] + gate_proj_weights = dbrx_state_dict[f"blocks.{l}.ffn.experts.mlp.w1"].view(num_experts, intermediate_size, hidden_size) + up_proj_weights = dbrx_state_dict[f"blocks.{l}.ffn.experts.mlp.v1"].view(num_experts, intermediate_size, hidden_size) + gate_up_proj = torch.cat([gate_proj_weights, up_proj_weights], dim=1).transpose(1, 2) + neuron_state_dict[f"layers.{l}.ffn.expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up_proj + + # Copy down_proj + # [num_experts, intermediate_size, hidden_size] + down_proj = dbrx_state_dict[f"blocks.{l}.ffn.experts.mlp.w2"].view(num_experts, intermediate_size, hidden_size) + neuron_state_dict[f"layers.{l}.ffn.expert_mlps.mlp_op.down_proj.weight"] = down_proj + + neuron_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = ( + dbrx_state_dict[f"blocks.{l}.norm_attn_norm.attn.Wqkv.weight"].clone().detach() + ) + neuron_state_dict[f"layers.{l}.self_attn.o_proj.weight"] = ( + dbrx_state_dict[f"blocks.{l}.norm_attn_norm.attn.out_proj.weight"].clone().detach() + ) + neuron_state_dict[f"layers.{l}.input_layernorm.weight"] = ( + dbrx_state_dict[f"blocks.{l}.norm_attn_norm.norm_1.weight"].clone().detach() + ) + neuron_state_dict[f"layers.{l}.post_attention_layernorm.weight"] = ( + dbrx_state_dict[f"blocks.{l}.norm_attn_norm.norm_2.weight"].clone().detach() + ) + + dbrx_state_dict.clear() + gc.collect() + + return neuron_state_dict + + +def preshard_hook_fn(module: torch.nn.Module, model_state_dict: dict, prefix: str) -> bool: + if isinstance(module, (BaseGroupQueryAttention,)): + return module.preshard_hook(model_state_dict, prefix) + return False + + +class NeuronDbrxConfig(NeuronInferenceConfig, DbrxConfig): + def __init__( + self, + batch_size: int = 1, + tp_degree: int = 1, + max_context_length: int = 128, + max_new_tokens: int = 128, + capacity_factor: float = None, + glu_mlp: bool = True, + padding_side: str = "right", + speculation_length: int = 0, + **kwargs, + ): + self.max_new_tokens = max_new_tokens + self.max_context_length = max_context_length + self.max_length = max_new_tokens + max_context_length + self.fused_qkv = True + + # capacity_factor = None corresponds to full capacity (no token dropping) + self.capacity_factor = float(capacity_factor) if capacity_factor is not None else None + self.glu_mlp = glu_mlp + + super().__init__( + tp_degree=tp_degree, + batch_size=batch_size, + seq_len=max_context_length+max_new_tokens, + padding_side=padding_side, + max_context_length=max_context_length, + speculation_length=speculation_length, + **kwargs, + ) + + +class NeuronDbrxAttention(NeuronAttentionBase): + + def __init__(self, config: DbrxConfig): + super().__init__() + self.config = config + self.hidden_size = config.d_model + self.num_attention_heads = config.n_heads + self.head_dim = self.hidden_size // self.num_attention_heads + self.max_position_embeddings = config.max_seq_len + self.torch_dtype = config.torch_dtype + self.padding_side = config.padding_side + self.num_key_value_heads = config.attn_config.kv_n_heads + self.rope_theta = config.attn_config.rope_theta + + if not parallel_state.model_parallel_is_initialized(): + raise ValueError( + "NeuronDbrxAttention has to be initialized in a distributed env. Please use neuronx_distributed" + " module to initialize a distributed env." + ) + self.tp_degree = parallel_state.get_tensor_model_parallel_size() + self.fused_qkv = config.fused_qkv + self.clip_qkv = config.attn_config.clip_qkv + + self.init_gqa_properties() + + self.rotary_emb = RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + +class NeuronDbrxBlock(nn.Module): + """ + Just replace the attention with the NXD version, and MLP with the NXD version + """ + + def __init__(self, config: NeuronDbrxConfig, block_idx: int): + super().__init__() + self.hidden_size = config.d_model + self.resid_pdrop = config.resid_pdrop + self.block_idx = block_idx + self.self_attn = NeuronDbrxAttention(config=config) + + ffn_config = config.ffn_config + router = RouterTopK( + num_experts=ffn_config.moe_num_experts, + top_k=ffn_config.moe_top_k, + hidden_size=config.d_model, + ) + expert_mlps = ExpertMLPs( + num_experts=ffn_config.moe_num_experts, + top_k=ffn_config.moe_top_k, + hidden_size=config.d_model, + intermediate_size=ffn_config.ffn_hidden_size, + hidden_act=ffn_config.ffn_act_fn['name'], + capacity_factor=config.capacity_factor, + glu_mlp=config.glu_mlp, + normalize_top_k_affinities=True, + ) + self.ffn = MoE( + router=router, + expert_mlps=expert_mlps, + ) + self.ffn.eval() # Set MoE module in eval mode + + self.input_layernorm = nn.LayerNorm(config.d_model, bias=False) + self.post_attention_layernorm = nn.LayerNorm(config.d_model, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + position_ids (`torch.FloatTensor`, *optional*): + position ids of size `(batch_size, sequence_length)`. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states).to(dtype=hidden_states.dtype) + + # Self Attention + hidden_states, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + **kwargs, + ) + + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states).to(dtype=hidden_states.dtype) + + # FFN + hidden_states = self.ffn(hidden_states)[0] + hidden_states = residual + hidden_states + + outputs = (hidden_states, present_key_value) + + return outputs + + +class NeuronDbrxModel(NeuronBaseModel, DbrxPreTrainedModel): + """Transformer decoder consisting of *config.num_hidden_layers*. Each layer is a [`DbrxBlock`] layer. + + Args: + config ([`DbrxConfig`]): Model configuration class with all parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. + """ + + _model_cls = DbrxPreTrainedModel + + def setup_attr_for_model(self, config: NeuronDbrxConfig): + self.emb_pdrop = config.emb_pdrop + + # Needed for init_inference_optimization() + self.on_device_sampling = config.on_device_sampling + self.tp_degree = config.tp_degree + self.hidden_size = config.d_model + self.num_attention_heads = config.n_heads + self.num_key_value_heads = config.attn_config.kv_n_heads + self.max_batch_size = config.max_batch_size + self.buckets = config.buckets + + def init_model(self, config: NeuronDbrxConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.d_model, + self.padding_idx, + dtype=config.torch_dtype, + shard_across_embedding=True, + ) + self.layers = nn.ModuleList([NeuronDbrxBlock(config, block_idx) for block_idx in range(config.n_layers)]) + self.norm = nn.LayerNorm(config.d_model, bias=False) + self.lm_head = ColumnParallelLinear(config.d_model, config.vocab_size, bias=False) + + + +class NeuronDbrxForCausalLM(NeuronBaseForCausalLM, DbrxPreTrainedModel): + """ + This class can be used as DbrxForCausalLM + """ + _STATE_DICT_MODEL_PREFIX = "transformer." + + _model_cls = NeuronDbrxModel + + def __init__(self, model_path: str, config: NeuronDbrxConfig): + super().__init__(model_path, config) + self.sampler = Sampler(self.config) + + @staticmethod + def load_hf_model(model_path, config): + return DbrxForCausalLM.from_pretrained(model_path, torch_dtype=config.torch_dtype) + + @classmethod + def get_state_dict(cls, model_path: str, config: DbrxConfig) -> dict: + model_sd = super().get_state_dict(model_path, config) + model_sd = convert_dbrx_to_neuron_state_dict(model_sd, config) + return model_sd + + def get_compiler_args(self): + compiler_args = "--enable-saturate-infinity --enable-mixed-precision-accumulation --model-type transformer -O1" + # Add flags for cc-overlap + compiler_args += " --tensorizer-options='--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2'" + # Prevent auto-downcasting when running with fp32 + if self.config.torch_dtype == torch.float32: + compiler_args += " --auto-cast=none" + # TODO: Remove this flag after compiler fix is merged (NCC-2677) + compiler_args += " --internal-hlo2tensorizer-options=--expand-batch-norm-training" + return compiler_args diff --git a/examples/inference/llama2/llama2_runner.py b/examples/inference/llama2/llama2_runner.py index 9aa2be3..b98988a 100644 --- a/examples/inference/llama2/llama2_runner.py +++ b/examples/inference/llama2/llama2_runner.py @@ -1,18 +1,17 @@ -import os - import torch from llama2.neuron_modeling_llama import ( NeuronLlamaConfig, NeuronLlamaForCausalLM, NeuronLlamaModel, - preshard_hook_fn, ) from runner import InferenceRunner from transformers import AutoTokenizer from neuronx_distributed.parallel_layers.checkpointing import _invoke_preshard_hook +from neuronx_distributed.quantization.quantization_config import QuantizationType from neuronx_distributed.quantization.quantization_utils import ( - convert_float_model_to_pytorch_int8_model, + quantize_pytorch_model_per_channel_symmetric, + quantize_pytorch_model_per_tensor_symmetric, ) @@ -20,42 +19,69 @@ class LlamaRunner(InferenceRunner): def load_hf_model(self): return NeuronLlamaForCausalLM.load_hf_model(self.model_path) - def load_neuron_model_on_cpu(self, max_context_length, max_new_tokens, batch_size, **kwargs): - config = self.get_config_for_nxd(batch_size, 1, max_context_length, max_new_tokens, **kwargs) - config.torch_dtype = torch.float32 + def load_neuron_model_on_cpu(self, max_prompt_length, sequence_length, batch_size, **kwargs): + self.config = self.get_config_for_nxd( + batch_size, + 1, + max_prompt_length=max_prompt_length, + sequence_length=sequence_length, + enable_bucketing=False, + **kwargs) + self.config.torch_dtype = torch.float32 - neuron_model = NeuronLlamaModel(config) + neuron_model = NeuronLlamaModel(self.config) - state_dict = NeuronLlamaForCausalLM.get_state_dict(self.model_path, config=config) + state_dict = NeuronLlamaForCausalLM.get_state_dict(self.model_path, config=self.config) _invoke_preshard_hook(neuron_model, state_dict) neuron_model.load_state_dict(state_dict, strict=False) - if config.torch_dtype == torch.bfloat16: + if self.config.torch_dtype == torch.bfloat16: neuron_model.bfloat16() - model = NeuronLlamaForCausalLM(None, config) + model = NeuronLlamaForCausalLM(None, self.config) model.context_encoding_model.model = neuron_model model.token_generation_model.model = neuron_model return model - def load_quantized_neuron_model_on_cpu(self, max_context_length, max_new_tokens, batch_size, **kwargs): - model = self.load_neuron_model_on_cpu(max_context_length, max_new_tokens, batch_size, **kwargs) - return convert_float_model_to_pytorch_int8_model(model, inplace=True) + def generate_quantized_hf_checkpoints_on_cpu(self, max_prompt_length, sequence_length, batch_size, **kwargs): + config = self.get_config_for_nxd(batch_size, 1, max_prompt_length, sequence_length, **kwargs) + config.torch_dtype = torch.float32 + + quantized_state_dict = NeuronLlamaForCausalLM.generate_quantized_state_dict( + model_path=self.model_path, config=config + ) + return quantized_state_dict + + def load_quantized_neuron_model_on_cpu(self, max_prompt_length, sequence_length, batch_size, **kwargs): + model = self.load_neuron_model_on_cpu(max_prompt_length, sequence_length, batch_size, **kwargs) + + quantization_type = QuantizationType(kwargs.get("quantization_type", "per_tensor_symmetric")) + if quantization_type == QuantizationType.PER_TENSOR_SYMMETRIC: + return quantize_pytorch_model_per_tensor_symmetric(model, inplace=True) + elif quantization_type == QuantizationType.PER_CHANNEL_SYMMETRIC: + return quantize_pytorch_model_per_channel_symmetric(model, inplace=True) + else: + raise RuntimeError(f"quantization_type: {quantization_type} not supported") def load_neuron_model(self, traced_model_path): config = NeuronLlamaConfig.from_pretrained(traced_model_path) model = NeuronLlamaForCausalLM.from_pretrained("", config) + self.config = config model.load(traced_model_path) if config.torch_dtype == torch.bfloat16: - os.environ["XLA_DOWNCAST_BF16"] = "1" + model.bfloat16() return model def load_tokenizer(self, padding_side=None): tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path) - tokenizer.pad_token = tokenizer.unk_token + if not hasattr(self.config, 'pad_token_id') or self.config.pad_token_id is None: + # Use eos_token as pad_token which works for both llama2 and llama3 + tokenizer.pad_token = tokenizer.eos_token + else: + tokenizer.pad_token_id = self.config.pad_token_id tokenizer.padding_side = padding_side if padding_side else self.get_padding_side() return tokenizer @@ -68,6 +94,13 @@ def get_model_cls(self): def get_padding_side(self): return "right" + def get_default_hf_generation_config_kwargs(self): + config = super().get_default_hf_generation_config_kwargs() + # set to eos_token_id as that's done in load_tokenizer + config['pad_token_id'] = self.generation_config.eos_token_id + + return config + if __name__ == "__main__": LlamaRunner.cmd_execute() diff --git a/examples/inference/llama2/neuron_modeling_llama.py b/examples/inference/llama2/neuron_modeling_llama.py index f8e0999..ff24815 100644 --- a/examples/inference/llama2/neuron_modeling_llama.py +++ b/examples/inference/llama2/neuron_modeling_llama.py @@ -17,70 +17,54 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch LLaMA model for NXD inference.""" -import copy -import logging -import math -from typing import Any, Dict, List, Optional, Tuple, Type, Union +"""PyTorch LLaMA model for NXD inference.""" +from typing import Optional, Tuple, Type, Union import torch + +from modules.attention.attention_base import NeuronAttentionBase +from modules.attention.utils import RotaryEmbedding from modules.custom_calls import CustomRMSNorm from torch import nn from transformers import LlamaPreTrainedModel from transformers.activations import ACT2FN from transformers.generation import SampleDecoderOnlyOutput, SampleEncoderDecoderOutput -from transformers.generation.logits_process import LogitsProcessorList -from transformers.generation.stopping_criteria import ( - StoppingCriteriaList, - validate_stopping_criteria, -) -from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import ( - LlamaDynamicNTKScalingRotaryEmbedding, - LlamaLinearScalingRotaryEmbedding, LlamaRMSNorm, - LlamaRotaryEmbedding, -) - -SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] - -from dataclasses import dataclass - -from modules.autobucketing import slice_lhs, slice_rhs -from modules.gqa import ( - BaseGroupQueryAttention, - GroupQueryAttention_O, - GroupQueryAttention_QKV, - determine_sharding_strategy, - get_shardable_head_counts, ) -from modules.model_base import NeuronBaseForCausalLM -from modules.model_wrapper import ( - CONTEXT_ENCODING_MODEL_TAG, - SPECULATION_MODEL_TAG, - TOKEN_GENERATION_MODEL_TAG, - ModelWrapper, -) -from neuronxcc.nki.kernels.attention import attention_isa_kernel -from torch_neuronx.xla_impl.ops import nki_jit -from transformers import LlamaForCausalLM -from neuronx_distributed.parallel_layers import parallel_state, utils -from neuronx_distributed.parallel_layers.layers import ( - ColumnParallelLinear, - ParallelEmbedding, - RowParallelLinear, +from transformers.models.llama.modeling_llama import ( + LlamaRotaryEmbedding, + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaLinearScalingRotaryEmbedding, ) -from neuronx_distributed.utils.sampling import Sampler +SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] -_flash_fwd_call = nki_jit()(attention_isa_kernel) +from modules.autobucketing import slice_lhs, slice_rhs # noqa: E402 +from modules.gqa import ( # noqa: E402 + BaseGroupQueryAttention, # noqa: E402 + determine_sharding_strategy, # noqa: E402 + get_shardable_head_counts, # noqa: E402 +) # noqa: E402 +from modules.model_base import NeuronBaseModel, NeuronBaseForCausalLM # noqa: E402 +from modules.config import NeuronInferenceConfig # noqa: E402 + +from transformers import LlamaForCausalLM # noqa: E402 + +from neuronx_distributed.parallel_layers import parallel_state, utils # noqa: E402 +from neuronx_distributed.parallel_layers.layers import ( # noqa: E402 + ColumnParallelLinear, # noqa: E402 + ParallelEmbedding, # noqa: E402 + RowParallelLinear, # noqa: E402 +) # noqa: E402 +from neuronx_distributed.utils.sampling import Sampler # noqa: E402 _LLAMA_MODULE_MAP = {} def get_rmsnorm_cls(): - # Intialize to the approperiate implementation of RMSNorm + # Initialize to the appropriate implementation of RMSNorm # If infer on NXD -> CustomRMSNorm # If infer on CPU -> HF_RMSNorm (CustomRMSNorm does not work on CPU) return CustomRMSNorm if parallel_state.model_parallel_is_initialized() else LlamaRMSNorm @@ -117,41 +101,20 @@ def inner(cls: Type[nn.Module]): return inner -class NeuronLlamaConfig(LlamaConfig): +class NeuronLlamaConfig(NeuronInferenceConfig, LlamaConfig): def __init__( - self, max_batch_size=1, tp_degree=1, n_positions=128, padding_side="right", speculation_length=0, **kwargs + self, max_batch_size=1, tp_degree=1, n_positions=128, padding_side="right", speculation_length=0, **kwargs ): - self.max_batch_size = max_batch_size - self.tp_degree = tp_degree self.attn_cls = "NeuronLlamaAttention" - self.n_positions = n_positions - self.padding_side = padding_side - self.speculation_length = speculation_length - self.trace_tokengen_model = True - - self.ctx_batch_size = kwargs.pop("ctx_batch_size", max_batch_size) - self.tkg_batch_size = kwargs.pop("tkg_batch_size", max_batch_size) - - # decoder specific params - self.batch_size = max_batch_size - self.n_active_tokens = n_positions - - # bucketing specific params - self.enable_context_encoding_bucketing = False - self.enable_token_generation_bucketing = False - self.buckets = [n_positions] - self.bucket_n_active_tokens = self.enable_context_encoding_bucketing - - self.is_continuous_batching = kwargs.pop("is_continuous_batching", False) - self.on_device_sampling = kwargs.pop("on_device_sampling", False) - - # Quantization specific params - self.quantized = kwargs.get("quantized", False) - self.quantized_checkpoints_path = kwargs.get("quantized_checkpoints_path", None) - # TODO: Add validation for quantized_checkpoints_path after the design discussions - - super().__init__(**kwargs) + super().__init__( + tp_degree=tp_degree, + seq_len=n_positions, + padding_side=padding_side, + speculation_length=speculation_length, + max_batch_size=max_batch_size, + **kwargs, + ) class NeuronLlamaMLP(nn.Module): @@ -202,7 +165,7 @@ def forward(self, x): @register_module("NeuronLlamaAttention") -class NeuronLlamaAttention(nn.Module): +class NeuronLlamaAttention(NeuronAttentionBase): """ Compared with LlamaAttention, this class just 1. replaces the q_proj, k_proj, v_proj with column parallel layer @@ -214,48 +177,43 @@ class NeuronLlamaAttention(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() + self.config = config - self.tp_degree = config.tp_degree self.hidden_size = config.hidden_size - self.head_dim = self.hidden_size // config.num_attention_heads + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = self.hidden_size // self.num_attention_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta - self.is_causal = True - - self.gqa_qkv = GroupQueryAttention_QKV( - hidden_size=self.hidden_size, - head_dim=config.hidden_size // config.num_attention_heads, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=config.num_key_value_heads, - tp_degree=self.tp_degree, - dtype=config.torch_dtype, - gather_output=False, - ) + self.padding_side = config.padding_side + self.torch_dtype = config.torch_dtype + self.is_medusa = config.is_medusa - self.o_proj = GroupQueryAttention_O( - hidden_size=self.hidden_size, - head_dim=config.hidden_size // config.num_attention_heads, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=config.num_key_value_heads, - tp_degree=self.tp_degree, - dtype=config.torch_dtype, - desired_sharding_strategy=self.gqa_qkv.get_sharding_strategy(), - input_is_parallel=True, - ) + if parallel_state.model_parallel_is_initialized(): + self.tp_degree = parallel_state.get_tensor_model_parallel_size() + else: + self.tp_degree = 1 + self.fused_qkv = False + self.clip_qkv = None - self.num_heads = utils.divide(self.gqa_qkv.get_num_attention_heads(), self.tp_degree) - self.num_key_value_heads = utils.divide(self.gqa_qkv.get_num_key_value_heads(), self.tp_degree) - self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.init_gqa_properties() - self._init_rope() + self.init_rope() - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) + def init_rope(self): + if not hasattr(self.config, "rope_scaling") or self.config.rope_scaling is None: + if self.is_medusa: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + self.rotary_emb = RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] @@ -276,162 +234,6 @@ def _init_rope(self): else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - def _rotate_half(self, x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def _apply_rotary_pos_emb(self, q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors.""" - - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (self._rotate_half(q) * sin) - k_embed = (k * cos) + (self._rotate_half(k) * sin) - return q_embed, k_embed - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - active_mask: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - Q, K, V = self.gqa_qkv(hidden_states=hidden_states) - - # Divide hidden_dim across heads for MHA - # Change layout: BSHD -> BHSD - Q = Q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - K = K.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - V = V.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - # Rotate(Q) - # Rotate(K) - kv_seq_len = K.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(V, seq_len=kv_seq_len) - Q, K = self._apply_rotary_pos_emb(Q, K, cos, sin, position_ids) - - if past_key_value is None: - # Context encoding - K_active = self._repeat_kv(K, self.num_key_value_groups) - V_active = self._repeat_kv(V, self.num_key_value_groups) - - # use flash attention if (i) sequence length is large enough to get best performance, - # (ii) Q, K, and V have the same shape. Conditions can be changed in future. - - if q_len >= 4096 and Q.shape == K_active.shape == V_active.shape: - # original shape of q, k, v is BHSD, and expected output is also BHSD. - logging.debug(f"Using flash_fwd for Q.shape={Q.shape}") - # make sure to cast inputs to self.config.torch_dtype (this is needed because the downcast to bf16 might happen - # after the kernel hlo creation step). Also convert shapes as expected by the kernel. - Q = Q.permute(0, 1, 3, 2).reshape((bsz*self.num_heads, self.head_dim, q_len)).to(self.config.torch_dtype) - Q = Q / math.sqrt(self.head_dim) - K_active = K_active.permute(0, 1, 3, 2).reshape((bsz*self.num_heads, self.head_dim, q_len)).to(self.config.torch_dtype) - V_active = V_active.reshape((bsz*self.num_heads, q_len, self.head_dim)).to(self.config.torch_dtype) - attn_output = torch.zeros(bsz*self.num_heads, q_len, self.head_dim, dtype=Q.dtype, device=Q.device) - _flash_fwd_call(Q, K_active, V_active, 1.0, attn_output, kernel_name="CausalAttentionMMSoftmaxMMWithoutSwap") - attn_output = attn_output.reshape((bsz, self.num_heads, q_len, self.head_dim)) - else: - logging.debug(f"Not using flash_fwd for Q.shape={Q.shape}") - - # (Q.K'/√dkv) + mask - active_scores = torch.matmul(Q, K_active.transpose(2, 3)) / math.sqrt(self.head_dim) - active_scores = torch.where(attention_mask, active_scores, torch.finfo(active_scores.dtype).min) - - # Softmax - active_scores = nn.functional.softmax(active_scores, dim=-1, dtype=torch.float32).to(Q.dtype) - attn_output = torch.matmul(active_scores, V_active) - else: - is_speculation = position_ids.shape[-1] > 1 - - # Decomposed attention for token generation - K_prior = past_key_value[0] - V_prior = past_key_value[1] - - # Replicate KV for GQA/MQA - K_prior = self._repeat_kv(K_prior, self.num_key_value_groups) - V_prior = self._repeat_kv(V_prior, self.num_key_value_groups) - K_active = self._repeat_kv(K, self.num_key_value_groups) - V_active = self._repeat_kv(V, self.num_key_value_groups) - - # (Q.K'/√dkv) + mask - prior_scores = torch.matmul(Q, K_prior.transpose(2, 3)) / math.sqrt(self.head_dim) - - prior_scores = torch.where(attention_mask, prior_scores, torch.finfo(prior_scores.dtype).min) - - active_scores = torch.matmul(Q, K_active.transpose(2, 3)) / math.sqrt(self.head_dim) - - # Mask active scores for speculation - if is_speculation: - active_scores = torch.where(active_mask, active_scores, torch.finfo(active_scores.dtype).min) - - # Softmax across prior and active scores - prior_scores = prior_scores.to(torch.float32) - active_scores = active_scores.to(torch.float32) - - max_score = torch.max(prior_scores, dim=-1, keepdim=True)[0] - if is_speculation: - max_active_score = torch.max(active_scores, dim=-1, keepdim=True)[0] - max_score = torch.maximum(max_score, max_active_score) - else: - max_score = torch.maximum(max_score, active_scores) - - prior_scores = prior_scores - max_score - active_scores = active_scores - max_score - - prior_scores = torch.exp(prior_scores) - active_scores = torch.exp(active_scores) - - divisor = prior_scores.sum(dim=-1, keepdim=True) - if is_speculation: - divisor += active_scores.sum(dim=-1, keepdim=True) - else: - divisor += active_scores - - softmax_prior = prior_scores / divisor - softmax_active = active_scores / divisor - - softmax_prior = softmax_prior.to(Q.dtype) - softmax_active = softmax_active.to(Q.dtype) - - attn_prior = torch.matmul(softmax_prior, V_prior) - attn_active = torch.matmul(softmax_active, V_active) - - attn_output = attn_prior + attn_active - - # transpose BHSD -> BSHD - attn_output = attn_output.transpose(1, 2).contiguous() - - # merge multi head hidden - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) - - # Z = Z.Wo - attn_output = self.o_proj(attn_output) - - past_key_value = (K, V) - - return attn_output, past_key_value - class NeuronLlamaDecoderLayer(nn.Module): """ @@ -447,12 +249,12 @@ def __init__(self, config: LlamaConfig): self.post_attention_layernorm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - **kwargs, + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -478,12 +280,53 @@ def forward( return outputs -class LlamaModel(LlamaPreTrainedModel): - def __init__(self, config: NeuronLlamaConfig): - super().__init__(config) +class NeuronLlamaModel(NeuronBaseModel, LlamaPreTrainedModel): + """ + The neuron version of the LlamaModel + """ + def setup_attr_for_model(self, config: NeuronLlamaConfig): + # Needed for init_inference_optimization() + self.on_device_sampling = config.on_device_sampling + self.tp_degree = config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.max_batch_size + self.buckets = config.buckets + + def init_model(self, config: NeuronLlamaConfig): + + def forward(self, x): + """ + Forward pass of the ResBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output after the residual connection and activation. + """ + return x + self.act(self.linear(x)) + + +class NeuronLlamaModel(NeuronBaseModel, LlamaPreTrainedModel): + """ + The neuron version of the LlamaModel + """ + def setup_attr_for_model(self, config: NeuronLlamaConfig): + # Needed for init_inference_optimization() + self.on_device_sampling = config.on_device_sampling + self.tp_degree = config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.max_batch_size + self.buckets = config.buckets + + def init_model(self, config: NeuronLlamaConfig): + self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - tp_degree = config.tp_degree if parallel_state.model_parallel_is_initialized(): self.embed_tokens = ParallelEmbedding( @@ -492,267 +335,34 @@ def __init__(self, config: NeuronLlamaConfig): self.padding_idx, dtype=config.torch_dtype, shard_across_embedding=True, - # We choose to shard across embedding dimesion because this stops XLA from introducing + # We choose to shard across embedding dimension because this stops XLA from introducing # rank specific constant parameters into the HLO. We could shard across vocab, but that # would require us to use non SPMD parallel_model_trace. pad=True, ) + self.lm_head = ColumnParallelLinear(config.hidden_size, config.vocab_size, bias=False, pad=True) else: self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.layers = nn.ModuleList([NeuronLlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - active_mask: Optional[List[torch.FloatTensor]] = None, - ): - batch_size, seq_length = input_ids.shape[:2] - - past_key_values_length = 0 - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0) - - inputs_embeds = self.embed_tokens(input_ids) - - # NeuronLlamaModel class manages the KV cache. So the attention_mask will be generated and passed - # through to LlamaModel. We override the HF's code that generates attention mask because HF does - # not support left aligned RHS padding. This enables Neuron to achieve higher performance and - # extensibility. - # - # 4d mask is passed through the layers - # attention_mask = _prepare_4d_causal_attention_mask( - # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - # ) - - # embed positions - hidden_states = inputs_embeds - - # decoder layers - next_decoder_cache = () - - for idx, decoder_layer in enumerate(self.layers): - past_key_value = past_key_values[idx] if past_key_values is not None else None - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - active_mask=active_mask, - ) - - hidden_states = layer_outputs[0] - - next_decoder_cache += (layer_outputs[1],) + self.is_medusa = config.is_medusa + self.num_medusa_heads = config.num_medusa_heads + self.medusa_speculation_length = config.medusa_speculation_length - hidden_states = self.norm(hidden_states) - - return (hidden_states, next_decoder_cache) - - -class NeuronLlamaModel(LlamaModel): - """ - NeuronLlamaModel extends the LlamaModel to be traceable. - The forward function of this class is traced. - """ - - def __init__(self, config: NeuronLlamaConfig): - super().__init__(config) - tp_degree = config.tp_degree - self.batch_size = config.batch_size - self.n_positions = config.n_positions - self.pretraining_tp = config.pretraining_tp - self.vocab_size = config.vocab_size - self.speculation_length = config.speculation_length - self.padding_side = config.padding_side - self.on_device_sampling = config.on_device_sampling - if config.on_device_sampling: - self.sampler = Sampler(config) - self.hidden_dim_per_head = config.hidden_size // config.num_attention_heads - - gqa_sharding_strategy = determine_sharding_strategy(tp_degree, config.num_key_value_heads) - _, num_key_value_heads = get_shardable_head_counts( - tp_degree, config.num_attention_heads, config.num_key_value_heads, gqa_sharding_strategy - ) - - self.num_kv_heads_per_partition = num_key_value_heads - - if parallel_state.model_parallel_is_initialized(): - world_size = parallel_state.get_tensor_model_parallel_size() # Same as tp_degree - self.num_kv_heads_per_partition = utils.divide(num_key_value_heads, world_size) - self.lm_head = ColumnParallelLinear(config.hidden_size, config.vocab_size, bias=False, pad=True) - else: - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - self.kv_shape = ( - self.config.max_batch_size, - self.num_kv_heads_per_partition, - self.config.buckets[-1], - # self.n_positions, - self.hidden_dim_per_head, - ) - self.past_key_values = nn.ParameterList( - [ - nn.Parameter(torch.zeros(self.kv_shape, dtype=config.torch_dtype), requires_grad=False) - for _ in range(config.num_hidden_layers * 2) - ] - ) - - def _bucket_slice_kv_cacheline(self, idx): - dim = 2 - if self.padding_side == "right": - return slice_lhs(self.past_key_values[idx], self.n_positions, dim) - else: - max_idx = self.past_key_values[idx].shape[dim] - return slice_rhs(self.past_key_values[idx], self.n_positions, max_idx, dim) - - def _gather_bucket_slice_into_kv_cacheline(self, idx, bucket_slice): - dim = 2 - max_idx = self.past_key_values[idx].shape[dim] - if self.padding_side == "right": - remaining = slice_rhs(self.past_key_values[idx], max_idx - self.n_positions, max_idx, dim) - return torch.cat([bucket_slice, remaining], dim=2) - else: - remaining = slice_lhs(self.past_key_values[idx], max_idx - self.n_positions, dim) - return torch.cat([remaining, bucket_slice], dim=2) - - def create_attn_mask(self, attention_mask, is_for_context_encoding, is_for_speculation, position_ids): - if is_for_context_encoding: - mask = torch.full((self.n_positions, self.n_positions), True, device=attention_mask.device).tril(diagonal=0) - mask = mask[None, None, :, :].expand(self.batch_size, 1, self.n_positions, self.n_positions) - - if self.padding_side == "right": - return mask + if self.is_medusa: + if parallel_state.model_parallel_is_initialized(): + medusa_head_cls = ColumnParallelLinear else: - expanded_mask = ( - attention_mask[:, None, None, :] - .expand(self.batch_size, 1, self.n_positions, self.n_positions) - .to(torch.bool) + medusa_head_cls = nn.Linear + for i in range(self.num_medusa_heads): + medusa_head = nn.Sequential( + *([ResBlock(config.hidden_size)] * 1), + medusa_head_cls(config.hidden_size, config.vocab_size, bias=False), ) - return torch.logical_and(mask, expanded_mask) - elif is_for_speculation: - return ( - attention_mask[:, None, None, :] - .expand(self.batch_size, 1, self.speculation_length, self.n_positions) - .to(torch.bool) - ) - else: - return attention_mask[:, None, None, :].expand(self.batch_size, 1, 1, self.n_positions).to(torch.bool) - - def forward(self, input_ids, attention_mask, position_ids, seq_ids): - is_for_context_encoding = input_ids.shape[-1] > 1 and self.speculation_length != input_ids.shape[-1] - is_for_speculation = input_ids.shape[-1] == self.speculation_length - # It is either for context encoding or for token generation - if is_for_context_encoding: - past_key_values = None - else: - past_key_values = [] - for key_layer_idx in range(0, len(self.past_key_values), 2): - key_state = self._bucket_slice_kv_cacheline(key_layer_idx) - value_state = self._bucket_slice_kv_cacheline(key_layer_idx + 1) - past_key_values.append([key_state, value_state]) - - # Prepare attention mask(s) - attention_mask = self.create_attn_mask( - attention_mask, is_for_context_encoding, is_for_speculation, position_ids - ) - active_mask = None - if is_for_speculation: - active_mask = torch.full( - (self.speculation_length, self.speculation_length), True, device=attention_mask.device - ).tril(diagonal=0) - active_mask = active_mask[None, None, :, :].expand( - self.batch_size, 1, self.speculation_length, self.speculation_length - ) - - hidden_states, past_key_values = super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - active_mask=active_mask, - ) - - updated_kv_cache = [] - for idx, kv_per_layer in enumerate(past_key_values): - k_cache = self._bucket_slice_kv_cacheline(idx * 2) - v_cache = self._bucket_slice_kv_cacheline(idx * 2 + 1) - - if is_for_context_encoding: - if self.config.is_continuous_batching: - # scatter back to the desired seq_ids - seq_id_index_shape = seq_ids.shape[:1] + k_cache.shape[1:] - seq_id_index = seq_ids.view(-1, 1, 1, 1).expand(seq_id_index_shape) - k_cache = torch.scatter(k_cache, 0, seq_id_index, kv_per_layer[0]) - v_cache = torch.scatter(v_cache, 0, seq_id_index, kv_per_layer[1]) - else: - # assign back to full kv_cacheline - k_cache = kv_per_layer[0] - v_cache = kv_per_layer[1] - else: - if self.padding_side == "left": - # TODO: fix it with scatter after right padding - k_cache = k_cache[:, :, 1:, :] - v_cache = v_cache[:, :, 1:, :] - k_cache = torch.cat([k_cache, kv_per_layer[0]], dim=2) - v_cache = torch.cat([v_cache, kv_per_layer[1]], dim=2) - else: - scatter_index = position_ids.view(-1, 1, position_ids.shape[-1], 1).expand_as(kv_per_layer[0]) - k_cache = torch.scatter(k_cache, 2, scatter_index, kv_per_layer[0]) - v_cache = torch.scatter(v_cache, 2, scatter_index, kv_per_layer[1]) - - k_cache = self._gather_bucket_slice_into_kv_cacheline(idx * 2, k_cache) - v_cache = self._gather_bucket_slice_into_kv_cacheline(idx * 2 + 1, v_cache) - - updated_kv_cache.append(k_cache) - updated_kv_cache.append(v_cache) - - if self.padding_side == "left": - index = torch.tensor([hidden_states.shape[1] - 1], device=hidden_states.device) - index = index.unsqueeze(1).expand(self.batch_size, 1, self.config.hidden_size) - hidden_states = torch.gather(hidden_states, dim=1, index=index) - else: - # simple token generation - if position_ids.shape[-1] != self.speculation_length: - index = torch.max(position_ids, dim=1, keepdim=True).indices - index = index.unsqueeze(1).expand(self.batch_size, 1, self.config.hidden_size) - hidden_states = torch.gather(hidden_states, dim=1, index=index) - # speculative decoding case; only batch_size=1 - # will need to extend the logic to support multi-batch later - # maybe just use position_ids for index? - else: - index = torch.min(position_ids) - index = torch.arange(index, index + self.speculation_length, device=hidden_states.device) - index = index[None, :, None].expand(self.batch_size, self.speculation_length, self.config.hidden_size) - hidden_states = torch.gather(hidden_states, dim=1, index=index) - - logits = self.lm_head(hidden_states) - logits = logits.float() - - logits_or_next_tokens = logits - if self.on_device_sampling: - # perform sampling on Neuron to get tokens - logits_or_next_tokens = self.sampler.sample(logits[:, -1, :]) - - return [logits_or_next_tokens] + updated_kv_cache + setattr(self, f"medusa_head_{i}", medusa_head) class NeuronLlamaForCausalLM(NeuronBaseForCausalLM, LlamaPreTrainedModel): @@ -764,362 +374,8 @@ class NeuronLlamaForCausalLM(NeuronBaseForCausalLM, LlamaPreTrainedModel): LlamaForCausalLM (_type_): _description_ """ - def __init__(self, model_path: str, config: NeuronLlamaConfig): - super().__init__(config) - self.config = config - self.vocab_size = config.vocab_size - self.padding_side = config.padding_side - self.kv_cache_populated = False - - self.sampler = None - - self.models = [] - self.enable_context_encoding() - if config.trace_tokengen_model: - self.enable_token_generation() - if config.speculation_length > 0: - self.enable_speculation() - self.model_path = model_path + _model_cls = NeuronLlamaModel @staticmethod def load_hf_model(model_path): return LlamaForCausalLM.from_pretrained(model_path) - - def enable_context_encoding(self): - new_config = copy.deepcopy(self.config) - new_config.batch_size = self.config.ctx_batch_size - new_config.n_active_tokens = self.config.n_positions - - if not new_config.enable_context_encoding_bucketing: - new_config.buckets = [new_config.buckets[-1]] - - self.context_encoding_model = ModelWrapper(new_config, NeuronLlamaModel, tag=CONTEXT_ENCODING_MODEL_TAG) - - self.models.append(self.context_encoding_model) - - def enable_token_generation(self): - new_config = copy.deepcopy(self.config) - new_config.batch_size = self.config.tkg_batch_size - new_config.n_active_tokens = 1 - new_config.bucket_n_active_tokens = False - - if not new_config.enable_token_generation_bucketing: - new_config.buckets = [new_config.buckets[-1]] - - self.token_generation_model = ModelWrapper(new_config, NeuronLlamaModel, tag=TOKEN_GENERATION_MODEL_TAG) - - self.models.append(self.token_generation_model) - - def enable_speculation(self): - new_config = copy.deepcopy(self.config) - new_config.batch_size = self.config.spec_batch_size - new_config.n_active_tokens = self.config.speculation_length - self.speculation_model = ModelWrapper(new_config, NeuronLlamaModel, tag=SPECULATION_MODEL_TAG) - - self.models.append(self.speculation_model) - - def forward( - self, - input_ids: torch.LongTensor = None, - seq_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # infer attention_mask from position_ids if not provided - if attention_mask is None: - assert position_ids is not None, "need to call forward with position_ids if attention_mask is not provided" - batch_size, seq_len = position_ids.shape - if position_ids.shape[-1] == 1: - seq_len = self.config.n_positions - position_ids_to_compare = position_ids.expand(batch_size, seq_len) - 1 - else: - seq_len = position_ids.shape[-1] - position_ids_to_compare = position_ids - mask = torch.arange(seq_len).view(1, -1).expand(batch_size, seq_len) - attention_mask = (position_ids_to_compare >= mask).to(dtype=position_ids.dtype) - - logging.debug("---input---") - logging.debug("input_ids shape = %s type=%s", input_ids.shape, input_ids.type()) - logging.debug("attention_mask shape = %s type=%s", attention_mask.shape, attention_mask.type()) - logging.debug("position_ids shape = %s type=%s", position_ids.shape, position_ids.type()) - logging.debug("input_ids =%s", input_ids) - logging.debug("attention_mask =%s", attention_mask) - logging.debug("position_ids =%s", position_ids) - logging.debug(f"seq_ids: {seq_ids}") - - if self.config.trace_tokengen_model and not self.token_generation_model.is_neuron(): - logging.debug(f"first layer kv_cache: {self.token_generation_model.model.past_key_values[0][:, 0, :, 0]}") - - if seq_ids is None: - seq_ids = torch.arange(input_ids.shape[0]) - - if input_ids.shape[-1] > 1 and input_ids.shape[-1] != self.config.speculation_length: - outputs = self.context_encoding_model(input_ids, attention_mask, position_ids, seq_ids) - - if self.context_encoding_model.is_neuron(): - # Copy the KV cache from the context_encoding_model to token generation model - if self.config.trace_tokengen_model: - for encoder_model, token_gen_model in zip( - self.context_encoding_model.model.models, self.token_generation_model.model.models - ): - encoder_kv_cache_line = encoder_model.states - token_gen_kv_cache_line = token_gen_model.states - for name, _ in token_gen_kv_cache_line._parameters.items(): - token_gen_kv_cache_line._parameters[name] = encoder_kv_cache_line._parameters[name] - # Also need to copy to the speculation model for speculation - if self.config.speculation_length > 0: - for encoder_model, speculation_model in zip( - self.context_encoding_model.model.models, self.speculation_model.model.models - ): - encoder_kv_cache_line = encoder_model.states - speculation_kv_cache_line = speculation_model.states - for name, _ in speculation_kv_cache_line._parameters.items(): - speculation_kv_cache_line._parameters[name] = encoder_kv_cache_line._parameters[name] - self.kv_cache_populated = True - elif input_ids.shape[-1] == self.config.speculation_length: - outputs = self.speculation_model(input_ids, attention_mask, position_ids, seq_ids) - else: - outputs = self.token_generation_model(input_ids, attention_mask, position_ids, seq_ids) - - if self.config.trace_tokengen_model and not self.token_generation_model.is_neuron(): - # When traced the output kv tensors are aliased to the kv parameter list. - # The code below mimicks that on CPU. - new_past_key_values = outputs[1:] - for i, new_past_key_value in enumerate(new_past_key_values): - self.token_generation_model.model.past_key_values[i].data = new_past_key_value - self.context_encoding_model.model.past_key_values[i].data = new_past_key_value - - logits_or_next_tokens, *_ = outputs - - logging.debug("---output---") - logging.debug(f"{'tokens' if self.config.on_device_sampling else 'logits'} = %s, ", logits_or_next_tokens) - - next_tokens = logits_or_next_tokens - - OutputParams = CausalLMOutputWithPast( - loss=0, - logits=None if self.config.on_device_sampling else logits_or_next_tokens, - past_key_values=[], - hidden_states=logits_or_next_tokens, - attentions=None, - ) - OutputParams.tokens = next_tokens - return OutputParams - - # We override this function because we want to change the way attention_mask - # is updated each iteration. - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_for_token_generation: Optional[bool] = False, - is_encoder_decoder: bool = False, - ) -> Dict[str, Any]: - # update past_key_values - - if getattr(outputs, "state", None) is not None: - model_kwargs["state"] = outputs.state - - # update token_type_ids with last value - if "token_type_ids" in model_kwargs: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) - - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - if is_for_token_generation: - if self.padding_side == "left": - attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - attention_mask = attention_mask[:, 1:] - else: - attention_mask = torch.cat( - [attention_mask.new_ones((attention_mask.shape[0], 1)), attention_mask], dim=-1 - ) - attention_mask = attention_mask[:, :-1] - model_kwargs["attention_mask"] = attention_mask - return model_kwargs - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if self.kv_cache_populated: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if self.kv_cache_populated: - position_ids = torch.amax(position_ids, 1, keepdim=True) - position_ids = position_ids + 1 - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - def reset(self): - # We need to reset the KV cache flag for a new batch of inference. - # When the flag is reset, the subsequent run will invoke the - # context encoding model. - self.kv_cache_populated = False - - def reset_kv_cache(self): - # Zero out kv cache for debug. - # For new batch inference, use reset() instead - if not self.context_encoding_model.is_neuron(): - for i, kv_tensor in enumerate(self.context_encoding_model.model.past_key_values): - self.context_encoding_model.model.past_key_values[i] = torch.zeros_like(kv_tensor) - - if not self.token_generation_model.is_neuron(): - for i, kv_tensor in enumerate(self.token_generation_model.model.past_key_values): - self.token_generation_model.model.past_key_values[i] = torch.zeros_like(kv_tensor) - - def sample( - self, - input_ids: torch.LongTensor, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_warper: Optional[LogitsProcessorList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - **model_kwargs, - ) -> Union[SampleOutput, torch.LongTensor]: - r""" - We override the GenerationMixin sample function to add support for right side padding. - """ - # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if max_length is not None: - stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None - - # keep track of which sequences are already finished - unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) - - this_peer_finished = False - # auto-regressive generation - while True: - # prepare model inputs - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - is_for_token_generation = self.kv_cache_populated - - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - ) - - if not self.config.on_device_sampling: - if self.sampler is None: - self.config.do_sample = True - self.sampler = Sampler(self.config) - next_tokens = self.sampler.sample(outputs.logits[:, -1, :]) - else: - next_tokens = outputs.tokens - - # finished sentences should have their next token be a padding token - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - is_for_token_generation=is_for_token_generation, - ) - - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) - ) - - # stop when each sentence is finished - if unfinished_sequences.max() == 0: - this_peer_finished = True - - # stop if we exceed the maximum length - if stopping_criteria(input_ids, None): - this_peer_finished = True - - if this_peer_finished: - break - - return input_ids diff --git a/examples/inference/mixtral/mixtral_runner.py b/examples/inference/mixtral/mixtral_runner.py index bf109c6..07a9fef 100644 --- a/examples/inference/mixtral/mixtral_runner.py +++ b/examples/inference/mixtral/mixtral_runner.py @@ -1,5 +1,3 @@ -import os - import torch from mixtral.neuron_modeling_mixtral import ( NeuronMixtralConfig, @@ -17,10 +15,17 @@ class MixtralRunner(InferenceRunner): def load_hf_model(self): return NeuronMixtralForCausalLM.load_hf_model(self.model_path) - def load_neuron_model_on_cpu(self, max_context_length, max_new_tokens, batch_size, **kwargs): + def load_neuron_model_on_cpu(self, max_prompt_length, sequence_length, batch_size, **kwargs): # On CPU we can only run tensor parallelism with degree 1 - config = self.get_config_for_nxd(batch_size, 1, max_context_length, max_new_tokens, **kwargs) + config = self.get_config_for_nxd( + batch_size, + 1, + max_prompt_length=max_prompt_length, + sequence_length=sequence_length, + enable_bucketing=False, + **kwargs) config.torch_dtype = torch.float32 + config.on_cpu = True # to avoid running custom RMSNorm on cpu self.init_ditributed_env() neuron_model = NeuronMixtralModel(config) @@ -45,7 +50,7 @@ def load_neuron_model(self, traced_model_path): model.load(traced_model_path) if config.torch_dtype == torch.bfloat16: - os.environ["XLA_DOWNCAST_BF16"] = "1" + model.bfloat16() return model @@ -64,6 +69,12 @@ def get_model_cls(self): def get_padding_side(self): return "right" + def get_default_hf_generation_config_kwargs(self): + config = super().get_default_hf_generation_config_kwargs() + config['pad_token_id'] = 0 + + return config + if __name__ == "__main__": MixtralRunner.cmd_execute() diff --git a/examples/inference/mixtral/neuron_modeling_mixtral.py b/examples/inference/mixtral/neuron_modeling_mixtral.py index 3d628d1..8644080 100644 --- a/examples/inference/mixtral/neuron_modeling_mixtral.py +++ b/examples/inference/mixtral/neuron_modeling_mixtral.py @@ -12,57 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Mixtral model for NXD inference.""" -import copy import gc -import logging -import math import warnings -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch -from modules.autobucketing import slice_lhs, slice_rhs +from modules.custom_calls import CustomRMSNorm from modules.gqa import ( GQA, BaseGroupQueryAttention, - GroupQueryAttention_O, - GroupQueryAttention_QKV, - get_shardable_head_counts, ) -from modules.model_base import NeuronBaseForCausalLM -from modules.model_wrapper import ( - CONTEXT_ENCODING_MODEL_TAG, - SPECULATION_MODEL_TAG, - TOKEN_GENERATION_MODEL_TAG, - ModelWrapper, -) -from neuronxcc.nki.kernels.attention import attention_isa_kernel +from modules.model_base import NeuronBaseModel, NeuronBaseForCausalLM +from neuronx_distributed.utils.sampling import Sampler + +# Try except for the compatibility with older compiler version +try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel +except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel from torch import nn from torch_neuronx.xla_impl.ops import nki_jit from transformers import MixtralForCausalLM, MixtralPreTrainedModel -from transformers.cache_utils import Cache from transformers.generation import SampleDecoderOnlyOutput, SampleEncoderDecoderOutput -from transformers.generation.logits_process import LogitsProcessorList -from transformers.generation.stopping_criteria import ( - StoppingCriteriaList, - validate_stopping_criteria, -) -from transformers.modeling_outputs import ( - CausalLMOutputWithPast, - ModelOutput, - MoeModelOutputWithPast, -) +from modules.attention.attention_base import NeuronAttentionBase +from modules.attention.utils import RotaryEmbedding +from modules.config import NeuronInferenceConfig from transformers.models.mixtral.configuration_mixtral import MixtralConfig from transformers.models.mixtral.modeling_mixtral import ( MixtralRMSNorm, - MixtralRotaryEmbedding, - apply_rotary_pos_emb, - repeat_kv, ) -from neuronx_distributed.modules.moe.expert_mlps import ExpertMLPsCapacityFactor +from neuronx_distributed.modules.moe.expert_mlps import ExpertMLPs from neuronx_distributed.modules.moe.model import MoE -from neuronx_distributed.modules.moe.model_utils import MoESequenceParallelMode from neuronx_distributed.modules.moe.routing import RouterTopK from neuronx_distributed.parallel_layers import parallel_state, utils from neuronx_distributed.parallel_layers.layers import ( @@ -81,9 +62,9 @@ def convert_mixtral_to_neuron_state_dict(neuron_state_dict, cfg): """ Helper function which returns the model weights from the mixtral model in a state dictionary compatible with the stucture of the neuron MoE model. """ - assert cfg.glu_mlp == True, f"Only GLU MLP is supported for Mixtral Top-K model" + assert cfg.glu_mlp is True, "Only GLU MLP is supported for Mixtral Top-K model" - for l in range(cfg.num_hidden_layers): + for l in range(cfg.num_hidden_layers): # noqa: E741 # Copy router weights neuron_state_dict[f"layers.{l}.mlp.router.linear_router.weight"] = ( neuron_state_dict[f"layers.{l}.block_sparse_moe.gate.weight"].detach().clone() @@ -113,7 +94,7 @@ def convert_mixtral_to_neuron_state_dict(neuron_state_dict, cfg): del neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w1.weight"] del neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w3.weight"] - neuron_state_dict[f"layers.{l}.mlp.expert_mlps.gate_up_proj.weight"] = gate_up_proj + neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up_proj down_proj = torch.empty(cfg.num_local_experts, intermediate_size, hidden_size, dtype=dtype, device=device) for e in range(cfg.num_local_experts): @@ -124,7 +105,7 @@ def convert_mixtral_to_neuron_state_dict(neuron_state_dict, cfg): down_proj_slice = torch.narrow(down_proj, 0, e, 1) down_proj_slice.copy_(down_proj_weights) del neuron_state_dict[f"layers.{l}.block_sparse_moe.experts.{e}.w2.weight"] - neuron_state_dict[f"layers.{l}.mlp.expert_mlps.down_proj.weight"] = down_proj + neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.down_proj.weight"] = down_proj gc.collect() @@ -137,244 +118,77 @@ def preshard_hook_fn(module: torch.nn.Module, model_state_dict: dict, prefix: st return False -class NeuronMixtralConfig(MixtralConfig): +def get_rmsnorm_cls(config): + # Initialize to the appropriate implementation of RMSNorm + # If infer on NXD -> CustomRMSNorm + # If infer on CPU -> HF_RMSNorm (CustomRMSNorm does not work on CPU) + return MixtralRMSNorm if config.on_cpu else CustomRMSNorm + + +class NeuronMixtralConfig(NeuronInferenceConfig, MixtralConfig): def __init__( - self, - batch_size: int = 1, - tp_degree: int = 1, - max_context_length: int = 128, - max_new_tokens: int = 128, - permute_strategy: str = "matmul", - moe_sequence_parallel_mode: MoESequenceParallelMode = MoESequenceParallelMode.NO_SP, - capacity_factor: float = 4.0, - glu_mlp: bool = True, - padding_side: str = "right", - speculation_length: int = 0, - **kwargs, + self, + batch_size: int = 1, + tp_degree: int = 1, + max_context_length: int = 128, + max_new_tokens: int = 128, + capacity_factor: float = None, + glu_mlp: bool = True, + padding_side: str = "right", + speculation_length: int = 0, + **kwargs, ): - self.batch_size = batch_size - self.tp_degree = tp_degree self.max_new_tokens = max_new_tokens self.max_context_length = max_context_length self.max_length = max_new_tokens + max_context_length - self.permute_strategy = permute_strategy - self.moe_sequence_parallel_mode = moe_sequence_parallel_mode - self.capacity_factor = float(capacity_factor) + # capacity_factor = None corresponds to full capacity (no token dropping) + self.capacity_factor = float(capacity_factor) if capacity_factor is not None else None self.glu_mlp = glu_mlp - self.padding_side = padding_side - self.speculation_length = speculation_length - self.trace_tokengen_model = True - self.n_positions = self.max_length - self.n_active_tokens = self.max_length - self.max_batch_size = batch_size - self.ctx_batch_size = kwargs.pop("ctx_batch_size", self.max_batch_size) - self.tkg_batch_size = kwargs.pop("tkg_batch_size", self.max_batch_size) - - # bucketing specific params - self.enable_context_encoding_bucketing = False - self.enable_token_generation_bucketing = False - self.buckets = [self.n_positions] - self.bucket_n_active_tokens = self.enable_context_encoding_bucketing - - self.is_continuous_batching = kwargs.pop("is_continuous_batching", False) - - super().__init__(**kwargs) + self.on_cpu = False + super().__init__( + tp_degree=tp_degree, + batch_size=batch_size, + padding_side=padding_side, + seq_len=max_context_length+max_new_tokens, + max_context_length=max_context_length, + speculation_length=speculation_length, + **kwargs, + ) -class NeuronMixtralAttention(nn.Module): - """ - Compared with MixtralAttention, this class just - 1. replaces the linear layers in attention with NxD parallel linear layers - 2. updates attention heads and KV heads to work with given TP degree - 3. uses decomposed attention during token generation for lower latency - """ - def __init__(self, config: MixtralConfig, layer_idx: int): +class NeuronMixtralAttention(NeuronAttentionBase): + def __init__(self, config: MixtralConfig): super().__init__() self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logging.warning( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will lead" - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads + self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads + self.head_dim = self.hidden_size // self.num_attention_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta - self.is_causal = True - self.attention_dropout = config.attention_dropout - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) + self.padding_side = config.padding_side + self.torch_dtype = config.torch_dtype if not parallel_state.model_parallel_is_initialized(): raise ValueError( - f"NeuronMixtralAttention has to be initialized in a distributed env. Please use neuronx_distributed" - f" module to initialize a distributed env." + "NeuronMixtralAttention has to be initialized in a distributed env. Please use neuronx_distributed" + " module to initialize a distributed env." ) - self.world_size = parallel_state.get_tensor_model_parallel_size() + self.tp_degree = parallel_state.get_tensor_model_parallel_size() + self.fused_qkv = False + self.clip_qkv = None - self.qkv_proj = GroupQueryAttention_QKV( - hidden_size=self.hidden_size, - head_dim=self.head_dim, - num_attention_heads=self.num_heads, - num_key_value_heads=self.num_key_value_heads, - tp_degree=self.world_size, - dtype=config.torch_dtype, - gather_output=False, - desired_sharding_strategy=GQA_SHARDING_STRATEGY, - ) - self.o_proj = GroupQueryAttention_O( - hidden_size=self.hidden_size, - head_dim=self.head_dim, - num_attention_heads=self.num_heads, - num_key_value_heads=self.num_key_value_heads, - tp_degree=self.world_size, - dtype=config.torch_dtype, - input_is_parallel=True, - desired_sharding_strategy=GQA_SHARDING_STRATEGY, - ) - self.num_heads = utils.divide(self.qkv_proj.get_num_attention_heads(), self.world_size) - self.num_key_value_heads = utils.divide(self.qkv_proj.get_num_key_value_heads(), self.world_size) - self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.init_gqa_properties() - self.rotary_emb = MixtralRotaryEmbedding( + self.rotary_emb = RotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - active_mask: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - 1. replace the q_proj, k_proj, v_proj with qkv_proj - 2. replace the `attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)` with - `attn_output = attn_output.reshape(bsz, q_len, self.hidden_size // self.world_size)` - """ - bsz, q_len, _ = hidden_states.size() - - Q, K, V = self.qkv_proj(hidden_states) - - Q = Q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - K = K.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - V = V.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = K.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(V, seq_len=kv_seq_len) - Q, K = apply_rotary_pos_emb(Q, K, cos, sin, position_ids) - - if past_key_value is None: - # Context encoding - K_active = repeat_kv(K, self.num_key_value_groups) - V_active = repeat_kv(V, self.num_key_value_groups) - - # use flash attention if (i) sequence length is large enough to get best performance, - # (ii) Q, K, and V have the same shape. Conditions can be changed in future. - - if q_len >= 4096 and Q.shape == K_active.shape == V_active.shape: - # original shape of q, k, v is BHSD, and expected output is also BHSD. - logging.debug(f"Using flash_fwd for Q.shape={Q.shape}") - # make sure to cast inputs to self.config.torch_dtype (this is needed because the downcast to bf16 might happen - # after the kernel hlo creation step). Also convert shapes as expected by the kernel. - Q = Q.permute(0, 1, 3, 2).reshape((bsz*self.num_heads, self.head_dim, q_len)).to(self.config.torch_dtype) - Q = Q / math.sqrt(self.head_dim) - K_active = K_active.permute(0, 1, 3, 2).reshape((bsz*self.num_heads, self.head_dim, q_len)).to(self.config.torch_dtype) - V_active = V_active.reshape((bsz*self.num_heads, q_len, self.head_dim)).to(self.config.torch_dtype) - attn_output = torch.zeros(bsz*self.num_heads, q_len, self.head_dim, dtype=Q.dtype, device=Q.device) - _flash_fwd_call(Q, K_active, V_active, 1.0, attn_output, kernel_name="CausalAttentionMMSoftmaxMMWithoutSwap") - attn_output = attn_output.reshape((bsz, self.num_heads, q_len, self.head_dim)) - else: - logging.debug(f"Not using flash_fwd for Q.shape={Q.shape}") - - # (Q.K'/√dkv) + mask - active_scores = torch.matmul(Q, K_active.transpose(2, 3)) / math.sqrt(self.head_dim) - active_scores = torch.where(attention_mask, active_scores, torch.finfo(active_scores.dtype).min) - - # Softmax - active_scores = nn.functional.softmax(active_scores, dim=-1, dtype=torch.float32).to(Q.dtype) - attn_output = torch.matmul(active_scores, V_active) - else: - is_speculation = position_ids.shape[-1] > 1 - - # Decomposed attention for token generation - K_prior = past_key_value[0] - V_prior = past_key_value[1] - - # Replicate KV for GQA/MQA - K_prior = repeat_kv(K_prior, self.num_key_value_groups) - V_prior = repeat_kv(V_prior, self.num_key_value_groups) - K_active = repeat_kv(K, self.num_key_value_groups) - V_active = repeat_kv(V, self.num_key_value_groups) - - # (Q.K'/√dkv) + mask - prior_scores = torch.matmul(Q, K_prior.transpose(2, 3)) / math.sqrt(self.head_dim) - prior_scores = torch.where(attention_mask, prior_scores, torch.finfo(prior_scores.dtype).min) - active_scores = torch.matmul(Q, K_active.transpose(2, 3)) / math.sqrt(self.head_dim) - - # Mask active scores for speculation - if is_speculation: - active_scores = torch.where(active_mask, active_scores, torch.finfo(active_scores.dtype).min) - - # Softmax across prior and active scores - prior_scores = prior_scores.to(torch.float32) - active_scores = active_scores.to(torch.float32) - - max_score = torch.max(prior_scores, dim=-1, keepdim=True)[0] - if is_speculation: - max_active_score = torch.max(active_scores, dim=-1, keepdim=True)[0] - max_score = torch.maximum(max_score, max_active_score) - else: - max_score = torch.maximum(max_score, active_scores) - - prior_scores = prior_scores - max_score - active_scores = active_scores - max_score - - prior_scores = torch.exp(prior_scores) - active_scores = torch.exp(active_scores) - - divisor = prior_scores.sum(dim=-1, keepdim=True) - if is_speculation: - divisor += active_scores.sum(dim=-1, keepdim=True) - else: - divisor += active_scores - - softmax_prior = prior_scores / divisor - softmax_active = active_scores / divisor - - softmax_prior = softmax_prior.to(Q.dtype) - softmax_active = softmax_active.to(Q.dtype) - - attn_prior = torch.matmul(softmax_prior, V_prior) - attn_active = torch.matmul(softmax_active, V_active) - - attn_output = attn_prior + attn_active - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) - attn_output = self.o_proj(attn_output) - - past_key_value = (K, V) - - return attn_output, past_key_value - class NeuronMixtralDecoderLayer(nn.Module): """ @@ -384,43 +198,39 @@ class NeuronMixtralDecoderLayer(nn.Module): def __init__(self, config: NeuronMixtralConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = NeuronMixtralAttention(config=config, layer_idx=layer_idx) + self.self_attn = NeuronMixtralAttention(config=config) router = RouterTopK( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, - sequence_parallel_mode=config.moe_sequence_parallel_mode, ) - expert_mlps = ExpertMLPsCapacityFactor( + expert_mlps = ExpertMLPs( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - capacity_factor=config.capacity_factor, glu_mlp=config.glu_mlp, - sequence_parallel_mode=config.moe_sequence_parallel_mode, - permute_strategy=config.permute_strategy, + capacity_factor=config.capacity_factor, normalize_top_k_affinities=True, ) self.mlp = MoE( router=router, expert_mlps=expert_mlps, - sequence_parallel_mode=config.moe_sequence_parallel_mode, ) self.mlp.eval() # Set MoE module in eval mode - self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = get_rmsnorm_cls(config)(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = get_rmsnorm_cls(config)(config.hidden_size, eps=config.rms_norm_eps) def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - **kwargs, + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -462,30 +272,26 @@ def forward( return outputs -class NeuronMixtralModel(MixtralPreTrainedModel): +class NeuronMixtralModel(NeuronBaseModel, MixtralPreTrainedModel): """ NeuronMixtralModel extends the MixtralModel to be traceable. The forward function of this class is traced. """ - def __init__(self, config: NeuronMixtralConfig): - # Initialization to ensure proper processing from Mixtral model - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.batch_size = config.batch_size - self.max_length = config.max_length - self.padding_side = config.padding_side + _model_cls = MixtralPreTrainedModel - self.speculation_length = config.speculation_length - self.n_positions = config.n_positions + def setup_attr_for_model(self, config: NeuronMixtralConfig): + self.on_device_sampling = config.on_device_sampling + self.tp_degree = config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.max_batch_size + self.buckets = config.buckets - if not parallel_state.model_parallel_is_initialized(): - raise ValueError( - f"NeuronMixtralAttention has to be initialized in a distributed env. Please use neuronx_distributed" - f" to initialize a distributed env." - ) + def init_model(self, config: NeuronMixtralConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size self.embed_tokens = ParallelEmbedding( config.vocab_size, @@ -497,224 +303,8 @@ def __init__(self, config: NeuronMixtralConfig): self.layers = nn.ModuleList( [NeuronMixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - # Initialization to ensure proper KV cache management - world_size = parallel_state.get_tensor_model_parallel_size() - _, num_key_value_heads = get_shardable_head_counts( - world_size, config.num_attention_heads, config.num_key_value_heads, GQA_SHARDING_STRATEGY - ) - num_kv_heads_per_partition = utils.divide(num_key_value_heads, world_size) - head_dim = config.hidden_size // config.num_attention_heads - kv_shape = (config.max_batch_size, num_kv_heads_per_partition, config.buckets[-1], head_dim) - self.past_key_values = nn.ParameterList( - [ - nn.Parameter(torch.zeros(kv_shape, dtype=config.torch_dtype), requires_grad=False) - for _ in range(config.num_hidden_layers * 2) - ] - ) - - self.lm_head = ColumnParallelLinear(config.hidden_size, config.vocab_size, bias=False) - - self.post_init() - - def _bucket_slice_kv_cacheline(self, idx): - dim = 2 - if self.padding_side == "right": - return slice_lhs(self.past_key_values[idx], self.n_positions, dim) - else: - max_idx = self.past_key_values[idx].shape[dim] - return slice_rhs(self.past_key_values[idx], self.n_positions, max_idx, dim) - - def _gather_bucket_slice_into_kv_cacheline(self, idx, bucket_slice): - dim = 2 - max_idx = self.past_key_values[idx].shape[dim] - if self.padding_side == "right": - remaining = slice_rhs(self.past_key_values[idx], max_idx - self.n_positions, max_idx, dim) - return torch.cat([bucket_slice, remaining], dim=2) - else: - remaining = slice_lhs(self.past_key_values[idx], max_idx - self.n_positions, dim) - return torch.cat([remaining, bucket_slice], dim=2) - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def create_attn_mask(self, attention_mask, is_for_context_encoding, is_for_speculation, position_ids): - if is_for_context_encoding: - mask = torch.full((self.n_positions, self.n_positions), True, device=attention_mask.device).tril(diagonal=0) - mask = mask[None, None, :, :].expand(self.batch_size, 1, self.n_positions, self.n_positions) - - if self.padding_side == "right": - return mask - else: - expanded_mask = ( - attention_mask[:, None, None, :] - .expand(self.batch_size, 1, self.n_positions, self.n_positions) - .to(torch.bool) - ) - return torch.logical_and(mask, expanded_mask) - elif is_for_speculation: - return ( - attention_mask[:, None, None, :] - .expand(self.batch_size, 1, self.speculation_length, self.n_positions) - .to(torch.bool) - ) - else: - return attention_mask[:, None, None, :].expand(self.batch_size, 1, 1, self.n_positions).to(torch.bool) - - def forward(self, input_ids, attention_mask, position_ids, seq_ids): - """ - This function is to maintain the KV cache during inference - """ - is_for_context_encoding = input_ids.shape[-1] > 1 and self.speculation_length != input_ids.shape[-1] - is_for_speculation = input_ids.shape[-1] == self.speculation_length - # It is either for context encoding or for token generation - if is_for_context_encoding: - past_key_values = None - else: - past_key_values = [] - for key_layer_idx in range(0, len(self.past_key_values), 2): - key_state = self._bucket_slice_kv_cacheline(key_layer_idx) - value_state = self._bucket_slice_kv_cacheline(key_layer_idx + 1) - past_key_values.append([key_state, value_state]) - - attention_mask = self.create_attn_mask( - attention_mask, is_for_context_encoding, is_for_speculation, position_ids - ) - active_mask = None - if is_for_speculation: - active_mask = torch.full( - (self.speculation_length, self.speculation_length), True, device=attention_mask.device - ).tril(diagonal=0) - active_mask = active_mask[None, None, :, :].expand( - self.batch_size, 1, self.speculation_length, self.speculation_length - ) - - hidden_states, past_key_values = self._forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - active_mask=active_mask, - ) - - updated_kv_cache = [] - for idx, kv_per_layer in enumerate(past_key_values): - k_cache = self._bucket_slice_kv_cacheline(idx * 2) - v_cache = self._bucket_slice_kv_cacheline(idx * 2 + 1) - - if is_for_context_encoding: - # scatter back to the desired seq_ids - seq_id_index_shape = seq_ids.shape[:1] + k_cache.shape[1:] - seq_id_index = seq_ids.view(-1, 1, 1, 1).expand(seq_id_index_shape) - k_cache = torch.scatter(k_cache, 0, seq_id_index, kv_per_layer[0]) - v_cache = torch.scatter(v_cache, 0, seq_id_index, kv_per_layer[1]) - else: - if self.padding_side == "left": - # TODO: fix it with scatter after right padding - k_cache = k_cache[:, :, 1:, :] - v_cache = v_cache[:, :, 1:, :] - k_cache = torch.cat([k_cache, kv_per_layer[0]], dim=2) - v_cache = torch.cat([v_cache, kv_per_layer[1]], dim=2) - else: - scatter_index = position_ids.view(-1, 1, position_ids.shape[-1], 1).expand_as(kv_per_layer[0]) - k_cache = torch.scatter(k_cache, 2, scatter_index, kv_per_layer[0]) - v_cache = torch.scatter(v_cache, 2, scatter_index, kv_per_layer[1]) - - k_cache = self._gather_bucket_slice_into_kv_cacheline(idx * 2, k_cache) - v_cache = self._gather_bucket_slice_into_kv_cacheline(idx * 2 + 1, v_cache) - - updated_kv_cache.append(k_cache) - updated_kv_cache.append(v_cache) - - if self.padding_side == "left": - index = torch.tensor([hidden_states.shape[1] - 1], device=hidden_states.device) - index = index.unsqueeze(1).expand(self.batch_size, 1, self.config.hidden_size) - hidden_states = torch.gather(hidden_states, dim=1, index=index) - else: - # simple token generation - if position_ids.shape[-1] != self.speculation_length: - index = torch.max(position_ids, dim=1, keepdim=True).indices - index = index.unsqueeze(1).expand(self.batch_size, 1, self.config.hidden_size) - hidden_states = torch.gather(hidden_states, dim=1, index=index) - # speculative decoding case; only batch_size=1 - # will need to extend the logic to support multi-batch later - # maybe just use position_ids for index? - else: - index = torch.min(position_ids) - index = torch.arange(index, index + self.speculation_length, device=hidden_states.device) - index = index[None, :, None].expand(self.batch_size, self.speculation_length, self.config.hidden_size) - hidden_states = torch.gather(hidden_states, dim=1, index=index) - - logits = self.lm_head(hidden_states) - logits = logits.float() - - return [logits] + updated_kv_cache - - def _forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - active_mask: Optional[List[torch.FloatTensor]] = None, - ) -> Union[Tuple, MoeModelOutputWithPast]: - """ - This function is similar to the forward function of Huggingface MixtralModel - """ - _, seq_length = input_ids.shape - - past_key_values_length = 0 - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - inputs_embeds = self.embed_tokens(input_ids) - - # NeuronMixtralModel class manages the KV cache. So the attention_mask will be generated and passed - # through to NeuronMixtralModel. We override the HF's code that generates attention mask because HF does - # not support left aligned RHS padding. This enables Neuron to achieve higher performance and - # extensibility. - # - # 4d mask is passed through the layers - # attention_mask = _prepare_4d_causal_attention_mask( - # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - # ) - - # embed positions - hidden_states = inputs_embeds - - # decoder layers - next_decoder_cache = () - - for idx, decoder_layer in enumerate(self.layers): - past_key_value = past_key_values[idx] if past_key_values is not None else None - - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - active_mask=active_mask, - ) - - hidden_states = layer_outputs[0] - next_decoder_cache += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - return (hidden_states, next_decoder_cache) + self.norm = get_rmsnorm_cls(config)(self.hidden_size, eps=config.rms_norm_eps) + self.lm_head = ColumnParallelLinear(config.hidden_size, self.vocab_size, bias=False) class NeuronMixtralForCausalLM(NeuronBaseForCausalLM, MixtralPreTrainedModel): @@ -722,20 +312,11 @@ class NeuronMixtralForCausalLM(NeuronBaseForCausalLM, MixtralPreTrainedModel): This class can be used as MixtralForCausalLM """ - def __init__(self, model_path: str, config: NeuronMixtralConfig): - super().__init__(config) - self.config = config - self.vocab_size = config.vocab_size - self.padding_side = config.padding_side - self.kv_cache_populated = False + _model_cls = NeuronMixtralModel - self.models = [] - self.enable_context_encoding() - if config.trace_tokengen_model: - self.enable_token_generation() - if config.speculation_length > 0: - self.enable_speculation() - self.model_path = model_path + def __init__(self, model_path: str, config: NeuronMixtralConfig): + super().__init__(model_path, config) + self.sampler = Sampler(self.config) @staticmethod def load_hf_model(model_path): @@ -749,320 +330,10 @@ def get_state_dict(cls, model_path: str, config: MixtralConfig) -> dict: return model_sd def get_compiler_args(self): - if self.config.torch_dtype == torch.bfloat16: - return "--enable-saturate-infinity --enable-mixed-precision-accumulation --model-type transformer -O1" - else: - return "--enable-saturate-infinity --enable-mixed-precision-accumulation --auto-cast=none --model-type transformer -O1" - - def enable_context_encoding(self): - new_config = copy.deepcopy(self.config) - new_config.batch_size = self.config.ctx_batch_size - new_config.n_active_tokens = self.config.n_positions - - if not new_config.enable_context_encoding_bucketing: - new_config.buckets = [new_config.buckets[-1]] - - self.context_encoding_model = ModelWrapper( - config=new_config, - model_cls=NeuronMixtralModel, - tag=CONTEXT_ENCODING_MODEL_TAG, - compiler_args=self.get_compiler_args(), - ) - - self.models.append(self.context_encoding_model) - - def enable_token_generation(self): - new_config = copy.deepcopy(self.config) - new_config.batch_size = self.config.tkg_batch_size - new_config.n_active_tokens = 1 - new_config.bucket_n_active_tokens = False - - if not new_config.enable_token_generation_bucketing: - new_config.buckets = [new_config.buckets[-1]] - - self.token_generation_model = ModelWrapper( - config=new_config, - model_cls=NeuronMixtralModel, - tag=TOKEN_GENERATION_MODEL_TAG, - compiler_args=self.get_compiler_args(), - ) - - self.models.append(self.token_generation_model) - - def enable_speculation(self): - new_config = copy.deepcopy(self.config) - new_config.batch_size = self.config.spec_batch_size - new_config.n_active_tokens = self.config.speculation_length - self.speculation_model = ModelWrapper( - config=new_config, - model_cls=NeuronMixtralModel, - tag=SPECULATION_MODEL_TAG, - compiler_args=self.get_compiler_args(), - ) - - self.models.append(self.speculation_model) - - def forward( - self, - input_ids: torch.LongTensor = None, - seq_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - logging.debug("---input---") - logging.debug("input_ids shape = %s type=%s", input_ids.shape, input_ids.type()) - logging.debug("attention_mask shape = %s type=%s", attention_mask.shape, attention_mask.type()) - logging.debug("position_ids shape = %s type=%s", position_ids.shape, position_ids.type()) - logging.debug("input_ids =%s", input_ids) - logging.debug("attention_mask =%s", attention_mask) - logging.debug("position_ids =%s", position_ids) - logging.debug(f"seq_ids: {seq_ids}") - - if self.config.trace_tokengen_model and not self.token_generation_model.is_neuron(): - logging.debug(f"first layer kv_cache: {self.token_generation_model.model.past_key_values[0][:, 0, :, 0]}") - - if seq_ids is None: - seq_ids = torch.arange(input_ids.shape[0]) - - if input_ids.shape[-1] > 1 and input_ids.shape[-1] != self.config.speculation_length: - outputs = self.context_encoding_model(input_ids, attention_mask, position_ids, seq_ids) - - if self.context_encoding_model.is_neuron(): - # Copy the KV cache from the context_encoding_model to token generation model - if self.config.trace_tokengen_model: - for encoder_model, token_gen_model in zip( - self.context_encoding_model.model.models, self.token_generation_model.model.models - ): - encoder_kv_cache_line = encoder_model.states - token_gen_kv_cache_line = token_gen_model.states - for name, _ in token_gen_kv_cache_line._parameters.items(): - token_gen_kv_cache_line._parameters[name] = encoder_kv_cache_line._parameters[name] - # Also need to copy to the speculation model for speculation - if self.config.speculation_length > 0: - for encoder_model, speculation_model in zip( - self.context_encoding_model.model.models, self.speculation_model.model.models - ): - encoder_kv_cache_line = encoder_model.states - speculation_kv_cache_line = speculation_model.states - for name, _ in speculation_kv_cache_line._parameters.items(): - speculation_kv_cache_line._parameters[name] = encoder_kv_cache_line._parameters[name] - self.kv_cache_populated = True - elif input_ids.shape[-1] == self.config.speculation_length: - outputs = self.speculation_model(input_ids, attention_mask, position_ids, seq_ids) - else: - outputs = self.token_generation_model(input_ids, attention_mask, position_ids, seq_ids) - - logits = outputs[0] - - if self.config.trace_tokengen_model and not self.token_generation_model.is_neuron(): - # When traced the output kv tensors are aliased to the kv parameter list. - # The code below mimicks that on CPU. - new_past_key_values = outputs[1:] - for i, new_past_key_value in enumerate(new_past_key_values): - self.token_generation_model.model.past_key_values[i].data = new_past_key_value - self.context_encoding_model.model.past_key_values[i].data = new_past_key_value - - logging.debug("---output---") - logging.debug("logits = %s", logits) - - return CausalLMOutputWithPast( - loss=None, - logits=logits, - past_key_values=[], - hidden_states=logits, - attentions=None, - ) - - # We override this function because we want to change the way attention_mask - # is updated each iteration. - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_for_token_generation: Optional[bool] = False, - is_encoder_decoder: bool = False, - ) -> Dict[str, Any]: - # update past_key_values - - if getattr(outputs, "state", None) is not None: - model_kwargs["state"] = outputs.state - - # update token_type_ids with last value - if "token_type_ids" in model_kwargs: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) - - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - if is_for_token_generation: - if self.padding_side == "left": - attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - attention_mask = attention_mask[:, 1:] - else: - attention_mask = torch.cat( - [attention_mask.new_ones((attention_mask.shape[0], 1)), attention_mask], dim=-1 - ) - attention_mask = attention_mask[:, :-1] - model_kwargs["attention_mask"] = attention_mask - return model_kwargs - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if self.kv_cache_populated: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if self.kv_cache_populated: - position_ids = torch.amax(position_ids, 1, keepdim=True) - position_ids = position_ids + 1 - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - def reset(self): - # We need to reset the KV cache flag for a new batch of inference. - # When the flag is reset, the subsequent run will invoke the - # context encoding model. - self.kv_cache_populated = False - - def reset_kv_cache(self): - # Zero out kv cache for debug. - # For new batch inference, use reset() instead - if not self.context_encoding_model.is_neuron(): - for i, kv_tensor in enumerate(self.context_encoding_model.model.past_key_values): - self.context_encoding_model.model.past_key_values[i] = torch.zeros_like(kv_tensor) - - if not self.token_generation_model.is_neuron(): - for i, kv_tensor in enumerate(self.token_generation_model.model.past_key_values): - self.token_generation_model.model.past_key_values[i] = torch.zeros_like(kv_tensor) - - def sample( - self, - input_ids: torch.LongTensor, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_warper: Optional[LogitsProcessorList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - **model_kwargs, - ) -> Union[SampleOutput, torch.LongTensor]: - r""" - We override the GenerationMixin sample function to add support for right side padding. - """ - # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if max_length is not None: - stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None - - # keep track of which sequences are already finished - unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) - - this_peer_finished = False - # auto-regressive generation - while True: - # prepare model inputs - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - is_for_token_generation = self.kv_cache_populated - - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - ) - - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - - # finished sentences should have their next token be a padding token - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - is_for_token_generation=is_for_token_generation, - ) - - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) - ) - - # stop when each sentence is finished - if unfinished_sequences.max() == 0: - this_peer_finished = True - - # stop if we exceed the maximum length - if stopping_criteria(input_ids, None): - this_peer_finished = True - - if this_peer_finished: - break - - return input_ids + compiler_args = "--enable-saturate-infinity --enable-mixed-precision-accumulation --model-type transformer -O1" + # Add flags for cc-overlap + compiler_args += " --tensorizer-options='--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2'" + # Prevent auto-down casting when running with fp32 + if self.config.torch_dtype == torch.float32: + compiler_args += " --auto-cast=none" + return compiler_args diff --git a/examples/inference/modules/attention/attention_base.py b/examples/inference/modules/attention/attention_base.py new file mode 100644 index 0000000..a2450da --- /dev/null +++ b/examples/inference/modules/attention/attention_base.py @@ -0,0 +1,200 @@ +import logging +import math +from typing import Optional, Tuple + +import torch +from torch import nn, Tensor + +from modules.attention.utils import apply_rotary_pos_emb, repeat_kv, manual_softmax, move_heads_front + +# Try except for the compatibility with older compiler version +try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel # noqa: E402 +except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel # noqa: E402 +from torch_neuronx.xla_impl.ops import nki_jit # noqa: E402 + +from modules.gqa import ( # noqa: E402 + GroupQueryAttention_O, # noqa: E402 + GroupQueryAttention_QKV, # noqa: E402 +) # noqa: E402 + +from neuronx_distributed.parallel_layers import utils # noqa: E402 + +_flash_fwd_call = nki_jit()(attention_isa_kernel) + + +class NeuronAttentionBase(nn.Module): + """ + This base attention class implements the core Neuron related adaptation including + 1. replaces the q_proj, k_proj, v_proj with column parallel layer + 2. replaces the o_proj with row parallel layer + 3. update self.num_head to be self.num_head / tp_degree + 4. update self.num_key_value_heads to be self.num_key_value_heads / tp_degree + 5. update forward() method to adjust to changes from self.num_head + """ + + def __init__(self): + super().__init__() + self.is_causal = True + self.num_key_value_groups = None + self.num_key_value_heads = None + self.num_heads = None + self.rotary_emb = None + self.o_proj = None + self.qkv_proj = None + + def init_gqa_properties(self): + if (self.head_dim * self.num_attention_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_attention_heads})." + ) + self.qkv_proj = GroupQueryAttention_QKV( + hidden_size=self.hidden_size, + head_dim=self.head_dim, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + tp_degree=self.tp_degree, + dtype=self.torch_dtype, + gather_output=False, + fused_qkv=self.fused_qkv, + clip_qkv=self.clip_qkv + ) + self.o_proj = GroupQueryAttention_O( + hidden_size=self.hidden_size, + head_dim=self.head_dim, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + tp_degree=self.tp_degree, + dtype=self.torch_dtype, + input_is_parallel=True, + ) + self.num_heads = utils.divide(self.qkv_proj.get_num_attention_heads(), self.tp_degree) + self.num_key_value_heads = utils.divide(self.qkv_proj.get_num_key_value_heads(), self.tp_degree) + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + + def scaled_qk(self, Q, K, attention_mask): + QK = torch.matmul(Q, K.transpose(2, 3)) / math.sqrt(self.head_dim) + QK = torch.where(attention_mask, QK, torch.finfo(QK.dtype).min) + return QK + + def prep_qkv_tensors(self, position_ids, hidden_states, past_key_value): + """ take care of the shape, layout, group query, custom position encoding, etc. """ + Q, K, V = self.qkv_proj(hidden_states=hidden_states) + + # Divide hidden_dim across heads for MHA + # Change layout: BSHD -> BHSD + bsz, q_len, _ = hidden_states.size() + Q = move_heads_front(Q, bsz, q_len, self.num_heads, self.head_dim) + K = move_heads_front(K, bsz, q_len, self.num_key_value_heads, self.head_dim) + V = move_heads_front(V, bsz, q_len, self.num_key_value_heads, self.head_dim) + + # Rotate Q and K + cos, sin = self.rotary_emb(V, position_ids) + Q, K = apply_rotary_pos_emb(Q, K, cos, sin) + return Q, K, V + + def perform_prefill(self, Q, K, V, q_len, bsz, attention_mask) -> Tensor: + """ attention computation at prefilling (context encoding) phase """ + K_active = repeat_kv(K, self.num_key_value_groups) + V_active = repeat_kv(V, self.num_key_value_groups) + + # use flash attention if (i) sequence length is large enough to get the best performance, + # (ii) Q, K, and V have the same shape. Conditions can be changed in the future. + flash_attention_eligible = q_len >= 4096 and Q.shape == K_active.shape == V_active.shape + + if flash_attention_eligible: + # if we are using left padding, then the bzs needs be 1 (otherwise we get wrong result + # because flash attention does not use attention_mask). In practice, we use right + # padding so this is unlikely to cause issues + assert self.padding_side == "right" or bsz == 1 + + # original shape of q, k, v is BHSD, and expected output is also BHSD. + logging.debug(f"Using flash_fwd for Q.shape={Q.shape}") + # make sure to cast inputs to self.config.torch_dtype (this is needed because the downcast to bf16 + # might happen after the kernel hlo creation step). Also convert shapes as expected by the kernel. + Q = ( + Q.permute(0, 1, 3, 2) + .reshape((bsz * self.num_heads, self.head_dim, q_len)) + .to(self.config.torch_dtype) + ) + Q = Q / math.sqrt(self.head_dim) + K_active = ( + K_active.permute(0, 1, 3, 2) + .reshape((bsz * self.num_heads, self.head_dim, q_len)) + .to(self.config.torch_dtype) + ) + V_active = V_active.reshape((bsz * self.num_heads, q_len, self.head_dim)).to(self.config.torch_dtype) + attn_output = torch.zeros(bsz * self.num_heads, q_len, self.head_dim, dtype=Q.dtype, device=Q.device) + _flash_fwd_call( + Q, K_active, V_active, 1.0, attn_output, kernel_name="CausalAttentionMMSoftmaxMMWithoutSwap" + ) + attn_output = attn_output.reshape((bsz, self.num_heads, q_len, self.head_dim)) + else: + logging.debug(f"Not using flash_fwd for Q.shape={Q.shape}") + active_scores = self.scaled_qk(Q, K_active, attention_mask) + active_scores = nn.functional.softmax(active_scores, dim=-1, dtype=torch.float32).to(Q.dtype) + attn_output = torch.matmul(active_scores, V_active) + return attn_output + + def compute_for_token_gen(self, Q, K, V, position_ids, past_key_value, attention_mask, active_mask) -> Tensor: + """ attention computation at token generation phase """ + is_speculation = position_ids.shape[-1] > 1 + + # Attention computation: softmax((Q.K/√dkv) + mask).V + # i. prior (cached) KV + K_prior = past_key_value[0] + V_prior = past_key_value[1] + K_prior = repeat_kv(K_prior, self.num_key_value_groups) + V_prior = repeat_kv(V_prior, self.num_key_value_groups) + prior_scores = torch.matmul(Q, K_prior.transpose(2, 3)) / math.sqrt(self.head_dim) + prior_scores = torch.where(attention_mask, prior_scores, torch.finfo(prior_scores.dtype).min) + prior_scores = prior_scores.to(torch.float32) + + # ii. active (current/new) KV + K_active = repeat_kv(K, self.num_key_value_groups) + V_active = repeat_kv(V, self.num_key_value_groups) + active_scores = torch.matmul(Q, K_active.transpose(2, 3)) / math.sqrt(self.head_dim) + if is_speculation: + active_scores = torch.where(active_mask, active_scores, torch.finfo(active_scores.dtype).min) + active_scores = active_scores.to(torch.float32) + + # iii. attention scores + softmax_prior, softmax_active = manual_softmax(prior_scores, active_scores, is_speculation) + softmax_prior, softmax_active = softmax_prior.to(Q.dtype), softmax_active.to(Q.dtype) + attn_prior = torch.matmul(softmax_prior, V_prior) + attn_active = torch.matmul(softmax_active, V_active) + attn_output = attn_prior + attn_active + + return attn_output + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + active_mask: Optional[torch.LongTensor] = None, + ) -> Tuple[Tensor, Optional[Tuple[Tensor, Tensor]]]: + """ Implements each layer's forward pass for the attention block. """ + bsz, q_len, _ = hidden_states.size() + Q, K, V = self.prep_qkv_tensors(position_ids, hidden_states, past_key_value) + + if past_key_value is None: + attn_output = self.perform_prefill(Q, K, V, q_len, bsz, attention_mask) + else: + attn_output = self.compute_for_token_gen(Q, K, V, position_ids, past_key_value, attention_mask, active_mask) + + # transpose BHSD -> BSHD + attn_output = attn_output.transpose(1, 2).contiguous() + + # merge multi head hidden + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + + # Z = Z.Wo + attn_output = self.o_proj(attn_output) + + past_key_value: Tuple[Tensor, Tensor] = (K, V) + + return attn_output, past_key_value diff --git a/examples/inference/modules/attention/utils.py b/examples/inference/modules/attention/utils.py new file mode 100644 index 0000000..303d07e --- /dev/null +++ b/examples/inference/modules/attention/utils.py @@ -0,0 +1,86 @@ +from typing import Tuple + +import torch +from torch import Tensor +from torch import nn + +torch.manual_seed(0) + + +def move_heads_front(tensor: Tensor, bsz: int, seq_len: int, num_head: int, head_dim: int) -> Tensor: + """ BSHD -> BHSD """ + return tensor.view(bsz, seq_len, num_head, head_dim).transpose(1, 2).contiguous() + + +def repeat_kv(hidden_states: Tensor, n_rep: int) -> Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def _rotate_half(x) -> Tensor: + """ Rotates half the hidden dims of the input. """ + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1) -> Tuple[Tensor, Tensor]: + """ Applies Rotary Position Embedding to the query and key tensors. """ + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (_rotate_half(q) * sin) + k_embed = (k * cos) + (_rotate_half(k) * sin) + return q_embed, k_embed + + +def manual_softmax(prior_scores, active_scores, is_speculation) -> Tuple[Tensor, Tensor]: + """ + simple softmax computation: denominator is the sum of exp over all vocab and only need compute numerator (exp) + """ + max_score = torch.max(prior_scores, dim=-1, keepdim=True)[0] + max_active_score = torch.max(active_scores, dim=-1, keepdim=True)[0] + max_score = torch.maximum(max_score, max_active_score) if is_speculation else torch.maximum(max_score, active_scores) + + exp_prior = torch.exp(prior_scores - max_score) + exp_active = torch.exp(active_scores - max_score) + denominator = exp_prior.sum(dim=-1, keepdim=True) + exp_active.sum(dim=-1, keepdim=True) + + softmax_prior = exp_prior / denominator + softmax_active = exp_active / denominator + return softmax_prior, softmax_active + + +class RotaryEmbedding(nn.Module): + """ + Adapted from Llama 4.0 impl https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models + /llama/modeling_llama.py#L96-L145 + """ + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.register_buffer("inv_freq", None, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.inv_freq is None: + self.inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) diff --git a/examples/inference/modules/autobucketing.py b/examples/inference/modules/autobucketing.py index 7870743..8b16815 100644 --- a/examples/inference/modules/autobucketing.py +++ b/examples/inference/modules/autobucketing.py @@ -1,7 +1,21 @@ +from math import log2 from typing import List import torch +def generate_buckets(min_length: int, max_length: int): + if (min_length == max_length): + return [max_length] + + min_bound = int(log2(min_length)) + max_bound = round(log2(max_length)) # we use round because it creates optimal bucket spacing + + # NOTE: because range operates on [a,b), and we rounded the log2 result + # we won't get 2**i results close to the max_length. + # ex. we won't see bucket spacing of [128,256,512,513] or [128,256,510,512] + buckets = [2**i for i in range(min_bound,max_bound)] + [max_length] + return buckets + def slice_rhs(tensor, bucket: int, max_idx: int, dim: int): tensor = torch.ops.aten.slice(tensor, dim, max_idx - bucket, max_idx, 1) @@ -15,20 +29,29 @@ def slice_lhs(tensor, bucket: int, dim: int): @torch.jit.script def token_generation_bk(tensors: List[torch.Tensor], buckets: torch.Tensor, padding_side: str): + """ + The Bucket Kernel for Token Generation Models. + + 1) tensors: A list of torch tensors after running through the flattener + 2) buckets: A torch.tensor of the bucket sizes + 3) padding_side: A string specifying padding side, must be "left" or "right" + """ attention_mask = tensors[1] position_ids = tensors[2] - seq_ids = tensors[3] - # DO NOT USE argmax since we monkeypatch it, causing issues with torch.jit.script - bucket_idx = torch.argmin(((buckets - position_ids) <= 0).to(torch.int)) + # Refer to the context_encoding_bk comments on selecting a bucket_idx + # The difference is that this single line of code is valid for all batch sizes + bucket_mask = (buckets <= position_ids).to(torch.int) + bucket_idx = torch.max(torch.argmin(bucket_mask, dim=1)) bucket = buckets[bucket_idx] + # slice the attention mask based on the selected bucket size if padding_side == "right": tensors[1] = slice_lhs(attention_mask, bucket, 1) else: tensors[1] = slice_rhs(attention_mask, bucket, buckets[-1], 1) - return tensors, torch.tensor(bucket_idx).to(torch.int) + return tensors, bucket_idx.to(torch.int) def get_token_generation_bk(): @@ -36,58 +59,68 @@ def get_token_generation_bk(): @torch.jit.script -def context_encoder_bk(tensors: List[torch.Tensor], buckets, padding_side: str): +def context_encoder_bk(tensors: List[torch.Tensor], buckets, padding_side: str, pad_token: int): + """ + The Bucket Kernel for Context Encoding Models. + + 1) tensors: A list of torch tensors after running through the flattener + 2) buckets: A torch.tensor of the bucket sizes + 3) padding_side: A string specifying padding side, must be "left" or "right" + 4) pad_token: An integer representing the pad token id. Typically this is 0. + """ input_ids = tensors[0] - position_idx = (input_ids > 0).sum(dim=1)[0] - bucket_idx = torch.argmin(((buckets - position_idx) < 0).to(torch.int)) - bucket = buckets[bucket_idx] + # -----Remarks for calculating position_idx----- + # finds the number of non pad tokens and that is the active sequence_length + # The resulting tensor is of shape (batch_size,) + # + # NOTE: We derive position_ids from input_ids because + # position_ids is eliminated from the flattener for context encoding models. + # ---------------------------------------------- + position_idx = (input_ids != pad_token).sum(dim=1) + position_idx = position_idx[:, None] # shape (batch_size, 1) + buckets = buckets[None, :] # shape (1, seq_len) + + # -----Remarks for choosing the bucket_idx----- + # 1. (buckets < position_idx) produces a bucket_mask where invalid buckets are 0 + # 2. We convert the boolean tensor to int because argmin doesn't support + # boolean tensors + # 3. We choose the minimum valid bucket, which is the first 1 value + # 4. From the minimum valid buckets, we choose the largest bucket, otherwise + # we'd be truncating generated tokens from longer sequences. + # 5. DO NOT USE argmax since we monkeypatch it, + # causing issues with torch.jit.script + # --------------------------------------------- + bucket_mask = (buckets < position_idx).to(torch.int) # shape (batch_size, seq_len) + bucket_idx = torch.max(torch.argmin(bucket_mask, dim=1)) + + # select the chosen bucket after squeezing back to original form + bucket = buckets.squeeze(0)[bucket_idx] new_tensors = [] + + # ---------Remarks on handling padding sides------- + # 1. slice from the opposite side for padding + # 2. Identify seq_id tensors by shape and don't slice it + # ------------------------------------------------- if padding_side == "right": for i, tens in enumerate(tensors): - if tens.shape[-1] == 1: + # identifies the seq_ids, which don't need to be sliced + if len(tens.shape) == 1: new_tensors.append(tens) - else: + else: # all other tensors are of shape (batch_size,seq_len) so we slice on seq_len new_tensors.append(slice_lhs(tens, bucket, 1)) else: max_idx = buckets[-1] for i, tens in enumerate(tensors): - if i == len(tensors) - 1: + # identifies the seq_ids, which don't need to be sliced + if len(tens.shape) == 1: new_tensors.append(tens) else: new_tensors.append(slice_rhs(tens, bucket, max_idx, 1)) - return new_tensors, torch.tensor(bucket_idx) + return new_tensors, bucket_idx.to(torch.int) def get_context_encoder_bk(): return context_encoder_bk - - -@torch.jit.script -def state_preprocessor( - shapes_collection: List[List[List[int]]], - states: List[torch.Tensor], - bucket_idx_tensor: torch.Tensor, - padding_side: str, -) -> List[torch.Tensor]: - bucket_idx = torch.ops.aten.Int(bucket_idx_tensor) - shapes = shapes_collection[bucket_idx] - sliced_state_tensors = [] - for i in range(len(shapes)): - expected_shape = shapes[i] - state_tensor = states[i] - state_tensor_shape = state_tensor.shape - for j, npos in enumerate(expected_shape): - state_tensor_dim_length = state_tensor_shape[j] - if padding_side == "right": - state_tensor = slice_lhs(state_tensor, npos, j) - else: - state_tensor = slice_rhs(state_tensor, npos, state_tensor_dim_length, j) - sliced_state_tensors.append(state_tensor) - return sliced_state_tensors - - -def get_state_preprocessor(layout="bsh"): - return state_preprocessor diff --git a/examples/inference/modules/benchmark.py b/examples/inference/modules/benchmark.py index 59c422c..853a93e 100644 --- a/examples/inference/modules/benchmark.py +++ b/examples/inference/modules/benchmark.py @@ -7,7 +7,7 @@ class Benchmark: - def __init__(self, benchmark_func, input_param, config, num_runs=20, preprocess_func=None) -> None: + def __init__(self, benchmark_func, input_param, config, num_runs=20, preprocess_func=None, post_warmup_func=None) -> None: if isinstance(input_param, (tuple, list)): self.benchmark_func = partial(benchmark_func, *input_param) elif isinstance(input_param, dict): @@ -18,40 +18,55 @@ def __init__(self, benchmark_func, input_param, config, num_runs=20, preprocess_ self.config = config self.num_runs = num_runs self.preprocess_func = preprocess_func + self.post_warmup_func = post_warmup_func + self.latency_list = None def run(self): - # Warmp up + # Warm up if self.preprocess_func: self.preprocess_func() self.benchmark_func() - latency_list = [] - e2e_start = time.time() + if self.post_warmup_func: + self.post_warmup_func() + + latency_collector = LatencyCollector() for _ in range(self.num_runs): - start = time.time() + latency_collector.pre_hook() if self.preprocess_func: self.preprocess_func() self.benchmark_func() - latency_list.append(time.time() - start) - e2e_time = time.time() - e2e_start - - return self.process_metrics(latency_list, e2e_time, self.config) - - def process_metrics(self, latency_list, e2e_time, config): - latency_array = np.array(latency_list) - - max_length = config.max_length - batch_size = config.max_batch_size - n_runs = self.num_runs - throughput = (max_length * n_runs * batch_size) / e2e_time - - metrics = { - "latency_ms_p50": np.percentile(latency_array, 50) * 1000, - "latency_ms_p90": np.percentile(latency_array, 90) * 1000, - "latency_ms_p95": np.percentile(latency_array, 95) * 1000, - "latency_ms_p99": np.percentile(latency_array, 99) * 1000, - "latency_ms_p100": np.percentile(latency_array, 100) * 1000, - "latency_ms_avg": np.average(latency_array) * 1000, - "throughput": throughput, - } - return metrics + latency_collector.hook() + self.latency_list = latency_collector.latency_list + + +class LatencyCollector: + def __init__(self): + self.start = None + self.latency_list = [] + + def pre_hook(self, *args): + self.start = time.time() + + def hook(self, *args): + self.latency_list.append(time.time() - self.start) + + +def generate_report(latency_list, config): + latency_array = np.array(latency_list) + + n_runs = len(latency_list) + max_length = config.max_length + batch_size = config.max_batch_size + total_time = np.sum(latency_array) + throughput = (n_runs * max_length * batch_size) / total_time + + return { + "latency_ms_p50": np.percentile(latency_array, 50) * 1000, + "latency_ms_p90": np.percentile(latency_array, 90) * 1000, + "latency_ms_p95": np.percentile(latency_array, 95) * 1000, + "latency_ms_p99": np.percentile(latency_array, 99) * 1000, + "latency_ms_p100": np.percentile(latency_array, 100) * 1000, + "latency_ms_avg": np.average(latency_array) * 1000, + "throughput": throughput, + } diff --git a/examples/inference/modules/config.py b/examples/inference/modules/config.py new file mode 100644 index 0000000..5198848 --- /dev/null +++ b/examples/inference/modules/config.py @@ -0,0 +1,68 @@ +from transformers import PretrainedConfig + + +class NeuronInferenceConfig(PretrainedConfig): + """ + Base config class for inference in NxD. + + This class contains attributes that are needed for various inference + optimization/features in NxD. + """ + + def __init__( + self, + tp_degree: int = 1, + batch_size: int = 1, + seq_len: int = 128, + padding_side: str = "right", + **kwargs + ) -> None: + # Basic config for inference in NxD + self.tp_degree = tp_degree + self.batch_size = batch_size + self.padding_side = padding_side + # TODO: see if we can consolidate n_active_tokens and n_positions into one + self.n_active_tokens = seq_len # Need to provide example input shape for tracing + self.n_positions = seq_len + + # fallback to seq_len is for compatibility with vllm + self.max_context_length = kwargs.pop("max_context_length", seq_len) + self.max_new_tokens = seq_len - self.max_context_length + if self.max_new_tokens == 0: + self.max_new_tokens = None + self.max_length = seq_len + + # Continuous batching + # TODO: Check if we really need different batch size for CTE and TKG, given + # that we anyway provide two different config instance for them. + self.ctx_batch_size = kwargs.get("ctx_batch_size", batch_size) + self.tkg_batch_size = kwargs.get("tkg_batch_size", batch_size) + self.max_batch_size = kwargs.get("max_batch_size", batch_size) + self.is_continuous_batching = kwargs.get("is_continuous_batching", False) + + # On-device sampling + self.on_device_sampling = kwargs.get("on_device_sampling", False) + + # Bucketing + self.enable_bucketing = kwargs.get("enable_bucketing", False) + self.buckets = [seq_len] + self.bucket_n_active_tokens = False + + # Quantization + self.quantized = kwargs.get("quantized", False) + self.quantized_checkpoints_path = kwargs.get("quantized_checkpoints_path") + self.quantization_type = kwargs.get("quantization_type", "per_tensor_symmetric") + # TODO: Add validation for quantized_checkpoints_path after the design discussions + + # Speculative decoding + self.trace_tokengen_model = kwargs.get("trace_tokengen_model", True) + self.speculation_length = kwargs.get("speculation_length", 0) + self.spec_batch_size = batch_size + + # Medusa decoding + self.is_medusa = kwargs.get("is_medusa", False) + self.medusa_speculation_length = kwargs.get("medusa_speculation_length", 0) + self.num_medusa_heads = kwargs.get("num_medusa_heads", 0) + self.medusa_tree = kwargs.get("medusa_tree", 0) + + super().__init__(**kwargs) diff --git a/examples/inference/modules/gqa.py b/examples/inference/modules/gqa.py index 6a641c7..343e083 100644 --- a/examples/inference/modules/gqa.py +++ b/examples/inference/modules/gqa.py @@ -11,6 +11,9 @@ RowParallelLinear, ) from neuronx_distributed.parallel_layers.pad import get_number_of_extra_heads +from neuronx_distributed.quantization.quantization_layers import ( + BaseQuantizeParallelLinear, +) class GQA(enum.Enum): @@ -74,7 +77,55 @@ def get_shardable_head_counts( return updated_num_attention_heads, updated_num_key_value_heads -def maybe_pad_interleaved(tensor, pad_dim: int, source_heads: int, target_heads: int, source_group_size: int): +def is_per_channel(scale: torch.Tensor) -> bool: + """See if the scale is per channel""" + if scale.shape == (1,): + return False + return True + + +def get_tensor_per_channel_scale_axis(scale: torch.Tensor) -> int: + """Get the channel axis for the per channel scale""" + scale_shape = scale.shape + # Only one dimension would have scale values + for i, dim_length in enumerate(scale_shape): + if dim_length > 1: + return i + raise RuntimeError(f"Cannot get channel axis for the scale: {scale}") + + +def should_pad_scale(tensor_scale: torch.Tensor, pad_dim: int) -> bool: + """Should scale be padded""" + if ( + (tensor_scale is not None) + and (is_per_channel(tensor_scale)) + and (get_tensor_per_channel_scale_axis(tensor_scale) == pad_dim) + ): + return True + return False + + +def verify_scale_dimension(tensor: torch.Tensor, tensor_scale: torch.Tensor): + channel_axis = get_tensor_per_channel_scale_axis(scale=tensor_scale) + assert tensor_scale.shape[channel_axis] == tensor.shape[channel_axis] + + +def maybe_pad_interleaved( + tensor, + pad_dim: int, + source_heads: int, + target_heads: int, + source_group_size: int, + tensor_scale: torch.Tensor = None, +): + tensor = _maybe_pad_interleaved(tensor, pad_dim, source_heads, target_heads, source_group_size) + if should_pad_scale(tensor_scale=tensor_scale, pad_dim=pad_dim): + tensor_scale = _maybe_pad_interleaved(tensor_scale, pad_dim, source_heads, target_heads, source_group_size) + + return tensor, tensor_scale + + +def _maybe_pad_interleaved(tensor, pad_dim: int, source_heads: int, target_heads: int, source_group_size: int): if tensor is None: return tensor shape = tensor.shape[:pad_dim] + (source_heads, tensor.shape[pad_dim] // source_heads) + tensor.shape[pad_dim + 1 :] @@ -93,7 +144,14 @@ def maybe_pad_interleaved(tensor, pad_dim: int, source_heads: int, target_heads: return tensor.view(shape) -def maybe_pad_tail(tensor, source_heads: int, target_heads: int, pad_dim: int): +def maybe_pad_tail(tensor, source_heads: int, target_heads: int, pad_dim: int, tensor_scale=None): + tensor = _maybe_pad_tail(tensor, source_heads, target_heads, pad_dim) + if should_pad_scale(tensor_scale=tensor_scale, pad_dim=pad_dim): + tensor_scale = _maybe_pad_tail(tensor_scale, source_heads, target_heads, pad_dim) + return tensor, tensor_scale + + +def _maybe_pad_tail(tensor, source_heads: int, target_heads: int, pad_dim: int): if tensor is None: return tensor size_to_pad = int((tensor.shape[pad_dim] // source_heads) * target_heads - tensor.shape[pad_dim]) @@ -105,7 +163,14 @@ def maybe_pad_tail(tensor, source_heads: int, target_heads: int, pad_dim: int): return F.pad(tensor, pad) -def replicate_kv(tensor, source_heads: int, repeats: int, head_dim=0): +def replicate_kv(tensor, source_heads: int, repeats: int, head_dim=0, tensor_scale=None): + tensor = _replicate_kv(tensor=tensor, source_heads=source_heads, repeats=repeats, head_dim=head_dim) + if should_pad_scale(tensor_scale=tensor_scale, pad_dim=head_dim): + tensor_scale = _replicate_kv(tensor=tensor_scale, source_heads=source_heads, repeats=repeats, head_dim=head_dim) + return tensor, tensor_scale + + +def _replicate_kv(tensor, source_heads: int, repeats: int, head_dim=0): if tensor is None: return tensor shape = ( @@ -185,6 +250,8 @@ def __init__( bias: bool = False, desired_sharding_strategy: Optional[GQA] = None, gather_output: bool = True, + fused_qkv: bool = False, + clip_qkv: Optional[float] = None, ): super().__init__( hidden_size=hidden_size, @@ -196,40 +263,86 @@ def __init__( bias=bias, desired_sharding_strategy=desired_sharding_strategy, ) + if fused_qkv and gather_output: + raise ValueError( + "Gathering states followed by fused qkv is not allowed as it has a different weight sharding scheme." + ) self.gather_output = gather_output + self.fused_qkv = fused_qkv + self.clip_qkv = clip_qkv if parallel_state.model_parallel_is_initialized(): - self.q_proj = ColumnParallelLinear( - self.hidden_size, - self.num_attention_heads * self.head_dim, - bias=self.bias, - gather_output=self.gather_output, - dtype=dtype, - ) - self.k_proj = ColumnParallelLinear( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - bias=self.bias, - gather_output=self.gather_output, - dtype=dtype, - ) - self.v_proj = ColumnParallelLinear( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - bias=self.bias, - gather_output=self.gather_output, - dtype=dtype, - ) + if self.fused_qkv: + self.Wqkv = ColumnParallelLinear( + self.hidden_size, + (self.num_attention_heads + 2 * self.num_key_value_heads) * self.head_dim, + bias=self.bias, + gather_output=self.gather_output, + dtype=dtype, + ) + # Set heads info as weight parameter attributes to be used in weights sharding + setattr(self.Wqkv.weight, "fused_qkv", True) + setattr(self.Wqkv.weight, "num_attention_heads", self.num_attention_heads) + setattr(self.Wqkv.weight, "num_key_value_heads", self.num_key_value_heads) + setattr(self.Wqkv.weight, "head_dim", self.head_dim) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_attention_heads * self.head_dim, + bias=self.bias, + gather_output=self.gather_output, + dtype=dtype, + ) + self.k_proj = ColumnParallelLinear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=self.bias, + gather_output=self.gather_output, + dtype=dtype, + ) + self.v_proj = ColumnParallelLinear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=self.bias, + gather_output=self.gather_output, + dtype=dtype, + ) else: - self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=self.bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.bias) + if self.fused_qkv: + self.Wqkv = nn.Linear( + self.hidden_size, + (self.num_attention_heads + 2 * self.num_key_value_heads) * self.head_dim, + bias=self.bias, + ) + else: + self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=self.bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.bias) def forward(self, hidden_states: torch.Tensor): - Q = self.q_proj(hidden_states) - K = self.k_proj(hidden_states) - V = self.v_proj(hidden_states) + if self.fused_qkv: + QKV = self.Wqkv(hidden_states) + if self.clip_qkv is not None: + QKV = QKV.clamp(min=-self.clip_qkv, max=self.clip_qkv) + # torch.split has accuracy issue and leads to more reshapes in hlo. + # Using torch.tensor_split here. NAPP-3145 + Q, K, V = torch.tensor_split( + QKV, + ( + self.num_attention_heads * self.head_dim // self.tp_degree, + (self.num_attention_heads + self.num_key_value_heads) * self.head_dim // self.tp_degree, + ), + dim=2, + ) + else: + Q = self.q_proj(hidden_states) + K = self.k_proj(hidden_states) + V = self.v_proj(hidden_states) + if self.clip_qkv is not None: + Q = Q.clamp(min=-self.clip_qkv, max=self.clip_qkv) + K = K.clamp(min=-self.clip_qkv, max=self.clip_qkv) + V = V.clamp(min=-self.clip_qkv, max=self.clip_qkv) return Q, K, V def get_weight( @@ -237,9 +350,17 @@ def get_weight( ) -> Tuple[torch.Tensor]: if hasattr(layer, "get_weight_from_state_dict"): weight = layer.get_weight_from_state_dict(prefix=f"{prefix}.{layer_name}.", state_dict=model_state_dict) + if isinstance(layer, BaseQuantizeParallelLinear): + scale = layer.get_scale_from_state_dict(prefix=f"{prefix}.{layer_name}.", state_dict=model_state_dict) + else: + scale = None else: weight = model_state_dict[f"{prefix}.{layer_name}.weight"] - return weight + if isinstance(layer, BaseQuantizeParallelLinear): + scale = model_state_dict[f"{prefix}.{layer_name}.scale"] + else: + scale = None + return weight, scale def get_bias( self, prefix: str, layer: torch.nn.Module, layer_name: str, model_state_dict: dict @@ -251,12 +372,19 @@ def get_bias( return bias def set_weight( - self, tensor: torch.Tensor, prefix: str, layer: torch.nn.Module, layer_name, model_state_dict: dict + self, + tensor: torch.Tensor, + prefix: str, + layer: torch.nn.Module, + layer_name, + model_state_dict: dict, + scale: torch.Tensor = None, ) -> Tuple[torch.Tensor]: - if hasattr(layer, "set_weight_to_state_dict"): - layer.set_weight_to_state_dict(prefix=f"{prefix}.{layer_name}.", tensor=tensor, state_dict=model_state_dict) - else: - model_state_dict[f"{prefix}.{layer_name}.weight"] = tensor + # TODO: set weight to state dict support is pending. + model_state_dict[f"{prefix}.{layer_name}.weight"] = tensor + if scale is not None: + model_state_dict[f"{prefix}.{layer_name}.scale"] = scale + verify_scale_dimension(tensor=tensor, tensor_scale=scale) def set_bias( self, tensor: torch.Tensor, prefix: str, layer: torch.nn.Module, layer_name: str, model_state_dict: dict @@ -270,66 +398,103 @@ def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool: prefix_parts = prefix.split(".") prefix = ".".join(prefix_parts[:-1]) hf_prefix = ".".join(prefix_parts[:-2]) + if self.fused_qkv: + self.replace_prefixes( + old_prefix=f"{hf_prefix}.Wqkv", new_prefix=f"{prefix}.Wqkv", model_state_dict=model_state_dict + ) + qkv_weight, _ = self.get_weight( + prefix=prefix, layer=self.Wqkv, layer_name="Wqkv", model_state_dict=model_state_dict + ) + q_proj_weight, k_proj_weight, v_proj_weight = qkv_weight.split( + [ + self._src_num_attention_heads * self.head_dim, + self._src_num_key_value_heads * self.head_dim, + self._src_num_key_value_heads * self.head_dim, + ], + dim=0, + ) + q_proj_scale, k_proj_scale, v_proj_scale = None, None, None + qkv_bias = self.get_bias( + prefix=prefix, layer=self.Wqkv, layer_name="Wqkv", model_state_dict=model_state_dict + ) + if qkv_bias is not None: + q_proj_bias, k_proj_bias, v_proj_bias = qkv_bias.split( + [ + self._src_num_attention_heads * self.head_dim, + self._src_num_key_value_heads * self.head_dim, + self._src_num_key_value_heads * self.head_dim, + ], + dim=0, + ) + else: + q_proj_bias, k_proj_bias, v_proj_bias = None, None, None + else: + self.replace_prefixes( + old_prefix=f"{hf_prefix}.q_proj", new_prefix=f"{prefix}.q_proj", model_state_dict=model_state_dict + ) + self.replace_prefixes( + old_prefix=f"{hf_prefix}.k_proj", new_prefix=f"{prefix}.k_proj", model_state_dict=model_state_dict + ) + self.replace_prefixes( + old_prefix=f"{hf_prefix}.v_proj", new_prefix=f"{prefix}.v_proj", model_state_dict=model_state_dict + ) - # import pdb;pdb.set_trace() - - self.replace_prefixes( - old_prefix=f"{hf_prefix}.q_proj", new_prefix=f"{prefix}.q_proj", model_state_dict=model_state_dict - ) - self.replace_prefixes( - old_prefix=f"{hf_prefix}.k_proj", new_prefix=f"{prefix}.k_proj", model_state_dict=model_state_dict - ) - self.replace_prefixes( - old_prefix=f"{hf_prefix}.v_proj", new_prefix=f"{prefix}.v_proj", model_state_dict=model_state_dict - ) - - q_proj_weight = self.get_weight( - prefix=prefix, layer=self.q_proj, layer_name="q_proj", model_state_dict=model_state_dict - ) - k_proj_weight = self.get_weight( - prefix=prefix, layer=self.k_proj, layer_name="k_proj", model_state_dict=model_state_dict - ) - v_proj_weight = self.get_weight( - prefix=prefix, layer=self.v_proj, layer_name="v_proj", model_state_dict=model_state_dict - ) + q_proj_weight, q_proj_scale = self.get_weight( + prefix=prefix, layer=self.q_proj, layer_name="q_proj", model_state_dict=model_state_dict + ) + k_proj_weight, k_proj_scale = self.get_weight( + prefix=prefix, layer=self.k_proj, layer_name="k_proj", model_state_dict=model_state_dict + ) + v_proj_weight, v_proj_scale = self.get_weight( + prefix=prefix, layer=self.v_proj, layer_name="v_proj", model_state_dict=model_state_dict + ) - q_proj_bias = self.get_bias( - prefix=prefix, layer=self.q_proj, layer_name="q_proj", model_state_dict=model_state_dict - ) - k_proj_bias = self.get_bias( - prefix=prefix, layer=self.k_proj, layer_name="k_proj", model_state_dict=model_state_dict - ) - v_proj_bias = self.get_bias( - prefix=prefix, layer=self.v_proj, layer_name="v_proj", model_state_dict=model_state_dict - ) + q_proj_bias = self.get_bias( + prefix=prefix, layer=self.q_proj, layer_name="q_proj", model_state_dict=model_state_dict + ) + k_proj_bias = self.get_bias( + prefix=prefix, layer=self.k_proj, layer_name="k_proj", model_state_dict=model_state_dict + ) + v_proj_bias = self.get_bias( + prefix=prefix, layer=self.v_proj, layer_name="v_proj", model_state_dict=model_state_dict + ) if self.num_key_value_heads != self._src_num_key_value_heads: if self.sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE: repeats = self.tp_degree // self._src_num_key_value_heads elif self.sharding_strategy == GQA.CONVERT_TO_MHA: repeats = self._src_num_attention_heads // self._src_num_key_value_heads - k_proj_weight = replicate_kv( - k_proj_weight, source_heads=self._src_num_key_value_heads, repeats=repeats, head_dim=0 + k_proj_weight, k_proj_scale = replicate_kv( + k_proj_weight, + source_heads=self._src_num_key_value_heads, + repeats=repeats, + head_dim=0, + tensor_scale=k_proj_scale, ) - k_proj_bias = replicate_kv( + k_proj_bias, _ = replicate_kv( k_proj_bias, source_heads=self._src_num_key_value_heads, repeats=repeats, head_dim=0 ) - v_proj_weight = replicate_kv( - v_proj_weight, source_heads=self._src_num_key_value_heads, repeats=repeats, head_dim=0 + v_proj_weight, v_proj_scale = replicate_kv( + v_proj_weight, + source_heads=self._src_num_key_value_heads, + repeats=repeats, + head_dim=0, + tensor_scale=v_proj_scale, ) - v_proj_bias = replicate_kv( + v_proj_bias, _ = replicate_kv( v_proj_bias, source_heads=self._src_num_key_value_heads, repeats=repeats, head_dim=0 ) if self.sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE: - q_proj_weight = maybe_pad_interleaved( + q_proj_weight, q_proj_scale = maybe_pad_interleaved( q_proj_weight, pad_dim=0, source_heads=self._src_num_attention_heads, target_heads=self.num_attention_heads, source_group_size=self._src_num_attention_heads // self._src_num_key_value_heads, + tensor_scale=q_proj_scale, ) - q_proj_bias = maybe_pad_interleaved( + q_proj_bias, _ = maybe_pad_interleaved( q_proj_bias, pad_dim=0, source_heads=self._src_num_attention_heads, @@ -338,88 +503,113 @@ def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool: ) if self.sharding_strategy == GQA.CONVERT_TO_MHA: - q_proj_weight = maybe_pad_tail( + q_proj_weight, q_proj_scale = maybe_pad_tail( q_proj_weight, source_heads=self._src_num_attention_heads, target_heads=self.num_attention_heads, pad_dim=0, + tensor_scale=q_proj_scale, ) - q_proj_bias = maybe_pad_tail( + q_proj_bias, _ = maybe_pad_tail( q_proj_bias, source_heads=self._src_num_attention_heads, target_heads=self.num_attention_heads, pad_dim=0, ) - k_proj_weight = maybe_pad_tail( + k_proj_weight, k_proj_scale = maybe_pad_tail( k_proj_weight, - source_heads=self._src_num_attention_heads, - target_heads=self.num_attention_heads, + source_heads=self._src_num_key_value_heads, + target_heads=self.num_key_value_heads, pad_dim=0, + tensor_scale=k_proj_scale, ) - k_proj_bias = maybe_pad_tail( + k_proj_bias, _ = maybe_pad_tail( k_proj_bias, - source_heads=self._src_num_attention_heads, - target_heads=self.num_attention_heads, + source_heads=self._src_num_key_value_heads, + target_heads=self.num_key_value_heads, pad_dim=0, ) - v_proj_weight = maybe_pad_tail( + v_proj_weight, v_proj_scale = maybe_pad_tail( v_proj_weight, - source_heads=self._src_num_attention_heads, - target_heads=self.num_attention_heads, + source_heads=self._src_num_key_value_heads, + target_heads=self.num_key_value_heads, pad_dim=0, + tensor_scale=v_proj_scale, ) - v_proj_bias = maybe_pad_tail( + v_proj_bias, _ = maybe_pad_tail( v_proj_bias, - source_heads=self._src_num_attention_heads, - target_heads=self.num_attention_heads, + source_heads=self._src_num_key_value_heads, + target_heads=self.num_key_value_heads, pad_dim=0, ) - self.set_weight( - tensor=q_proj_weight, - prefix=prefix, - layer=self.q_proj, - layer_name="q_proj", - model_state_dict=model_state_dict, - ) - self.set_weight( - tensor=k_proj_weight, - prefix=prefix, - layer=self.k_proj, - layer_name="k_proj", - model_state_dict=model_state_dict, - ) - self.set_weight( - tensor=v_proj_weight, - prefix=prefix, - layer=self.v_proj, - layer_name="v_proj", - model_state_dict=model_state_dict, - ) - - if self.bias: - self.set_bias( - tensor=q_proj_bias, + if self.fused_qkv: + qkv_weight = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=0) + self.set_weight( + tensor=qkv_weight, + prefix=prefix, + layer=self.Wqkv, + layer_name="Wqkv", + model_state_dict=model_state_dict, + ) + if self.bias: + qkv_bias = torch.cat([q_proj_bias, k_proj_bias, v_proj_bias], dim=0) + self.set_bias( + tensor=qkv_bias, + prefix=prefix, + layer=self.Wqkv, + layer_name="Wqkv", + model_state_dict=model_state_dict, + ) + else: + self.set_weight( + tensor=q_proj_weight, prefix=prefix, layer=self.q_proj, layer_name="q_proj", model_state_dict=model_state_dict, + scale=q_proj_scale, ) - self.set_bias( - tensor=k_proj_bias, + self.set_weight( + tensor=k_proj_weight, prefix=prefix, layer=self.k_proj, layer_name="k_proj", model_state_dict=model_state_dict, + scale=k_proj_scale, ) - self.set_bias( - tensor=v_proj_bias, + self.set_weight( + tensor=v_proj_weight, prefix=prefix, layer=self.v_proj, layer_name="v_proj", model_state_dict=model_state_dict, + scale=v_proj_scale, ) + if self.bias: + self.set_bias( + tensor=q_proj_bias, + prefix=prefix, + layer=self.q_proj, + layer_name="q_proj", + model_state_dict=model_state_dict, + ) + self.set_bias( + tensor=k_proj_bias, + prefix=prefix, + layer=self.k_proj, + layer_name="k_proj", + model_state_dict=model_state_dict, + ) + self.set_bias( + tensor=v_proj_bias, + prefix=prefix, + layer=self.v_proj, + layer_name="v_proj", + model_state_dict=model_state_dict, + ) + return True @@ -461,8 +651,8 @@ def __init__( self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=self.bias) def forward(self, attention_output: torch.Tensor): - O = self.o_proj(attention_output) - return O + o = self.o_proj(attention_output) + return o def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool: prefix_parts = prefix.split(".") @@ -473,22 +663,30 @@ def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool: old_prefix=f"{hf_prefix}.o_proj", new_prefix=f"{prefix}.o_proj", model_state_dict=model_state_dict ) o_proj_weight = model_state_dict[f"{prefix}.o_proj.weight"] + o_proj_scale = model_state_dict.get(f"{prefix}.o_proj.scale", None) + if self.sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE: - o_proj_weight = maybe_pad_interleaved( + o_proj_weight, o_proj_scale = maybe_pad_interleaved( o_proj_weight, pad_dim=1, source_heads=self._src_num_attention_heads, target_heads=self.num_attention_heads, source_group_size=self._src_num_attention_heads // self._src_num_key_value_heads, + tensor_scale=o_proj_scale, ) if self.sharding_strategy == GQA.CONVERT_TO_MHA: - o_proj_weight = maybe_pad_tail( + o_proj_weight, o_proj_scale = maybe_pad_tail( o_proj_weight, source_heads=self._src_num_attention_heads, target_heads=self.num_attention_heads, pad_dim=1, + tensor_scale=o_proj_scale, ) + model_state_dict[f"{prefix}.o_proj.weight"] = o_proj_weight + if o_proj_scale is not None: + model_state_dict[f"{prefix}.o_proj.scale"] = o_proj_scale + verify_scale_dimension(tensor=o_proj_weight, tensor_scale=o_proj_scale) return True diff --git a/examples/inference/modules/model_base.py b/examples/inference/modules/model_base.py index d22cedf..b63652b 100644 --- a/examples/inference/modules/model_base.py +++ b/examples/inference/modules/model_base.py @@ -1,26 +1,644 @@ import os +import copy import tempfile import warnings +from typing import Any, Dict, List, Optional, Tuple, Union +import logging + import torch -from transformers import PretrainedConfig +from torch import nn +from modules.autobucketing import generate_buckets +from modules.checkpoint import load_state_dict +from transformers import PretrainedConfig, PreTrainedModel +from transformers.generation import SampleDecoderOnlyOutput, SampleEncoderDecoderOutput +from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput +from transformers.generation.logits_process import LogitsProcessorList +from transformers.generation.stopping_criteria import ( + StoppingCriteriaList, + validate_stopping_criteria, +) + +from neuronx_distributed.quantization.quantization_config import QuantizationType +from safetensors.torch import load_file from neuronx_distributed.quantization.quantization_utils import ( - convert_float_model_to_pytorch_int8_model, convert_qint8_to_int8_state_dict, + quantize_pytorch_model_per_channel_symmetric, + quantize_pytorch_model_per_tensor_symmetric, ) +from neuronx_distributed.parallel_layers import parallel_state, utils # noqa: E402 +from neuronx_distributed.trace.model_builder import ModelBuilder from neuronx_distributed.utils.speculative_decoding import NeuronSpeculation -from modules.checkpoint import load_state_dict +from neuronx_distributed.utils.sampling import Sampler # noqa: E402 + +from modules.model_wrapper import ( # noqa: E402 + CONTEXT_ENCODING_MODEL_TAG, # noqa: E402 + SPECULATION_MODEL_TAG, # noqa: E402 + MEDUSA_MODEL_TAG, # noqa: E402 + TOKEN_GENERATION_MODEL_TAG, # noqa: E402 + ModelWrapper, # noqa: E402 +) +SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] +from modules.autobucketing import slice_lhs, slice_rhs # noqa: E402 +from modules.gqa import ( # noqa: E402 + determine_sharding_strategy, # noqa: E402 + get_shardable_head_counts, # noqa: E402 +) # noqa: E402 -class NeuronBaseForCausalLM(NeuronSpeculation): +class NeuronBaseModel(PreTrainedModel): + """ + Base model that NeuronXXXModel classes inherit from. + + The forward() function will be traced and compiled by NxD. + """ + + SEQ_DIM = 2 + + def __init__(self, config: PretrainedConfig): + super().__init__(config) + + self.batch_size = config.batch_size + self.n_positions = config.n_positions + self.vocab_size = config.vocab_size + self.speculation_length = config.speculation_length + self.padding_side = config.padding_side + self.max_length = config.max_length + + self.setup_attr_for_model(config) + self.init_model(config) + self.init_inference_optimization(config) + self.post_init() + + def setup_attr_for_model(self, config: PretrainedConfig): + """ + Please provide model-specific definition for the following attributes + self.on_device_sampling + self.tp_degree + self.hidden_size + self.num_attention_heads + self.num_key_value_heads + self.max_batch_size + self.buckets + """ + raise NotImplementedError("setup_attr_for_model() is not implemented") + + def init_model(self, config: PretrainedConfig): + """ + Please provide definition for the following components: + self.embed_tokens + self.layers + self.norm + self.lm_head + """ + raise NotImplementedError("init_model() is not implemented") + + def init_inference_optimization(self, config: PretrainedConfig): + if self.on_device_sampling: + self.sampler = Sampler(config) + + gqa_sharding_strategy = determine_sharding_strategy(self.tp_degree, self.num_key_value_heads) + _, num_key_value_heads = get_shardable_head_counts( + self.tp_degree, self.num_attention_heads, self.num_key_value_heads, gqa_sharding_strategy + ) + if parallel_state.model_parallel_is_initialized(): + num_kv_heads_per_partition = utils.divide(num_key_value_heads, self.tp_degree) + else: + num_kv_heads_per_partition = num_key_value_heads + + hidden_dim_per_head = self.hidden_size // self.num_attention_heads + + self.kv_shape = ( + self.max_batch_size, + num_kv_heads_per_partition, + self.max_length, + hidden_dim_per_head, + ) + self.past_key_values = nn.ParameterList( + [ + nn.Parameter(torch.zeros(self.kv_shape, dtype=config.torch_dtype), requires_grad=False) + for _ in range(config.num_hidden_layers * 2) + ] + ) + + def _bucket_slice_kv_cacheline(self, cache): + + if self.padding_side == "right": + return slice_lhs(cache, self.n_positions, self.SEQ_DIM) + else: + max_idx = cache.shape[self.SEQ_DIM] + return slice_rhs(cache, self.n_positions, max_idx, self.SEQ_DIM) + + def _gather_bucket_slice_into_kv_cacheline(self, idx, bucket_slice): + max_idx = self.past_key_values[idx].shape[self.SEQ_DIM] + if self.padding_side == "right": + remaining = slice_rhs(self.past_key_values[idx], max_idx - self.n_positions, max_idx, self.SEQ_DIM) + return torch.cat([bucket_slice, remaining], dim=self.SEQ_DIM) + else: + remaining = slice_lhs(self.past_key_values[idx], max_idx - self.n_positions, self.SEQ_DIM) + return torch.cat([remaining, bucket_slice], dim=self.SEQ_DIM) + + def _create_context_attn_mask(self, attention_mask): + mask = torch.full((self.n_positions, self.n_positions), True, device=attention_mask.device).tril(diagonal=0) + mask = mask[None, None, :, :].expand(self.batch_size, 1, self.n_positions, self.n_positions) + + if self.padding_side == "right": + return mask + else: + expanded_mask = ( + attention_mask[:, None, None, :] + .expand(self.batch_size, 1, self.n_positions, self.n_positions) + .to(torch.bool) + ) + return torch.logical_and(mask, expanded_mask) + + def _create_spec_attn_mask(self, attention_mask): + return ( + attention_mask[:, None, None, :] + .expand(self.batch_size, 1, self.speculation_length, self.n_positions) + .to(torch.bool) + ) + + def _create_simple_attn_mask(self, attention_mask): + return attention_mask[:, None, None, :].expand(self.batch_size, 1, 1, self.n_positions).to(torch.bool) + + def create_attn_mask(self, attention_mask, is_for_context_encoding, is_for_speculation, position_ids): + if is_for_context_encoding: + return self._create_context_attn_mask(attention_mask) + elif is_for_speculation: + return self._create_spec_attn_mask(attention_mask) + else: + return self._create_simple_attn_mask(attention_mask) + + def _medusa_forward( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + accepted_indices = None, + current_length = None, + medusa_mask = None, + scatter_index = None, + ): + is_for_context_encoding = ( + input_ids.shape[-1] > 1 + and self.medusa_speculation_length != input_ids.shape[-1] + ) + is_for_medusa_speculation = input_ids.shape[-1] == self.medusa_speculation_length + + # It is either for context encoding or for token generation + if is_for_context_encoding: + past_key_values = None + else: + past_key_values = [] + if is_for_medusa_speculation: + index = current_length.view(-1, 1, current_length.shape[-1], 1).expand_as( + self.past_key_values[0][:, :, 0 : self.config.num_medusa_heads + 1, :] + ) + gather_index = accepted_indices.view(-1, 1, accepted_indices.shape[-1], 1).expand_as( + self.past_key_values[0][:, :, 0 : self.config.num_medusa_heads + 1, :] + ) + + for key_layer_idx in range(0, len(self.past_key_values), 2): + k_cache = self.past_key_values[key_layer_idx] + v_cache = self.past_key_values[key_layer_idx + 1] + + accepted_k_cache = torch.gather(k_cache, dim=2, index=gather_index) + accepted_v_cache = torch.gather(v_cache, dim=2, index=gather_index) + k_cache = torch.scatter(k_cache, 2, index, accepted_k_cache) + v_cache = torch.scatter(v_cache, 2, index, accepted_v_cache) + + key_state = self._bucket_slice_kv_cacheline(k_cache) + value_state = self._bucket_slice_kv_cacheline(v_cache) + + past_key_values.append([key_state, value_state]) + + else: + for key_layer_idx in range(0, len(self.past_key_values), 2): + k_cache = self.past_key_values[key_layer_idx] + v_cache = self.past_key_values[key_layer_idx + 1] + key_state = self._bucket_slice_kv_cacheline(k_cache) + value_state = self._bucket_slice_kv_cacheline(v_cache) + + past_key_values.append([key_state, value_state]) + + # Prepare attention mask(s) + attention_mask = self.create_attn_mask( + attention_mask, is_for_context_encoding, False, position_ids + ) + active_mask = None + if is_for_medusa_speculation: + medusa_mask = medusa_mask[0].bool() + active_mask = medusa_mask[None, None, :, :].expand( + self.batch_size, 1, self.medusa_speculation_length, self.medusa_speculation_length + ) + + hidden_states, past_key_values = self.get_model_output( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + active_mask=active_mask, + ) + + updated_kv_cache = [] + for idx, kv_per_layer in enumerate(past_key_values): + k_cache = self.past_key_values[idx * 2] + v_cache = self.past_key_values[idx * 2 + 1] + + if is_for_context_encoding: + if self.config.is_continuous_batching: + # scatter back to the desired seq_ids + seq_id_index_shape = seq_ids.shape[:1] + k_cache.shape[1:] + seq_id_index = seq_ids.view(-1, 1, 1, 1).expand(seq_id_index_shape) + k_cache = torch.scatter(k_cache, 0, seq_id_index, kv_per_layer[0]) + v_cache = torch.scatter(v_cache, 0, seq_id_index, kv_per_layer[1]) + else: + # assign back to full kv_cacheline + k_cache = kv_per_layer[0] + v_cache = kv_per_layer[1] + k_cache = self._gather_bucket_slice_into_kv_cacheline(idx * 2, k_cache) + v_cache = self._gather_bucket_slice_into_kv_cacheline(idx * 2 + 1, v_cache) + else: + if self.padding_side == "left": + # TODO: fix it with scatter after right padding + k_cache = k_cache[:, :, 1:, :] + v_cache = v_cache[:, :, 1:, :] + k_cache = torch.cat([k_cache, kv_per_layer[0]], dim=2) + v_cache = torch.cat([v_cache, kv_per_layer[1]], dim=2) + else: + if is_for_medusa_speculation: + scatter_index_new = scatter_index.view(-1, 1, scatter_index.shape[-1], 1).expand_as( + kv_per_layer[0] + ) + else: + scatter_index_new = position_ids.view(-1, 1, position_ids.shape[-1], 1).expand_as( + kv_per_layer[0] + ) + k_cache = torch.scatter(k_cache, 2, scatter_index_new, kv_per_layer[0]) + v_cache = torch.scatter(v_cache, 2, scatter_index_new, kv_per_layer[1]) + + updated_kv_cache.append(k_cache) + updated_kv_cache.append(v_cache) + + if self.padding_side == "left": + index = torch.tensor([hidden_states.shape[1] - 1], device=hidden_states.device) + index = index.unsqueeze(1).expand(self.batch_size, 1, self.hidden_size) + hidden_states = torch.gather(hidden_states, dim=1, index=index) + else: + if position_ids.shape[-1] == self.medusa_speculation_length: + index = torch.min(position_ids) + index = torch.arange(index, index + self.medusa_speculation_length, device=hidden_states.device) + index = index[None, :, None].expand( + self.batch_size, self.medusa_speculation_length, self.hidden_size + ) + hidden_states = torch.gather(hidden_states, dim=1, index=index) + else: + # simple token generation + index = torch.max(position_ids, dim=1, keepdim=True).indices + index = index.unsqueeze(1).expand(self.batch_size, 1, self.hidden_size) + hidden_states = torch.gather(hidden_states, dim=1, index=index) + + logits = self.lm_head(hidden_states) + logits = logits.float() + + medusa_logits = [logits] + [ + head(hidden_states).float() + for head in [getattr(self, f"medusa_head_{i}") for i in range(self.num_medusa_heads)] + ] + stacked_logits = torch.stack(medusa_logits, dim=0) + + res = logits + if is_for_context_encoding: + result = [ + self.sampler.sample(stacked_logits[i : i + 1, -1, :].squeeze(0)) + for i in range(self.config.num_medusa_heads + 1) + ] + res = torch.stack(result, dim=0) # 5, 1, 10 + else: + results = [] + for i in range(stacked_logits.shape[1]): + result = [ + self.sampler.sample(stacked_logits[j : j + 1, i, :].squeeze(0)) + for j in range(self.config.num_medusa_heads + 1) + ] + res = torch.stack(result, dim=0) + results.append(res) + + return [res] + updated_kv_cache + + def forward( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + accepted_indices = None, + current_length = None, + medusa_mask = None, + scatter_index = None, + ): + if self.config.is_medusa: + return self._medusa_forward(input_ids, attention_mask, position_ids, seq_ids, accepted_indices, current_length, medusa_mask, scatter_index) + + is_for_context_encoding = ( + input_ids.shape[-1] > 1 + and self.speculation_length != input_ids.shape[-1] + ) + is_for_speculation = input_ids.shape[-1] == self.speculation_length + + # It is either for context encoding or for token generation + if is_for_context_encoding: + past_key_values = None + else: + past_key_values = [] + for key_layer_idx in range(0, len(self.past_key_values), 2): + k_cache = self.past_key_values[key_layer_idx] + v_cache = self.past_key_values[key_layer_idx + 1] + key_state = self._bucket_slice_kv_cacheline(k_cache) + value_state = self._bucket_slice_kv_cacheline(v_cache) + + past_key_values.append([key_state, value_state]) + + # Prepare attention mask(s) + attention_mask = self.create_attn_mask( + attention_mask, is_for_context_encoding, is_for_speculation, position_ids + ) + active_mask = None + if is_for_speculation: + active_mask = torch.full( + (self.speculation_length, self.speculation_length), True, device=attention_mask.device + ).tril(diagonal=0) + active_mask = active_mask[None, None, :, :].expand( + self.batch_size, 1, self.speculation_length, self.speculation_length + ) + + hidden_states, past_key_values = self.get_model_output( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + active_mask=active_mask, + ) + + updated_kv_cache = [] + for idx, kv_per_layer in enumerate(past_key_values): + k_cache = self._bucket_slice_kv_cacheline(self.past_key_values[idx*2]) + v_cache = self._bucket_slice_kv_cacheline(self.past_key_values[idx*2+1]) + + if is_for_context_encoding: + if self.config.is_continuous_batching: + # scatter back to the desired seq_ids + seq_id_index_shape = seq_ids.shape[:1] + k_cache.shape[1:] + seq_id_index = seq_ids.view(-1, 1, 1, 1).expand(seq_id_index_shape) + k_cache = torch.scatter(k_cache, 0, seq_id_index, kv_per_layer[0]) + v_cache = torch.scatter(v_cache, 0, seq_id_index, kv_per_layer[1]) + else: + # assign back to full kv_cacheline + k_cache = kv_per_layer[0] + v_cache = kv_per_layer[1] + else: + if self.padding_side == "left": + # TODO: fix it with scatter after right padding + k_cache = k_cache[:, :, 1:, :] + v_cache = v_cache[:, :, 1:, :] + k_cache = torch.cat([k_cache, kv_per_layer[0]], dim=2) + v_cache = torch.cat([v_cache, kv_per_layer[1]], dim=2) + else: + scatter_index_new = position_ids.view(-1, 1, position_ids.shape[-1], 1).expand_as( + kv_per_layer[0] + ) + k_cache = torch.scatter(k_cache, 2, scatter_index_new, kv_per_layer[0]) + v_cache = torch.scatter(v_cache, 2, scatter_index_new, kv_per_layer[1]) + + k_cache = self._gather_bucket_slice_into_kv_cacheline(idx * 2, k_cache) + v_cache = self._gather_bucket_slice_into_kv_cacheline(idx * 2 + 1, v_cache) + + updated_kv_cache.append(k_cache) + updated_kv_cache.append(v_cache) + + if self.padding_side == "left": + index = torch.tensor([hidden_states.shape[1] - 1], device=hidden_states.device) + index = index.unsqueeze(1).expand(self.batch_size, 1, self.hidden_size) + hidden_states = torch.gather(hidden_states, dim=1, index=index) + else: + # speculative decoding case; only batch_size=1 + # will need to extend the logic to support multi-batch later + # maybe just use position_ids for index? + if position_ids.shape[-1] == self.speculation_length: + index = torch.min(position_ids) + index = torch.arange(index, index + self.speculation_length, device=hidden_states.device) + index = index[None, :, None].expand(self.batch_size, self.speculation_length, self.hidden_size) + hidden_states = torch.gather(hidden_states, dim=1, index=index) + else: + # simple token generation + index = torch.max(position_ids, dim=1, keepdim=True).indices + index = index.unsqueeze(1).expand(self.batch_size, 1, self.hidden_size) + hidden_states = torch.gather(hidden_states, dim=1, index=index) + + logits = self.lm_head(hidden_states) + logits = logits.float() + + res = logits + if self.on_device_sampling: + # perform sampling on Neuron to get tokens + res = self.sampler.sample(logits[:, -1, :]) + + + + # NeuronLlamaModel class manages the KV cache. So the attention_mask will be generated and passed + # through to LlamaModel. We override the HF's code that generates attention mask because HF does + # not support left aligned RHS padding. This enables Neuron to achieve higher performance and + # extensibility. + # + # 4d mask is passed through the layers + # attention_mask = _prepare_4d_causal_attention_mask( + # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + # ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + next_decoder_cache = () + + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + active_mask=active_mask, + ) + + hidden_states = layer_outputs[0] + + next_decoder_cache += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + return (hidden_states, next_decoder_cache) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def get_model_output( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + active_mask: Optional[List[torch.FloatTensor]] = None, + ): + batch_size, seq_length = input_ids.shape[:2] + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + + inputs_embeds = self.embed_tokens(input_ids) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device #noqa + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + + + # NeuronLlamaModel class manages the KV cache. So the attention_mask will be generated and passed + # through to LlamaModel. We override the HF's code that generates attention mask because HF does + # not support left aligned RHS padding. This enables Neuron to achieve higher performance and + # extensibility. + # + # 4d mask is passed through the layers + # attention_mask = _prepare_4d_causal_attention_mask( + # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + # ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + next_decoder_cache = () + + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + active_mask=active_mask, + ) + + hidden_states = layer_outputs[0] + + next_decoder_cache += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + return (hidden_states, next_decoder_cache) + + +class NeuronBaseForCausalLM(NeuronSpeculation): _STATE_DICT_MODEL_PREFIX = "model." + _model_cls = None + + def __init__(self, model_path: str, config: PretrainedConfig): + super().__init__(config) + + self.config = config + self.vocab_size = config.vocab_size + self.padding_side = config.padding_side + self.kv_cache_populated = False + + self.sampler = None + + self.models = [] + self.enable_context_encoding() + if config.trace_tokengen_model: + self.enable_token_generation() + if config.speculation_length > 0: + self.enable_speculation() + if config.medusa_speculation_length > 0: + self.enable_medusa_speculation() + self.model_path = model_path + @staticmethod def load_hf_model(model_path): - raise NotImplementedError(f"load_hf_model is not implemented") + raise NotImplementedError("load_hf_model is not implemented") + + def get_compiler_args(self): + return None + + def enable_context_encoding(self): + new_config = copy.deepcopy(self.config) + new_config.batch_size = self.config.ctx_batch_size + new_config.n_active_tokens = self.config.max_context_length + new_config.bucket_n_active_tokens = True + + if not new_config.enable_bucketing: + new_config.buckets = generate_buckets(new_config.max_context_length,new_config.max_context_length) + else: + new_config.buckets = generate_buckets(128, new_config.max_context_length) + + self.context_encoding_model = ModelWrapper( + config=new_config, + model_cls=self._model_cls, + tag=CONTEXT_ENCODING_MODEL_TAG, + compiler_args=self.get_compiler_args(), + ) + self.models.append(self.context_encoding_model) + + def enable_token_generation(self): + new_config = copy.deepcopy(self.config) + new_config.batch_size = self.config.tkg_batch_size + new_config.n_active_tokens = 1 + new_config.bucket_n_active_tokens = False + + if not new_config.enable_bucketing: + new_config.buckets = generate_buckets(new_config.max_length,new_config.max_length) + else: + new_config.buckets = generate_buckets(128, new_config.max_length) + + + self.token_generation_model = ModelWrapper( + config=new_config, + model_cls=self._model_cls, + tag=TOKEN_GENERATION_MODEL_TAG, + compiler_args=self.get_compiler_args(), + ) + self.models.append(self.token_generation_model) + + def enable_speculation(self): + new_config = copy.deepcopy(self.config) + new_config.batch_size = self.config.spec_batch_size + new_config.n_active_tokens = self.config.speculation_length + self.speculation_model = ModelWrapper(new_config, self._model_cls, tag=SPECULATION_MODEL_TAG) + + self.models.append(self.speculation_model) + + def enable_medusa_speculation(self): + new_config = copy.deepcopy(self.config) + new_config.batch_size = self.config.spec_batch_size + new_config.n_active_tokens = self.config.medusa_speculation_length + self.medusa_speculation_model = ModelWrapper(new_config, self._model_cls, tag=MEDUSA_MODEL_TAG) + + self.models.append(self.medusa_speculation_model) @classmethod def get_state_dict(cls, model_path: str, config: PretrainedConfig) -> dict: @@ -31,12 +649,21 @@ def get_state_dict(cls, model_path: str, config: PretrainedConfig) -> dict: updated_param_name = param_name.replace(cls._STATE_DICT_MODEL_PREFIX, "", 1) model_sd[updated_param_name] = model_sd[param_name] del model_sd[param_name] + if os.path.exists(model_path + "/medusa_heads.pt"): + medusa_head = torch.load(model_path + "/medusa_heads.pt", map_location="cpu") + model_sd.update(medusa_head) return model_sd @classmethod - def get_quantized_state_dict(cls, model_path: str, config: PretrainedConfig) -> dict: + def generate_quantized_state_dict(cls, model_path: str, config: PretrainedConfig) -> dict: hf_model = cls.load_hf_model(model_path) - hf_model_quant = convert_float_model_to_pytorch_int8_model(float_model=hf_model) + quantization_type = QuantizationType(config.quantization_type) + if quantization_type == QuantizationType.PER_TENSOR_SYMMETRIC: + hf_model_quant = quantize_pytorch_model_per_tensor_symmetric(float_model=hf_model, inplace=True) + elif quantization_type == QuantizationType.PER_CHANNEL_SYMMETRIC: + hf_model_quant = quantize_pytorch_model_per_channel_symmetric(float_model=hf_model, inplace=True) + else: + raise RuntimeError(f"{config.quantization_type} not supported") model_quant_sd = hf_model_quant.model.state_dict() lm_head_quant_sd = hf_model_quant.lm_head.state_dict() @@ -53,7 +680,7 @@ def from_pretrained(cls, model_path: str, config: PretrainedConfig): return cls(model_path, config) def checkpoint_loader_fn(self, mmap: bool = False): - # this function loads the model's state dcitionary and weights from + # this function loads the model's state dictionary and weights from # the hf model if self.config.quantized is False: model_sd = self.get_state_dict(self.model_path, self.config) @@ -78,20 +705,55 @@ def get_quantized_checkpoints(self, mmap: bool = False): if self.config.torch_dtype == torch.bfloat16: for name, param in model_quant_sd.items(): if param is not None and param.dtype == torch.float32: - warnings.warn(f"Found float32 weights in quantized checkpoint: {name}. Will convert to bfloat16") - model_quant_sd[name] = param.bfloat16() + if name.endswith(".scale"): + warnings.warn(f"Found float32 weights in quantized checkpoint: {name}. Will skip converting to bfloat16 as its scale") + else: + warnings.warn(f"Found float32 weights in quantized checkpoint: {name}. Will convert to bfloat16") + model_quant_sd[name] = param.bfloat16() return model_quant_sd def compile(self, serialize_base_path=None): - if serialize_base_path: - self.config.save_pretrained(serialize_base_path) + + base_compile_work_dir = os.environ.get("BASE_COMPILE_WORK_DIR", "/tmp/nxd_model/") + + builder = ModelBuilder( + router=None, + tp_degree=self.config.tp_degree, + checkpoint_loader=self.checkpoint_loader_fn, + compiler_workdir=base_compile_work_dir + ) + for model in self.models: - model.compile(self.checkpoint_loader_fn, serialize_base_path=serialize_base_path) + builder.add( + key=model.tag, + model_instance=model.get_model_instance(), + example_inputs=model.input_generator(), + compiler_args=model.compiler_args, + bucket_config=model.bucket_config, + priority_model_idx=model.priority_model_idx, + ) + + traced_model = builder.trace(initialize_model_weights=False) + torch.jit.save(traced_model, serialize_base_path + "model.pt") + del traced_model + + builder.shard_checkpoint(serialize_path=os.path.join(serialize_base_path, "weights/")) + self.is_loaded_to_neuron = True def load(self, serialize_base_path): - for model in self.models: - model.load(serialize_base_path) + + traced_model = torch.jit.load(serialize_base_path + "model.pt") + + weights = [] + for rank in range(self.config.tp_degree): + ckpt = load_file(os.path.join(serialize_base_path, f"weights/tp{rank}_sharded_checkpoint.safetensors")) + weights.append(ckpt) + + traced_model.nxd_model.initialize(weights) + + for model_wrapper in self.models: + model_wrapper.model = traced_model def to_neuron(self, serialize_base_path=None): if serialize_base_path is None: @@ -101,3 +763,410 @@ def to_neuron(self, serialize_base_path=None): else: self.compile(serialize_base_path) self.load(serialize_base_path) + + @property + def device(self) -> torch.device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + # We dont want HF to move parameters to device + return torch.device("cpu") + + def forward( + self, + input_ids: torch.LongTensor = None, + seq_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + medusa_args = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + """ + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + + output_attentions, output_hidden_states, return_dict = self._setup_func_config(output_attentions, + output_hidden_states, + return_dict) + + # infer attention_mask from position_ids if not provided + if attention_mask is None: + attention_mask = self._infer_attention_mask(position_ids) + + self._log_input(input_ids, attention_mask, position_ids, seq_ids) + + if seq_ids is None: + seq_ids = torch.arange(input_ids.shape[0]) + + outputs, is_run_on_neuron = self._get_model_outputs(input_ids, attention_mask, position_ids, seq_ids, medusa_args) + + if self.config.trace_tokengen_model and not self.token_generation_model.is_neuron(): + self._copy_past_key_values(outputs) + + if is_run_on_neuron: + # When run on neuron, KV cache remains on device + logits_or_next_tokens = outputs + else: + # When run on cpu, KV cache is returned which has to be ignored + logits_or_next_tokens, *_ = outputs + + logging.debug("---output---") + logging.debug(f"{'tokens' if self.config.on_device_sampling else 'logits'} = %s, ", logits_or_next_tokens) + + return self._construct_output(logits_or_next_tokens) + + def _setup_func_config(self, output_attentions, output_hidden_states, return_dict): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return output_attentions, output_hidden_states, return_dict + + def _infer_attention_mask(self, position_ids): + assert position_ids is not None, "need to call forward with position_ids if attention_mask is not provided" + batch_size, seq_len = position_ids.shape + if position_ids.shape[-1] == 1: + seq_len = self.config.n_positions + position_ids_to_compare = position_ids.expand(batch_size, seq_len) - 1 + else: + seq_len = position_ids.shape[-1] + position_ids_to_compare = position_ids + mask = torch.arange(seq_len).view(1, -1).expand(batch_size, seq_len) + attention_mask = (position_ids_to_compare >= mask).to(dtype=position_ids.dtype) + return attention_mask + + def _log_input(self, input_ids, attention_mask, position_ids, seq_ids): + logging.debug("---input---") + logging.debug("input_ids shape = %s type=%s", input_ids.shape, input_ids.type()) + logging.debug("attention_mask shape = %s type=%s", attention_mask.shape, attention_mask.type()) + logging.debug("position_ids shape = %s type=%s", position_ids.shape, position_ids.type()) + logging.debug("input_ids =%s", input_ids) + logging.debug("attention_mask =%s", attention_mask) + logging.debug("position_ids =%s", position_ids) + logging.debug(f"seq_ids: {seq_ids}") + + if self.config.trace_tokengen_model and not self.token_generation_model.is_neuron(): + logging.debug(f"first layer kv_cache: {self.token_generation_model.model.past_key_values[0][:, 0, :, 0]}") + + def _get_model_outputs(self, input_ids, attention_mask, position_ids, seq_ids, medusa_args): + if ( + input_ids.shape[-1] > 1 + and input_ids.shape[-1] != self.config.speculation_length + and input_ids.shape[-1] != self.config.medusa_speculation_length + ): + if self.config.is_medusa: + medusa_args = self._prepare_inputs() + outputs = self.context_encoding_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + *medusa_args, + ) + else: + outputs = self.context_encoding_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + ) + self.kv_cache_populated = True + is_run_on_neuron = self.context_encoding_model.is_neuron() + elif input_ids.shape[-1] == self.config.speculation_length: + outputs = self.speculation_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + ) + is_run_on_neuron = self.speculation_model.is_neuron() + elif input_ids.shape[-1] == self.config.medusa_speculation_length: + outputs = self.medusa_speculation_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + *medusa_args, + ) + is_run_on_neuron = self.medusa_speculation_model.is_neuron() + else: + outputs = self.token_generation_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + ) + is_run_on_neuron = self.token_generation_model.is_neuron() + + return outputs, is_run_on_neuron + + def _copy_kv_cache(self, source_model, target_model): + for source, target in zip(source_model.model.models, target_model.model.models): + encoder_kv_cache_line = source.states + token_gen_kv_cache_line = target.states + for name, _ in token_gen_kv_cache_line._parameters.items(): + token_gen_kv_cache_line._parameters[name] = encoder_kv_cache_line._parameters[name] + + def _copy_past_key_values(self, outputs): + new_past_key_values = outputs[1:] + for i, new_past_key_value in enumerate(new_past_key_values): + self.token_generation_model.model.past_key_values[i].data = new_past_key_value + self.context_encoding_model.model.past_key_values[i].data = new_past_key_value + + def _construct_output(self, logits_or_next_tokens): + if self.config.is_medusa: + next_tokens = logits_or_next_tokens[:1, :, :] + else: + next_tokens = logits_or_next_tokens + + def _construct_output(self, logits_or_next_tokens): + OutputParams = CausalLMOutputWithPast( + logits=None if self.config.on_device_sampling else logits_or_next_tokens, + hidden_states=logits_or_next_tokens, + attentions=None, + ) + OutputParams.tokens = logits_or_next_tokens + return OutputParams + + # We override this function because we want to change the way attention_mask + # is updated each iteration. + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_for_token_generation: Optional[bool] = False, + is_encoder_decoder: bool = False, + ) -> Dict[str, Any]: + + if getattr(outputs, "state", None) is not None: + model_kwargs["state"] = outputs.state + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + if is_for_token_generation: + if self.padding_side == "left": + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + attention_mask = attention_mask[:, 1:] + else: + attention_mask = torch.cat( + [attention_mask.new_ones((attention_mask.shape[0], 1)), attention_mask], dim=-1 + ) + model_kwargs["attention_mask"] = attention_mask + return model_kwargs + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if self.kv_cache_populated: + input_ids = input_ids[:, -1:] + + accepted_indices = kwargs.get("accepted_indices", None) + current_length = kwargs.get("current_length", None) + medusa_mask = kwargs.get("medusa_mask", None) + scatter_index = kwargs.get("scatter_index", None) + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if self.kv_cache_populated: + position_ids = torch.amax(position_ids, 1, keepdim=True) + position_ids = position_ids + 1 + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache", False), + "attention_mask": attention_mask, + "medusa_args": (accepted_indices, current_length, medusa_mask, scatter_index), + } + ) + return model_inputs + + def prepare_medusa_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if self.kv_cache_populated: + input_ids = input_ids[:, -self.config.medusa_speculation_length :] + position_ids = kwargs.get("position_ids") + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "medusa_args": ( + kwargs.get("accepted_indices"), + kwargs.get("current_length"), + kwargs.get("medusa_mask"), + kwargs.get("scatter_index"), + ), + } + ) + return model_inputs + + def _prepare_inputs(self): + accepted_indices = torch.zeros((self.config.batch_size, self.config.num_medusa_heads + 1), dtype=torch.int64) + current_length = torch.zeros((self.config.batch_size, self.config.num_medusa_heads + 1), dtype=torch.int64) + medusa_mask = torch.zeros( + (self.config.batch_size, self.config.medusa_speculation_length, self.config.medusa_speculation_length), + dtype=torch.int64, + ) + scatter_index = torch.zeros((self.config.batch_size, self.config.medusa_speculation_length), dtype=torch.int64) + return accepted_indices, current_length, medusa_mask, scatter_index + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + def reset(self): + # We need to reset the KV cache flag for a new batch of inference. + # When the flag is reset, the subsequent run will invoke the + # context encoding model. + self.kv_cache_populated = False + + def reset_kv_cache(self): + # Zero out kv cache for debug. + # For new batch inference, use reset() instead + if not self.context_encoding_model.is_neuron(): + for i, kv_tensor in enumerate(self.context_encoding_model.model.past_key_values): + self.context_encoding_model.model.past_key_values[i] = torch.zeros_like(kv_tensor) + + if not self.token_generation_model.is_neuron(): + for i, kv_tensor in enumerate(self.token_generation_model.model.past_key_values): + self.token_generation_model.model.past_key_values[i] = torch.zeros_like(kv_tensor) + + def _sample( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_scores: Optional[bool] = None, + output_logits: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + **model_kwargs, + ) -> Union[SampleOutput, torch.LongTensor]: + r""" + We override the GenerationMixin sample function (_sample for transformers>=4.39.0) to add support for right side padding. + """ + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + + # keep track of which sequences are already finished + unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) + + this_peer_finished = False + # auto-regressive generation + while not this_peer_finished: + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + is_for_token_generation = self.kv_cache_populated + + # forward pass to get next token + outputs = self(**model_inputs, return_dict=True) + + if not self.config.on_device_sampling: + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (next_token_logits,) + + if not self.config.on_device_sampling: + if self.sampler is None: + self.config.do_sample = True + self.sampler = Sampler(self.config) + next_tokens = self.sampler.sample(outputs.logits[:, -1, :]) + else: + next_tokens = outputs.tokens + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + is_for_token_generation=is_for_token_generation, + ) + + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None) + this_peer_finished = unfinished_sequences.max() == 0 + + if return_dict_in_generate: + return SampleDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + ) + else: + return input_ids diff --git a/examples/inference/modules/model_wrapper.py b/examples/inference/modules/model_wrapper.py index 128442e..b9e2392 100644 --- a/examples/inference/modules/model_wrapper.py +++ b/examples/inference/modules/model_wrapper.py @@ -7,6 +7,11 @@ from modules.autobucketing import get_context_encoder_bk, get_token_generation_bk from torch_neuronx import BucketModelConfig +from neuronx_distributed.quantization.quantization_config import ( + QuantizationType, + get_default_custom_qconfig_dict, + get_default_per_channel_custom_qconfig_dict, +) from neuronx_distributed.quantization.quantize import convert from neuronx_distributed.trace import ( parallel_model_load, @@ -14,10 +19,12 @@ parallel_model_trace, ) from neuronx_distributed.trace.trace import ParallelModel +from neuronx_distributed.trace.model_builder import BaseModelInstance, NxDModelExecutor CONTEXT_ENCODING_MODEL_TAG = "context_encoding_model" TOKEN_GENERATION_MODEL_TAG = "token_generation_model" SPECULATION_MODEL_TAG = "speculation_model" +MEDUSA_MODEL_TAG = "medusa_speculation_model" def get_bucket_model_config_from_tag(tag, config): @@ -25,13 +32,15 @@ def get_bucket_model_config_from_tag(tag, config): if bucket_degree == 1: return None + pad_token = config.pad_token_id + # NOTE: KV Cache preprocessing is done within the model and not the # shared buffer preprocessor due to lack of support of non-contiguous # slicing of nrt tensors via the NRT API. if tag == CONTEXT_ENCODING_MODEL_TAG: return BucketModelConfig( bucket_kernel=get_context_encoder_bk, - bucket_kernel_constant_args=(torch.tensor(config.buckets), config.padding_side), + bucket_kernel_constant_args=(torch.tensor(config.buckets), config.padding_side, pad_token), shared_state_buffer=None, func_kwargs=[{"bucket_rank": i} for i in range(bucket_degree)], ) @@ -49,27 +58,33 @@ def get_bucket_model_config_from_tag(tag, config): class ModelWrapper(torch.nn.Module): - def __init__(self, config, model_cls, tag="", compiler_args: str = None) -> None: + def __init__(self, config, model_cls, tag="", compiler_args: str = None, priority_model_idx: int = None) -> None: super().__init__() self.config = config if not self.config.torch_dtype: self.config.torch_dtype = torch.float32 + if self.config.pad_token_id is None: + self.config.pad_token_id = 0 + self.model_cls = model_cls self.model = None self.is_compiled = False self.serialize_base_path = None self.tag = tag + self.is_medusa = config.is_medusa if compiler_args is None: - self.compiler_args = "--enable-saturate-infinity --auto-cast=none --model-type=transformer -O1" + self.compiler_args = "--enable-saturate-infinity --auto-cast=none --model-type=transformer --tensorizer-options='--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2' -O1 " + else: self.compiler_args = compiler_args self.bucket_config = get_bucket_model_config_from_tag(tag, self.config) + self.priority_model_idx = priority_model_idx def is_neuron(self): - return self.model is not None and isinstance(self.model, ParallelModel) + return self.model is not None and isinstance(self.model, torch.jit.ScriptModule) def compile(self, checkpoint_loader, serialize_base_path): inputs = self.input_generator() @@ -111,12 +126,44 @@ def input_generator( position_ids = torch.zeros((self.config.batch_size, n_active_tokens), dtype=torch.int64) seq_ids = torch.zeros((self.config.batch_size), dtype=torch.int64) - inputs.append((input_ids, attention_mask, position_ids, seq_ids)) + if self.is_medusa: + accepted_indices = torch.zeros( + (self.config.batch_size, self.config.num_medusa_heads + 1), dtype=torch.int64 + ) + current_length = torch.zeros((self.config.batch_size, self.config.num_medusa_heads + 1), dtype=torch.int64) + medusa_mask = torch.zeros( + (self.config.batch_size, self.config.medusa_speculation_length, self.config.medusa_speculation_length), + dtype=torch.int64, + ) + scatter_index = torch.zeros( + (self.config.batch_size, self.config.medusa_speculation_length), dtype=torch.int64 + ) + inputs.append( + ( + input_ids, + attention_mask, + position_ids, + seq_ids, + accepted_indices, + current_length, + medusa_mask, + scatter_index, + ) + ) + else: + inputs.append((input_ids, attention_mask, position_ids, seq_ids)) + + return inputs - return inputs if len(inputs) > 1 else inputs[0] + def get_model_instance(self): + return DecoderModelInstance(model_cls=self.model_cls, config=self.config) def _forward_with_pad(self, *args): - tensor, *_, seq_ids = args + seq_ids = args[3] + if len(args) > 4: + medusa_args = args[4:8] + else: + medusa_args = None # pad the inputs up to the compiled batch size in the end def pad_helper(tensor): @@ -130,7 +177,8 @@ def pad_helper(tensor): return padded_tensor padded_args = [] - for arg in args[:-1]: + # pad input_ids, attn_mask and postition_ids + for arg in args[0:3]: padded_args.append(pad_helper(arg)) # need to handle seq_ids seperately, when compiled batch is 4, if we pad seq_ids from [0,2,1] to [0,2,1,0]. @@ -142,15 +190,24 @@ def pad_helper(tensor): ) padded_args.append(padded_seq_ids) - logits, *kv_cache = self._forward(*padded_args) + if medusa_args is not None: + for arg in medusa_args: + padded_args.append(pad_helper(arg)) + + outputs = self._forward(*padded_args) # note that we don't do index select here as it should already be handled, simply sliced out padding here - return [logits[: seq_ids.shape[0]], *kv_cache] + if self.is_neuron(): + logits = outputs + return logits[: seq_ids.shape[0]] + else: + logits, *kv_cache = outputs + return [logits[: seq_ids.shape[0]], *kv_cache] def reorder_helper(self, *args): # we then reorder the other inputs based on padded_seq_ids # because there are issue with compiler to do gather, we cannot fully support artibrary order of seq_ids for now - *_, seq_ids = args + seq_ids = args[3] reorder_args = [] @@ -161,7 +218,7 @@ def reorder_helper(self, *args): def _forward(self, *args): if self.config.is_continuous_batching and self.config.batch_size == self.config.max_batch_size: - logging.debug(f"running forward and reorder the inputs based on seq_ids") + logging.debug("running forward and reorder the inputs based on seq_ids") seq_ids, *args = self.reorder_helper(*args) logging.debug("Processed inputs to the model", self.tag, args) @@ -169,32 +226,31 @@ def _forward(self, *args): outputs = self.model(*args) if self.config.is_continuous_batching and self.config.batch_size == self.config.max_batch_size: - return [torch.index_select(outputs[0], 0, seq_ids), *outputs[1:]] + if self.is_neuron(): + return torch.index_select(outputs, 0, seq_ids) + else: + return [torch.index_select(outputs[0], 0, seq_ids), *outputs[1:]] return outputs - def pad_on_seq(self, *args): - """ - Pad on the right, to make the inputs (input_ids, position_ids, attention_mask) - on sequence dimension to match n_active_tokens, mainly apply to context encoding model - """ - *to_pad_args, seq_id = args - - padded_args = [] - - input_ids, *_ = to_pad_args - - pad_len = self.config.n_active_tokens - input_ids.shape[1] - - if padded_args == 0: - return args - - logging.debug(f"padding inputs by {pad_len}") - for arg in to_pad_args: - arg = F.pad(arg, (0, pad_len), "constant", 0) - padded_args.append(arg) + def pad_to_max_compiled_seq(self, *args): + if self.tag == CONTEXT_ENCODING_MODEL_TAG: + to_pad = args[:3] + pad_lengths = [self.config.max_context_length - arg.shape[1] for arg in to_pad] + tensor_pad_vals = [ + self.config.pad_token_id, + 0, + 1 + ] + padded_args = [F.pad(arg, (0, pad_len), "constant", pad_val) for arg, pad_val, pad_len in zip(to_pad, tensor_pad_vals, pad_lengths)] + args = (*padded_args,*args[3:]) + else: + input_ids,attention_mask,*rest_of_args = args + pad_len = self.config.max_length - attention_mask.shape[1] + padded_attention_mask = F.pad(attention_mask, (0, pad_len), "constant", 0) + args = (input_ids,padded_attention_mask,*rest_of_args) - return [*padded_args, seq_id] + return args def forward(self, *args): logging.debug(f"calling forward on network {self.tag}") @@ -202,9 +258,9 @@ def forward(self, *args): if self.model is None: raise RuntimeError("Forward called before load. Run load() or load_state_dict() making calling forward") - args = self.pad_on_seq(*args) + args = self.pad_to_max_compiled_seq(*args) - *_, seq_ids = args + seq_ids = args[3] input_batch_size = seq_ids.shape[0] @@ -219,7 +275,7 @@ def forward(self, *args): if cur_batch + self.config.batch_size <= input_batch_size: # we only process part of the input to run logging.debug(f"running foward on batch {cur_batch}:{cur_batch+self.config.batch_size}") - logits, *kv_caches = self._forward( + outputs = self._forward( *[arg[cur_batch : cur_batch + self.config.batch_size] for arg in args] ) else: @@ -227,16 +283,62 @@ def forward(self, *args): logging.debug( f"running forward on batch {cur_batch}:{input_batch_size}, padded up to {self.config.batch_size}" ) - logits, *kv_caches = self._forward_with_pad(*[arg[cur_batch:input_batch_size] for arg in args]) + outputs = self._forward_with_pad(*[arg[cur_batch:input_batch_size] for arg in args]) - if not self.is_neuron(): + if self.is_neuron(): + logits = outputs + else: + logits, *kv_caches = outputs for i, kv_cache in enumerate(kv_caches): self.model.past_key_values[i].data = kv_cache output_logits.append(logits) cur_batch += self.config.batch_size - return [torch.cat(output_logits, dim=0), *kv_caches] + if self.is_neuron(): + return torch.cat(output_logits, dim=0) + else: + return [torch.cat(output_logits, dim=0), *kv_caches] + +class DecoderModelInstance(BaseModelInstance): + + def __init__(self, model_cls, config): + self.model_cls = model_cls + self.module = None + self.input_output_aliases = None + self.config = config + + def load_module(self): + float_model = self.model_cls(self.config) + float_model.eval() + + if self.config.torch_dtype == torch.bfloat16: + float_model.bfloat16() + + if self.config.quantized is True: + quantization_type = QuantizationType(self.config.quantization_type) + if quantization_type == QuantizationType.PER_CHANNEL_SYMMETRIC: + q_config = get_default_per_channel_custom_qconfig_dict() + elif quantization_type == QuantizationType.PER_TENSOR_SYMMETRIC: + q_config = get_default_custom_qconfig_dict() + else: + raise RuntimeError(f"{self.config.quantization_type} is not supported") + self.module = convert(float_model, q_config=q_config, inplace=False, mapping=None) + else: + self.module = float_model + + def get(self, bucket_rank, **kwargs): + if bucket_rank is not None: + self.module.n_positions = self.config.buckets[bucket_rank] + + # Currently we have to init an input_output_aliases map for + # each buckets, otherwise it will fail the aliasing setup when + # generating HLO + self.input_output_aliases = {} + num_output_from_trace = 1 + for i in range(len(self.module.past_key_values)): + self.input_output_aliases[self.module.past_key_values[i]] = num_output_from_trace + i + return self.module, self.input_output_aliases def get_trace_callable(model_cls, config, bucket_rank=None): @@ -245,10 +347,17 @@ def get_trace_callable(model_cls, config, bucket_rank=None): float_model = model_cls(config) float_model.eval() if config.torch_dtype == torch.bfloat16: - os.environ["XLA_DOWNCAST_BF16"] = "1" + float_model.bfloat16() if config.quantized is True: - model = convert(float_model, q_config=None, inplace=False, mapping=None) + quantization_type = QuantizationType(config.quantization_type) + if quantization_type == QuantizationType.PER_CHANNEL_SYMMETRIC: + q_config = get_default_per_channel_custom_qconfig_dict() + elif quantization_type == QuantizationType.PER_TENSOR_SYMMETRIC: + q_config = get_default_custom_qconfig_dict() + else: + raise RuntimeError(f"{config.quantization_type} is not supported") + model = convert(float_model, q_config=q_config, inplace=False, mapping=None) else: model = float_model diff --git a/examples/inference/requirements.txt b/examples/inference/requirements.txt index c48208b..c12f855 100644 --- a/examples/inference/requirements.txt +++ b/examples/inference/requirements.txt @@ -1,2 +1,2 @@ -transformers==4.36.2 +transformers==4.40.0 sentencepiece diff --git a/examples/inference/run_dbrx.py b/examples/inference/run_dbrx.py new file mode 100644 index 0000000..db1ece5 --- /dev/null +++ b/examples/inference/run_dbrx.py @@ -0,0 +1,52 @@ +import torch +from dbrx.dbrx_runner import DbrxRunner +from transformers import GenerationConfig + +model_path = "/data/model_hf/dbrx-base/" +traced_model_path = "/data/traced_model/dbrx-base/" + +torch.manual_seed(0) + +def dbrx_sample(): + # Compile the model for a specific configuration + generation_config = GenerationConfig.from_pretrained(model_path) + generation_config.top_k = 1 + generation_config.do_sample = True + + runner = DbrxRunner(model_path=model_path, tokenizer_path=model_path, generation_config=generation_config) + + batch_size = 1 + max_prompt_length = 1024 + sequence_length = 1024 + 128 + + runner.trace( + traced_model_path=traced_model_path, + tp_degree=32, + batch_size=batch_size, + max_prompt_length=max_prompt_length, + sequence_length=sequence_length, + ) + # Load model weights into Neuron devise + # We will use the returned model to run accuracy and perf tests + print("\nLoading model to Neuron device ..") + neuron_model = runner.load_neuron_model(traced_model_path) + + # Confirm the traced model matches the huggingface model run on cpu + print("\nChecking accuracy ..") + runner.check_accuracy_logits(neuron_model, batch_size, sequence_length) + + # Perform inference + prompt = ["I believe the meaning of life is", "The color of the sky is"] + print("\nGenerating ..") + _, outputs = runner.generate_on_neuron(prompt, neuron_model) + print("Generated outputs:") + for idx, output in enumerate(outputs): + print(f"output {idx}: {output}") + + print("\nBenchmarking ..") + # Now lets benchmark + runner.benchmark_sampling(neuron_model) + + +if __name__ == "__main__": + dbrx_sample() diff --git a/examples/inference/run_llama.py b/examples/inference/run_llama.py index ca08520..88c22f4 100644 --- a/examples/inference/run_llama.py +++ b/examples/inference/run_llama.py @@ -1,8 +1,12 @@ +import torch from llama2.llama2_runner import LlamaRunner from transformers import GenerationConfig -model_path = "/home/ubuntu/model_hf/Llama-2-7b-hf/" -traced_model_path = "/home/ubuntu/traced_model/Llama-2-7b-hf/" +model_path = "/home/ubuntu/model_hf/Llama-2-7b/" +traced_model_path = "/home/ubuntu/traced_model/Llama-2-7b/" + +torch.manual_seed(0) + def llama_sample(): # Compile the model for a specific configuration @@ -13,17 +17,18 @@ def llama_sample(): runner = LlamaRunner(model_path=model_path, tokenizer_path=model_path, generation_config=generation_config) batch_size = 2 - max_context_length = 1024 - max_new_tokens = 1024 + max_prompt_length = 1024 + sequence_length = 2048 runner.trace( traced_model_path=traced_model_path, tp_degree=32, batch_size=batch_size, - context_lengths=max_context_length, - new_token_counts=max_new_tokens, + max_prompt_length=max_prompt_length, + sequence_length=sequence_length, on_device_sampling=True, ) + # Load model weights into Neuron device # We will use the returned model to run accuracy and perf tests print("\nLoading model to Neuron device ..") @@ -31,8 +36,8 @@ def llama_sample(): # Confirm the traced model matches the huggingface model run on cpu print("\nChecking accuracy ..") - runner.check_accuracy(neuron_model, batch_size, max_context_length, max_new_tokens) - + runner.check_accuracy(neuron_model, batch_size, sequence_length) + # Perform inference prompt = ["I believe the meaning of life is", "The color of the sky is"] print("\nGenerating ..") diff --git a/examples/inference/run_llama_quantized.py b/examples/inference/run_llama_quantized.py index c9c3986..529da16 100644 --- a/examples/inference/run_llama_quantized.py +++ b/examples/inference/run_llama_quantized.py @@ -1,8 +1,50 @@ +import os + +import torch from llama2.llama2_runner import LlamaRunner from transformers import GenerationConfig model_path = "/home/ubuntu/LLama7b/" -traced_model_path = "/home/ubuntu/traced_model/quantized_Llama-2-7b/" +traced_model_path = "/home/ubuntu/traced_model/LLama7b_quantized/" + + +def llama_get_quantized_checkpoint(path_to_save): + """ + This example generates the quantized checkpoints and returns a state dict + """ + # Compile the model for a specific configuration + generation_config = GenerationConfig.from_pretrained(model_path) + generation_config.top_k = 1 + generation_config.do_sample = True + + runner = LlamaRunner(model_path=model_path, tokenizer_path=model_path, generation_config=generation_config) + batch_size = 2 + max_prompt_length = 128 + sequence_length = 512 + + quantized_state_dict = runner.generate_quantized_hf_checkpoints_on_cpu( + max_prompt_length=max_prompt_length, + sequence_length=sequence_length, + batch_size=batch_size, + quantized=True, + quantized_checkpoints_path="", + quantization_type="per_channel_symmetric", + ) + + # delete None values in the quantized_state_dict + keys_to_delete = [] + for key in quantized_state_dict: + if quantized_state_dict[key] is None: + keys_to_delete.append(key) + + print(f"Will be deleting following keys as its Value is None: {keys_to_delete}") + + for key in keys_to_delete: + del quantized_state_dict[key] + + torch.save(quantized_state_dict, path_to_save) + + return quantized_state_dict def llama_cpu_sample(): @@ -17,23 +59,25 @@ def llama_cpu_sample(): runner = LlamaRunner(model_path=model_path, tokenizer_path=model_path, generation_config=generation_config) batch_size = 2 - max_context_length = 128 - max_new_tokens = 384 + max_prompt_length = 128 + sequence_length = 512 prompt = ["I believe the meaning of life is", "The color of the sky is"] _, outputs = runner.generate_on_cpu( prompt=prompt, batch_size=batch_size, - max_context_length=max_context_length, - max_new_tokens=max_new_tokens, + max_prompt_length=max_prompt_length, + sequence_length=sequence_length, quantized=True, + quantized_checkpoints_path="", + quantization_type="per_channel_symmetric", ) print("\nGenerating ..") for idx, output in enumerate(outputs): print(f"output {idx}: {output}") -def llama_sample(): +def llama_sample(generate_checkpoint=False): # Compile the model for a specific configuration generation_config = GenerationConfig.from_pretrained(model_path) generation_config.top_k = 1 @@ -42,18 +86,26 @@ def llama_sample(): runner = LlamaRunner(model_path=model_path, tokenizer_path=model_path, generation_config=generation_config) batch_size = 2 - max_context_length = 128 - max_new_tokens = 384 + max_prompt_length = 128 + sequence_length = 512 + + if generate_checkpoint: + quantized_checkpoints_path = os.path.join(model_path, "model_quant.pt") + quantized_state_dict = runner.generate_quantized_hf_checkpoints_on_cpu( + max_prompt_length=max_prompt_length, sequence_length=sequence_length, batch_size=batch_size + ) + torch.save(quantized_state_dict, quantized_checkpoints_path) runner.trace( traced_model_path=traced_model_path, tp_degree=32, batch_size=batch_size, - context_lengths=max_context_length, - new_token_counts=max_new_tokens, + max_prompt_length=max_prompt_length, + sequence_length=sequence_length, on_device_sampling=True, quantized=True, - quantized_checkpoints_path="/home/ubuntu/LLama7b/model_quant.pt", + quantized_checkpoints_path=os.path.join(model_path, "model_quant.pt"), + quantization_type="per_channel_symmetric", ) # Perform inference @@ -67,7 +119,10 @@ def llama_sample(): for idx, output in enumerate(outputs): print(f"output {idx}: {output}") + runner.benchmark_sampling(neuron_model) + if __name__ == "__main__": # llama_cpu_sample() llama_sample() + # llama_get_quantized_checkpoint(os.path.join(model_path, "model_quant.pt")) diff --git a/examples/inference/run_llama_speculative.py b/examples/inference/run_llama_speculative.py index 183df32..bb4e9d9 100644 --- a/examples/inference/run_llama_speculative.py +++ b/examples/inference/run_llama_speculative.py @@ -13,8 +13,8 @@ def llama_sample(): # Batch size must be 1 for speculative decoding batch_size = 1 - max_context_length = 256 - max_new_tokens = 256 + max_prompt_length = 256 + sequence_length = 512 # Need to trace both target and draft models # We don't need to trace token generation model for target @@ -23,8 +23,8 @@ def llama_sample(): traced_model_path=traced_target_model_path, tp_degree=32, batch_size=batch_size, - context_lengths=max_context_length, - new_token_counts=max_new_tokens, + max_prompt_length=max_prompt_length, + sequence_length=sequence_length, speculation_length=5, trace_tokengen_model=False, ) @@ -32,8 +32,8 @@ def llama_sample(): traced_model_path=traced_draft_model_path, tp_degree=32, batch_size=batch_size, - context_lengths=max_context_length, - new_token_counts=max_new_tokens, + max_prompt_length=max_prompt_length, + sequence_length=sequence_length, ) target_model = target_runner.load_neuron_model(traced_target_model_path) @@ -44,8 +44,7 @@ def llama_sample(): target_runner.check_accuracy( target_model, batch_size, - max_context_length, - max_new_tokens, + sequence_length, traced_draft_model=draft_model, speculation_length=5, ) diff --git a/examples/inference/run_mixtral.py b/examples/inference/run_mixtral.py index d237e42..d614aac 100644 --- a/examples/inference/run_mixtral.py +++ b/examples/inference/run_mixtral.py @@ -1,9 +1,12 @@ +import torch from mixtral.mixtral_runner import MixtralRunner from transformers import GenerationConfig model_path = "/home/ubuntu/model_hf/Mixtral-8x7B-v0.1/" traced_model_path = "/home/ubuntu/traced_model/Mixtral-8x7B-v0.1/" +torch.manual_seed(0) + def mixtral_sample(): # Compile the model for a specific configuration @@ -14,29 +17,29 @@ def mixtral_sample(): runner = MixtralRunner(model_path=model_path, tokenizer_path=model_path, generation_config=generation_config) batch_size = 2 - max_context_length = 1024 - max_new_tokens = 1024 + max_prompt_length = 1024 + sequence_length = 2048 runner.trace( traced_model_path=traced_model_path, tp_degree=32, batch_size=batch_size, - context_lengths=max_context_length, - new_token_counts=max_new_tokens, + max_prompt_length=max_prompt_length, + sequence_length=sequence_length, ) # Load model weights into Neuron devise # We will use the returned model to run accuracy and perf tests - print("\ Loading model to Neuron device ..") + print("\nLoading model to Neuron device ..") neuron_model = runner.load_neuron_model(traced_model_path) # Confirm the traced model matches the huggingface model run on cpu print("\nChecking accuracy ..") - runner.check_accuracy(neuron_model, batch_size, max_context_length, max_new_tokens) + runner.check_accuracy(neuron_model, batch_size, sequence_length) # Perform inference - prompt = ["I believe the meaning of life is", "The color of the sky is"] + prompts = ["I believe the meaning of life is", "The color of the sky is"] print("\nGenerating ..") - _, outputs = runner.generate_on_neuron(prompt, neuron_model) + _, outputs = runner.generate_on_neuron(prompts, neuron_model) print("Generated outputs:") for idx, output in enumerate(outputs): print(f"output {idx}: {output}") diff --git a/examples/inference/runner.py b/examples/inference/runner.py index d9c37e6..774423c 100644 --- a/examples/inference/runner.py +++ b/examples/inference/runner.py @@ -2,19 +2,23 @@ import logging import os from contextlib import contextmanager -from typing import List +from functools import partial +from typing import List, Union import torch -from modules.benchmark import BENCHMARK_REPORT_FILENAME, Benchmark +from modules.benchmark import BENCHMARK_REPORT_FILENAME, Benchmark, LatencyCollector, generate_report from torch.profiler import ProfilerActivity, profile from transformers import AutoTokenizer, GenerationConfig, PreTrainedModel, set_seed +from transformers.generation import SampleDecoderOnlyOutput, SampleEncoderDecoderOutput +from torch_neuronx.testing.validation import logit_validation import neuronx_distributed as nxd END_TO_END_MODEL = "e2e_model" CONTEXT_ENCODING_MODEL = "context_encoding_model" TOKEN_GENERATION_MODEL = "token_generation_model" SPECULATION_MODEL = "speculation_model" +MEDUSA_MODEL = "medusa_speculation_model" LM_HEAD_NAME = "lm_head.pt" @@ -24,6 +28,11 @@ SPEC_MODEL_COMPILER_WORK_DIR = BASE_COMPILER_WORK_DIR + SPECULATION_MODEL + "/" +SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] + +TEST_PROMPT = "I believe the meaning of life is" + + class InferenceRunner: """ Use the runner class to trace the model and perform inference. @@ -40,10 +49,11 @@ def __init__(self, model_path: str = None, tokenizer_path: str = None, generatio self.tokenizer_path = tokenizer_path self._is_torch_profile_enabled = False - if generation_config == None: + if generation_config is None: generation_config = GenerationConfig.from_pretrained(model_path) generation_config.top_k = 1 generation_config.do_sample = True + generation_config.pad_token_id = 0 self.generation_config = generation_config @@ -51,11 +61,18 @@ def load_hf_model(self): # Implement per model raise NotImplementedError - def load_neuron_model_on_cpu(self, max_context_length, max_new_tokens, batch_size, **kwargs): + def load_neuron_model_on_cpu(self, max_prompt_length, sequence_length, batch_size, **kwargs): + # Implement per model + raise NotImplementedError + + def generate_quantized_hf_checkpoints_on_cpu(self, max_prompt_length, sequence_length, batch_size, **kwargs): # Implement per model + """ + This is a utility function that quantizes a HF model and returns the checkpoints + """ raise NotImplementedError - def load_quantized_neuron_model_on_cpu(self, max_context_length, max_new_tokens, batch_size, **kwargs): + def load_quantized_neuron_model_on_cpu(self, max_prompt_length, sequence_length, batch_size, **kwargs): # Implement per model raise NotImplementedError @@ -79,6 +96,13 @@ def get_padding_side(self): # Implement per model raise NotImplementedError + def get_default_hf_generation_config_kwargs(self) -> dict: + return { + 'do_sample': self.generation_config.do_sample, + 'top_k': self.generation_config.top_k, + 'pad_token_id': self.generation_config.pad_token_id + } + def enable_torch_profile(self): self._is_torch_profile_enabled = True @@ -115,25 +139,34 @@ def get_config_for_nxd( self, batch_size, tp_degree, - context_lengths, - new_token_counts, + max_prompt_length, + sequence_length, + enable_bucketing, **kwargs, ): - if isinstance(context_lengths, int): - context_lengths = [context_lengths] - - if isinstance(new_token_counts, int): - new_token_counts = [new_token_counts] + """ + Set up the value for config attributes if needed. + Please don't add new config attribute here. Instead, please add new + attributes in NeuronInferenceConfig or model-specific config class. + """ config_cls = self.get_config_cls() - config = config_cls.from_pretrained(self.model_path, **kwargs) + + merged_kwargs = self.get_default_hf_generation_config_kwargs() + if kwargs is not None: + merged_kwargs.update(kwargs) + config = config_cls.from_pretrained(self.model_path, **merged_kwargs) + config.tp_degree = tp_degree - config.max_context_length = context_lengths[-1] - config.max_new_tokens = new_token_counts[-1] - max_length = config.max_context_length + config.max_new_tokens + config.max_context_length = max_prompt_length + config.max_new_tokens = sequence_length - max_prompt_length + if config.max_new_tokens == 0: + config.max_new_tokens = None + max_length = sequence_length config.max_length = max_length config.n_positions = max_length + config.n_active_tokens = max_length if config.max_position_embeddings <= max_length: logging.warning( @@ -147,11 +180,8 @@ def get_config_for_nxd( config.batch_size = batch_size # bucketing specific - config.enable_context_encoding_bucketing, config.enable_token_generation_bucketing = [ - len(context_lengths) > 1, - ] * 2 - config.buckets = [cl + tc for cl, tc in zip(context_lengths, new_token_counts)] - config.bucket_n_active_tokens = config.enable_context_encoding_bucketing + config.enable_bucketing = enable_bucketing + config.buckets = [max_length] config.padding_side = self.get_padding_side() config.on_device_sampling = kwargs.get("on_device_sampling", False) @@ -160,24 +190,27 @@ def get_config_for_nxd( config.speculation_length = kwargs.get("speculation_length", 0) config.trace_tokengen_model = kwargs.get("trace_tokengen_model", True) - config.do_sample = self.generation_config.do_sample - config.top_k = self.generation_config.top_k config.quantized = kwargs.get("quantized", False) config.quantized_checkpoints_path = kwargs.get("quantized_checkpoints_path", None) if config.quantized is True: assert config.quantized_checkpoints_path is not None, "quantized_checkpoints_path is required" + config.quantization_type = kwargs.get("quantization_type", "per_tensor_symmetric") + config.is_medusa = kwargs.get("is_medusa", False) + config.medusa_speculation_length = kwargs.get("medusa_speculation_length", 0) + config.num_medusa_heads = kwargs.get("num_medusa_heads", 0) + config.pad_token_id = kwargs.get("pad_token_id", None) return config - def generate_with_hf(self, prompt, max_context_length: int, max_new_tokens: int, do_sample=True): + def generate_with_hf(self, prompts: List[str], max_length: int, **kwargs): """ Use this to generate CPU goldens against which the trace is validated. """ model = self.load_hf_model() tokenizer = self.load_tokenizer(padding_side="left") - return self.generate(model, tokenizer, prompt, max_context_length, max_new_tokens, do_sample=do_sample) + return self.generate(model, tokenizer, prompts, max_length, **kwargs) - def generate_on_neuron(self, prompt, model: PreTrainedModel, draft_model: PreTrainedModel = None): + def generate_on_neuron(self, prompts: List[str], model: PreTrainedModel, draft_model: PreTrainedModel = None, **kwargs): """ Runs the trace on Neuron. """ @@ -186,87 +219,90 @@ def generate_on_neuron(self, prompt, model: PreTrainedModel, draft_model: PreTra raise ValueError(f"Model should be of type PreTrainedModel, got type {type(model)}") tokenizer = self.load_tokenizer() - if len(prompt) != model.config.max_batch_size: + if len(prompts) != model.config.max_batch_size: raise ValueError(f"Number of prompts should match batch size {model.config.max_batch_size}") + max_length = kwargs.pop("max_length", model.config.max_length) + if (max_length > model.config.max_length): + ValueError(f"Found user supplied {max_length=} exceeds the compiled model sequence_length={model.config.max_length}") + with self.torch_profile(chrome_trace_path="generate-on-neuron.torch-trace.json"): - generate_ids, outputs = self.generate( - model, tokenizer, prompt, model.config.max_context_length, model.config.max_new_tokens, draft_model + outputs, output_tokens = self.generate( + model, tokenizer, prompts, max_length, draft_model, **kwargs ) model.reset() if draft_model is not None: draft_model.reset() - return generate_ids, outputs + return outputs, output_tokens - def generate_on_cpu(self, prompt: str, batch_size: int, max_context_length: int, max_new_tokens: int, **kwargs): + def generate_on_cpu(self, prompts: List[str], batch_size: int, max_prompt_length: int, sequence_length: int, **kwargs): """ Use generate_on_cpu to confirm the neuron wrapper is correct. If the wrapper works on CPU, then the trace should work too. If it does not, it indicates a problem with the trace itself. """ if kwargs.get("quantized", False) is False: - model = self.load_neuron_model_on_cpu(max_context_length, max_new_tokens, batch_size, **kwargs) + model = self.load_neuron_model_on_cpu(max_prompt_length, sequence_length, batch_size, **kwargs) else: - model = self.load_quantized_neuron_model_on_cpu(max_context_length, max_new_tokens, batch_size, **kwargs) + model = self.load_quantized_neuron_model_on_cpu(max_prompt_length, sequence_length, batch_size) tokenizer = self.load_tokenizer() - generate_ids, outputs = self.generate(model, tokenizer, prompt, max_context_length, max_new_tokens) + outputs, output_tokens = self.generate(model, tokenizer, prompts, sequence_length) model.reset() - return generate_ids, outputs + return outputs, output_tokens def generate( self, model: PreTrainedModel, tokenizer: AutoTokenizer, - prompt: str, - max_context_length: int, - max_new_tokens: int, + prompts: List[str], + max_length: int, draft_model: PreTrainedModel = None, - do_sample=True, + **kwargs ): set_seed(0) # to avoid randomness in sampling if any - max_length = max_context_length + max_new_tokens - inputs = tokenizer(prompt, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt") + inputs = tokenizer(prompts, padding=True, return_tensors="pt") for idx, input in enumerate(inputs["input_ids"]): - logging.debug("padded tokenized input %s : %s", idx, tokenizer.decode(input)) + logging.debug("tokenized input %s : %s", idx, tokenizer.decode(input)) if draft_model is not None: - generate_ids = model.generate( - inputs.input_ids, - attention_mask=inputs.attention_mask, - max_new_tokens=max_new_tokens, - do_sample=False, - assistant_model=draft_model, - pad_token_id=tokenizer.eos_token_id, # Set `pad_token_id` to `eos_token_id` for open-end generation - ) + kwargs.update({ + "assistant_model": draft_model, + "do_sample": False + }) + + outputs = model.generate( + inputs.input_ids, + generation_config=self.generation_config, + attention_mask=inputs.attention_mask, + max_length=max_length, + **kwargs, + ) + + if isinstance(outputs, SampleOutput.__args__): + # Get token ids from output when return_dict_in_generate=True + output_ids = outputs.sequences else: - generate_ids = model.generate( - inputs.input_ids, - attention_mask=inputs.attention_mask, - max_new_tokens=max_new_tokens, - top_k=1, - do_sample=do_sample, - pad_token_id=tokenizer.eos_token_id, # Set `pad_token_id` to `eos_token_id` for open-end generation - ) - outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) - return generate_ids, outputs + output_ids = outputs + output_tokens = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + return outputs, output_tokens def check_accuracy( self, traced_model: PreTrainedModel, batch_size: int, - max_context_length: int, - max_new_tokens: int, - expected_token_ids: List=None, - on_cpu: bool=False, - do_sample: bool=True, - traced_draft_model: PreTrainedModel=None, - speculation_length: int=0, + max_length: int, + expected_token_ids: List = None, + on_cpu: bool = False, + do_sample: bool = True, + traced_draft_model: PreTrainedModel = None, + speculation_length: int = 0, + **kwargs, ): """ Function to compare outputs from huggingface model and neuronx NxD model """ - prompt = ["I believe the meaning of life is"] * batch_size + prompts = [TEST_PROMPT] * batch_size tokenizer = self.load_tokenizer() if expected_token_ids is not None: @@ -276,19 +312,22 @@ def check_accuracy( else: # Generate goldens with HF on CPU expected_token_ids, outputs_expected = self.generate_with_hf( - prompt, max_context_length, max_new_tokens, do_sample=do_sample + prompts, max_length, do_sample=do_sample ) print(f"Expected output: {outputs_expected}") # Generate outputs with NxD - prompt = ["I believe the meaning of life is"] * batch_size if on_cpu: + max_prompt_length = kwargs.pop("max_prompt_length") output_token_ids, outputs_actual = self.generate_on_cpu( - prompt, batch_size, max_context_length, max_new_tokens + prompts, + batch_size, + max_prompt_length=max_prompt_length, + sequence_length=max_length ) else: output_token_ids, outputs_actual = self.generate_on_neuron( - prompt, traced_model, traced_draft_model + prompts, traced_model, traced_draft_model, do_sample=do_sample, max_length=max_length ) print(f"Actual output : {outputs_actual}") @@ -304,25 +343,82 @@ def check_accuracy( expected_token_ids = expected_token_ids[: tokens_to_compare] output_token_ids = output_token_ids[: tokens_to_compare] - device = "cpu" if on_cpu else "neuron" assert torch.equal( output_token_ids, expected_token_ids ), f"\nActual: ({device}) {output_token_ids} \nExpected (hf-cpu): {expected_token_ids}" print(f"The output from Neuronx NxD on {device} is accurate!") + def check_accuracy_logits( + self, + traced_model: PreTrainedModel, + batch_size: int, + max_length: int, + expected_logits: torch.Tensor = None, + divergence_difference_tol: float = 0.001, + remove_shift: bool = True, + tol_map: dict = None, + ): + if traced_model.config.on_device_sampling: + raise ValueError("Logits validation is not supported with on-device sampling.") + + prompts = [TEST_PROMPT] * batch_size + tokenizer = self.load_tokenizer() + inputs = tokenizer(prompts, padding=True, return_tensors="pt") + + if not expected_logits: + # logit_validation assumes greedy sampling + expected_outputs, _ = self.generate_with_hf( + prompts, max_length, do_sample=False, output_logits=True, return_dict_in_generate=True, + ) + expected_logits = torch.stack(expected_outputs.logits) + expected_token_ids = expected_logits.argmax(dim=2).T + expected_tokens = tokenizer.batch_decode( + expected_token_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + print("Expected Output: ", expected_tokens, expected_token_ids) + print("Expected Logits Shape: ", expected_logits.shape) + + def generate_logits(model, tokenizer, input_ids): + prompt = tokenizer.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + actual_outputs, actual_tokens = self.generate_on_neuron( + prompt, traced_model, do_sample=True, output_logits=True, return_dict_in_generate=True, + max_length=max_length + ) + actual_logits = torch.stack(actual_outputs.logits) + actual_token_ids = actual_logits.argmax(dim=2).T + print("Actual Output: ", actual_tokens, actual_token_ids) + print("Actual Logits Shape: ", actual_logits.shape) + model.reset() + return actual_logits + + generate_fn = partial(generate_logits, traced_model, tokenizer) + passed, result, status_msg = logit_validation(inputs.input_ids, + generate_fn, + expected_logits, + divergence_difference_tol=divergence_difference_tol, + tol_map=tol_map, + pad_token_id=tokenizer.pad_token_id, + padding_side=tokenizer.padding_side) + assert passed, status_msg + print("Passed logits validation") + def trace( self, traced_model_path, tp_degree, batch_size, - context_lengths, - new_token_counts, + max_prompt_length=128, + sequence_length=256, + enable_bucketing=True, **kwargs, ): """ Function to trace a model with neuronx NxD """ + if (sequence_length <= max_prompt_length): + raise ValueError(f"Found {sequence_length=} which is less than or equal to {max_prompt_length=}. Please make sure sequence_length is strictly greater than max_prompt_length") + if traced_model_path is not None: if not os.path.exists(traced_model_path): os.makedirs(traced_model_path) @@ -331,8 +427,9 @@ def trace( config = self.get_config_for_nxd( batch_size, tp_degree, - context_lengths, - new_token_counts, + max_prompt_length, + sequence_length, + enable_bucketing, **kwargs, ) if config.torch_dtype != torch.float32 and config.torch_dtype != torch.bfloat16: @@ -351,11 +448,14 @@ def trace( tokenizer.save_pretrained(traced_model_path) model = self.get_model_cls().from_pretrained(self.model_path, config) + model.compile(serialize_base_path=traced_model_path) - def benchmark_sampling(self, model: PreTrainedModel, draft_model: PreTrainedModel=None, target: str=None): + def benchmark_sampling(self, model: PreTrainedModel, draft_model: PreTrainedModel = None, target: str = None): config = model.config tokenizer = self.load_tokenizer() + tokenizer.pad_token = tokenizer.eos_token + target = target if target is not None else "all" report = {} @@ -371,26 +471,54 @@ def benchmark_sampling(self, model: PreTrainedModel, draft_model: PreTrainedMode "do_sample": draft_model is None, "assistant_model": draft_model, } - e2e_benchmark = Benchmark(model.generate, input_param, config, preprocess_func=model.reset) - report[END_TO_END_MODEL] = e2e_benchmark.run() - # Benchmark context encoding model - if target in ["all", "context_encode"]: + if target == "all": + latency_collectors = self.create_submodule_latency_collectors(model) + + # Register latency collectors after warm-up to avoid recording warm-up metrics. + def register_latency_collectors(): + if target == "all": + self.register_latency_collectors(latency_collectors, model) + + e2e_benchmark = Benchmark(model.generate, input_param, config, preprocess_func=model.reset, + post_warmup_func=register_latency_collectors) + e2e_benchmark.run() + report[END_TO_END_MODEL] = generate_report(e2e_benchmark.latency_list, config) + + if target == "all": + report.update(self.generate_submodule_reports(latency_collectors, config)) + + # Benchmark context encoding model only + if target == "context_encode": input_param = self.get_sample_inputs(CONTEXT_ENCODING_MODEL, config) ctx_enc_benchmark = Benchmark(model.context_encoding_model, input_param, config) - report[CONTEXT_ENCODING_MODEL] = ctx_enc_benchmark.run() + ctx_enc_benchmark.run() + report[CONTEXT_ENCODING_MODEL] = generate_report(ctx_enc_benchmark.latency_list, config) - # Benchmark token generation model - if hasattr(model, "token_generation_model") and target in ["all", "token_gen"]: + # Benchmark token generation model only + if hasattr(model, "token_generation_model") and target == "token_gen": input_param = self.get_sample_inputs(TOKEN_GENERATION_MODEL, config) tkn_gen_benchmark = Benchmark(model.token_generation_model, input_param, config) - report[TOKEN_GENERATION_MODEL] = tkn_gen_benchmark.run() + tkn_gen_benchmark.run() + report[TOKEN_GENERATION_MODEL] = generate_report(tkn_gen_benchmark.latency_list, config) - # Benchmark speculation model - if hasattr(model, "speculation_model") and target in ["all", "speculation"]: + # Benchmark speculation model only + if hasattr(model, "speculation_model") and target == "speculation": input_param = self.get_sample_inputs(SPECULATION_MODEL, config) spec_benchmark = Benchmark(model.speculation_model, input_param, config) - report[SPECULATION_MODEL] = spec_benchmark.run() + spec_benchmark.run() + report[SPECULATION_MODEL] = generate_report(spec_benchmark.latency_list, config) + + # Benchmark Medusa speculation model + if hasattr(model, "medusa_speculation_model") and target == "speculation": + input_param = self.get_sample_inputs(MEDUSA_MODEL, config) + spec_benchmark = Benchmark(model.medusa_speculation_model, input_param, config) + spec_benchmark.run() + report[MEDUSA_MODEL] = generate_report(spec_benchmark.latency_list, config) + + model.reset() + if draft_model is not None: + draft_model.reset() print("Benchmark completed and its result is as following") print(json.dumps(report, indent=4)) @@ -403,11 +531,13 @@ def benchmark_sampling(self, model: PreTrainedModel, draft_model: PreTrainedMode def get_sample_inputs(self, model_type, config, tokenizer=None): max_length = config.max_length batch_size = config.batch_size + num_medusa_heads = config.num_medusa_heads if config.num_medusa_heads else 4 + medusa_speculation_length = config.medusa_speculation_length if config.medusa_speculation_length else 64 sample_inputs = None if model_type == END_TO_END_MODEL: sample_inputs = tokenizer( - ["I believe the meaning of life is"] * batch_size, + [TEST_PROMPT] * batch_size, max_length=max_length, truncation=True, padding="max_length", @@ -419,21 +549,101 @@ def get_sample_inputs(self, model_type, config, tokenizer=None): attention_mask = torch.zeros((batch_size, max_length), dtype=torch.int64) position_ids = torch.zeros((batch_size, max_length), dtype=torch.int64) seq_ids = torch.zeros((batch_size), dtype=torch.int64) - sample_inputs = (input_ids, attention_mask, position_ids, seq_ids) + if config.is_medusa: + accepted_indices = torch.zeros((batch_size, num_medusa_heads + 1), dtype=torch.int64) + current_length = torch.zeros((batch_size, num_medusa_heads + 1), dtype=torch.int64) + medusa_mask = torch.zeros( + (batch_size, medusa_speculation_length, medusa_speculation_length), dtype=torch.int64 + ) + scatter_index = torch.zeros((batch_size, medusa_speculation_length), dtype=torch.int64) + sample_inputs = ( + input_ids, + attention_mask, + position_ids, + seq_ids, + accepted_indices, + current_length, + medusa_mask, + scatter_index, + ) + else: + sample_inputs = ( + input_ids, + attention_mask, + position_ids, + seq_ids, + ) elif model_type == TOKEN_GENERATION_MODEL: input_ids = torch.zeros((batch_size, 1), dtype=torch.int64) attention_mask = torch.zeros((batch_size, max_length), dtype=torch.int64) position_ids = torch.zeros((batch_size, 1), dtype=torch.int64) seq_ids = torch.zeros((batch_size), dtype=torch.int64) - sample_inputs = (input_ids, attention_mask, position_ids, seq_ids) - + sample_inputs = ( + input_ids, + attention_mask, + position_ids, + seq_ids, + ) elif model_type == SPECULATION_MODEL: spec_len = config.speculation_length input_ids = torch.zeros((batch_size, spec_len), dtype=torch.int64) attention_mask = torch.zeros((batch_size, max_length), dtype=torch.int64) position_ids = torch.zeros((batch_size, spec_len), dtype=torch.int64) seq_ids = torch.zeros((batch_size), dtype=torch.int64) - sample_inputs = (input_ids, attention_mask, position_ids, seq_ids) + sample_inputs = ( + input_ids, + attention_mask, + position_ids, + seq_ids, + ) + + elif model_type == MEDUSA_MODEL: + spec_len = config.medusa_speculation_length + input_ids = torch.zeros((batch_size, spec_len), dtype=torch.int64) + attention_mask = torch.zeros((batch_size, max_length), dtype=torch.int64) + position_ids = torch.zeros((batch_size, spec_len), dtype=torch.int64) + seq_ids = torch.zeros((batch_size), dtype=torch.int64) + accepted_indices = torch.zeros((batch_size, num_medusa_heads + 1), dtype=torch.int64) + current_length = torch.zeros((batch_size, num_medusa_heads + 1), dtype=torch.int64) + medusa_mask = torch.zeros( + (batch_size, medusa_speculation_length, medusa_speculation_length), dtype=torch.int64 + ) + scatter_index = torch.zeros((batch_size, medusa_speculation_length), dtype=torch.int64) + sample_inputs = ( + input_ids, + attention_mask, + position_ids, + seq_ids, + accepted_indices, + current_length, + medusa_mask, + scatter_index, + ) return sample_inputs + + def create_submodule_latency_collectors(self, model): + collectors = {} + collectors[CONTEXT_ENCODING_MODEL] = LatencyCollector() + if hasattr(model, "token_generation_model"): + collectors[TOKEN_GENERATION_MODEL] = LatencyCollector() + if hasattr(model, "speculation_model"): + collectors[SPECULATION_MODEL] = LatencyCollector() + return collectors + + def register_latency_collectors(self, latency_collectors, model): + self.register_forward_latency_collector(latency_collectors[CONTEXT_ENCODING_MODEL], + model.context_encoding_model) + if TOKEN_GENERATION_MODEL in latency_collectors: + self.register_forward_latency_collector(latency_collectors[TOKEN_GENERATION_MODEL], + model.token_generation_model) + if SPECULATION_MODEL in latency_collectors: + self.register_forward_latency_collector(latency_collectors[SPECULATION_MODEL], model.speculation_model) + + def register_forward_latency_collector(self, latency_collector, model): + model.register_forward_pre_hook(latency_collector.pre_hook) + model.register_forward_hook(latency_collector.hook) + + def generate_submodule_reports(self, latency_collectors, config): + return {key : generate_report(collector.latency_list, config) for key, collector in latency_collectors.items()} diff --git a/examples/training/checkpoint_converter.py b/examples/training/checkpoint_converter.py deleted file mode 100644 index 4d2d6df..0000000 --- a/examples/training/checkpoint_converter.py +++ /dev/null @@ -1,493 +0,0 @@ -import argparse -import json -import os -import re - -import torch -import torch_xla.utils.serialization as xser - -from neuronx_distributed.pipeline.partition import ( - create_partitions, - stage_to_pipeline_parallel_rank, -) - - -class CheckpointConverterBase: - - # ParallelEmbedding - embedding_partition_dim = 0 - # ColumnParallelLinear or GQAQKVColumnParallelLinear - qkv_partition_dim = 0 - # ColumnParallelLinear - gate_up_proj_partition_dim = 0 - # RowParallelLinear - down_proj_partition_dim = 1 - # RowParallelLinear - o_proj_partition_dim = 1 - - def get_partition_dim(self, name): - if "embed_tokens" in name or "lm_head" in name: - partition_dim = self.embedding_partition_dim - elif self.is_qkv_weight(name): - partition_dim = self.qkv_partition_dim - elif "gate_proj" in name or "up_proj" in name or "gate_up_proj" in name: - partition_dim = self.gate_up_proj_partition_dim - elif "down_proj" in name: - partition_dim = self.down_proj_partition_dim - elif "o_proj" in name: - partition_dim = self.o_proj_partition_dim - else: - raise AssertionError(f"Unknown partition_dim for {name}") - return partition_dim - - # QKV Helper functions - def get_hf_to_nxd_model_keys(self, qkv_linear=True, is_gqa=True): - if qkv_linear: - keys_hf_to_nxd = { - "q_proj.weight": "qkv_proj.weight_q", - "k_proj.weight": "qkv_proj.weight_k", - "v_proj.weight": "qkv_proj.weight_v", - } - elif is_gqa: - keys_hf_to_nxd = { - "q_proj.weight": "q_proj.weight", - "k_proj.weight": "k_proj.weight", - "v_proj.weight": "v_proj.weight", - } - else: - keys_hf_to_nxd = { - "q_proj.weight": "qkv_proj.weight", - "k_proj.weight": "qkv_proj.weight", - "v_proj.weight": "qkv_proj.weight", - } - keys_nxd_to_hf = {v: k for k, v in keys_hf_to_nxd.items()} - return keys_hf_to_nxd, keys_nxd_to_hf - - def is_qkv_weight(self, name): - return "q_proj" in name or "k_proj" in name or "v_proj" in name or "qkv_proj" in name - - def coalesce_qkv(self, state_dict, config, tp_degree): - for i in range(config["num_hidden_layers"]): - q = state_dict.pop(f"model.layers.{i}.self_attn.q_proj.weight") - k = state_dict.pop(f"model.layers.{i}.self_attn.k_proj.weight") - v = state_dict.pop(f"model.layers.{i}.self_attn.v_proj.weight") - partition_size = config["hidden_size"] // tp_degree - tp_partititons = [] - for tp_rank in range(tp_degree): - q_split = q.narrow(0, tp_rank * partition_size, partition_size).detach().clone() - k_split = k.narrow(0, tp_rank * partition_size, partition_size).detach().clone() - v_split = v.narrow(0, tp_rank * partition_size, partition_size).detach().clone() - tp_partititons.append(torch.cat([q_split, k_split, v_split], dim=self.qkv_partition_dim)) - - state_dict[f"model.layers.{i}.self_attn.qkv_proj.weight"] = torch.cat(tp_partititons, dim=self.qkv_partition_dim) - - return state_dict - - def get_weight_key(self, keys_hf_to_nxd, keys_nxd_to_hf, name, hf_to_nxd): - if not self.is_qkv_weight(name): - return name - - keys = keys_hf_to_nxd if hf_to_nxd else keys_nxd_to_hf - return ".".join(name.split(".")[:-2]) + "." + keys[".".join(name.split(".")[-2:])] - - # Helper function for convert_to_full_state() - def merge_tp_checkpoints(self, args): - full_state = {} - with open(args.config, "r") as f: - config = json.load(f) - q_heads = config["num_attention_heads"] - kv_heads = config["num_key_value_heads"] - head_dim = config["hidden_size"] // q_heads - is_gqa = q_heads != kv_heads - keys_hf_to_nxd, keys_nxd_to_hf = self.get_hf_to_nxd_model_keys(args.qkv_linear, is_gqa) - - for tp_rank in range(args.tp_size): - for pp_rank in range(args.pp_size): - if args.load_xser: - partial_state = self.load_partial_xser(args, tp_rank, pp_rank) - else: - partial_state = self.load_partial_no_xser(args, tp_rank, pp_rank) - if args.model_key is not None and args.model_key in partial_state: - partial_state = partial_state[args.model_key] - for name, param in partial_state.items(): - if (self.is_qkv_weight(name) or "o_proj" in name) and args.qkv_linear: - # qkv_proj would be a key if we are using the QKVLinear layer - partition_dim = self.get_partition_dim(name) - name = self.get_weight_key(keys_hf_to_nxd, keys_nxd_to_hf, name, False) - - if name not in full_state: - full_state[name] = [] - - full_state[name].append(param) - if tp_rank != (args.tp_size - 1): - continue - - full_weight = torch.cat(full_state[name], dim=partition_dim) - if "k" in name or "v" in name: - # If kv_multiplier is set, the kv heads are repeated. So we need to - # take only the first chunk - full_state[name] = torch.chunk(full_weight, args.kv_size_multiplier)[0].detach().clone() - else: - # Since we do the replication of KV heads, the Q heads are placed as: - # Q0Q1Q8Q9...Q2Q3Q10Q11... - # Hence when creating the merged checkpoint, we need to bring the Q heads and o_proj in order. - if "o_proj" in name: - # The shuffling is same for both o_proj and q, but o_proj is sharded on column. - # Hence to reuse the same shuffling code, we just transpose, do the shuffling and - # transpose back - full_weight = torch.transpose(full_weight, 0, 1) - weights = full_weight.reshape(q_heads, head_dim, -1) - weights_shape = weights.size() - weights = weights.reshape( - -1, q_heads // (kv_heads * args.kv_size_multiplier), head_dim, weights_shape[-1] - ) - weight_splits = [] - indicies = torch.arange(0, args.tp_size // kv_heads) * kv_heads - for i in range(kv_heads): - weight_splits.append(weights[indicies + i].reshape(-1, weights_shape[-1])) - full_weight = torch.cat(weight_splits, dim=self.qkv_partition_dim) - full_state[name] = ( - torch.transpose(full_weight, 0, 1).detach().clone() - if "o_proj" in name - else full_weight.detach().clone() - ) - elif "qkv_proj" in name and not is_gqa: - partition_dim = self.get_partition_dim(name) - partition_size = config["hidden_size"] // args.tp_size - q, k, v = torch.split(param, partition_size, dim=partition_dim) - q_name = name.replace("qkv", "q") - k_name = name.replace("qkv", "k") - v_name = name.replace("qkv", "v") - for name, weight in zip([q_name, k_name, v_name], [q, k, v]): - if name not in full_state: - full_state[name] = [] - full_state[name].append(weight) - if tp_rank == (args.tp_size - 1): - full_weight = torch.cat(full_state[name], dim=partition_dim) - full_state[name] = full_weight.detach().clone() - elif ( - "embed_tokens" in name - or self.is_qkv_weight(name) - or "o_proj" in name - or "down_proj" in name - or "lm_head" in name - ): - partition_dim = self.get_partition_dim(name) - name = self.get_weight_key(keys_hf_to_nxd, keys_nxd_to_hf, name, False) - if name not in full_state: - full_state[name] = [] - full_state[name].append(param) - if tp_rank == (args.tp_size - 1): - full_weight = torch.cat(full_state[name], dim=partition_dim) - full_state[name] = full_weight.detach().clone() - elif "gate_up_proj" in name: - partition_dim = self.get_partition_dim(name) - dim_size = param.size()[partition_dim] // 2 - gate_proj_name = name.replace("gate_up_proj", "gate_proj") - up_proj_name = name.replace("gate_up_proj", "up_proj") - gate_proj_weight = param.narrow(partition_dim, 0, dim_size).detach().clone() - up_proj_weight = param.narrow(partition_dim, dim_size, dim_size).detach().clone() - if gate_proj_name not in full_state: - full_state[gate_proj_name] = [] - if up_proj_name not in full_state: - full_state[up_proj_name] = [] - full_state[gate_proj_name].append(gate_proj_weight) - full_state[up_proj_name].append(up_proj_weight) - if tp_rank == (args.tp_size - 1): - full_gate_proj_weight = torch.cat(full_state[gate_proj_name], dim=partition_dim) - full_up_proj_weight = torch.cat(full_state[up_proj_name], dim=partition_dim) - full_state[gate_proj_name] = full_gate_proj_weight - full_state[up_proj_name] = full_up_proj_weight - else: - if name not in full_state: - full_state[name] = param - full_state = self.post_process_full_state_after_tp_conversion(full_state, args) - return full_state - - # Helper function for convert_from_full_state() - def convert_full_state_to_tp(self, full_state, args, tp_rank, pp_rank, partitions, config): - tp_size = args.tp_size - pp_size = args.pp_size - kv_size_multiplier = args.kv_size_multiplier - - partial_state = {} - q_heads = config["num_attention_heads"] - kv_heads = config["num_key_value_heads"] - head_dim = config["hidden_size"] // q_heads - - is_gqa = q_heads != kv_heads - keys_hf_to_nxd, keys_nxd_to_hf = self.get_hf_to_nxd_model_keys(args.qkv_linear, is_gqa) - - for name, full_p in full_state.items(): - ##################### PP Slice ######################################### - # Embedding only in first PP - if pp_rank != 0 and "embed_tokens" in name: - continue - # LMhead and final layer norm only in last PP rank - if pp_rank != pp_size - 1 and ("lm_head" in name or "model.norm.weight" in name): - continue - if "layers" in name: - layer_idx = int(name.split(".")[2]) - current_stage = len(partitions) - # iterate through the pp cuts and find the current stage - for stage, pp_cut in enumerate(partitions): - cut_layer_idx = int(pp_cut.split(".")[2]) - if layer_idx <= cut_layer_idx: - current_stage = stage - break - current_pp_rank = stage_to_pipeline_parallel_rank(current_stage, pp_size=pp_size) - if current_pp_rank != pp_rank: - continue - - ##################### TP Slice ######################################### - if (self.is_qkv_weight(name) or "o_proj" in name) and args.qkv_linear: - name = self.get_weight_key(keys_hf_to_nxd, keys_nxd_to_hf, name, True) - if "weight_k" in name or "weight_v" in name: - repeated_kv = full_p.repeat(kv_size_multiplier, 1) - - dim_size = repeated_kv.size()[0] - assert dim_size % tp_size == 0, "0th dim after KV replication is not divisible by tp_size" - partition_size = dim_size // tp_size - with torch.no_grad(): - to_load = repeated_kv.narrow(0, tp_rank * partition_size, partition_size).detach().clone() - # Cloning the tensor is really important, since we have performed slice and reshape operations. - # These operations are just views and if we don't clone, we would end up saving the entire tensor - partial_state[name] = to_load.detach().clone() - else: - # When GQAQKV linear with kv_multiplier is used, we need to reshuffle the order of Q heads - # so they interact with the right KV heads. Now since the heads are shuffled, we have to - # shuffle the o_proj rows since that translates the heads to hidden dim - if "o_proj" in name: - # The shuffling is same for both o_proj and q, but o_proj is sharded on column. - # Hence to reuse the same shuffling code, we just transpose, do the shuffling and - # transpose back - full_p = torch.transpose(full_p, 0, 1) - weights = full_p.reshape(q_heads, head_dim, -1) - weights_shape = weights.size() - weights = weights.reshape(-1, q_heads // (kv_heads * kv_size_multiplier), head_dim, weights_shape[-1]) - weight_splits = [] - indicies = torch.arange(0, kv_heads) * tp_size // kv_heads - for i in range(tp_size // kv_heads): - weight_splits.append(weights[indicies + i]) - weights = torch.cat(weight_splits, dim=self.qkv_partition_dim) - with torch.no_grad(): - to_load = weights[tp_rank].reshape(-1, weights_shape[-1]) - if "o_proj" in name: - to_load = torch.transpose(to_load, 0, 1) - # Cloning the tensor is really important, since we have performed slice and reshape operations. - # These operations are just views and if we don't clone, we would end up saving the entire tensor - partial_state[name] = to_load.detach().clone() - elif ( - "embed_tokens" in name - or self.is_qkv_weight(name) - or "o_proj" in name - or "down_proj" in name - or "lm_head" in name - ): - partition_dim = self.get_partition_dim(name) - dim_size = full_p.size()[partition_dim] - assert dim_size % tp_size == 0, "vocab size is not divisiable" - partition_size = dim_size // tp_size - with torch.no_grad(): - to_load = full_p.narrow(partition_dim, tp_rank * partition_size, partition_size) - partial_state[name] = to_load.detach().clone() - elif "gate_proj" in name or "up_proj" in name: - partition_dim = self.get_partition_dim(name) - dim_size = full_p.size()[partition_dim] - assert dim_size % tp_size == 0, "vocab size is not divisiable" - partition_size = dim_size // tp_size - with torch.no_grad(): - to_load = full_p.narrow(partition_dim, tp_rank * partition_size, partition_size).detach().clone() - token = "gate_proj" if "gate_proj" in name else "up_proj" - updated_name = name.replace(token, "gate_up_proj") - if updated_name in partial_state: - if token == "gate_proj": - partial_state[updated_name] = ( - torch.cat([to_load, partial_state[updated_name]], dim=partition_dim).detach().clone() - ) - else: - partial_state[updated_name] = ( - torch.cat([partial_state[updated_name], to_load], dim=partition_dim).detach().clone() - ) - else: - partial_state[updated_name] = to_load.detach().clone() - else: - # no TP - partial_state[name] = full_p - return partial_state - - # Placeholder functions for additional processing of full_state - def pre_process_full_state_before_tp_conversion(self, full_state, args): - """Child classes can override this function to implement custom logic.""" - return full_state - - def post_process_full_state_after_tp_conversion(self, full_state, args): - """Child classes can override this function to implement custom logic.""" - return full_state - - # Helper functions for save/load - def load_full_state(self, args): - full_state = torch.load(args.input_dir) - return full_state - - def get_input_filename(self, args, tp_rank, pp_rank, xser): - if xser: - old_api_filename = os.path.join(args.input_dir, "tp_rank_{:02d}_pp_rank_{:02d}".format(tp_rank, pp_rank)) - else: - old_api_filename = os.path.join( - args.input_dir, "tp_rank_{:02d}_pp_rank_{:02d}".format(tp_rank, pp_rank), "checkpoint.pt" - ) - - new_api_filename = os.path.join( - args.input_dir, "dp_rank_00_tp_rank_{:02d}_pp_rank_{:02d}.pt".format(tp_rank, pp_rank) - ) - - if os.path.exists(old_api_filename): - return old_api_filename - - if os.path.exists(new_api_filename): - return new_api_filename - - raise RuntimeError(f"Error: neither {old_api_filename} nor {new_api_filename} exist") - - def get_output_filename(self, args, tp_rank, pp_rank, xser): - return os.path.join( - args.output_dir, "model", "dp_rank_00_tp_rank_{:02d}_pp_rank_{:02d}.pt".format(tp_rank, pp_rank) - ) - - def load_partial_xser(self, args, tp_rank, pp_rank): - filename = self.get_input_filename(args, tp_rank, pp_rank, 1) - partial_state = xser.load(filename) - return partial_state - - def load_partial_no_xser(self, args, tp_rank, pp_rank): - filename = self.get_input_filename(args, tp_rank, pp_rank, 0) - partial_state = torch.load(filename) - return partial_state - - def save_full(self, args, full_state): - save_path = args.output_dir - os.makedirs(save_path, exist_ok=True) - if os.path.isdir(save_path): - save_path = os.path.join(save_path, "checkpoint.pt") - print(f"Saving full checkpoint to {save_path}") - torch.save(full_state, save_path) - - def save_partial_xser(self, args, partial_state, tp_rank, pp_rank): - filename = self.get_output_filename(args, tp_rank, pp_rank, 1) - os.makedirs(args.output_dir + "/model", exist_ok=True) - print(f"Saving to {filename}") - xser.save(partial_state, filename) - - def save_partial_no_xser(self, args, partial_state, tp_rank, pp_rank): - filename = self.get_output_filename(args, tp_rank, pp_rank, 0) - os.makedirs(args.output_dir + "/model", exist_ok=True) - print(f"Saving to {filename}") - torch.save(partial_state, filename) - - # Main functions to run checkpoint conversion - def convert_from_xser(self, args): - for tp_rank in range(args.tp_size): - for pp_rank in range(args.pp_size): - partial_state = self.load_partial_xser(args, tp_rank, pp_rank) - self.save_partial_no_xser(args, partial_state, tp_rank, pp_rank) - - def convert_to_xser(self, args): - for tp_rank in range(args.tp_size): - for pp_rank in range(args.pp_size): - partial_state = self.load_partial_no_xser(args, tp_rank, pp_rank) - self.save_partial_xser(args, partial_state, tp_rank, pp_rank) - - def convert_from_full_state(self, args): - full_state = self.load_full_state(args) - layer_name_pattern = r"^(model\.layers\.\d+)" - model_layer_names = sorted( - list( - set( - [ - re.match(layer_name_pattern, key).group(1) - for key in full_state.keys() - if re.match(layer_name_pattern, key) - ] - ) - ), - key=lambda x: int(re.search(r"\d+", x).group()), - ) - partitions = create_partitions(args.pp_size * args.virtual_pp_size, model_layer_names) - print(f"pipeline_cuts {partitions}") - with open(args.config, "r") as f: - config = json.load(f) - if args.coalesce_qkv: - full_state = self.coalesce_qkv(full_state, config, args.tp_size) - - full_state = self.pre_process_full_state_before_tp_conversion(full_state, args) - - for tp_rank in range(args.tp_size): - for pp_rank in range(args.pp_size): - partial_state = self.convert_full_state_to_tp( - full_state, - args, - tp_rank, - pp_rank, - partitions, - config, - ) - if args.save_xser: - self.save_partial_xser(args, partial_state, tp_rank, pp_rank) - else: - self.save_partial_no_xser(args, partial_state, tp_rank, pp_rank) - - def convert_to_full_state(self, args): - full_state = self.merge_tp_checkpoints(args) - self.save_full(args, full_state) - - # Argument parsing and execution - def get_arg_parser(self): - """Child classes can override this to add new arguments.""" - - parser = argparse.ArgumentParser() - parser.add_argument("--input_dir", type=str, required=True, help="Path to input model/weights") - parser.add_argument("--output_dir", type=str, required=True, help="Path to save converted model/weights") - parser.add_argument("--config", type=str, help="Config.json") - parser.add_argument( - "--model_key", type=str, default="model", help="Key of the model state dict in the checkpoint object" - ) - parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallel degree for the model") - parser.add_argument("--pp_size", type=int, default=1, help="Pipeline Parallel degree for the model") - parser.add_argument("--virtual_pp_size", type=int, default=1, help="Virtual Pipeline Parallel degree for the model") - parser.add_argument("--n_layers", type=int, default=0, help="Number of Layers") - parser.add_argument("--coalesce_qkv", type=bool, default=False, help="whether to coalesce qkv") - parser.add_argument( - "--kv_size_multiplier", type=int, default=1, help="Factor by which the KV heads were replicated" - ) - parser.add_argument( - "--qkv_linear", type=bool, default=False, help="Factor by which the KV heads were replicated" - ) - parser.add_argument("--load_xser", type=bool, default=False, help="Load from xser saved checkpoints") - parser.add_argument("--save_xser", type=bool, default=False, help="Save with xser") - parser.add_argument( - "--convert_from_xser", action="store_true", help="Convert xser saved checkpoint to normal torch checkpoint" - ) - parser.add_argument( - "--convert_to_xser", action="store_true", help="Convert normal torch checkpoint to xser checkpoint" - ) - parser.add_argument("--convert_from_full_state", action="store_true", help="Convert full model to sharded model") - parser.add_argument("--convert_to_full_state", action="store_true", help="Convert sharded model to full model") - return parser - - def run(self, args): - """Main function used to run checkpoint conversion.""" - - assert sum( - int(getattr(args, flag)) - for flag in ["convert_from_full_state", "convert_to_full_state", "convert_from_xser", "convert_to_xser"] - ) == 1, "Exactly one '--convert_*' flag must be specified" - - if args.convert_from_full_state: - self.convert_from_full_state(args) - elif args.convert_to_full_state: - self.convert_to_full_state(args) - elif args.convert_from_xser: - self.convert_from_xser(args) - elif args.convert_to_xser: - self.convert_to_xser(args) diff --git a/examples/training/codegen25/config.json b/examples/training/codegen25/config.json index 4b32bca..fb1bdba 100644 --- a/examples/training/codegen25/config.json +++ b/examples/training/codegen25/config.json @@ -25,4 +25,4 @@ "sequence_parallel_enabled": false, "selective_checkpoint_enabled": false, "move_model_to_device": true -} \ No newline at end of file +} diff --git a/examples/training/codegen25/get_dataset_infill.py b/examples/training/codegen25/get_dataset_infill.py index a0299ee..1a25ba0 100644 --- a/examples/training/codegen25/get_dataset_infill.py +++ b/examples/training/codegen25/get_dataset_infill.py @@ -7,6 +7,7 @@ from datasets import load_dataset from transformers import AutoTokenizer + def get_args(): parser = argparse.ArgumentParser() group = parser.add_argument_group(title='input data') @@ -30,6 +31,7 @@ def get_args(): args = parser.parse_args() return args + def sample_num_spans(rng, max_num_spans=16, num_samples=1): # Choose at leat 1 span num_spans = rng.poisson(1.0, size=(num_samples)) + 1 @@ -39,10 +41,10 @@ def sample_num_spans(rng, max_num_spans=16, num_samples=1): def format_to_infill(tokens, num_spans, tokenizer, rng): # Choose at leat 1 span - #num_spans = rng.poisson(1.0) + 1 - #num_spans = np.min([num_spans, max_num_spans]) + # num_spans = rng.poisson(1.0) + 1 + # num_spans = np.min([num_spans, max_num_spans]) - # Based on the number of spans, a span can + # Based on the number of spans, a span can # have a size of maximum (len(text) - num_spans) // num_spans # This assumes we want at least one letter between two spans. max_len = (len(tokens) - num_spans) // num_spans @@ -52,7 +54,7 @@ def format_to_infill(tokens, num_spans, tokenizer, rng): return None # Set first start and end index based `max_len` - max_start_idx = 1 # Very first letter will be skipped + max_start_idx = 1 # Very first letter will be skipped max_end_idx = max_len prefix_tokens = [] @@ -60,12 +62,7 @@ def format_to_infill(tokens, num_spans, tokenizer, rng): for span_idx in range(num_spans): # Randomly sample the length of the span - try: - sampled_length = rng.integers(1, max_len + 1) - except: - print("Error") - print(len(tokens), num_spans, max_len) - import pdb; pdb.set_trace() + sampled_length = rng.integers(1, max_len + 1) # Define low and high to sample a start position. # The first `+ 1` is due to `.integers` expecting a value @@ -83,7 +80,7 @@ def format_to_infill(tokens, num_spans, tokenizer, rng): suffix_tokens += tokenizer.encode(f'') + span + tokenizer.encode('') # Update start and end indices - max_start_idx = chosen_start + sampled_length + 1 + max_start_idx = chosen_start + sampled_length + 1 max_end_idx = chosen_start + sampled_length + max_len # Append leftover to prefix string @@ -92,6 +89,7 @@ def format_to_infill(tokens, num_spans, tokenizer, rng): return prefix_tokens + tokenizer.encode('<|endoftext|>') + suffix_tokens + def main(args): block_size = args.block_size tokenizer_path = os.path.expanduser(args.tokenizer_model) @@ -193,6 +191,7 @@ def group_texts(examples): test_dataset.save_to_disk(test_save_path) valid_dataset.save_to_disk(valid_save_path) + if __name__ == '__main__': args = get_args() - main(args) \ No newline at end of file + main(args) diff --git a/examples/training/codegen25/get_dataset_infill.sh b/examples/training/codegen25/get_dataset_infill.sh index f324fe0..2e1664f 100644 --- a/examples/training/codegen25/get_dataset_infill.sh +++ b/examples/training/codegen25/get_dataset_infill.sh @@ -7,4 +7,4 @@ python get_dataset_infill.py \ --json-keys=content \ --tokenizer-model=$TOKENIZER \ --output_path="~/example_datasets/bigcode-stack-java_tokenized_infill" \ - --block_size=2048 \ No newline at end of file + --block_size=2048 diff --git a/examples/training/codegen25/tp_zero1_codegen25_7b_hf_pretrain.sh b/examples/training/codegen25/tp_zero1_codegen25_7b_hf_pretrain.sh index 8460df5..d3f3ceb 100644 --- a/examples/training/codegen25/tp_zero1_codegen25_7b_hf_pretrain.sh +++ b/examples/training/codegen25/tp_zero1_codegen25_7b_hf_pretrain.sh @@ -121,4 +121,4 @@ torchrun $DISTRIBUTED_ARGS \ --seq_len $SEQ_LEN \ --sequence_parallel_enabled \ --selective_checkpoint_enabled \ - $EXTRA_ARGS |& tee $OUTPUT_LOG \ No newline at end of file + $EXTRA_ARGS |& tee $OUTPUT_LOG diff --git a/examples/training/llama/convert_checkpoints.py b/examples/training/llama/convert_checkpoints.py index 5995ca2..838aa4c 100644 --- a/examples/training/llama/convert_checkpoints.py +++ b/examples/training/llama/convert_checkpoints.py @@ -1,6 +1,6 @@ import argparse -from checkpoint_converter import CheckpointConverterBase +from neuronx_distributed.scripts.checkpoint_converter import CheckpointConverterBase class CheckpointConverterLlama(CheckpointConverterBase): diff --git a/examples/training/llama/get_dataset.py b/examples/training/llama/get_dataset.py index 45d43b4..d46b5d8 100644 --- a/examples/training/llama/get_dataset.py +++ b/examples/training/llama/get_dataset.py @@ -36,7 +36,7 @@ def tokenize_function(examples): return tokenizer(examples[text_column_name]) - + tokenized_datasets = raw_datasets.map( tokenize_function, diff --git a/examples/training/llama/lightning/module_llama.py b/examples/training/llama/lightning/module_llama.py index 9c5a9e1..c7e5796 100644 --- a/examples/training/llama/lightning/module_llama.py +++ b/examples/training/llama/lightning/module_llama.py @@ -69,7 +69,7 @@ def get_model(model_config): print("Failed to import optimum-neuron dependency, generation will not work on Neuron.") # Load Pretrained checkpoint if hasattr(self.model_args[0], "pretrained_ckpt") and self.model_args[0].pretrained_ckpt: - user_content = nxd.load_checkpoint( + nxd.load_checkpoint( self.model_args[0].pretrained_ckpt, tag="pretrained_weight", model=self.model, diff --git a/examples/training/llama/lightning/run_llama_7b_tp_ptl.sh b/examples/training/llama/lightning/run_llama_7b_tp_ptl.sh index d0a5a42..d24a9a4 100755 --- a/examples/training/llama/lightning/run_llama_7b_tp_ptl.sh +++ b/examples/training/llama/lightning/run_llama_7b_tp_ptl.sh @@ -106,7 +106,7 @@ torchrun $DISTRIBUTED_ARGS \ --data_dir $DATA_PATH \ --tensor_parallel_size $TP_DEGREE \ --train_batch_size $MBS \ - --steps_this_run $STEPS_THIS_RUN\ + --steps_this_run $STEPS_THIS_RUN \ --max_steps $TOTAL_STEPS \ --warmup_steps $WARMUP_STEPS \ --lr $LR \ diff --git a/examples/training/llama/lightning/run_llama_nxd_ptl.py b/examples/training/llama/lightning/run_llama_nxd_ptl.py index 3b1995b..c801044 100644 --- a/examples/training/llama/lightning/run_llama_nxd_ptl.py +++ b/examples/training/llama/lightning/run_llama_nxd_ptl.py @@ -21,6 +21,24 @@ import torch import torch_xla.core.xla_model as xm +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from transformers import LlamaConfig, set_seed +# Workaround for NaNs seen with transformers version >= 4.21.0 +# https://github.com/aws-neuron/aws-neuron-sdk/issues/593 +import transformers.modeling_utils as modeling_utils +from transformers.optimization import get_linear_schedule_with_warmup + +import neuronx_distributed as nxd +from neuronx_distributed.lightning import ( + NeuronTensorBoardLogger, + NeuronTQDMProgressBar, + NeuronXLAPrecisionPlugin, + NeuronXLAStrategy, +) +from neuronx_distributed.parallel_layers import mappings +from neuronx_distributed.utils.adamw_fp32_optim_params import AdamW_FP32OptimParams + from data_module import NeuronLightningDataModule from modeling_llama_nxd import ( CoreAttention, @@ -30,8 +48,6 @@ init_weights, ) from module_llama import NeuronLlamaLTModule -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint from training_utils import ( create_llama_pretraining_dataset, get_learning_rate_scheduler, @@ -39,26 +55,10 @@ get_param_groups_by_weight_decay, get_sin_cos_matrix, ) -from transformers import LlamaConfig, set_seed -from transformers.optimization import get_linear_schedule_with_warmup - -import neuronx_distributed as nxd -from neuronx_distributed.lightning import ( - NeuronTensorBoardLogger, - NeuronTQDMProgressBar, - NeuronXLAPrecisionPlugin, - NeuronXLAStrategy, -) -from neuronx_distributed.parallel_layers import mappings -from neuronx_distributed.utils.adamw_fp32_optim_params import AdamW_FP32OptimParams # For PT autocast. torch.cuda.is_bf16_supported = lambda: True -# Workaround for NaNs seen with transformers version >= 4.21.0 -# https://github.com/aws-neuron/aws-neuron-sdk/issues/593 -import transformers.modeling_utils as modeling_utils - if os.environ.get("XLA_USE_BF16") or os.environ.get("XLA_DOWNCAST_BF16"): modeling_utils.get_parameter_dtype = lambda x: torch.bfloat16 @@ -103,6 +103,7 @@ def train_llama(args): "use_zero1_optimizer": args.use_zero1_optimizer > 0, "use_optimizer_wrapper": True, "broadcast_and_average_loss": args.log_rank0 > 0, + "fuse_microbatches": args.fuse_microbatches > 0, } # Create model with different options @@ -223,7 +224,7 @@ def configure_scheduler(optimizer, warmup_steps, max_steps): # PTLTODO: check l else: trainer.fit(model=model, datamodule=dm) - print(f"Training finished!") + print("Training finished!") def _mp_fn(index, args): @@ -376,10 +377,16 @@ def _mp_fn(index, args): ) parser.add_argument( "--use_gpu_compatible_precision", - default=0, + default=1, type=int, help="Use gpu compatible precision", ) + parser.add_argument( + "--fuse_microbatches", + type=int, + default=0, + help="Fuse microbatches into a single graph" + ) args = parser.parse_args(sys.argv[1:]) diff --git a/examples/training/llama/lightning/tp_zero1_llama2_7b_hf_finetune_ptl.py b/examples/training/llama/lightning/tp_llama_hf_finetune_ptl.py similarity index 94% rename from examples/training/llama/lightning/tp_zero1_llama2_7b_hf_finetune_ptl.py rename to examples/training/llama/lightning/tp_llama_hf_finetune_ptl.py index a04a164..70b3465 100644 --- a/examples/training/llama/lightning/tp_zero1_llama2_7b_hf_finetune_ptl.py +++ b/examples/training/llama/lightning/tp_llama_hf_finetune_ptl.py @@ -21,11 +21,6 @@ import sys import numpy as np - -current = os.path.dirname(os.path.realpath(__file__)) -parent = os.path.dirname(current) -sys.path.append(parent) - import torch import torch_xla.core.xla_model as xm from data_module import NeuronLightningDataModule @@ -35,9 +30,8 @@ from pytorch_lightning.callbacks import ModelCheckpoint from torchmetrics.text.rouge import ROUGEScore from training_utils import create_instruction_based_dataset, get_mixed_precision_config -from transformers import AdamW, LlamaConfig, LlamaTokenizer, set_seed +from transformers import AdamW, LlamaConfig, LlamaTokenizer, set_seed, AutoTokenizer from transformers.optimization import get_linear_schedule_with_warmup - import neuronx_distributed as nxd from neuronx_distributed.lightning import ( NeuronTensorBoardLogger, @@ -49,13 +43,13 @@ from neuronx_distributed.parallel_layers import parallel_state from neuronx_distributed.utils.adamw_fp32_optim_params import AdamW_FP32OptimParams -# For PT autocast. -torch.cuda.is_bf16_supported = lambda: True - # Workaround for NaNs seen with transformers version >= 4.21.0 # https://github.com/aws-neuron/aws-neuron-sdk/issues/593 import transformers.modeling_utils as modeling_utils +# For PT autocast. +torch.cuda.is_bf16_supported = lambda: True + if os.environ.get("XLA_USE_BF16") or os.environ.get("XLA_DOWNCAST_BF16"): modeling_utils.get_parameter_dtype = lambda x: torch.bfloat16 @@ -64,6 +58,7 @@ def train_llama(flags): print(f"Namespace: {flags}") set_seed(flags.seed) + target_modules = ["q_proj", "v_proj", "k_proj"] if flags.qkv_linear == 0 else ["qkv_proj"] lora_config = LoraConfig( enable_lora=flags.enable_lora, lora_rank=16, @@ -71,7 +66,7 @@ def train_llama(flags): lora_dropout=0.05, bias="none", lora_verbose=True, - target_modules=["q_proj", "v_proj", "k_proj"], + target_modules=target_modules, ) mixed_precision_config = get_mixed_precision_config(flags.use_gpu_compatible_precision > 0) @@ -89,8 +84,8 @@ def train_llama(flags): model_config.pretrained_ckpt = flags.pretrained_ckpt model_config.use_cache = False model_config.separate_qkv = flags.separate_qkv - model_config.kv_shared_group_size = args.kv_replicator - model_config.qkv_linear = args.qkv_linear + model_config.kv_shared_group_size = flags.kv_replicator + model_config.qkv_linear = flags.qkv_linear model_config.max_position_embeddings = max(model_config.max_position_embeddings, flags.seq_len) if flags.num_layers > 0: model_config.num_hidden_layers = flags.num_layers @@ -113,9 +108,7 @@ def configure_scheduler(optimizer, warmup_steps, max_steps): # PTLTODO: check l last_epoch=-1, ) - BASE_MODEL = "NousResearch/Llama-2-7b-hf" - tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL) - + tokenizer = AutoTokenizer.from_pretrained(flags.model_name, token=flags.hf_token) model = NeuronLlamaLTModule( tokenizer=tokenizer, model_fn=LlamaForCausalLM, @@ -181,7 +174,7 @@ def configure_scheduler(optimizer, warmup_steps, max_steps): # PTLTODO: check l else: trainer.fit(model=model, datamodule=dm) - xm.master_print(f"Training finished!") + xm.master_print("Training finished!") if flags.do_eval and not os.environ.get("NEURON_EXTRACT_GRAPHS_ONLY", None): evaluate(model, tokenizer, dm.test_dataloader(), args.golden_rouge_score_path) @@ -215,7 +208,7 @@ def evaluate(model, tokenizer, test_loader, golden_rouge_score_path): label_text = tokenizer.decode(labels[0].cpu(), clean_up_tokenization_spaces=True) rouge.update(predicted_text, label_text) if parallel_state.get_tensor_model_parallel_rank() == 0: - print(f"=== PROMPT ===") + print("=== PROMPT ===") print(prompt) print("=== GENERATED SEQUENCE ===") print(predicted_text) @@ -254,6 +247,18 @@ def _mp_fn(index, flags): type=str, help="Model weight and config path.", ) + parser.add_argument( + "--model_name", + type=str, + default='meta-llama/Meta-Llama-3-8B', + help="Base model name.", + ) + parser.add_argument( + "--hf_token", + type=str, + default=None, + help="Huggingface token to access base model and tokenizer.", + ) parser.add_argument( "--data_dir", type=str, @@ -379,7 +384,7 @@ def _mp_fn(index, flags): ) parser.add_argument( "--use_gpu_compatible_precision", - default=0, + default=1, type=int, help="Use gpu compatible precision", ) diff --git a/examples/training/llama/lightning/tp_zero1_llama2_7b_hf_finetune_ptl.sh b/examples/training/llama/lightning/tp_zero1_llama2_7b_hf_finetune_ptl.sh index 5c7005a..8ec71ca 100755 --- a/examples/training/llama/lightning/tp_zero1_llama2_7b_hf_finetune_ptl.sh +++ b/examples/training/llama/lightning/tp_zero1_llama2_7b_hf_finetune_ptl.sh @@ -37,6 +37,8 @@ WARMUP_STEPS=20 LR=5.0e-5 # model path MODEL_PATH=$SCRIPT_DIR +# base model name +BASE_MODEL="NousResearch/Llama-2-7b-hf" # sequence length SEQ_LEN=4096 # Path to dataset @@ -116,8 +118,9 @@ echo STEPS_THIS_RUN=$STEPS_THIS_RUN echo OUTPUT_LOG=$OUTPUT_LOG torchrun $DISTRIBUTED_ARGS \ - tp_zero1_llama2_7b_hf_finetune_ptl.py \ + tp_llama_hf_finetune_ptl.py \ --model_path $MODEL_PATH \ + --model_name $BASE_MODEL \ --data_dir $DATA_PATH \ --task "open_qa" \ --tensor_parallel_size $TP_DEGREE \ diff --git a/examples/training/llama/modeling_llama_nxd.py b/examples/training/llama/modeling_llama_nxd.py index bbbf499..f1ec2f6 100644 --- a/examples/training/llama/modeling_llama_nxd.py +++ b/examples/training/llama/modeling_llama_nxd.py @@ -764,19 +764,11 @@ def init_weights(module): Re-init weights after partition Referred from HF transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L690 """ - if isinstance(module, torch.nn.Linear): - module.weight.data.normal_(mean=0.0, std=config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, torch.nn.Embedding): - module.weight.data.normal_(mean=0.0, std=config.initializer_range) - if module.padding_idx: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, LlamaRMSNorm): + if isinstance(module, LlamaRMSNorm): module.weight.data.fill_(1.0) elif isinstance(module, (ParallelEmbedding, RowParallelLinear, ColumnParallelLinear)): module.init_weight_cpu() if hasattr(module, "bias") and module.bias is not None: module.bias.data.zero_() elif isinstance(module, GQAQKVColumnParallelLinear): - module.initialize_weight_biases() \ No newline at end of file + module.initialize_weight_biases() diff --git a/examples/training/llama/tp_pp_llama_hf_pretrain/70B_config_llama2/config.json b/examples/training/llama/tp_pp_llama_hf_pretrain/70B_config_llama2/config.json index d096b4a..7b1797e 100644 --- a/examples/training/llama/tp_pp_llama_hf_pretrain/70B_config_llama2/config.json +++ b/examples/training/llama/tp_pp_llama_hf_pretrain/70B_config_llama2/config.json @@ -26,4 +26,3 @@ "selective_checkpoint_enabled": false, "move_model_to_device":false } - diff --git a/examples/training/llama/tp_pp_llama_hf_pretrain/70B_config_llama3.1/config.json b/examples/training/llama/tp_pp_llama_hf_pretrain/70B_config_llama3.1/config.json index 1c76c52..9546f1c 100644 --- a/examples/training/llama/tp_pp_llama_hf_pretrain/70B_config_llama3.1/config.json +++ b/examples/training/llama/tp_pp_llama_hf_pretrain/70B_config_llama3.1/config.json @@ -1,32 +1,32 @@ { - "architectures": [ - "LlamaForCausalLM" - ], - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 128000, - "eos_token_id": 128001, - "hidden_act": "silu", - "hidden_size": 8192, - "initializer_range": 0.02, - "intermediate_size": 28672, - "max_position_embeddings": 131072, - "mlp_bias": false, - "model_type": "llama", - "num_attention_heads": 64, - "num_hidden_layers": 80, - "num_key_value_heads": 8, - "pad_token_id": 0, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "rope_theta": 500000.0, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.31.0", - "use_cache": true, - "vocab_size": 128256, - "sequence_parallel_enabled": false, - "selective_checkpoint_enabled": false, - "move_model_to_device":false - } \ No newline at end of file + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 28672, + "max_position_embeddings": 131072, + "mlp_bias": false, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "pad_token_id": 0, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.31.0", + "use_cache": true, + "vocab_size": 128256, + "sequence_parallel_enabled": false, + "selective_checkpoint_enabled": false, + "move_model_to_device":false +} diff --git a/examples/training/llama/tp_pp_llama_hf_pretrain/70B_config_llama3/config.json b/examples/training/llama/tp_pp_llama_hf_pretrain/70B_config_llama3/config.json index 86c4863..cbfd6ab 100644 --- a/examples/training/llama/tp_pp_llama_hf_pretrain/70B_config_llama3/config.json +++ b/examples/training/llama/tp_pp_llama_hf_pretrain/70B_config_llama3/config.json @@ -28,4 +28,3 @@ "selective_checkpoint_enabled": false, "move_model_to_device":false } - \ No newline at end of file diff --git a/examples/training/llama/tp_pp_llama_hf_pretrain/activation_checkpoint.py b/examples/training/llama/tp_pp_llama_hf_pretrain/activation_checkpoint.py index 641db47..cac319c 100644 --- a/examples/training/llama/tp_pp_llama_hf_pretrain/activation_checkpoint.py +++ b/examples/training/llama/tp_pp_llama_hf_pretrain/activation_checkpoint.py @@ -93,7 +93,7 @@ def apply_checkpoint(dist_model, layers_to_checkpoint=None): # checkpoint layers that are provided in input # if layers not provide in input, then checkpoint if it is transformer layer if (layers_to_checkpoint and name in layers_to_checkpoint) or ( - not layers_to_checkpoint and type(module) == dist_model.transformer_layer_cls + not layers_to_checkpoint and type(module) is dist_model.transformer_layer_cls ): # add_module replaces old module with our own custom module. # https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.add_module @@ -104,6 +104,8 @@ def apply_checkpoint(dist_model, layers_to_checkpoint=None): elif layers_to_checkpoint is None and not checkpoint_wrapper_added: logger.warning( rmsg( - f"During applying activation checkpointing, transformer_layer_cls {dist_model.transformer_layer_cls.__name__} can not be found in stage {dist_model.pipeline_parallel_rank}, skipping..." + f"During applying activation checkpointing, transformer_layer_cls " + f"{dist_model.transformer_layer_cls.__name__} cannot be found in stage " + f"{dist_model.pipeline_parallel_rank}, skipping ..." ) ) diff --git a/examples/training/llama/tp_pp_llama_hf_pretrain/run_llama2_13B_tp_pp.sh b/examples/training/llama/tp_pp_llama_hf_pretrain/run_llama2_13B_tp_pp.sh index 130d72e..7dd8172 100644 --- a/examples/training/llama/tp_pp_llama_hf_pretrain/run_llama2_13B_tp_pp.sh +++ b/examples/training/llama/tp_pp_llama_hf_pretrain/run_llama2_13B_tp_pp.sh @@ -59,7 +59,7 @@ echo $DISTRIBUTED_ARGS # Input sequence length SEQ_LEN=4096 # Pipeline parallel degree -PP_DEGREE=4 +: ${PP_DEGREE:=4} # Tensor parallel degree TP_DEGREE=8 # : paralell size diff --git a/examples/training/llama/tp_pp_llama_hf_pretrain/run_llama2_70B_tp_pp.sh b/examples/training/llama/tp_pp_llama_hf_pretrain/run_llama2_70B_tp_pp.sh index 8e6b62d..890b7d7 100755 --- a/examples/training/llama/tp_pp_llama_hf_pretrain/run_llama2_70B_tp_pp.sh +++ b/examples/training/llama/tp_pp_llama_hf_pretrain/run_llama2_70B_tp_pp.sh @@ -66,7 +66,7 @@ echo $DISTRIBUTED_ARGS # Input sequence length SEQ_LEN=4096 # Pipeline parallel degree -PP_DEGREE=4 +: ${PP_DEGREE:=8} # Virtual pipeline degree : ${VPP_DEGREE:=1} # Tensor parallel degree @@ -78,6 +78,10 @@ BS=$(($GBS / $DP)) # Number microbatches for pipeline execution # Setting same as BS so each microbatch contains a single datasample NUM_MICROBATCHES=$BS + +# Turn on the GPU compatible precision by default +: ${GPU_COMPATIBLE_PRECISION:=1} + DATA_PATH="$HOME/examples_datasets/wikicorpus_llama${LLAMA_VERSION}_tokenized_4k" @@ -117,5 +121,6 @@ torchrun $DISTRIBUTED_ARGS run_llama_nxd.py \ --use_selective_checkpoint 1 \ --qkv_linear 1 \ --kv_replicator 4 \ + --use_gpu_compatible_precision $GPU_COMPATIBLE_PRECISION \ --tb_dir $tb_dir |& tee $LOG_PATH/log exit ${PIPESTATUS[0]} diff --git a/examples/training/llama/tp_pp_llama_hf_pretrain/run_llama31_70B_tp_pp.sh b/examples/training/llama/tp_pp_llama_hf_pretrain/run_llama31_70B_tp_pp.sh new file mode 100755 index 0000000..e7e5737 --- /dev/null +++ b/examples/training/llama/tp_pp_llama_hf_pretrain/run_llama31_70B_tp_pp.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +LLAMA_CONFIG_VERISON="3.1" exec ./run_llama3_70B_tp_pp.sh diff --git a/examples/training/llama/tp_pp_llama_hf_pretrain/run_llama3_70B_tp_pp.sh b/examples/training/llama/tp_pp_llama_hf_pretrain/run_llama3_70B_tp_pp.sh index fa3634d..f7b0df6 100755 --- a/examples/training/llama/tp_pp_llama_hf_pretrain/run_llama3_70B_tp_pp.sh +++ b/examples/training/llama/tp_pp_llama_hf_pretrain/run_llama3_70B_tp_pp.sh @@ -6,6 +6,7 @@ sudo sysctl -w net.ipv4.ip_local_reserved_ports=44000 MODEL_SIZE="70B" LLAMA_VERSION='3' +: ${LLAMA_CONFIG_VERISON:=3} export FI_EFA_USE_DEVICE_RDMA=1 export FI_PROVIDER=efa @@ -87,7 +88,7 @@ torchrun $DISTRIBUTED_ARGS run_llama_nxd.py \ --train_batch_size $BS \ --use_meta_device_init 1 \ --training_dir $DATA_PATH \ - --training_config $SCRIPT_DIR/${MODEL_SIZE}_config_llama${LLAMA_VERSION} \ + --training_config $SCRIPT_DIR/${MODEL_SIZE}_config_llama${LLAMA_CONFIG_VERISON} \ --max_steps $max_steps \ --seq_len $SEQ_LEN \ --pipeline_parallel_size $PP_DEGREE \ @@ -113,4 +114,3 @@ torchrun $DISTRIBUTED_ARGS run_llama_nxd.py \ --loading_step -1 \ --tb_dir $tb_dir |& tee $LOG_PATH/log exit ${PIPESTATUS[0]} - diff --git a/examples/training/llama/tp_pp_llama_hf_pretrain/run_llama_nxd.py b/examples/training/llama/tp_pp_llama_hf_pretrain/run_llama_nxd.py index 71d9e3f..b46804e 100644 --- a/examples/training/llama/tp_pp_llama_hf_pretrain/run_llama_nxd.py +++ b/examples/training/llama/tp_pp_llama_hf_pretrain/run_llama_nxd.py @@ -134,6 +134,7 @@ def train_llama(args): "use_zero1_optimizer": args.use_zero1_optimizer > 0, "use_optimizer_wrapper": True, "deallocate_pipeline_outputs": args.deallocate_pipeline_outputs > 0, + "fuse_microbatches": args.fuse_microbatches > 0, }, optimizer_config={ "zero_one_enabled": args.use_zero1_optimizer > 0, @@ -167,8 +168,8 @@ def get_model(config): # Create NxD model model = nxd.initialize_parallel_model(nxd_config, get_model, config) + world_size = parallel_state.get_data_parallel_size() - # model_dtype = get_dtype(model) param_groups = get_param_groups_by_weight_decay(model) @@ -283,11 +284,9 @@ def get_loading_tag(args): attention_mask=attention_mask, labels=labels, ) - total_steps += 1 optimizer.step() global_norm = optimizer.grad_norm # Global norm before clipping optimizer.zero_grad() - lr_scheduler.step() if should_print: if total_steps % args.logging_interval == 0: xm.add_step_closure( @@ -302,6 +301,8 @@ def get_loading_tag(args): start, ), ) + total_steps += 1 + lr_scheduler.step() xm.mark_step() # Saving checkpoints if (args.checkpoint_freq > 0) and (total_steps % args.checkpoint_freq == 0): @@ -416,11 +417,17 @@ def _mp_fn(index, args): ) parser.add_argument( "--use_gpu_compatible_precision", - default=0, + default=1, type=int, help="Use gpu compatible precision", ) parser.add_argument("--metrics_file", type=str, default="results.json", help="training metrics results file") + parser.add_argument( + "--fuse_microbatches", + type=int, + default=0, + help="Fuse microbatches into a single graph" + ) # optimization opt_grp = parser.add_argument_group(title="optimization", description="arguments for optimization") opt_grp.add_argument("--weight_decay", default=0.01, type=float, help="weight decay") diff --git a/examples/training/llama/tp_zero1_llama_hf_pretrain/7B_config_llama2/config.json b/examples/training/llama/tp_zero1_llama_hf_pretrain/7B_config_llama2/config.json index 621797d..d4458ce 100644 --- a/examples/training/llama/tp_zero1_llama_hf_pretrain/7B_config_llama2/config.json +++ b/examples/training/llama/tp_zero1_llama_hf_pretrain/7B_config_llama2/config.json @@ -27,4 +27,3 @@ "selective_checkpoint_enabled": false, "move_model_to_device":true } - diff --git a/examples/training/llama/tp_zero1_llama_hf_pretrain/8B_config_llama3.1/config.json b/examples/training/llama/tp_zero1_llama_hf_pretrain/8B_config_llama3.1/config.json index ecb7d77..535f383 100644 --- a/examples/training/llama/tp_zero1_llama_hf_pretrain/8B_config_llama3.1/config.json +++ b/examples/training/llama/tp_zero1_llama_hf_pretrain/8B_config_llama3.1/config.json @@ -1,33 +1,32 @@ - { - "architectures": [ - "LlamaForCausalLM" - ], - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 128000, - "eos_token_id": 128001, - "hidden_act": "silu", - "hidden_size": 4096, - "initializer_range": 0.02, - "intermediate_size": 14336, - "max_position_embeddings": 131072, - "mlp_bias": false, - "model_type": "llama", - "num_attention_heads": 32, - "num_hidden_layers": 32, - "num_key_value_heads": 8, - "pad_token_id": 0, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "rope_theta": 500000.0, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.31.0", - "use_cache": true, - "vocab_size": 128256, - "sequence_parallel_enabled": false, - "selective_checkpoint_enabled": false, - "move_model_to_device":false - } \ No newline at end of file + "architectures": [ + "LlamaForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 131072, + "mlp_bias": false, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pad_token_id": 0, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": null, + "rope_theta": 500000.0, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.31.0", + "use_cache": true, + "vocab_size": 128256, + "sequence_parallel_enabled": false, + "selective_checkpoint_enabled": false, + "move_model_to_device":false +} diff --git a/examples/training/llama/tp_zero1_llama_hf_pretrain/8B_config_llama3/config.json b/examples/training/llama/tp_zero1_llama_hf_pretrain/8B_config_llama3/config.json index 807ca1a..43556c8 100644 --- a/examples/training/llama/tp_zero1_llama_hf_pretrain/8B_config_llama3/config.json +++ b/examples/training/llama/tp_zero1_llama_hf_pretrain/8B_config_llama3/config.json @@ -28,4 +28,4 @@ "sequence_parallel_enabled": false, "selective_checkpoint_enabled": false, "move_model_to_device":false -} \ No newline at end of file +} diff --git a/examples/training/llama/tp_zero1_llama_hf_pretrain/logger.py b/examples/training/llama/tp_zero1_llama_hf_pretrain/logger.py index 9f13900..4a120bb 100644 --- a/examples/training/llama/tp_zero1_llama_hf_pretrain/logger.py +++ b/examples/training/llama/tp_zero1_llama_hf_pretrain/logger.py @@ -52,7 +52,7 @@ def get_instance_type(self): headers={"X-aws-ec2-metadata-token": token.text}, ) return data.text - except: + except Exception: return os.environ.get("HOSTNAME", "unknown") def log(self, epoch, step, step_loss, learning_rate, throughput, grad_norm=None): @@ -60,7 +60,8 @@ def log(self, epoch, step, step_loss, learning_rate, throughput, grad_norm=None) grad_norm_msg = f"grad-norm : {grad_norm}" if grad_norm else "" print( f"LOG {time_now} - ({epoch}, {step}) step_loss : {step_loss:.4f} " - f"learning_rate : {learning_rate:.2e} throughput : {throughput:.2f} " + f"learning_rate : {learning_rate:.2e} " + f"throughput : {throughput:.2f} seq/s " f"{grad_norm_msg}", flush=True, ) diff --git a/examples/training/llama/tp_zero1_llama_hf_pretrain/tp_zero1_llama2_7B_hf_pretrain.sh b/examples/training/llama/tp_zero1_llama_hf_pretrain/tp_zero1_llama2_7B_hf_pretrain.sh index 80c96ec..3a4b085 100644 --- a/examples/training/llama/tp_zero1_llama_hf_pretrain/tp_zero1_llama2_7B_hf_pretrain.sh +++ b/examples/training/llama/tp_zero1_llama_hf_pretrain/tp_zero1_llama2_7B_hf_pretrain.sh @@ -32,7 +32,7 @@ TOTAL_STEPS=10000 # warmup steps WARMUP_STEPS=100 # learning rate -LR=3.0e-4 +LR=1.5e-4 # model path MODEL_PATH=$SCRIPT_DIR/${MODEL_SIZE}_config_llama${LLAMA_VERSION} # data path @@ -52,7 +52,7 @@ if [ ! -z "$SLURM_NTASKS" ]; then MASTER_ADDRESS=(`scontrol show hostnames $SLURM_JOB_NODELIST`) DISTRIBUTED_ARGS="--nproc_per_node $NUM_NEURONCORES --nnodes $WORLD_SIZE --node_rank $NODE_ID --master_addr $MASTER_ADDRESS --master_port 44000" if [ $NODE_ID -eq 0 ]; then - echo "WORLD_SIZE=$WORLD_SIZE" + echo "WORLD_SLURM_NTASKS=$WORLD_SIZE" echo "NODE_ID=$NODE_ID" echo "MASTER_ADDRESS=$MASTER_ADDRESS" echo "DISTRIBUTED_ARGS=$DISTRIBUTED_ARGS" @@ -61,7 +61,7 @@ if [ ! -z "$SLURM_NTASKS" ]; then export FI_PROVIDER=efa fi -echo "WORLD_SIZE=$WORLD_SIZE" +echo "WORLD_SLURM_NTASKS=$WORLD_SIZE" echo "NODE_ID=$NODE_ID" echo "MASTER_ADDRESS=$MASTER_ADDRESS" @@ -130,5 +130,6 @@ torchrun $DISTRIBUTED_ARGS \ --sequence_parallel_enabled \ --selective_checkpoint_enabled \ --logging_interval 10 \ + --qkv_linear \ $EXTRA_ARGS |& tee $OUTPUT_LOG exit ${PIPESTATUS[0]} diff --git a/examples/training/llama/tp_zero1_llama_hf_pretrain/tp_zero1_llama3_8B_hf_pretrain.sh b/examples/training/llama/tp_zero1_llama_hf_pretrain/tp_zero1_llama3_8B_hf_pretrain.sh index 2be0234..92274a0 100644 --- a/examples/training/llama/tp_zero1_llama_hf_pretrain/tp_zero1_llama3_8B_hf_pretrain.sh +++ b/examples/training/llama/tp_zero1_llama_hf_pretrain/tp_zero1_llama3_8B_hf_pretrain.sh @@ -52,7 +52,7 @@ if [ ! -z "$SLURM_NTASKS" ]; then MASTER_ADDRESS=(`scontrol show hostnames $SLURM_JOB_NODELIST`) DISTRIBUTED_ARGS="--nproc_per_node $NUM_NEURONCORES --nnodes $WORLD_SIZE --node_rank $NODE_ID --master_addr $MASTER_ADDRESS --master_port 44000" if [ $NODE_ID -eq 0 ]; then - echo "WORLD_SIZE=$WORLD_SIZE" + echo "WORLD_SLURM_NTASKS=$WORLD_SIZE" echo "NODE_ID=$NODE_ID" echo "MASTER_ADDRESS=$MASTER_ADDRESS" echo "DISTRIBUTED_ARGS=$DISTRIBUTED_ARGS" @@ -61,7 +61,7 @@ if [ ! -z "$SLURM_NTASKS" ]; then export FI_PROVIDER=efa fi -echo "WORLD_SIZE=$WORLD_SIZE" +echo "WORLD_SLURM_NTASKS=$WORLD_SIZE" echo "NODE_ID=$NODE_ID" echo "MASTER_ADDRESS=$MASTER_ADDRESS" @@ -129,9 +129,9 @@ torchrun $DISTRIBUTED_ARGS \ --seq_len $SEQ_LEN \ --sequence_parallel_enabled \ --selective_checkpoint_enabled \ - --logging_interval 10 \ + --logging_interval 1 \ --qkv_linear \ --kv_replicator 4 \ --use_flash_attention 1 \ $EXTRA_ARGS |& tee $OUTPUT_LOG -exit ${PIPESTATUS[0]} \ No newline at end of file +exit ${PIPESTATUS[0]} diff --git a/examples/training/llama/tp_zero1_llama_hf_pretrain/tp_zero1_llama_hf_pretrain.py b/examples/training/llama/tp_zero1_llama_hf_pretrain/tp_zero1_llama_hf_pretrain.py index 37b743a..6d2c3ee 100644 --- a/examples/training/llama/tp_zero1_llama_hf_pretrain/tp_zero1_llama_hf_pretrain.py +++ b/examples/training/llama/tp_zero1_llama_hf_pretrain/tp_zero1_llama_hf_pretrain.py @@ -31,28 +31,23 @@ import torch_xla.distributed.parallel_loader as pl import torch_xla.distributed.xla_multiprocessing as xmp from logger import Logger -from modeling_llama_nxd import CoreAttention, LlamaForCausalLM -from training_utils import Throughput, create_llama_pretraining_dataset, get_mixed_precision_config +from modeling_llama_nxd import CoreAttention, LlamaForCausalLM, init_weights +from training_utils import Throughput, create_llama_pretraining_dataset, get_mixed_precision_config, get_sin_cos_matrix from transformers import AdamW, LlamaConfig, set_seed from transformers.optimization import get_linear_schedule_with_warmup import neuronx_distributed as nxd -from neuronx_distributed.parallel_layers import ( - checkpointing, - grads, - layers, - parallel_state, -) +from neuronx_distributed.parallel_layers import parallel_state from neuronx_distributed.parallel_layers.utils import requires_init_pg_override from neuronx_distributed.utils.adamw_fp32_optim_params import AdamW_FP32OptimParams -# For PT autocast. -torch.cuda.is_bf16_supported = lambda: True - # Workaround for NaNs seen with transformers version >= 4.21.0 # https://github.com/aws-neuron/aws-neuron-sdk/issues/593 import transformers.modeling_utils as modeling_utils +# For PT autocast. +torch.cuda.is_bf16_supported = lambda: True + if os.environ.get("XLA_USE_BF16") or os.environ.get("XLA_DOWNCAST_BF16"): modeling_utils.get_parameter_dtype = lambda x: torch.bfloat16 @@ -153,16 +148,6 @@ def get_model(flags): xm.master_print(config) model = LlamaForCausalLM(config) - def get_sin_cos_matrix(config): - head_dim = config.hidden_size // config.num_attention_heads - base = 10000 - inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) - t = torch.arange(config.max_position_embeddings, dtype=inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - return emb.cos()[None, None, :, :].to(torch.float32), emb.sin()[None, None, :, :].to(torch.float32) - # Here we make sure we use the same sine and cosine matrices for all layers. # Making use of same tensors would make the CSE algorithm eliminate the lookup call # from layers, keeping only lookup from first layer. @@ -642,7 +627,7 @@ def _mp_fn(index, flags): ) parser.add_argument( "--use_gpu_compatible_precision", - default=0, + default=1, type=int, help="Use gpu compatible precision", ) diff --git a/examples/training/llama/training_utils.py b/examples/training/llama/training_utils.py index 4a67723..c2095f1 100644 --- a/examples/training/llama/training_utils.py +++ b/examples/training/llama/training_utils.py @@ -76,7 +76,7 @@ def get_learning_rate_scheduler(optimizer, args, last_epoch=-1): def get_param_groups_by_weight_decay(model): """Get param groups.""" - if hasattr(model, "local_named_parameters"): + if hasattr(model, "local_named_parameters") and hasattr(model, "partitioned") and model.partitioned: # Zero1 use the first param in opt to decide the device param_optimizer = list(model.local_named_parameters()) else: @@ -177,7 +177,7 @@ def __call__(self, id): def preprocess_test_dataset(sample): instruction = f"### Instruction\n{sample['instruction']}" context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None - response = f"### Answer\n" + response = "### Answer\n" # join all the parts together prompt = "\n".join([i for i in [instruction, context, response] if i is not None]) model_input = tokenizer(prompt, add_special_tokens=False) diff --git a/examples/training/mixtral/convert_checkpoints.py b/examples/training/mixtral/convert_checkpoints.py index b44975b..ce49feb 100644 --- a/examples/training/mixtral/convert_checkpoints.py +++ b/examples/training/mixtral/convert_checkpoints.py @@ -1,13 +1,10 @@ -import argparse import json -import torch - -from checkpoint_converter import CheckpointConverterBase +import torch +from neuronx_distributed.scripts.checkpoint_converter import CheckpointConverterBase class CheckpointConverterMixtral(CheckpointConverterBase): - # ExpertFusedColumnParallelLinear gate_up_proj_partition_dim = 2 # ExpertFusedRowParallelLinear @@ -34,9 +31,9 @@ def pre_process_full_state_before_tp_conversion(self, state_dict, args): down_proj = torch.stack(down_proj_per_expert) state_dict[f"model.layers.{i}.mlp.router.linear_router.weight"] = router_weight - state_dict[f"model.layers.{i}.mlp.expert_mlps.gate_proj.weight"] = gate_proj - state_dict[f"model.layers.{i}.mlp.expert_mlps.up_proj.weight"] = up_proj - state_dict[f"model.layers.{i}.mlp.expert_mlps.down_proj.weight"] = down_proj + state_dict[f"model.layers.{i}.mlp.expert_mlps.mlp_op.gate_proj.weight"] = gate_proj + state_dict[f"model.layers.{i}.mlp.expert_mlps.mlp_op.up_proj.weight"] = up_proj + state_dict[f"model.layers.{i}.mlp.expert_mlps.mlp_op.down_proj.weight"] = down_proj return state_dict @@ -48,9 +45,9 @@ def post_process_full_state_after_tp_conversion(self, state_dict, args): for i in range(config["num_hidden_layers"]): router_weight = state_dict.pop(f"model.layers.{i}.mlp.router.linear_router.weight") - gate_proj = state_dict.pop(f"model.layers.{i}.mlp.expert_mlps.gate_proj.weight") - up_proj = state_dict.pop(f"model.layers.{i}.mlp.expert_mlps.up_proj.weight") - down_proj = state_dict.pop(f"model.layers.{i}.mlp.expert_mlps.down_proj.weight") + gate_proj = state_dict.pop(f"model.layers.{i}.mlp.expert_mlps.mlp_op.gate_proj.weight") + up_proj = state_dict.pop(f"model.layers.{i}.mlp.expert_mlps.mlp_op.up_proj.weight") + down_proj = state_dict.pop(f"model.layers.{i}.mlp.expert_mlps.mlp_op.down_proj.weight") gate_proj_per_expert = torch.unbind(gate_proj) up_proj_per_expert = torch.unbind(up_proj) diff --git a/examples/training/mixtral/mixtral_pretrain/configs/8x7b_config.json b/examples/training/mixtral/mixtral_pretrain/configs/8x7b_config.json index 3e05fad..f1e5290 100644 --- a/examples/training/mixtral/mixtral_pretrain/configs/8x7b_config.json +++ b/examples/training/mixtral/mixtral_pretrain/configs/8x7b_config.json @@ -28,4 +28,4 @@ "vocab_size": 32000, "pretraining_tp": 1, "move_model_to_device": false -} \ No newline at end of file +} diff --git a/examples/training/mixtral/mixtral_pretrain/module_mixtral.py b/examples/training/mixtral/mixtral_pretrain/module_mixtral.py index 3377e3e..905682b 100644 --- a/examples/training/mixtral/mixtral_pretrain/module_mixtral.py +++ b/examples/training/mixtral/mixtral_pretrain/module_mixtral.py @@ -120,7 +120,7 @@ def on_train_batch_end(self, *args, **kwargs): and self.trainer.strategy.pipeline_parallel_rank == self.print_pp_rank ): print( - f"step {self.global_step} loss is {self.loss.detach().cpu().item()}, lr is {self.lr}, input_ids {torch.sum(self.input_ids.detach().cpu()).item()}, norm {self.global_norm}, global rank {xm.get_ordinal()}" + f"step {self.global_step} loss is {self.loss.detach().cpu().item()}, lr is {self.lr}, throughput {self.tps} seq/s, input_ids {torch.sum(self.input_ids.detach().cpu()).item()}, norm {self.global_norm}, global rank {xm.get_ordinal()}" ) # Logging, need to revisit when automatic_optimization enabled @@ -159,7 +159,7 @@ def on_train_start(self, *args, **kwargs): and self.trainer.strategy.tensor_parallel_rank == 0 and self.trainer.strategy.pipeline_parallel_rank == self.print_pp_rank ): - print(f"Training started!") + print("Training started!") # record training start time self.start_time = time.time() @@ -169,7 +169,7 @@ def on_train_end(self, *args, **kwargs): and self.trainer.strategy.tensor_parallel_rank == 0 and self.trainer.strategy.pipeline_parallel_rank == self.print_pp_rank ): - print(f"Training finished!") + print("Training finished!") extract_graphs_only = os.environ.get("NEURON_EXTRACT_GRAPHS_ONLY", None) if not extract_graphs_only: # record aggregate & final statistics in the metrics file diff --git a/examples/training/mixtral/mixtral_pretrain/run_mixtral_pretrain_ptl.py b/examples/training/mixtral/mixtral_pretrain/run_mixtral_pretrain_ptl.py index 4b2bcdf..d17a10d 100644 --- a/examples/training/mixtral/mixtral_pretrain/run_mixtral_pretrain_ptl.py +++ b/examples/training/mixtral/mixtral_pretrain/run_mixtral_pretrain_ptl.py @@ -23,43 +23,43 @@ parent = os.path.dirname(current) sys.path.append(parent) -import torch -import torch_xla.core.xla_model as xm -from data_module import NeuronLightningDataModule -from modeling_mixtral_moe_nxd import ( - CoreAttention, - MixtralDecoderLayer, - MixtralForCausalLM, - MixtralRMSNorm, - init_weights, -) -from module_mixtral import NeuronMixtralLTModule, NeuronMixtralPPLTModule -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary -from training_utils import ( - create_mixtral_pretraining_dataset, - get_learning_rate_scheduler, - get_mixed_precision_config, -) -from transformers import AdamW, MixtralConfig, set_seed - -import neuronx_distributed as nxd -from neuronx_distributed.lightning import ( - NeuronTensorBoardLogger, - NeuronTQDMProgressBar, - NeuronXLAPrecisionPlugin, - NeuronXLAStrategy, -) -from neuronx_distributed.modules.moe.loss_function import load_balancing_loss_func -from neuronx_distributed.parallel_layers import mappings -from neuronx_distributed.utils.adamw_fp32_optim_params import AdamW_FP32OptimParams +import torch # noqa: E402 +import torch_xla.core.xla_model as xm # noqa: E402 +from data_module import NeuronLightningDataModule # noqa: E402 +from modeling_mixtral_moe_nxd import ( # noqa: E402 + CoreAttention, # noqa: E402 + MixtralDecoderLayer, # noqa: E402 + MixtralForCausalLM, # noqa: E402 + MixtralRMSNorm, # noqa: E402 + init_weights, # noqa: E402 +) # noqa: E402 +from module_mixtral import NeuronMixtralLTModule, NeuronMixtralPPLTModule # noqa: E402 +from pytorch_lightning import Trainer # noqa: E402 +from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary # noqa: E402 +from training_utils import ( # noqa: E402 + create_mixtral_pretraining_dataset, # noqa: E402 + get_learning_rate_scheduler, # noqa: E402 + get_mixed_precision_config, # noqa: E402 +) # noqa: E402 +from transformers import AdamW, MixtralConfig, set_seed # noqa: E402 +import neuronx_distributed as nxd # noqa: E402 +from neuronx_distributed.modules.moe.model import MoE # noqa: E402 +from neuronx_distributed.lightning import ( # noqa: E402 + NeuronTensorBoardLogger, # noqa: E402 + NeuronTQDMProgressBar, # noqa: E402 + NeuronXLAPrecisionPlugin, # noqa: E402 + NeuronXLAStrategy, # noqa: E402 +) # noqa: E402 +from neuronx_distributed.modules.moe.loss_function import load_balancing_loss_func # noqa: E402 +from neuronx_distributed.parallel_layers import mappings # noqa: E402 +from neuronx_distributed.utils.adamw_fp32_optim_params import AdamW_FP32OptimParams # noqa: E402 # For PT autocast. torch.cuda.is_bf16_supported = lambda: True # Workaround for NaNs seen with transformers version >= 4.21.0 # https://github.com/aws-neuron/aws-neuron-sdk/issues/593 -import transformers.modeling_utils as modeling_utils +import transformers.modeling_utils as modeling_utils # noqa: E402 if os.environ.get("XLA_USE_BF16") or os.environ.get("XLA_DOWNCAST_BF16"): modeling_utils.get_parameter_dtype = lambda x: torch.bfloat16 @@ -127,10 +127,9 @@ def train_mixtral(flags): def _setup_model_config(flags): model_config = MixtralConfig.from_pretrained(flags.model_path) - model_config.capacity_factor = flags.capacity_factor + # capacity_factor = None corresponds to full capacity (no token dropping) + model_config.capacity_factor = float(flags.capacity_factor) if flags.capacity_factor is not None else None model_config.sequence_parallel_enabled = flags.sequence_parallel_enabled > 0 - model_config.moe_sequence_parallel_mode = flags.moe_sequence_parallel_mode - model_config.expert_mlps_permute_strategy = flags.expert_mlps_permute_strategy model_config.qkv_linear = flags.qkv_linear > 0 model_config.selective_checkpoint_enabled = flags.selective_checkpoint_enabled > 0 model_config.kv_shared_group_size = flags.kv_replicator @@ -147,6 +146,7 @@ def _setup_nxd_config(flags): else { "meta_device_init": True, "param_init_fn": init_weights, + "sequential_move_factor": 11, } ) @@ -175,10 +175,11 @@ def _setup_nxd_config(flags): return nxd.neuronx_distributed_config( tensor_parallel_size=flags.tensor_parallel_size, pipeline_parallel_size=flags.pipeline_parallel_size, + expert_parallel_size=flags.expert_parallel_size, pipeline_config=pipeline_config, optimizer_config={"zero_one_enabled": flags.use_zero_1, "grad_clipping": True, "max_grad_norm": 1.0}, sequence_parallel=flags.sequence_parallel_enabled, - activation_checkpoint_config=CoreAttention if flags.selective_checkpoint_enabled else "full", + activation_checkpoint_config=(CoreAttention, MoE) if flags.selective_checkpoint_enabled else "full", model_init_config=model_init_config, mixed_precision_config=mixed_precision_config, ) @@ -289,6 +290,7 @@ def _mp_fn(index, flags): parser.add_argument("--tensor_parallel_size", default=2, type=int, help="Tensor parallel size") parser.add_argument("--pipeline_parallel_size", type=int, default=1, help="PP size") + parser.add_argument("--expert_parallel_size", type=int, default=1, help="EP size") parser.add_argument("--num_microbatches", type=int, default=8, help="num_microbatches used for PP") parser.add_argument("--seq_len", default=2048, type=int, help="Sequence length") parser.add_argument("--use_mix_precision", action="store_true", help="Use mix precision.") @@ -309,18 +311,6 @@ def _mp_fn(index, flags): action="store_true", help="Enable sequence parallel", ) - parser.add_argument( - "--moe_sequence_parallel_mode", - default="EXIT_SP_ON_ENTRY", - type=str, - help="MoE layer sequence parallel mode", - ) - parser.add_argument( - "--expert_mlps_permute_strategy", - default="matmul", - type=str, - help="ExpertMLPs permute strategy (either 'matmul' or 'index')", - ) parser.add_argument( "--selective_checkpoint_enabled", default=False, @@ -341,8 +331,7 @@ def _mp_fn(index, flags): ) parser.add_argument( "--capacity_factor", - default=4.0, - type=float, + default=None, help="MoE capacity factor", ) parser.add_argument( @@ -350,7 +339,7 @@ def _mp_fn(index, flags): ) parser.add_argument( "--use_gpu_compatible_precision", - default=0, + default=1, type=int, help="Use gpu compatible precision", ) diff --git a/examples/training/mixtral/mixtral_pretrain/run_mixtral_pretrain_ptl.sh b/examples/training/mixtral/mixtral_pretrain/run_mixtral_pretrain_ptl.sh old mode 100755 new mode 100644 index 2843a37..85f13e2 --- a/examples/training/mixtral/mixtral_pretrain/run_mixtral_pretrain_ptl.sh +++ b/examples/training/mixtral/mixtral_pretrain/run_mixtral_pretrain_ptl.sh @@ -5,7 +5,8 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) # Neuron Compiler Flags # TODO: temporarily disable "--distribution-strategy=llm-training" because of a compiler bug. Ideally we should enable it for modular flow -export NEURON_CC_FLAGS="--model-type=transformer --retry_failed_compilation" +# --enable-saturate-infinity: convert inf to max_float to avoid nan (e.g. in transpose) +export NEURON_CC_FLAGS="--model-type=transformer --retry_failed_compilation --enable-saturate-infinity" export NEURON_FUSE_SOFTMAX=1 # Async Runtime @@ -14,10 +15,12 @@ export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=7 # HOST OOM export MALLOC_ARENA_MAX=64 -# TP degree -TP_DEGREE=32 +# Tensor parallel degree +: ${TP_DEGREE:=32} # Pipeline parallel degree -PP_DEGREE=4 +: ${PP_DEGREE:=4} +# Expert parallel degree +: ${EP_DEGREE:=1} # SP SEQUENCE_PARALLEL_ENABLED=1 # 0: bf16; 1: mixed precision @@ -25,9 +28,11 @@ USE_MIX_PRECISION=1 # 0: use pure DP; 1: use ZeRO-1 USE_ZERO_1=1 # global batch size -GBS=32 +: ${GBS:=32} # micro batch size MBS=8 +# Enable selective checkpointing in integration tests. +: ${SELECTIVE_CHECKPOINT_ENABLED:=1} # number of steps to run : ${TOTAL_STEPS:=22500} # warmup steps @@ -39,15 +44,11 @@ MIN_LR=3.0e-5 # model path MODEL_PATH=$SCRIPT_DIR/configs/8x7b_config.json # data path -DATA_PATH="$HOME/examples_datasets/wikicorpus_llama2_7B_tokenized_4k" +DATA_PATH="$HOME/examples_datasets/wikicorpus_llama2_tokenized_4k" # sequence length SEQ_LEN=4096 # capacity factor CAPACITY_FACTOR=2.0 -# MoE sequence parallel mode -MOE_SEQUENCE_PARALLEL_MODE="OPTIMIZED_SP_MATMUL" -# ExpertMLPs permute strategy ("matmul" or "index") -EXPERT_MLPS_PERMUTE_STRATEGY="matmul" # Use meta init META_DEVICE_INIT=1 ############################################# @@ -96,6 +97,9 @@ fi if [ $SEQUENCE_PARALLEL_ENABLED -eq 1 ]; then EXTRA_ARGS+=" --sequence_parallel_enabled" fi +if [ $SELECTIVE_CHECKPOINT_ENABLED -eq 1 ]; then + EXTRA_ARGS+=" --selective_checkpoint_enabled" +fi if [ $PP_DEGREE -gt 1 ]; then # Data paralell size @@ -124,6 +128,7 @@ else fi echo TP_DEGREE=$TP_DEGREE +echo EP_DEGREE=$EP_DEGREE echo SEQUENCE_PARALLEL_ENABLED=$SEQUENCE_PARALLEL_ENABLED echo PP_DEGREE=$PP_DEGREE echo USE_MIX_PRECISION=$USE_MIX_PRECISION @@ -138,8 +143,6 @@ echo MODEL_PATH=$MODEL_PATH echo DATA_PATH=$DATA_PATH echo SEQ_LEN=$SEQ_LEN echo CAPACITY_FACTOR=$CAPACITY_FACTOR -echo MOE_SEQUENCE_PARALLEL_MODE=$MOE_SEQUENCE_PARALLEL_MODE -echo EXPERT_MLPS_PERMUTE_STRATEGY=$EXPERT_MLPS_PERMUTE_STRATEGY echo EXTRA_ARGS=$EXTRA_ARGS echo DP=$DP echo STEPS_THIS_RUN=$STEPS_THIS_RUN @@ -151,6 +154,7 @@ torchrun $DISTRIBUTED_ARGS \ --data_dir $DATA_PATH \ --tensor_parallel_size $TP_DEGREE \ --pipeline_parallel_size $PP_DEGREE \ + --expert_parallel_size $EP_DEGREE \ --batch_size $MBS \ --steps_this_run $STEPS_THIS_RUN\ --max_steps $TOTAL_STEPS \ @@ -159,6 +163,4 @@ torchrun $DISTRIBUTED_ARGS \ --min_lr $MIN_LR \ --seq_len $SEQ_LEN \ --capacity_factor $CAPACITY_FACTOR \ - --moe_sequence_parallel_mode $MOE_SEQUENCE_PARALLEL_MODE \ - --expert_mlps_permute_strategy $EXPERT_MLPS_PERMUTE_STRATEGY \ $EXTRA_ARGS |& tee $OUTPUT_LOG diff --git a/examples/training/mixtral/modeling_mixtral_moe_nxd.py b/examples/training/mixtral/modeling_mixtral_moe_nxd.py index 0553d39..1e4a29a 100644 --- a/examples/training/mixtral/modeling_mixtral_moe_nxd.py +++ b/examples/training/mixtral/modeling_mixtral_moe_nxd.py @@ -68,10 +68,10 @@ ) import neuronx_distributed.parallel_layers.utils as neuronx_dist_utils -from neuronx_distributed.modules.moe.expert_mlps import ExpertMLPsCapacityFactor +from neuronx_distributed.modules.moe.expert_mlps import ExpertMLPs from neuronx_distributed.modules.moe.loss_function import load_balancing_loss_func -from neuronx_distributed.modules.moe.model import MoE, MoESequenceParallelMode -from neuronx_distributed.modules.moe.moe_parallel_layers import InputParallelLinear +from neuronx_distributed.modules.moe.model import MoE +from neuronx_distributed.modules.moe.moe_parallel_layers import LinearRouter from neuronx_distributed.modules.moe.routing import RouterTopK from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinear from neuronx_distributed.parallel_layers import mappings @@ -348,42 +348,34 @@ def forward( def initialize_mixtral_moe_layer(config): - if config.sequence_parallel_enabled: - assert ( - config.moe_sequence_parallel_mode != MoESequenceParallelMode.NO_SP - ), "sequence_parallel_enabled=true, but moe_sequence_parallel_mode set to NO_SP" - sequence_parallel_mode = MoESequenceParallelMode[config.moe_sequence_parallel_mode] - else: - sequence_parallel_mode = MoESequenceParallelMode.NO_SP - # Default to RouterTopK (without Sinkhorn) router = RouterTopK( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, - sequence_parallel_mode=sequence_parallel_mode, ) init_method = partial(_init_normal, config.initializer_range) # TODO: Potentially add activation checkpointing in the ExpertMLPs, depending on profile/performance needs - expert_mlps = ExpertMLPsCapacityFactor( + expert_mlps = ExpertMLPs( num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, + glu_mlp=True, capacity_factor=config.capacity_factor, + normalize_top_k_affinities=True, init_method=init_method, output_layer_init_method=init_method, - glu_mlp=True, - sequence_parallel_mode=sequence_parallel_mode, - permute_strategy=config.expert_mlps_permute_strategy, - top_k=config.num_experts_per_tok, - normalize_top_k_affinities=True, ) moe_layer = MoE( - router=router, expert_mlps=expert_mlps, return_router_logits=True, sequence_parallel_mode=sequence_parallel_mode + router=router, + expert_mlps=expert_mlps, + return_router_logits=True, + sequence_parallel_enabled=config.sequence_parallel_enabled, ) return moe_layer @@ -525,14 +517,19 @@ def forward( hidden_states = residual + hidden_states # Fully Connected + if type(self.mlp).__name__ == "NxDCheckpointWrapper": + mlp_class = type(self.mlp._checkpoint_wrapped_module).__name__ + else: + mlp_class = type(self.mlp).__name__ + residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - if isinstance(self.mlp, LlamaMLP): + if mlp_class == "LlamaMLP": hidden_states = self.mlp(hidden_states) - elif isinstance(self.mlp, MoE): + elif mlp_class == "MoE": hidden_states, router_logits = self.mlp(hidden_states) else: - raise Exception(f"MLP Layer type must be either LlamaMLP or MoE, got {type(self.mlp).__name__}.") + raise TypeError(f"MLP Layer type must be either LlamaMLP or MoE, got {type(self.mlp).__name__}.") hidden_states = residual + hidden_states outputs = (hidden_states,) @@ -546,10 +543,13 @@ def forward( if output_router_logits: # Concatenate the router logits with previous router logits if past_router_logits is not None: - if isinstance(self.mlp, MoE): + if mlp_class == "LlamaMLP": + router_logits = past_router_logits + elif mlp_class == "MoE": router_logits = torch.cat((past_router_logits, router_logits), dim=0) else: - router_logits = past_router_logits + raise TypeError(f"MLP Layer type must be either LlamaMLP or MoE, got {type(self.mlp).__name__}.") + outputs += (router_logits,) # TODO: Return a tuple here to workaround a NxD PP issue, can revert once the issue is fixed. @@ -881,7 +881,7 @@ def init_weights(module): """ if isinstance(module, MixtralRMSNorm): module.weight.data.fill_(1.0) - elif isinstance(module, (ParallelEmbedding, RowParallelLinear, ColumnParallelLinear, InputParallelLinear)): + elif isinstance(module, (ParallelEmbedding, RowParallelLinear, ColumnParallelLinear, LinearRouter)): module.init_weight_cpu() if hasattr(module, "bias") and module.bias is not None: module.bias.data.zero_() diff --git a/examples/training/tp_dp_bert_hf_pretrain/tp_dp_bert_large_hf_pretrain_hdf5.py b/examples/training/tp_dp_bert_hf_pretrain/tp_dp_bert_large_hf_pretrain_hdf5.py index 44bae0f..c6805e2 100644 --- a/examples/training/tp_dp_bert_hf_pretrain/tp_dp_bert_large_hf_pretrain_hdf5.py +++ b/examples/training/tp_dp_bert_hf_pretrain/tp_dp_bert_large_hf_pretrain_hdf5.py @@ -78,7 +78,7 @@ # Workaround for NaNs seen with transformers version >= 4.21.0 # https://github.com/aws-neuron/aws-neuron-sdk/issues/593 -import transformers.modeling_utils as modeling_utils +import transformers.modeling_utils as modeling_utils # noqa: E402 if os.environ.get("XLA_USE_BF16") or os.environ.get("XLA_DOWNCAST_BF16"): modeling_utils.get_parameter_dtype = lambda x: torch.bfloat16 @@ -218,7 +218,7 @@ def get_instance_type(self): headers={"X-aws-ec2-metadata-token": token.text}, ) return data.text - except: + except Exception: return os.environ.get("HOSTNAME", "unknown") def log(self, epoch, step, step_loss, learning_rate, throughput, grad_norm=None): @@ -226,7 +226,7 @@ def log(self, epoch, step, step_loss, learning_rate, throughput, grad_norm=None) grad_norm_msg = f"grad-norm : {grad_norm}" if grad_norm else "" print( f"LOG {time_now} - ({epoch}, {step}) step_loss : {step_loss:.4f} " - f"learning_rate : {learning_rate:.2e} throughput : {throughput:.2f} " + f"learning_rate : {learning_rate:.2e} throughput : {throughput:.2f} seq/s " f"{grad_norm_msg}", flush=True, ) @@ -564,7 +564,6 @@ def _print_logs(running_loss_reduced_detached, total_norm): scheduler_state_dict = None if flags.resume_ckpt: - step = flags.resume_step state_dict = checkpointing.load(flags.output_dir, model) optimizer.load_state_dict(state_dict["optimizer"]) global_step = state_dict["global_step"] diff --git a/examples/training/tp_dp_gpt_neox_hf_pretrain/tp_dp_gpt_neox_20b_hf_pretrain/tp_dp_gpt_neox_20b_hf_pretrain.py b/examples/training/tp_dp_gpt_neox_hf_pretrain/tp_dp_gpt_neox_20b_hf_pretrain/tp_dp_gpt_neox_20b_hf_pretrain.py index f29db3e..21ccc90 100644 --- a/examples/training/tp_dp_gpt_neox_hf_pretrain/tp_dp_gpt_neox_20b_hf_pretrain/tp_dp_gpt_neox_20b_hf_pretrain.py +++ b/examples/training/tp_dp_gpt_neox_hf_pretrain/tp_dp_gpt_neox_20b_hf_pretrain/tp_dp_gpt_neox_20b_hf_pretrain.py @@ -178,7 +178,7 @@ def get_instance_type(self): headers={"X-aws-ec2-metadata-token": token.text}, ) return data.text - except: + except Exception: return os.environ.get("HOSTNAME", "unknown") def log(self, epoch, step, step_loss, learning_rate, throughput, grad_norm=None): @@ -186,7 +186,7 @@ def log(self, epoch, step, step_loss, learning_rate, throughput, grad_norm=None) grad_norm_msg = f"grad-norm : {grad_norm}" if grad_norm else "" print( f"LOG {time_now} - ({epoch}, {step}) step_loss : {step_loss:.4f} " - f"learning_rate : {learning_rate:.2e} throughput : {throughput:.2f} " + f"learning_rate : {learning_rate:.2e} throughput : {throughput:.2f} seq/s " f"{grad_norm_msg}", flush=True, ) @@ -508,7 +508,7 @@ def _print_logs(running_loss_reduced_detached, total_norm): ), ] metric_writer.store_metrics(metric_data) - if not os.environ.get("NEURON_EXTRACT_GRAPHS_ONLY", None): # Do not save checkpoint during pre-compile + if not os.environ.get("NEURON_EXTRACT_GRAPHS_ONLY", None): # Do not save checkpoint during pre-compile state_dict = { "model": model.state_dict(), "global_step": global_step, diff --git a/examples/training/tp_dp_gpt_neox_hf_pretrain/tp_dp_gpt_neox_20b_hf_pretrain/tp_dp_gpt_neox_20b_hf_pretrain.sh b/examples/training/tp_dp_gpt_neox_hf_pretrain/tp_dp_gpt_neox_20b_hf_pretrain/tp_dp_gpt_neox_20b_hf_pretrain.sh index 6a17512..cae228c 100644 --- a/examples/training/tp_dp_gpt_neox_hf_pretrain/tp_dp_gpt_neox_20b_hf_pretrain/tp_dp_gpt_neox_20b_hf_pretrain.sh +++ b/examples/training/tp_dp_gpt_neox_hf_pretrain/tp_dp_gpt_neox_20b_hf_pretrain/tp_dp_gpt_neox_20b_hf_pretrain.sh @@ -38,7 +38,7 @@ if [ ! -z "$SLURM_NTASKS" ]; then MASTER_ADDRESS=(`scontrol show hostnames $SLURM_JOB_NODELIST`) DISTRIBUTED_ARGS="--nproc_per_node $NUM_NEURONCORES --nnodes $WORLD_SIZE --node_rank $NODE_ID --master_addr $MASTER_ADDRESS --master_port 44000" if [ $NODE_ID -eq 0 ]; then - echo "WORLD_SIZE=$WORLD_SIZE" + echo "WORLD_SLURM_NTASKS=$WORLD_SIZE" echo "NODE_ID=$NODE_ID" echo "MASTER_ADDRESS=$MASTER_ADDRESS" echo "DISTRIBUTED_ARGS=$DISTRIBUTED_ARGS" diff --git a/examples/training/tp_dp_gpt_neox_hf_pretrain/tp_dp_gpt_neox_6.9b_hf_pretrain/tp_dp_gpt_neox_6.9b_hf_pretrain.py b/examples/training/tp_dp_gpt_neox_hf_pretrain/tp_dp_gpt_neox_6.9b_hf_pretrain/tp_dp_gpt_neox_6.9b_hf_pretrain.py index 4af04b0..d724198 100644 --- a/examples/training/tp_dp_gpt_neox_hf_pretrain/tp_dp_gpt_neox_6.9b_hf_pretrain/tp_dp_gpt_neox_6.9b_hf_pretrain.py +++ b/examples/training/tp_dp_gpt_neox_hf_pretrain/tp_dp_gpt_neox_6.9b_hf_pretrain/tp_dp_gpt_neox_6.9b_hf_pretrain.py @@ -194,7 +194,7 @@ def get_instance_type(self): headers={"X-aws-ec2-metadata-token": token.text}, ) return data.text - except: + except Exception: return os.environ.get("HOSTNAME", "unknown") def log(self, epoch, step, step_loss, learning_rate, throughput, grad_norm=None): @@ -202,7 +202,7 @@ def log(self, epoch, step, step_loss, learning_rate, throughput, grad_norm=None) grad_norm_msg = f"grad-norm : {grad_norm}" if grad_norm else "" print( f"LOG {time_now} - ({epoch}, {step}) step_loss : {step_loss:.4f} " - f"learning_rate : {learning_rate:.2e} throughput : {throughput:.2f} " + f"learning_rate : {learning_rate:.2e} throughput : {throughput:.2f} seq/s " f"{grad_norm_msg}", flush=True, ) @@ -361,8 +361,6 @@ def train_gpt_neox(flags): ) def train_loop_fn(model, optimizer, train_loader, epoch, global_step, training_ustep, running_loss): - max_grad_norm = 1.0 - for _, data in enumerate(train_loader): training_ustep += 1 input_ids = data["input_ids"] @@ -536,7 +534,7 @@ def _print_logs(running_loss_reduced_detached, total_norm): ), ] metric_writer.store_metrics(metric_data) - if not os.environ.get("NEURON_EXTRACT_GRAPHS_ONLY", None): # Do not save checkpoint during pre-compile + if not os.environ.get("NEURON_EXTRACT_GRAPHS_ONLY", None): # Do not save checkpoint during pre-compile state_dict = { "model": model.state_dict(), "global_step": global_step, diff --git a/examples/training/tp_dp_gpt_neox_hf_pretrain/tp_dp_gpt_neox_6.9b_hf_pretrain/tp_dp_gpt_neox_6.9b_hf_pretrain.sh b/examples/training/tp_dp_gpt_neox_hf_pretrain/tp_dp_gpt_neox_6.9b_hf_pretrain/tp_dp_gpt_neox_6.9b_hf_pretrain.sh index 0d2092b..b3b7332 100644 --- a/examples/training/tp_dp_gpt_neox_hf_pretrain/tp_dp_gpt_neox_6.9b_hf_pretrain/tp_dp_gpt_neox_6.9b_hf_pretrain.sh +++ b/examples/training/tp_dp_gpt_neox_hf_pretrain/tp_dp_gpt_neox_6.9b_hf_pretrain/tp_dp_gpt_neox_6.9b_hf_pretrain.sh @@ -38,7 +38,7 @@ if [ ! -z "$SLURM_NTASKS" ]; then MASTER_ADDRESS=(`scontrol show hostnames $SLURM_JOB_NODELIST`) DISTRIBUTED_ARGS="--nproc_per_node $NUM_NEURONCORES --nnodes $WORLD_SIZE --node_rank $NODE_ID --master_addr $MASTER_ADDRESS --master_port 44000" if [ $NODE_ID -eq 0 ]; then - echo "WORLD_SIZE=$WORLD_SIZE" + echo "WORLD_SLURM_NTASKS=$WORLD_SIZE" echo "NODE_ID=$NODE_ID" echo "MASTER_ADDRESS=$MASTER_ADDRESS" echo "DISTRIBUTED_ARGS=$DISTRIBUTED_ARGS" diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..d3a48a7 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,35 @@ +[mypy] +# files = src/**/*.py, examples/**/*.py +files = src/**/*.py +pretty = True +explicit_package_bases = True + +[mypy-torch_xla.*] +ignore_missing_imports = True + +[mypy-torchdistx.*] +ignore_missing_imports = True + +[mypy-transformers.*] +ignore_missing_imports = True + +[mypy-datasets.*] +ignore_missing_imports = True + +[mypy-pytorch_lightning.*] +ignore_missing_imports = True + +[mypy-lightning_fabric.*] +ignore_missing_imports = True + +[mypy-torch_neuronx.*] +ignore_missing_imports = True + +[mypy-neuronxcc.*] +ignore_missing_imports = True + +[mypy-accelerate.*] +ignore_missing_imports = True + +[mypy-omegaconf.*] +ignore_missing_imports = True diff --git a/setup.py b/setup.py index b98c217..7416a29 100644 --- a/setup.py +++ b/setup.py @@ -23,8 +23,8 @@ def get_tag(self): exec(open('src/neuronx_distributed/_version.py').read()) setup( - name='neuronx-distributed', - version=__version__, + name="neuronx-distributed", + version=__version__, #noqa classifiers=[ 'Development Status :: 3 - Alpha', 'Intended Audience :: Developers', @@ -51,8 +51,14 @@ def get_tag(self): "console_scripts": ["nxd_convert_zero_checkpoints=neuronx_distributed.optimizer.convert_zero_checkpoints:main"], }, install_requires=[ +<<<<<<< HEAD 'torch-neuronx', 'torch-xla', +======= + "torch-neuronx", + "torch-xla", + "safetensors" +>>>>>>> 5470500 (Neuron SDK Release 2.20.0 (#21)) ], distclass=BinaryDistribution, cmdclass={ diff --git a/src/neuronx_distributed/_version.py b/src/neuronx_distributed/_version.py index f6fe8f4..886d341 100644 --- a/src/neuronx_distributed/_version.py +++ b/src/neuronx_distributed/_version.py @@ -1,3 +1,3 @@ # Copyright Amazon Web Services and its Affiliates. All Rights Reserved. # ============================================================================== -__version__ = "0.8.0" +__version__ = "0.9.0" diff --git a/src/neuronx_distributed/lightning/checkpoint_io.py b/src/neuronx_distributed/lightning/checkpoint_io.py index fc52e1b..f3de889 100644 --- a/src/neuronx_distributed/lightning/checkpoint_io.py +++ b/src/neuronx_distributed/lightning/checkpoint_io.py @@ -11,16 +11,22 @@ class NeuronCheckpointIO(XLACheckpointIO): - def __init__(self, save_load_xser: bool = True, *args: Any, **kwargs: Any) -> None: + def __init__(self, save_load_xser: bool = True, weights_only: bool = False, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.save_load_xser = save_load_xser + self.weights_only = weights_only def load_checkpoint( self, checkpoint_path: _PATH, master_dp_only: bool = True, ) -> Dict[str, Any]: - return load(chkpt_path=checkpoint_path, load_xser=self.save_load_xser, master_dp_only=master_dp_only) + return load( + chkpt_path=checkpoint_path, + load_xser=self.save_load_xser, + master_dp_only=master_dp_only, + weights_only=self.weights_only, + ) def save_checkpoint( self, diff --git a/src/neuronx_distributed/lightning/launcher.py b/src/neuronx_distributed/lightning/launcher.py index 2d3e6c4..eb216de 100644 --- a/src/neuronx_distributed/lightning/launcher.py +++ b/src/neuronx_distributed/lightning/launcher.py @@ -65,7 +65,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] self._recover_results_in_main_process(worker_output, trainer) return worker_output.trainer_results else: # Neuron change for launch with torchrun - process_idx = int(os.environ.get("LOCAL_RANK")) + process_idx = int(os.environ["LOCAL_RANK"]) self._strategy._local_rank = process_idx results = function(*args, **kwargs) _rank_teardown(process_idx) @@ -76,6 +76,7 @@ def _wrapping_function( # XLA's multiprocessing returns the global index, not the local index as torch's multiprocessing # https://github.com/pytorch/xla/blob/v1.13.0/torch_xla/distributed/xla_multiprocessing.py#L321 process_idx: int, + trainer: Optional["pl.Trainer"], function: Callable, args: Any, kwargs: Any, diff --git a/src/neuronx_distributed/lightning/logger.py b/src/neuronx_distributed/lightning/logger.py index b76ccdb..16ace1f 100644 --- a/src/neuronx_distributed/lightning/logger.py +++ b/src/neuronx_distributed/lightning/logger.py @@ -1,21 +1,25 @@ import os -from typing import Any, Callable, Mapping, Optional +from typing import Any, Callable, Mapping, Optional, TYPE_CHECKING from lightning_fabric.utilities.cloud_io import _is_dir from lightning_fabric.utilities.logger import _add_prefix +import pytorch_lightning as pl from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +from pytorch_lightning.utilities.rank_zero import rank_zero_warn from torch import Tensor from neuronx_distributed.parallel_layers.parallel_state import ( get_data_parallel_rank, - get_data_parallel_size, get_pipeline_model_parallel_rank, get_pipeline_model_parallel_size, get_tensor_model_parallel_rank, model_parallel_is_initialized, ) +if TYPE_CHECKING: + from torch.utils import tensorboard + class NeuronTensorBoardLogger(TensorBoardLogger): def __init__(self, log_rank0: bool = False, **kwargs): @@ -32,13 +36,11 @@ def print_step(self, value: int = -1): self._print_step = value @property - def experiment(self) -> "SummaryWriter": + def experiment(self) -> "tensorboard.SummaryWriter": """Actual tensorboard object. To use TensorBoard features anywhere in your code, do the following. Example:: - logger.experiment.some_tensorboard_function() - """ """Neuron change, log on the last PP rank""" if not self.should_print(): @@ -96,7 +98,8 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) ) from ex def log_graph( # type: ignore[override] - self, model: "pl.LightningModule", input_array: Optional[Tensor] = None + self, model: "pl.LightningModule", + input_array: Optional[Tensor] = None ) -> None: """Neuron change, log on the last PP rank""" if not self.should_print(): @@ -125,15 +128,12 @@ def log_graph( # type: ignore[override] with pl.core.module._jit_is_scripting(): self.experiment.add_graph(model, input_array) - def should_print(self): + def should_print(self) -> bool: # For NxD we should log on the last PP - assert model_parallel_is_initialized(), f"NxD model parallel not initialized" + assert model_parallel_is_initialized(), "NxD model parallel not initialized" print_pp_rank = 0 if self.log_rank0 else get_pipeline_model_parallel_size() - 1 return ( - get_data_parallel_rank() == 0 - and get_tensor_model_parallel_rank() == 0 - and get_pipeline_model_parallel_rank() == print_pp_rank - and self.print_step >= 0 + get_data_parallel_rank() == 0 and get_tensor_model_parallel_rank() == 0 and get_pipeline_model_parallel_rank() == print_pp_rank and self.print_step >= 0 ) @@ -143,7 +143,7 @@ class _DummyExperiment: def nop(self, *args: Any, **kw: Any) -> None: pass - def __getattr__(self, _: Any) -> Callable: + def __getattr__(self, _: Any) -> Callable[..., Any]: return self.nop def __getitem__(self, idx: int) -> "_DummyExperiment": diff --git a/src/neuronx_distributed/lightning/module.py b/src/neuronx_distributed/lightning/module.py index 4c09829..14c5adc 100644 --- a/src/neuronx_distributed/lightning/module.py +++ b/src/neuronx_distributed/lightning/module.py @@ -4,6 +4,7 @@ import torch import torch_xla.core.xla_model as xm from lightning_utilities.core.apply_func import apply_to_collection +from lightning_utilities.core.rank_zero import rank_zero_warn from pytorch_lightning import LightningModule from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import ( _FxValidator, @@ -32,7 +33,7 @@ def __init__( opt_kwargs: Dict = {}, scheduler_args: Tuple = (), scheduler_kwargs: Dict = {}, - model_fn: Callable = None, + model_fn: Optional[Callable[..., Any]] = None, grad_accum_steps: int = 1, train_batch_size: int = 16, logging_interval: int = 1, @@ -116,7 +117,7 @@ def load_state_dict(self, state_dict): def get_param_groups_by_weight_decay(self): """Get param groups. Customers can override this to have their own way of weight_decay""" - if hasattr(self.model, "local_named_parameters"): + if hasattr(self.model, "local_named_parameters") and hasattr(self.model, "partitioned") and self.model.partitioned: # Zero1 use the first param in opt to decide the device param_optimizer = list(self.model.local_named_parameters()) else: diff --git a/src/neuronx_distributed/lightning/precision_plugin.py b/src/neuronx_distributed/lightning/precision_plugin.py index 635984e..99491c4 100644 --- a/src/neuronx_distributed/lightning/precision_plugin.py +++ b/src/neuronx_distributed/lightning/precision_plugin.py @@ -1,9 +1,12 @@ -from typing import Any, Callable +from typing import Any, Callable, TYPE_CHECKING from lightning_fabric.accelerators.xla import _XLA_AVAILABLE from lightning_fabric.utilities.types import Optimizable from pytorch_lightning.plugins.precision import XLAPrecisionPlugin +if TYPE_CHECKING: + import pytorch_lightning as pl + class NeuronXLAPrecisionPlugin(XLAPrecisionPlugin): def __init__(self, mixed_precision_enabled: bool = False) -> None: @@ -19,7 +22,5 @@ def optimizer_step( # type: ignore[override] closure: Callable[[], Any], **kwargs: Any, ) -> Any: - pass - # TODO: currently using manual optimization, need further modification here for auto optimization optimizer.step() diff --git a/src/neuronx_distributed/lightning/strategy.py b/src/neuronx_distributed/lightning/strategy.py index d8aa4cc..ddc33e7 100644 --- a/src/neuronx_distributed/lightning/strategy.py +++ b/src/neuronx_distributed/lightning/strategy.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, TYPE_CHECKING import torch from lightning_fabric.plugins.environments import ( @@ -7,6 +7,7 @@ XLAEnvironment, ) from lightning_fabric.utilities.types import _PATH, ReduceOp +from lightning_fabric.utilities import move_data_to_device from pytorch_lightning.strategies import XLAStrategy from torch import Tensor @@ -23,16 +24,20 @@ from .checkpoint_io import NeuronCheckpointIO from .launcher import _NeuronXLALauncher +if TYPE_CHECKING: + from pytorch_lightning.strategies.strategy import TBroadcast + class NeuronXLAStrategy(XLAStrategy): def __init__( self, - nxd_config: Dict = None, + nxd_config: Optional[Dict[str, Any]] = None, tensor_parallel_size: int = 1, pipeline_parallel_size: int = 1, + expert_parallel_size: int = 1, debug: bool = False, sync_module_states: bool = False, - checkpoint_io: bool = None, + checkpoint_io: Optional[NeuronCheckpointIO] = None, save_load_xser: bool = True, ): if os.environ.get("TORCHELASTIC_RUN_ID") is not None: @@ -51,7 +56,7 @@ def __init__( elif isinstance(checkpoint_io, NeuronCheckpointIO): self.checkpoint_io = checkpoint_io else: - raise NotImplementedError(f"NeuronXLAStrategy only supports NeuronCheckpointIO") + raise NotImplementedError("NeuronXLAStrategy only supports NeuronCheckpointIO") self.debug = debug self._launched = False @@ -62,14 +67,16 @@ def __init__( if self.nxd_config is not None: self.tensor_parallel_size = self.nxd_config["tensor_parallel_size"] self.pipeline_parallel_size = self.nxd_config["pipeline_parallel_size"] + self.expert_parallel_size = self.nxd_config["expert_parallel_size"] else: self.tensor_parallel_size = tensor_parallel_size self.pipeline_parallel_size = pipeline_parallel_size + self.expert_parallel_size = expert_parallel_size def _configure_launcher(self) -> None: self._launcher = _NeuronXLALauncher(self) - def broadcast(self, obj, src: int = 0): + def broadcast(self, obj: "TBroadcast", src: int = 0) -> "TBroadcast": return obj @property @@ -80,7 +87,7 @@ def setup_distributed(self) -> None: import torch.distributed as dist if self.cluster_environment.creates_processes_externally: - global_rank = int(os.environ.get("RANK")) + global_rank = int(os.environ["RANK"]) else: import torch_xla.core.xla_model as xm @@ -98,6 +105,7 @@ def setup_distributed(self) -> None: initialize_model_parallel( tensor_model_parallel_size=self.tensor_parallel_size, pipeline_model_parallel_size=self.pipeline_parallel_size, + expert_model_parallel_size=self.expert_parallel_size, ) self.data_parallel_rank = get_data_parallel_rank() @@ -105,7 +113,7 @@ def setup_distributed(self) -> None: self.tensor_parallel_rank = get_tensor_model_parallel_rank() self.pipeline_parallel_rank = get_pipeline_model_parallel_rank() - def teardown(self): + def teardown(self) -> None: assert self.cluster_environment is not None self.cluster_environment.teardown() self.precision_plugin.teardown() @@ -167,13 +175,13 @@ def reduce( xm.mark_step() return output.cpu() - def process_dataloader(self, dataloader: object): + def process_dataloader(self, dataloader: object) -> object: from torch_xla.distributed.parallel_loader import MpDeviceLoader # PP requires input data on CPU if self.pipeline_parallel_size > 1: if isinstance(dataloader, MpDeviceLoader): - print(f"convertine dataloader {dataloader} to {dataloader._loader}") + print(f"converting dataloader {dataloader} to {dataloader._loader}") return dataloader._loader return dataloader diff --git a/src/neuronx_distributed/modules/lora/layer.py b/src/neuronx_distributed/modules/lora/layer.py index 64cc207..87f40ae 100644 --- a/src/neuronx_distributed/modules/lora/layer.py +++ b/src/neuronx_distributed/modules/lora/layer.py @@ -45,8 +45,11 @@ def __init__(self, base_layer: torch.nn.Module, lora_config: LoraConfig) -> None # QuantLinear in_features, out_features = base_layer.infeatures, base_layer.outfeatures elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"): - # Megatron ColumnParallelLinear,RowParallelLinear + # ColumnParallelLinear, RowParallelLinear in_features, out_features = base_layer.input_size, base_layer.output_size + elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_sizes"): + # GQAQKVColumnParallelLinear + in_features, out_features = base_layer.input_size, base_layer.output_sizes else: if is_hf_transformers_available(): from transformers.pytorch_utils import Conv1D @@ -63,6 +66,10 @@ def __init__(self, base_layer: torch.nn.Module, lora_config: LoraConfig) -> None self.in_features = in_features self.out_features = out_features self.merged = False + if lora_config.use_rslora: + self.scaling = self.lora_alpha / math.sqrt(self.lora_rank) + else: + self.scaling = self.lora_alpha / self.lora_rank def get_base_layer(self) -> torch.nn.Module: r""" @@ -94,7 +101,7 @@ def merge(self, safe_merge: bool = False) -> None: orig_weights += self.get_delta_weight() if not torch.isfinite(orig_weights).all(): - raise ValueError(f"NaNs detected in the merged weights. The adapter seems to be broken") + raise ValueError("NaNs detected in the merged weights. The adapter seems to be broken") base_layer.weight.data = orig_weights else: base_layer.weight.data += self.get_delta_weight() @@ -132,12 +139,12 @@ def init_lora_parameters(self, init_lora_weights): # https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124 nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) elif init_lora_weights == "gaussian": - nn.init.normal_(self.lora_A.weight, std=1 / self.r) + nn.init.normal_(self.lora_A.weight, std=1 / self.lora_rank) else: raise ValueError(f"Unknown LoRA parameters initialization with {init_lora_weights}") nn.init.zeros_(self.lora_B.weight) - if self.lora_embedding_A: + if self.lora_embedding_A is not None: # initialize a the same way as the default for nn.linear and b to zero nn.init.zeros_(self.lora_embedding_A) nn.init.normal_(self.lora_embedding_B) @@ -177,10 +184,6 @@ def update_layer(self, lora_config: LoraConfig): # Actual trainable parameters self.lora_A = nn.Linear(self.in_features, self.lora_rank, bias=False) self.lora_B = nn.Linear(self.lora_rank, self.out_features, bias=False) - if lora_config.use_rslora: - self.scaling = self.lora_alpha / math.sqrt(self.lora_rank) - else: - self.scaling = self.lora_alpha / self.lora_rank self.init_lora_parameters(lora_config.init_lora_weights) base_layer = self.get_base_layer() @@ -257,10 +260,6 @@ def update_layer(self, lora_config: LoraConfig): weight_B = torch.randn(self.out_features, self.lora_rank) self.lora_embedding_A = nn.Parameter(weight_A) self.lora_embedding_B = nn.Parameter(weight_B) - if lora_config.use_rslora: - self.scaling = self.lora_alpha / math.sqrt(self.lora_rank) - else: - self.scaling = self.lora_alpha / self.lora_rank self.init_lora_parameters(lora_config.init_lora_weights) base_layer = self.get_base_layer() @@ -352,10 +351,6 @@ def update_layer(self, lora_config: LoraConfig): padding = base_layer.padding self.lora_A = nn.Conv2d(self.in_features, self.lora_rank, kernel_size, stride, padding, bias=False) self.lora_B = nn.Conv2d(self.lora_rank, self.out_features, (1, 1), (1, 1), bias=False) - if lora_config.use_rslora: - self.scaling = self.lora_alpha / math.sqrt(self.lora_rank) - else: - self.scaling = self.lora_alpha / self.lora_rank self.init_lora_parameters(lora_config.init_lora_weights) weight = getattr(base_layer, "weight", None) diff --git a/src/neuronx_distributed/modules/lora/model.py b/src/neuronx_distributed/modules/lora/model.py index d7b111a..713b0d3 100644 --- a/src/neuronx_distributed/modules/lora/model.py +++ b/src/neuronx_distributed/modules/lora/model.py @@ -3,12 +3,14 @@ import json import os import re +from typing import Optional, Dict, Mapping, Any, Tuple, TYPE_CHECKING from dataclasses import asdict import torch import torch_xla.core.xla_model as xm from neuronx_distributed.parallel_layers import ColumnParallelLinear, RowParallelLinear +from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinear from neuronx_distributed.parallel_layers.parallel_state import ( model_parallel_is_initialized, ) @@ -19,7 +21,10 @@ from .config import LoraConfig from .layer import LoraConv2d, LoraEmbedding, LoraLayer, LoraLinear -from .tp_layer import LoraParallelLinear +from .tp_layer import LoraParallelLinear, LoraGQAQKVParallelLinear + +if TYPE_CHECKING: + import transformers # The mapping is based on https://github.com/huggingface/peft/blob/main/src/peft/utils/constants.py @@ -103,7 +108,7 @@ class LoraModel(torch.nn.Module): prefix: str = "lora_" - def __init__(self, module, config: LoraConfig) -> None: + def __init__(self, module: "transformers.PreTrainedModel", config: LoraConfig) -> None: assert config is not None super().__init__() @@ -114,6 +119,7 @@ def __init__(self, module, config: LoraConfig) -> None: self.is_verbose_enabled = config.lora_verbose self.modules_to_save = config.modules_to_save self.is_lora_enabled = False + self.is_optimum_enabled = False self.is_checkpoint_loaded = False self.lora_config = config self.is_base_model_loaded = False @@ -150,7 +156,7 @@ def forward( **kwargs, ) - def _set_optimum_generate(self): + def _set_optimum_generate(self) -> None: if not self.is_optimum_enabled: try: from optimum.neuron.utils.training_utils import ( @@ -159,14 +165,14 @@ def _set_optimum_generate(self): patch_generation_mixin_to_general_neuron_generation_mixin(self.module) self.is_optimum_enabled = True - except: + except Exception: raise ImportError("Failed to import optimum-neuron, generation will not work on Neuron.") def generate(self, *args, **kwargs): self._set_optimum_generate() return self.module.generate(*args, **kwargs) - def inject_adapter(self): + def inject_adapter(self) -> None: r""" Creates adapter layers and replaces the target modules with the adapter layers. It involves the following four steps: @@ -200,14 +206,14 @@ def inject_adapter(self): self._mark_only_adapters_as_trainable() self.is_lora_enabled = True - def _get_submodules(self, key): + def _get_submodules(self, key: str): module = self.module target_name = key.split(".")[-1] parent = module.get_submodule(".".join(key.split(".")[:-1])) target = module.get_submodule(key) return parent, target, target_name - def _set_target_modules(self): + def _set_target_modules(self) -> None: config = self.lora_config if config.target_modules is not None: return @@ -224,7 +230,7 @@ def _set_target_modules(self): else: self.lora_config.target_modules = MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_type] - def _check_target_module_exists(self, key): + def _check_target_module_exists(self, key: str) -> bool: r"""A helper method to check if the passed module's key name matches any of the target modules. Args: @@ -248,17 +254,17 @@ def _check_target_module_exists(self, key): def _create_and_replace( self, target, - target_name, + target_name: str, parent, current_key, - ): + ) -> None: if current_key is None: raise ValueError("Current Key shouldn't be `None`") new_module = self._create_new_module(target) self._replace_module(parent, target_name, new_module, target) - def _replace_module(self, parent, child_name, new_module, child): + def _replace_module(self, parent, child_name: str, new_module, child) -> None: setattr(parent, child_name, new_module) # child layer wraps the original module, unpack it if hasattr(child, "base_layer"): @@ -274,13 +280,12 @@ def _replace_module(self, parent, child_name, new_module, child): new_module.base_layer.state = child.state else: new_module.state = child.state - new_module.to(child.weight.device) + new_module.to("xla") # dispatch to correct device for name, module in new_module.named_modules(): if self.prefix in name: - weight = child.weight - module.to(weight.device) + module.to("xla") def _mark_only_adapters_as_trainable(self) -> None: module = self.module @@ -323,17 +328,17 @@ def _create_new_module(self, target): new_module = LoraConv2d(target, lora_config) elif isinstance(target, torch.nn.Linear): new_module = LoraLinear(target, lora_config) + elif isinstance(target, (ColumnParallelLinear, RowParallelLinear)): + # check NxD model + new_module = LoraParallelLinear(base_layer=target, lora_config=lora_config) + elif isinstance(target, GQAQKVColumnParallelLinear): + new_module = LoraGQAQKVParallelLinear(base_layer=target, lora_config=lora_config) elif is_hf_transformers_available(): from transformers.pytorch_utils import Conv1D if isinstance(target, Conv1D): new_module = LoraLinear(target, lora_config, is_conv_1d_layer=True) - # check NxD model - if model_parallel_is_initialized(): - if isinstance(target, (ColumnParallelLinear, RowParallelLinear)): - new_module = LoraParallelLinear(base_layer=target, lora_config=lora_config) - if new_module is None: # no module could be matched raise ValueError( @@ -344,18 +349,19 @@ def _create_new_module(self, target): transformers.pytorch_utils.Conv1D, nxd.parallel_layers.ColumnParallelLinear, nxd.parallel_layers.RowParallelLinear, + nxd.modules.qkv_linear.GQAQKVColumnParallelLinear, """ ) return new_module - def merge_lora(self): + def merge_lora(self) -> None: if not self.is_lora_merged: for module in self.module.modules(): if isinstance(module, LoraLayer): module.merge() self.is_lora_merged = True - def unmerge_lora(self): + def unmerge_lora(self) -> None: if self.is_lora_merged: for module in self.module.modules(): if isinstance(module, LoraLayer): @@ -368,14 +374,14 @@ def get_base_model(self) -> torch.nn.Module: """ return self.module - def _restore_module_name(self, key: str): + def _restore_module_name(self, key: str) -> str: key_word = ".base_layer" if key_word in key: return key.replace(key_word, "") else: return key - def _get_lora_adapter_state_dict(self, save_dir: str = None): + def _get_lora_adapter_state_dict(self, save_dir: Optional[str] = None) -> Dict[str, Any]: """ Return the state dict of the LoRA model and the modules specified by modules_to_save. There are three cases: @@ -449,7 +455,7 @@ def _save_config_to_json(self, filename: str = None) -> None: with open(filename, "w") as writer: writer.write(json.dumps(output_dict, indent=2, sort_keys=True)) - def save_config(self, save_dir: str = None): + def save_config(self, save_dir: Optional[str] = None) -> None: if save_dir is None: save_dir = self.lora_config.lora_save_dir @@ -458,7 +464,7 @@ def save_config(self, save_dir: str = None): self._save_config_to_json(config_filename) self.is_config_saved = True - def save_lora(self, save_dir: str = None, adapter_tag: str = None) -> None: + def save_lora(self, save_dir: Optional[str] = None, adapter_tag: Optional[str] = None) -> None: r"""for single-device LoRA saving only.""" if not model_parallel_is_initialized(): self._save_single_device_lora(save_dir, adapter_tag) @@ -467,8 +473,8 @@ def save_lora(self, save_dir: str = None, adapter_tag: str = None) -> None: def _save_single_device_lora( self, - save_dir: str = None, - adapter_tag: str = None, + save_dir: Optional[str] = None, + adapter_tag: Optional[str] = None, ) -> None: r""" Only the master device saves the checkpoint. @@ -515,7 +521,7 @@ def load_checkpoint(self, lora_config: LoraConfig) -> None: self.lora_ckpt = ckpt else: config_filename = os.path.join(save_dir, CONFIG_NAME) - logger.info(f"LoRA configuration is not save in the checkpoint. Try to load it from {config_filename}") + logger.info("LoRA configuration is not save in the checkpoint. Try to load it from %s", config_filename) if not os.path.isfile(config_filename): raise FileNotFoundError(f"Please name the file for LoRA confiugration as {CONFIG_NAME}.") self.lora_config = self._load_config_from_json(lora_config, config_filename) @@ -539,7 +545,7 @@ def _from_json_file(path_json_file: str): lora_config_dict[key] = loaded_attributes[key] return LoraConfig(**lora_config_dict) - def _load_config_from_ckpt(self, lora_config: LoraConfig, ckpt) -> LoraConfig: + def _load_config_from_ckpt(self, lora_config: LoraConfig, ckpt: Dict[str, Any]) -> LoraConfig: config = ckpt.get("lora_config", None) if config is None: logger.warn("No LoRA configuration is found in checkpoint.") @@ -551,7 +557,7 @@ def _load_config_from_ckpt(self, lora_config: LoraConfig, ckpt) -> LoraConfig: return LoraConfig(**lora_config_dict) def load_lora( - self, save_dir: str = None, adapter_tag: str = None, ckpt_path: str = None, adapter_only: bool = True + self, save_dir: Optional[str] = None, adapter_tag: Optional[str] = None, ckpt_path: Optional[str] = None, adapter_only: bool = True ) -> None: r""" for single-device LoRA load only. @@ -563,7 +569,7 @@ def load_lora( raise RuntimeError("Please use nxd.load_checkpoint() to load LoRA adapter when the base model is NxDModel.") def _load_single_device_lora( - self, save_dir: str = None, adapter_tag: str = None, ckpt_path: str = None, adapter_only: bool = True + self, save_dir: Optional[str] = None, adapter_tag: Optional[str] = None, ckpt_path: Optional[str] = None, adapter_only: bool = True ): if not adapter_only: if ckpt_path is None: @@ -596,7 +602,7 @@ def load_lora_adapter(self): self.print_model_info() return load_result - def update_state_dict_keys(self, state_dict): + def update_state_dict_keys(self, state_dict: Dict[str, Any]) -> Dict[str, Any]: modules_keys = self.module_state_dict().keys() key_word = ".base_layer" @@ -614,13 +620,13 @@ def named_parameters(self, *args, **kwargs): for n, p in self.module.named_parameters(*args, **kwargs): yield n, p - def module_state_dict(self): + def module_state_dict(self) -> Dict[str, Any]: return self.module.state_dict() def state_dict(self, *args, **kwargs): return self._get_lora_adapter_state_dict() - def load_state_dict(self, state_dict=None, strict=True): + def load_state_dict(self, state_dict: Mapping[str, Any] = None, strict: bool = True, assign: bool = False): r""" There are two steps to load state dict for LoRA model. Step 1: load the state dict for the base model @@ -634,8 +640,7 @@ def load_state_dict(self, state_dict=None, strict=True): load_result = self.module.load_state_dict(state_dict, strict=False) self.is_base_model_loaded = True - if lora_config.load_lora_from_ckpt: - load_result = self.load_lora_adapter() + load_result = self.load_lora_adapter() if lora_config.load_lora_from_ckpt else None return load_result @@ -651,14 +656,14 @@ def dtype(self): def config(self): return self.original_module().lora_config - def __getattr__(self, name: str): + def __getattr__(self, name: str) -> Any: """Forward missing attributes to the wrapped module.""" try: return super().__getattr__(name) # defer to nn.Module's logic except AttributeError: return getattr(self.module, name) - def get_nb_trainable_parameters(self) -> tuple[int, int]: + def get_nb_trainable_parameters(self) -> Tuple[int, int]: r""" Returns the number of trainable parameters and the number of all parameters in the model. """ @@ -684,6 +689,6 @@ def print_trainable_parameters(self) -> None: def print_model_info(self) -> None: if self.is_verbose_enabled: - logger.info(f"LoRA model: {self.module}") - logger.info(f"LoRA configuration: {self.lora_config}") + logger.info("LoRA model: %s", self.module) + logger.info("LoRA configuration: %s", self.lora_config) self.print_trainable_parameters() diff --git a/src/neuronx_distributed/modules/lora/tp_layer.py b/src/neuronx_distributed/modules/lora/tp_layer.py index 95d3f24..0843d75 100644 --- a/src/neuronx_distributed/modules/lora/tp_layer.py +++ b/src/neuronx_distributed/modules/lora/tp_layer.py @@ -1,11 +1,19 @@ import torch import torch.nn as nn +import warnings +import math from torch.nn.init import xavier_normal_ from neuronx_distributed.parallel_layers import ColumnParallelLinear, RowParallelLinear +from neuronx_distributed.modules.qkv_linear import ( + GQAQKVColumnParallelLinear, + gqa_qkv_linear_with_async_allreduce, + gather_from_tensor_model_parallel_region +) from .config import LoraConfig -from .layer import LoraLinear +from .layer import LoraLinear, LoraLayer +from typing import Any, Optional, Tuple class LoraParallelLinear(LoraLinear): @@ -20,11 +28,8 @@ def __init__(self, base_layer: nn.Module, lora_config: LoraConfig) -> None: super().__init__(base_layer, lora_config) def update_layer(self, lora_config, init_method=xavier_normal_): - input_is_parallel = ( - self.base_layer.input_is_parallel if isinstance(self.base_layer, RowParallelLinear) else True - ) - gather_output = self.base_layer.gather_output if isinstance(self.base_layer, ColumnParallelLinear) else False - + base_layer = self.get_base_layer() + sequence_parallel_enabled = base_layer.sequence_parallel_enabled lora_dropout = lora_config.lora_dropout self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0.0 else nn.Identity() @@ -33,9 +38,10 @@ def update_layer(self, lora_config, init_method=xavier_normal_): input_size=self.in_features, output_size=self.lora_rank, bias=False, - input_is_parallel=input_is_parallel, + input_is_parallel=base_layer.input_is_parallel, skip_bias_add=True, init_method=init_method, + sequence_parallel_enabled = sequence_parallel_enabled, ) self.lora_B = nn.Linear( in_features=self.lora_rank, out_features=self.out_features, bias=False, dtype=torch.float32 @@ -48,20 +54,154 @@ def update_layer(self, lora_config, init_method=xavier_normal_): input_size=self.lora_rank, output_size=self.out_features, bias=False, - gather_output=gather_output, + gather_output=base_layer.gather_output, init_method=init_method, + sequence_parallel_enabled = sequence_parallel_enabled, ) - if lora_config.use_rslora: - self.scaling = self.lora_alpha / (self.lora_rank**0.5) - else: - self.scaling = self.lora_alpha / self.lora_rank + self.init_lora_parameters(lora_config.init_lora_weights) + + + +class LoraGQAQKVParallelLinear(LoraLayer): + r""" + When the target layer parallel_linear is GQAQKVColumnParallelLinear, in order to keep the input and output shapes + consistent, we perform column segmentation on lora_B, while lora_A is still a complete linear layer. + """ + def __init__(self, base_layer: nn.Module, lora_config: LoraConfig) -> None: + super().__init__(base_layer, lora_config) + self.update_layer(lora_config) + + def update_layer(self, lora_config, init_method=xavier_normal_): + base_layer = self.get_base_layer() + lora_dropout = lora_config.lora_dropout + self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0.0 else nn.Identity() + self.sequence_parallel_enabled = base_layer.sequence_parallel_enabled + self.kv_size_multiplier = base_layer.kv_size_multiplier + self.gather_output = base_layer.gather_output + self.lora_A = nn.Linear( + in_features=self.in_features, out_features=self.lora_rank, bias=False, dtype=torch.float32 + ) + self.lora_B = GQAQKVColumnParallelLinear( + input_size=self.lora_rank, + output_sizes=self.out_features, + bias=False, + gather_output=self.gather_output, + dtype=torch.float32, + init_method=init_method, + kv_size_multiplier=self.kv_size_multiplier, + sequence_parallel_enabled = self.sequence_parallel_enabled, + ) self.init_lora_parameters(lora_config.init_lora_weights) - weight = getattr(self.get_base_layer(), "weight", None) - if weight is not None: - # the layer is already completely initialized, this is an update - if weight.dtype.is_floating_point or weight.dtype.is_complex: - self.to(weight.device, dtype=weight.dtype) - else: - self.to(weight.device) + + + def init_lora_parameters(self, init_lora_weights): + init_lora_weights = init_lora_weights.lower() + assert init_lora_weights in ["default", "gaussian"] + + if init_lora_weights == "default": + # initialize A the same way as the default for nn.Linear and B to zero + # https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124 + nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) + elif init_lora_weights == "gaussian": + nn.init.normal_(self.lora_A.weight, std=1 / self.lora_rank) + else: + raise ValueError(f"Unknown LoRA parameters initialization with {init_lora_weights}") + + q, k, v = self.get_qkv(self.lora_B) + nn.init.zeros_(q.data) + nn.init.zeros_(k.data) + nn.init.zeros_(v.data) + + + def merge(self) -> None: + """ + Merge the adapter weights into the base weights + """ + weight_q, weight_k, weight_v = self.get_qkv(self.base_layer) + delta_weight_q, delta_weight_k, delta_weight_v = self.get_delta_weight() + + weight_q.data += delta_weight_q + weight_k.data += delta_weight_k + weight_v.data += delta_weight_v + self.merged = True + + + def unmerge(self) -> None: + """ + This method unmerges merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + q, k, v = self.get_qkv(self.base_layer) + delta_weight_q, delta_weight_k, delta_weight_v = self.get_delta_weight() + + q.data -= delta_weight_q + k.data -= delta_weight_k + v.data -= delta_weight_v + self.merged = False + + + def get_qkv(self, layer): + return layer.weight_q, layer.weight_k, layer.weight_v + + + def get_delta_weight(self) -> torch.Tensor: + weight_A = self.lora_A.weight + q_lora_B, k_lora_B, v_lora_B = self.get_qkv(self.lora_B) + + output_q = (q_lora_B @ weight_A) * self.scaling + output_k = (k_lora_B @ weight_A) * self.scaling + output_v = (v_lora_B @ weight_A) * self.scaling + + return output_q, output_k, output_v + + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + previous_dtype = x.dtype + if self.merged: + output_q, output_k, output_v = self.base_layer(x, *args, **kwargs) + else: + output_q, output_k, output_v = self.base_layer(x, *args, **kwargs) + lora_A = self.lora_A + dropout = self.lora_dropout + scaling = self.scaling + x = x.to(lora_A.weight.dtype) + + q_lora_B, k_lora_B, v_lora_B = self.get_qkv(self.lora_B) + dropout_input = lora_A(dropout(x)) + lora_q_output, lora_k_output, lora_v_output = self._lora_forward(dropout_input, q_lora_B, k_lora_B, v_lora_B) + + output_q += lora_q_output * scaling + output_k += lora_k_output * scaling + output_v += lora_v_output * scaling + + return output_q.to(previous_dtype), output_k.to(previous_dtype), output_v.to(previous_dtype) + + + def _lora_forward(self, input, weight_q, weight_k, weight_v): + # Matrix multiply. + output_parallel_q, output_parallel_k, output_parallel_v = gqa_qkv_linear_with_async_allreduce( + input=input, + weight_q=weight_q, + weight_k=weight_k, + weight_v=weight_v, + bias_q=None, + bias_k=None, + bias_v=None, + async_grad_allreduce=not self.sequence_parallel_enabled, + sequence_parallel_enabled=self.sequence_parallel_enabled, + kv_size_multiplier=self.kv_size_multiplier, + ) + if self.gather_output: + # All-gather across the partitions. + assert not self.sequence_parallel_enabled + output_q = gather_from_tensor_model_parallel_region(output_parallel_q) + output_k = gather_from_tensor_model_parallel_region(output_parallel_k) + output_v = gather_from_tensor_model_parallel_region(output_parallel_v) + else: + output_q, output_k, output_v = output_parallel_q, output_parallel_k, output_parallel_v + return output_q, output_k, output_v diff --git a/src/neuronx_distributed/modules/moe/__init__.py b/src/neuronx_distributed/modules/moe/__init__.py index f7c7d7a..8a3d369 100644 --- a/src/neuronx_distributed/modules/moe/__init__.py +++ b/src/neuronx_distributed/modules/moe/__init__.py @@ -1,5 +1,5 @@ -from .expert_mlps import ExpertMLPsCapacityFactor +from .expert_mlps import ExpertMLPs from .loss_function import load_balancing_loss_func from .model import MoE -from .model_utils import ACT2FN, MoESequenceParallelMode +from .model_utils import ACT2FN from .routing import RouterSinkhorn, RouterTopK diff --git a/src/neuronx_distributed/modules/moe/expert_mlps.py b/src/neuronx_distributed/modules/moe/expert_mlps.py index aa2dadb..eb64cf9 100644 --- a/src/neuronx_distributed/modules/moe/expert_mlps.py +++ b/src/neuronx_distributed/modules/moe/expert_mlps.py @@ -1,44 +1,40 @@ import math -from abc import ABC, abstractmethod -from typing import Union +from typing import Union, Optional, Callable, Any import torch import torch.nn.functional as F -from neuronx_distributed.modules.moe.model_utils import ACT2FN, MoESequenceParallelMode -from neuronx_distributed.modules.moe.moe_parallel_layers import ( - ExpertFusedColumnParallelLinear, - ExpertFusedRowParallelLinear, -) -from neuronx_distributed.parallel_layers import mappings +from neuronx_distributed.modules.moe.experts import Experts +from neuronx_distributed.modules.moe.model_utils import ACT2FN from neuronx_distributed.utils.tensor_utils import cumsum +from neuronx_distributed.parallel_layers.parallel_state import get_expert_model_parallel_size -class ExpertMLPsBase(torch.nn.Module, ABC): - """Base class for ExpertMLPs, which are used for obtaining the output from passing the token hidden states through the assigned expert(s). - - This class is used to set common initialization parameters, and define the function signature of the forward pass of child classes. +class ExpertMLPs(torch.nn.Module): + """Class which obtains the output from passing the token hidden states through the assigned expert(s). Arguments: num_experts: Total number of experts. + top_k: Number of experts activated per token. Should be less than or equal to num_experts. hidden_size: Hidden dimension. intermediate_size: Intermediate dimension used in the MLPs. hidden_act: Activation function. See ACT2FN for supported activations. - capacity_factor: Hyperparameter which controls the expert capacity, and determines the rate of token dropping. - init_method: Function used for initializing the gate and up projection linear layer weights. - output_layer_init_method:Function used for initializing the down projection linear layer weights. glu_mlp: Whether to use the Gated Linear Unit in the MLP. If True, then a combination of gate and up projection is performed in the MLP. Otherwise, a simple up projection is performed. - sequence_parallel_mode: SP mode being used for the MoE layer. - permute_strategy: Specifies how to perform the token permute and un-permute. Must be one of 'matmul' or 'index. - top_k: Number of experts activated per token. Should be less than or equal to num_experts. + capacity_factor: Hyperparameter which controls the expert capacity, and determines the rate of token dropping. + If None, then assumed to be running with 'full capacity' (i.e. no tokens dropped). normalize_top_k_affinities: Whether to normalize the affinities of the chosen experts before combining with the MLP outputs. Should be used only with top_k > 1. return_bias: Whether to return the bias in the forward pass. Currently not supported. + init_method: Function used for initializing the gate and up projection linear layer weights. + output_layer_init_method: Function used for initializing the down projection linear layer weights. dtype: Datatype for the layer weights. device: Device for the layer weights. """ + # Used to determine when to use selective loading for token generation. See forward() for more details. + SELECTIVE_LOADING_THRESHOLD = 1.0 + def __init__( self, num_experts: int, @@ -47,9 +43,13 @@ def __init__( intermediate_size: int, hidden_act: str, glu_mlp: bool, - sequence_parallel_mode: Union[str, MoESequenceParallelMode], - dtype: torch.dtype, - device: torch.device, + capacity_factor: Union[None, float], + normalize_top_k_affinities: bool = False, + return_bias: bool = False, + init_method: Optional[Callable[..., Any]] = torch.nn.init.kaiming_uniform_, + output_layer_init_method: Optional[Callable[..., Any]] = torch.nn.init.kaiming_uniform_, + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu"), ): super().__init__() @@ -64,89 +64,10 @@ def __init__( self.act_fn = ACT2FN[hidden_act] self.glu_mlp = glu_mlp - if sequence_parallel_mode not in MoESequenceParallelMode.__members__: - raise TypeError(f"Unknown sequence_parallel_mode: {sequence_parallel_mode}") - self.sequence_parallel_mode = MoESequenceParallelMode[sequence_parallel_mode] - - self.dtype = dtype - self.device = device - - @abstractmethod - def forward(self, hidden_states, expert_affinities, expert_index): - """Forward pass of the ExpertMLPs. - - This function should internally account for whether the hidden_states are in SP or not, and return the output accordingly, - i.e. the output should be in SP iff the hidden_states are in SP. - - Common nomenclature: - S: Sequence length, B: Batch size, H: Hidden Size - S': Sequence length (when the input is in SP) - T: Tokens = S * B (token dimension obtained by flattening S and B) - T': Tokens (when the input is in SP) = S' * B - - Arguments: - hidden_states: Tensor of shape (S, B, H) or (S', B, H). - expert_affinities: Tensor of shape (T, E), containing the normalized affinities of each token for each expert. - expert_index: Tensor of shape (T, top_k), containing the 'chosen' experts for each token. - - Returns: - output: Output tensor of the same shape as hidden_states, obtained by passing each token through its assigned experts, - combined with the corresponding expert affinities. - """ - - -class ExpertMLPsCapacityFactor(ExpertMLPsBase): - """ExpertMLPs where each expert has a fixed 'expert capacity', i.e. maximum number of tokens that it can process. - This is necessary for maintaining static shapes in the compilation graph, but may lead to dropped tokens in the computation. - - Arguments: - capacity_factor: Hyperparameter which controls the expert capacity, and determines the rate of token dropping. - init_method: Function used for initializing the gate and up projection linear layer weights. - output_layer_init_method:Function used for initializing the down projection linear layer weights. - permute_strategy: Specifies how to perform the token permute and un-permute. Must be one of 'matmul' or 'index. - normalize_top_k_affinities: Whether to normalize the affinities of the chosen experts before combining with the MLP outputs. - Should be used only with top_k > 1. - return_bias: Whether to return the bias in the forward pass. Currently not supported. - """ - - def __init__( - self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - capacity_factor: float, - glu_mlp: bool, - sequence_parallel_mode: Union[str, MoESequenceParallelMode], - permute_strategy: str, - normalize_top_k_affinities: bool = False, - return_bias: bool = False, - init_method: torch.nn.init = torch.nn.init.kaiming_uniform_, - output_layer_init_method: torch.nn.init = torch.nn.init.kaiming_uniform_, - dtype: torch.dtype = torch.float32, - device: torch.device = torch.device("cpu"), - ): - super().__init__( - num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - hidden_act=hidden_act, - glu_mlp=glu_mlp, - sequence_parallel_mode=sequence_parallel_mode, - dtype=dtype, - device=device, - ) - + if capacity_factor is None or capacity_factor >= num_experts / top_k: + capacity_factor = None # Denotes full capacity self.capacity_factor = capacity_factor - if permute_strategy not in {"matmul", "index"}: - raise ValueError(f"Unknown permute_strategy: {permute_strategy}") - if self.sequence_parallel_mode == MoESequenceParallelMode.OPTIMIZED_SP_MATMUL and permute_strategy != "matmul": - raise ValueError("SP mode OPTIMIZED_SP_MATMUL can only be used with the 'matmul' permute strategy") - self.permute_strategy = permute_strategy - if normalize_top_k_affinities and top_k == 1: raise ValueError("top_k must be greater than 1 for normalizing top-k expert affinities") self.normalize_top_k_affinities = normalize_top_k_affinities @@ -155,93 +76,30 @@ def __init__( raise NotImplementedError("bias is currently unsupported for MoE") self.return_bias = return_bias - # Define the layers for expert MLP operations - if self.glu_mlp: - # Combine the gate and up projections into a single large tensor multiplication for efficiency - self.gate_up_proj = ExpertFusedColumnParallelLinear( - num_experts=num_experts, - input_size=hidden_size, - output_size=2 * intermediate_size, - gather_output=False, - dtype=dtype, - device=device, - stride=2, - init_method=init_method, - ) - else: - self.up_proj = ExpertFusedColumnParallelLinear( - num_experts=num_experts, - input_size=hidden_size, - output_size=intermediate_size, - gather_output=False, - dtype=dtype, - device=device, - init_method=init_method, - ) - - self.down_proj = ExpertFusedRowParallelLinear( + self.mlp_op = Experts( num_experts=num_experts, - input_size=intermediate_size, - output_size=hidden_size, - input_is_parallel=True, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + glu=glu_mlp, + activation_fn=self.act_fn, dtype=dtype, device=device, init_method=output_layer_init_method, - sequence_parallel_mode=self.sequence_parallel_mode, ) - def mlp_op(self, expert_aligned_hidden_states): - """Helper function which performs the expert MLP computations for the given hidden states. - - Common nomenclature: - E: Total number of experts - C: Expert capacity - H: Hidden Size - - Arguments: - expert_aligned_hidden_states: Input tensor of shape (E, C, H) containing the token hidden states for each expert. - - Returns: - expert_aligned_output: Output tensor of shape (E, C, H) obtained after the gate/up projection + activation + down - projection operations. - """ - - if self.glu_mlp: - # gate_up_proj_op: (E, C, H) @ (E, H, 2I) -> (E, C, 2I) - gate_up_proj_op = self.gate_up_proj(expert_aligned_hidden_states) - # Split into gate_proj and up_proj, both (E, C, I) - gate_proj_op, up_proj_op = torch.tensor_split(gate_up_proj_op, 2, dim=2) - # intermediate_op: (E, C, I) - intermediate_op = self.act_fn(gate_proj_op) * up_proj_op - else: - # up_proj_op: (E, C, H) @ (E, H, I) -> (E, C, I) - up_proj_op = self.up_proj(expert_aligned_hidden_states) - # intermediate_op: (E, C, I) - intermediate_op = self.act_fn(up_proj_op) - - # down projection: (E, C, I) @ (E, I, H) -> (E, C, H) - expert_aligned_output = self.down_proj(intermediate_op) - return expert_aligned_output + self.dtype = dtype + self.device = device - def compute_position_in_expert(self, expert_index, total_tokens): - """Helper function used for computing the expert capacity, expert mask and position in expert, - corresponding to the input expert_index. + def get_expert_mask(self, expert_index): + """Helper function which computes top_k-hot encoded expert_mask from the given expert_index. Arguments: expert_index: Tensor of shape (T, top_k), containing the 'chosen' experts for each token. - total_tokens: Integer specifying the number of input tokens to the forward() function - Returns: - expert_capacity: Integer indicating the capacity of each expert - expert_mask: top_k-hot tensor of shape (T, E), computed using expert_index - position_in_expert: Tensor of shape (T, E), specifying the position of a given token within each chosen expert. + expert_mask: Tensor of shape (T, E), containing top_k-hot encoded experts for each token derived from + expert_index. """ - # compute expert capacity C = (total_tokens * top_k * Cf) / E - expert_capacity = math.ceil(total_tokens * self.top_k * self.capacity_factor / self.num_experts) - # expert_capacity can be upper bounded by total number of tokens, for the case when every token is routed to an expert - expert_capacity = min(expert_capacity, total_tokens) - # Construct expert_mask from expert_index, using efficient version of one-hot encoding for xla device # Perform operation in float64 to prevent precision issues due to auto-downcasting to bf16 # (Use float dtype to perform computations in the vector engine for efficiency) @@ -253,179 +111,90 @@ def compute_position_in_expert(self, expert_index, total_tokens): for e in range(self.top_k): expert_mask += (expert_index[:, e].unsqueeze(1) == expert_num_idx_arr).to(torch.float64) - # Compute the position of each token in experts, by a cumulative sum over the T dimension - # position in expert: (T, E) - position_in_expert = cumsum(expert_mask) + return expert_mask - # Update expert_mask by accounting for capacity factor (i.e. tokens exceeding capacity are dropped) - expert_mask.masked_fill_(torch.gt(position_in_expert, expert_capacity), 0) + def get_expert_affinities_masked(self, expert_affinities, expert_mask): + """Helper function which computes the masked expert_affinities by selecting the chosen experts for each token, + and normalizes the affinities if needed. - # Mask out those positions which exceed capacity - position_in_expert.masked_fill_(torch.eq(expert_mask, 0), 0) + Arguments: + expert_affinities: Tensor of shape (T, E), containing the normalized affinities of each token for each expert. + expert_mask: Tensor of shape (T, E), containing top_k-hot encoded experts for each token derived from + expert_index. + Returns: + expert_affinities_masked: Tensor of shape (T, E) containing the affinities of just the chosen experts for + each token (after normalization if required). + """ - return expert_capacity, expert_mask, position_in_expert + # Apply expert_mask obtain the affinities for the chosen experts + # expert_affinities_masked -> (T, E) + expert_affinities_masked = expert_affinities.masked_fill(torch.eq(expert_mask, 0), 0) + if self.normalize_top_k_affinities: + # Normalize the affinities across the chosen experts + expert_affinities_masked = F.normalize(expert_affinities_masked, p=1.0, dim=1) - def forward(self, hidden_states, expert_affinities, expert_index): - """Lightweight wrapper which directs the computation to the forward function for the required permute_strategy. + return expert_affinities_masked - Common nomenclature: - S: Sequence length, B: Batch size, H: Hidden Size - S': Sequence length (when the input is in SP) - T: Tokens = S * B (token dimension obtained by flattening S and B) - E: Total number of experts - C: Expert capacity + def forward_all_experts(self, hidden_states, expert_affinities, expert_index): + """Forward pass where all tokens are computed by all experts. + This is equivalent to running forward_capacity_factor with full capacity (i.e. no token dropping), but + by avoiding the permute/unpermute overhead. """ - # hidden_states: (S, B, H) in training, (B, S, H) in inference - # expert_affinities: (T, E) - # expert_index: (T, top_k) - - # In token generation mode if running inference with seq_len = 1 - is_token_gen = (not self.training) and (hidden_states.shape[1] == 1) - - if is_token_gen or self.capacity_factor >= self.num_experts / self.top_k: - # Token generation or Training/Context encoding with full capacity (no tokens dropped) - return self.forward_full_capacity(hidden_states, expert_affinities, expert_index) - elif self.permute_strategy == "matmul": - return self.forward_permute_matmul(hidden_states, expert_affinities, expert_index) - else: - return self.forward_permute_index(hidden_states, expert_affinities, expert_index) + if get_expert_model_parallel_size() > 1: + raise NotImplementedError("Expert parallelism is not supported without capacity factor.") - def forward_full_capacity(self, hidden_states, expert_affinities, expert_index): - """Forward pass where all tokens are computed by all experts. - This is equivalent to running 'matmul' or 'index' with full capacity (i.e. no token dropping), but - by avoiding the permute/unpermute overhead. - """ - - hidden_shape = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_shape[-1]) # (T, H) + # expert_mask: (T, E) + expert_mask = self.get_expert_mask(expert_index) + # expert_affinities_masked: (T, E) + expert_affinities_masked = self.get_expert_affinities_masked(expert_affinities, expert_mask) # Pass all tokens through all experts - # gate_up_proj: 1TH @ EHI -> ETI - # down_proj: ETI @ EIH -> ETH + # gate_up_proj: (1, T, H) @ (E, H, I) -> (E, T, I) + # down_proj: (E, T, I) @ (E, I, H) -> (E, T, H) mlp_output = self.mlp_op(hidden_states.unsqueeze(0)) - # expert_mask: (T, E) (top_k-hot encoded) - expert_mask = torch.zeros( - expert_index.shape[0], self.num_experts, device=expert_index.device, dtype=torch.float64 - ) - expert_num_idx_arr = torch.arange(self.num_experts, device=expert_index.device, dtype=torch.float64) - for e in range(self.top_k): - expert_mask += (expert_index[:, e].unsqueeze(1) == expert_num_idx_arr).to(torch.float64) - - # expert_affinities_masked: (T, E) - expert_affinities_masked = expert_affinities.to(dtype=hidden_states.dtype).masked_fill( - torch.eq(expert_mask, 0), 0 - ) - if self.normalize_top_k_affinities: - # Normalize the affinities across the chosen experts - expert_affinities_masked = F.normalize(expert_affinities_masked, p=1.0, dim=1) - + # TODO: Modify to use multiplication + torch.sum instead # Scale by output affinity - output = torch.zeros(hidden_states.shape[0], hidden_states.shape[1], device=hidden_states.device, dtype=hidden_states.dtype) + output = torch.zeros( + hidden_states.shape[0], hidden_states.shape[1], device=hidden_states.device, dtype=hidden_states.dtype + ) for e in range(self.num_experts): # TH * T1 -> TH output += mlp_output[e] * expert_affinities_masked[:, e].unsqueeze(1) - # Reshape output to original hidden_shape - output = output.view(hidden_shape) return output - def forward_permute_matmul(self, hidden_states, expert_affinities, expert_index): - """Forward pass of the 'matmul' permute strategy, which uses matrix-multiplication to permute and un-permute the tokens.""" - - hidden_shape = hidden_states.shape - seq_len = hidden_shape[0] - hidden_states = hidden_states.view(-1, hidden_shape[-1]) # (T, H) or (T', H) - - # Due to different SP setup between training and inference, we have to implement SP for them differently here. - # For inference, we only perform sequence parallelism when it is context encoding, because we partition on - # the sequence dimension which can't be partitioned during token generation. - is_context_encoding = not self.training and seq_len > 1 - - if self.sequence_parallel_mode == MoESequenceParallelMode.OPTIMIZED_SP_MATMUL: - if self.training or is_context_encoding: - # hidden_states: (T', H) - total_tokens = hidden_states.shape[0] * mappings.get_tensor_model_parallel_size() - else: - # hidden_states: (T, H) - total_tokens = hidden_states.shape[0] - else: - # hidden_states: (T, H) - total_tokens = hidden_states.shape[0] - - assert total_tokens == expert_affinities.shape[0] - assert total_tokens == expert_index.shape[0] + def forward_capacity_factor(self, hidden_states, expert_affinities, expert_index): + """Forward pass for performing Expert MLP computations, where each expert has a fixed 'expert capacity', + i.e. maximum number of tokens that it can process. This is necessary for maintaining static shapes in the + compilation graph, but may lead to dropped tokens in the computation. - # Compute expert_capacity, expert_mask and position_in_expert - expert_capacity, expert_mask, position_in_expert = self.compute_position_in_expert(expert_index, total_tokens) - - if self.sequence_parallel_mode == MoESequenceParallelMode.OPTIMIZED_SP_MATMUL: - if self.training or is_context_encoding: - # Obtain the position_in_expert corresponding to just the tokens at this rank - # position_in_expert: (T, E) -> (T', E) - position_in_expert = mappings._split_along_first_dim(position_in_expert) - - # position_mask: one-hot encode position_in_expert (T, E) into (T, E, C) - # Perform operation in float64 to prevent precision issues due to auto-downcasting to bf16 - # (Use float dtype to perform computations in the vector engine for efficiency) - expert_capacity_idx_arr = torch.arange(expert_capacity + 1, device=hidden_states.device, dtype=torch.float64) - position_mask = (position_in_expert.unsqueeze(2) == expert_capacity_idx_arr).to(hidden_states.dtype) - # Account for 1-indexing of position_in_expert - position_mask = position_mask[:, :, 1:] - - # expert_aligned_hidden_states: (T, E, C) @ (T, H) -> (E, C, H) or (T', E, C) @ (T', H) -> (E, C, H) - expert_aligned_hidden_states = torch.einsum("tec,th->ech", position_mask, hidden_states) - - if self.sequence_parallel_mode == MoESequenceParallelMode.OPTIMIZED_SP_MATMUL: - if self.training or is_context_encoding: - # All-reduce across ranks, since expert_aligned_hidden_states was computed in SP (i.e. using the T' tokens at each rank) - expert_aligned_hidden_states = mappings.reduce_from_tensor_model_parallel_region( - expert_aligned_hidden_states - ) - - # Perform MLP operations - # expert_aligned_output: (E, C, H) - expert_aligned_output = self.mlp_op(expert_aligned_hidden_states) - - # Apply expert_mask obtain the affinities for the chosen experts - # expert_affinities_masked -> (T, E) - expert_affinities_masked = expert_affinities.to(dtype=hidden_states.dtype).masked_fill( - torch.eq(expert_mask, 0), 0 - ) - if self.normalize_top_k_affinities: - # Normalize the affinities across the chosen experts - expert_affinities_masked = F.normalize(expert_affinities_masked, p=1.0, dim=1) - - if self.sequence_parallel_mode == MoESequenceParallelMode.OPTIMIZED_SP_MATMUL: - if self.training or is_context_encoding: - # Obtain the expert affinities corresponding to just the tokens at this rank - # expert_affinities_masked: (T, E) -> (T', E) - expert_affinities_masked = mappings.scatter_to_sequence_parallel_region(expert_affinities_masked) - # Since the einsum operation below (with position_mask_with_affinities) is computed in SP (i.e. using T'), - # we need an all-reduce in the backward pass to obtain the complete gradients for expert_aligned_output. - expert_aligned_output = mappings.copy_to_tensor_model_parallel_region(expert_aligned_output) + Expert capacity C is defined as: + C = min(total_tokens, (total_tokens * top_k * capacity_factor) / num_experts) + Note that when capacity_factor >= num_experts / top_k, C = total_tokens (i.e. each expert can hold all + input tokens, and therefore no tokens are dropped). + """ - # position_mask_with_affinities: (T, E, C) * (T, E, 1) -> (T, E, C) or (T', E, C) * (T', E, 1) -> (T', E, C) - position_mask_with_affinities = position_mask * expert_affinities_masked.unsqueeze(-1) + total_tokens = hidden_states.shape[0] - # Unpermute and scale output with expert affinities - # output: (T, E, C) @ (E, C, H) -> (T, H) or (T', E, C) @ (E, C, H) -> (T', H) - output = torch.einsum("tec,ech->th", position_mask_with_affinities, expert_aligned_output) + # compute expert capacity C = (total_tokens * top_k * Cf) / E + expert_capacity = math.ceil(total_tokens * self.top_k * self.capacity_factor / self.num_experts) + # expert_capacity can be upper bounded by total number of tokens, for the case when every token is routed to an expert + expert_capacity = min(expert_capacity, total_tokens) - # Reshape output to original hidden_shape - output = output.view(hidden_shape) - return output + # expert_mask: (T, E) + expert_mask = self.get_expert_mask(expert_index) - def forward_permute_index(self, hidden_states, expert_affinities, expert_index): - """Forward pass of the 'index' permute strategy, which uses indexing-based lookups to permute and un-permute the tokens.""" + # Compute the position of each token in experts, by a cumulative sum over the T dimension + # position in expert: (T, E) + position_in_expert = cumsum(expert_mask) - hidden_shape = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_shape[-1]) # (T, H) - total_tokens = hidden_states.shape[0] + # Update expert_mask by accounting for capacity factor (i.e. tokens exceeding capacity are dropped) + expert_mask.masked_fill_(torch.gt(position_in_expert, expert_capacity), 0) - # Compute expert_capacity, expert_mask and position_in_expert - expert_capacity, expert_mask, position_in_expert = self.compute_position_in_expert(expert_index, total_tokens) + # expert_affinities_masked: (T, E) + expert_affinities_masked = self.get_expert_affinities_masked(expert_affinities, expert_mask) # Add expert offset to the position_in_expert # Perform operation in float64 to prevent precision issues due to auto-downcasting to bf16 @@ -435,6 +204,7 @@ def forward_permute_index(self, hidden_states, expert_affinities, expert_index): ) # position_in_expert_with_offset: (T, E) position_in_expert_with_offset = position_in_expert + expert_index_offsets + # Mask out those positions which exceed capacity position_in_expert_with_offset.masked_fill_(torch.eq(expert_mask, 0), 0) # Determine the index (with offset) of each token @@ -476,19 +246,13 @@ def forward_permute_index(self, hidden_states, expert_affinities, expert_index): expert_aligned_output = self.mlp_op(expert_aligned_hidden_states) # convert back (E, C, H) into (C*E, H) - permuted_output = expert_aligned_output.view(expert_capacity * self.num_experts, hidden_shape[2]) - - # Apply expert_mask obtain the affinities for the chosen experts - # expert_affinities_masked -> (T, E) - expert_affinities_masked = expert_affinities.to(dtype=hidden_states.dtype).masked_fill( - torch.eq(expert_mask, 0), 0 - ) - if self.normalize_top_k_affinities: - # Normalize the affinities across the chosen experts - expert_affinities_masked = F.normalize(expert_affinities_masked, p=1.0, dim=1) + permuted_output = expert_aligned_output.view(expert_capacity * self.num_experts, -1) + # TODO: Modify to use multiplication + torch.sum instead # output: (T, H) - output = torch.zeros(total_tokens, hidden_states.shape[1], device=hidden_states.device, dtype=hidden_states.dtype) + output = torch.zeros( + total_tokens, hidden_states.shape[1], device=hidden_states.device, dtype=hidden_states.dtype + ) for k in range(self.top_k): # Unpermute output from the kth chosen expert for each token using token_permutation_idx output_k = permuted_output[token_permutation_idx[:, k]] @@ -497,6 +261,96 @@ def forward_permute_index(self, hidden_states, expert_affinities, expert_index): # (T, H) * (T, 1) output += output_k * expert_affinities_k - # Reshape output to original hidden_shape - output = output.view(hidden_shape) return output + + def forward_selective_loading(self, hidden_states, expert_affinities, expert_index): + """Forward pass which selectively loads only the experts chosen for each input token, during token generation.""" + + # hidden_states: (T, H) + # expert_affinities: (T, E) + # expert_index: (T, top_k) + + T = hidden_states.shape[0] + + # chosen_expert_affinities: (T, top_k) + chosen_expert_affinities = expert_affinities[ + torch.arange(T, device=hidden_states.device).unsqueeze(1), expert_index + ] + if self.normalize_top_k_affinities: + # Normalize the affinities across the chosen experts + chosen_expert_affinities = F.normalize(chosen_expert_affinities, p=1.0, dim=1) + + output_list = [] + for t in range(T): + # gate_up_proj: (1, 1, H) @ (top_k, H, I) -> (top_k, 1, I) + # down_proj: (top_k, 1, I) @ (top_k, I, H) -> (top_k, 1, H) + mlp_output_t = self.mlp_op(hidden_states[t].unsqueeze(0).unsqueeze(1), expert_indices=expert_index[t]) + # output_t: sum((top_k, H) * (top_k, 1), dim=0) -> H + output_t = torch.sum(mlp_output_t.squeeze(1) * chosen_expert_affinities[t].unsqueeze(1), dim=0) + output_list.append(output_t) + + # output: (T, H) + output = torch.stack(output_list, dim=0) + + return output + + def forward(self, hidden_states, expert_affinities, expert_index, seq_len): + """Forward pass of the ExpertMLPs. + + For training: + 1. If capacity_factor is None (full capacity), run forward_all_experts(). + 2. Else run forward_capacity_factor(). + + For inference: + 1. If context encoding: + a. If capacity_factor is None (full capacity), run forward_all_experts(). + b. Else run forward_capacity_factor(). + 2. Else (token generation): + Run forward_selective_loading() or forward_all_experts() depending on the following logic. + Let T be the total number of tokens. Using selective loading, T*top_k experts will be loaded. + If (T*top_k/num_experts) is less than SELECTIVE_LOADING_THRESHOLD, then we use selective loading. + Otherwise, we use forward_all_experts (for better performance). + + Note on the SELECTIVE_LOADING_THRESHOLD: + This parameter determines when forward_selective_loading is used for token-gen (in favor of + forward_all_experts), and should be a float <= 1. + + Common nomenclature: + S: Sequence length, B: Batch size, H: Hidden Size + T: Tokens = S * B (token dimension obtained by flattening S and B) + + Arguments: + hidden_states: Tensor of shape (T, H). + expert_affinities: Tensor of shape (T, E), containing the normalized affinities of each token for each expert. + expert_index: Tensor of shape (T, top_k), containing the 'chosen' experts for each token. + seq_len: Sequence length S. Used to infer context encoding vs token generation in inference. + + Returns: + output: Output tensor of the same shape as hidden_states, obtained by passing each token through its assigned experts, + combined with the corresponding expert affinities. + """ + + if self.training: + # Training flow + if self.capacity_factor is None: + return self.forward_all_experts(hidden_states, expert_affinities, expert_index) + else: + return self.forward_capacity_factor(hidden_states, expert_affinities, expert_index) + else: + # Inference flow + if seq_len > 1: + # Context encoding + if self.capacity_factor is None: + return self.forward_all_experts(hidden_states, expert_affinities, expert_index) + else: + return self.forward_capacity_factor(hidden_states, expert_affinities, expert_index) + else: + if get_expert_model_parallel_size() > 1: + raise NotImplementedError("Expert parallelism is not supported in token generation.") + + # Token generation + perc_experts_loaded = hidden_states.shape[0] * self.top_k / self.num_experts + if perc_experts_loaded >= self.SELECTIVE_LOADING_THRESHOLD: + return self.forward_all_experts(hidden_states, expert_affinities, expert_index) + else: + return self.forward_selective_loading(hidden_states, expert_affinities, expert_index) diff --git a/src/neuronx_distributed/modules/moe/experts.py b/src/neuronx_distributed/modules/moe/experts.py new file mode 100644 index 0000000..ff88f8d --- /dev/null +++ b/src/neuronx_distributed/modules/moe/experts.py @@ -0,0 +1,163 @@ +from typing import Callable + +import torch +from torch import Tensor +from torch.nn import Module + +from neuronx_distributed.modules.moe.moe_parallel_layers import ( + ExpertFusedColumnParallelLinear, + ExpertFusedRowParallelLinear, +) +from neuronx_distributed.parallel_layers.mappings import ( + enter_expert_parallel_region, + exit_expert_parallel_region, + reduce_from_tensor_model_parallel_region, + reduce_scatter_to_tensor_model_parallel_region_with_dim, +) +from neuronx_distributed.parallel_layers.parallel_state import ( + get_expert_model_parallel_size, + get_tensor_model_parallel_size, +) + + +class Experts(Module): + """ Module which performs the expert MLP computations for the given hidden states. + Expert Parallelism (EP), if enabled, is applied through scatter-gather optimization + across TP ranks. + """ + + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + glu: bool, + activation_fn: Callable[[Tensor], Tensor], + dtype: torch.dtype, + device: torch.device, + input_layer_init_method=None, + output_layer_init_method=None, + ) -> None: + super().__init__() + + self._glu = glu + self._activation_fn = activation_fn + + self.num_experts = num_experts + + if self._glu: + self.gate_up_proj = ExpertFusedColumnParallelLinear( + # we pass the global number of experts. the linear layer will itself + # decide to initialize a subset of them if EP is applied. + num_experts=num_experts, + input_size=hidden_size, + # we fuse up and gate projections to a single matmul. Later on in code + # we'll split the resulting output to yield up and gate matrices. + output_size=intermediate_size * 2, + dtype=dtype, + device=device, + stride=2, + init_method=input_layer_init_method, + ) + else: + self.up_proj = ExpertFusedColumnParallelLinear( + # we pass the global number of experts. the linear layer will itself + # decide to initialize a subset of them if EP is applied. + num_experts=num_experts, + input_size=hidden_size, + output_size=intermediate_size, + dtype=dtype, + device=device, + init_method=input_layer_init_method, + ) + + self.down_proj = ExpertFusedRowParallelLinear( + # we pass the global number of experts. the linear layer will itself + # decide to initialize a subset of them if EP is applied. + num_experts=num_experts, + input_size=intermediate_size, + output_size=hidden_size, + reduce_output=get_expert_model_parallel_size() > 1, + dtype=dtype, + device=device, + init_method=output_layer_init_method, + ) + + def forward(self, hidden_states: Tensor, expert_indices: Tensor = None) -> Tensor: + """ + Common nomenclature: + E: Total number of experts, C: Expert capacity, H: Hidden Size + + If expert_indices is None, then the mlp_op is computed on all E experts. + If specified, then the mlp_op is computed only on the experts specified. + + Let num_experts_computed E' = E if expert_indices is None else expert_indices.shape[0] + + Arguments: + hidden_states: Input tensor containing the token hidden states. + Its shape must be broadcastable with (E', C, H). + expert_indices: (Optional) 1D Tensor containing the indices of experts to perform the mlp_op on. + Returns: + output: Output tensor of shape (E', C, H) obtained after the gate/up projection + + activation + down projection operations. + """ + + if expert_indices is not None and get_expert_model_parallel_size() > 1: + raise ValueError("Selective expert loading is not supported with expert parallelism.") + + # Verify shapes + assert len(hidden_states.shape) == 3 + num_experts_computed = self.num_experts if expert_indices is None else expert_indices.shape[0] + assert hidden_states.shape[0] in {1, num_experts_computed} + + e, c, h = hidden_states.shape + + # Apply scatter-gather optimization in EP only when the number of tokens + # are divisible by TP. Note that this will exclude the token-generation scenario. + # Also in training, performance will be much worse with EP if the local expert + # capacity is not divisible by TP degree. + # num_tokens_divisible_by_tp = c % get_tensor_model_parallel_size() == 0 + + if get_expert_model_parallel_size() > 1: + # (e, c, h) -> (e/ep, ep, c, h) + dispatched_hidden_states = enter_expert_parallel_region( + hidden_states, + # temporarily disable scatter_gather optimization + scatter_gather=False, + #scatter_gather=num_tokens_divisible_by_tp, + ) + else: + dispatched_hidden_states = hidden_states.view(e, 1, c, h) + + if self._glu: + # (e/ep, ep, c, 2i/tp) + intermediate_states = self.gate_up_proj.forward(dispatched_hidden_states, expert_indices=expert_indices) + else: + # (e/ep, ep, c, i/tp) + intermediate_states = self.up_proj.forward(dispatched_hidden_states, expert_indices=expert_indices) + + # (e/ep, ep, c, i/tp) + intermediate_states = self._activation(intermediate_states) + + # (e/ep, ep, c, h) + projected_states = self.down_proj.forward(intermediate_states, expert_indices=expert_indices) + + if get_expert_model_parallel_size() > 1: + # (e/ep, ep, c, h) -> (e, c, h) + output = exit_expert_parallel_region( + projected_states, + # temporarily disable scatter_gather optimization + scatter_gather=False, + #scatter_gather=num_tokens_divisible_by_tp, + ) + else: + output = projected_states.squeeze(1) + + return output + + def _activation(self, x: Tensor) -> Tensor: + if self._glu: + gate, up = torch.chunk(x, chunks=2, dim=-1) + return self._activation_fn(gate) * up + else: + return self._activation_fn(x) diff --git a/src/neuronx_distributed/modules/moe/model.py b/src/neuronx_distributed/modules/moe/model.py index b9c2a59..c597b66 100644 --- a/src/neuronx_distributed/modules/moe/model.py +++ b/src/neuronx_distributed/modules/moe/model.py @@ -1,9 +1,7 @@ import torch -from typing import Union from neuronx_distributed.modules.moe import expert_mlps, routing -from neuronx_distributed.modules.moe.model_utils import MoESequenceParallelMode -from neuronx_distributed.parallel_layers import mappings +from neuronx_distributed.parallel_layers import mappings, parallel_state class MoE(torch.nn.Module): @@ -21,54 +19,46 @@ class MoE(torch.nn.Module): normalize_top_k_affinities = True # Other configurations - capacity_factor = 4.0 # Full capacity with no token dropping, set to num_experts/top_k - sequence_parallel_mode = MoESequenceParallelMode.EXIT_SP_ON_ENTRY - permute_strategy = "matmul" + capacity_factor = None # Full capacity with no token dropping init_method = lambda weight: torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) output_layer_init_method = lambda weight: torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) + sequence_parallel_enabled = True # Initialize router router = routing.RouterTopK( num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, - sequence_parallel_mode=sequence_parallel_mode, ) # Initialize expert_mlps - expert_mlps_cf = expert_mlps.ExpertMLPsCapacityFactor( + expert_mlps = expert_mlps.ExpertMLPs( num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act, + glu_mlp=glu_mlp, capacity_factor=capacity_factor, + normalize_top_k_affinities=normalize_top_k_affinities, init_method=init_method, output_layer_init_method=init_method, - glu_mlp=glu_mlp, - sequence_parallel_mode=sequence_parallel_mode, - permute_strategy=permute_strategy, - normalize_top_k_affinities=normalize_top_k_affinities, ) # Initial moe_layer moe_layer = MoE( router=router, - expert_mlps=expert_mlps_cf, + expert_mlps=expert_mlps, return_router_logits=True, # Required downstream for the load balancing loss function - sequence_parallel_mode=sequence_parallel_mode, + sequence_parallel_enabled=sequence_parallel_enabled, ) ``` - Due to difference between training and inference on SP, SP is implementated differently for them in MoE. - Note that the NO_SP mode is equivalent to the EXIT_SP_ON_ENTRY mode under the inference assumptions. (There - are no additional collectives for EXIT_SP_ON_ENTRY). - Arguments: router: Determines expert routing for input tokens expert_mlps: Obtains the output of the MoE layer by passing tokens through the chosen experts + sequence_parallel_enabled: Whether the model is running in sequence parallel or not return_router_logits: Whether to return the router logits in the forward pass - sequence_parallel_mode: SP mode being used for the MoE layer. """ # Flag used in testing. Should not be used in production. @@ -77,24 +67,21 @@ class MoE(torch.nn.Module): def __init__( self, router: routing.RouterBase, - expert_mlps: expert_mlps.ExpertMLPsBase, - sequence_parallel_mode: Union[str, MoESequenceParallelMode], + expert_mlps: expert_mlps.ExpertMLPs, + sequence_parallel_enabled: bool = False, return_router_logits: bool = False, ): super().__init__() - if sequence_parallel_mode not in MoESequenceParallelMode.__members__: - raise TypeError(f"Unknown sequence_parallel_mode: {sequence_parallel_mode}") - if len({sequence_parallel_mode, router.sequence_parallel_mode, expert_mlps.sequence_parallel_mode}) > 1: - raise ValueError("Inconsistent SP modes across router, expert_mlps and MoE modules") for attr in ["num_experts", "top_k", "hidden_size"]: if getattr(router, attr) != getattr(expert_mlps, attr): raise ValueError("Inconsistent {attr} across the router and expert_mlps") self.router = router self.expert_mlps = expert_mlps - self.sequence_parallel_mode = MoESequenceParallelMode[sequence_parallel_mode] + self.sequence_parallel_enabled = sequence_parallel_enabled self.return_router_logits = return_router_logits + self.ep_enabled = parallel_state.get_expert_model_parallel_size() > 1 def forward(self, hidden_states): """Forward pass of the MoE layer. @@ -105,7 +92,7 @@ def forward(self, hidden_states): T: Tokens = S * B (token dimension obtained by flattening S and B) Arguments: - hidden_states: Input tensor of shape (S, B, H) or (S', B, H) + hidden_states: Input tensor of shape (S, B, H) or (S', B, H) in training, (B, S, H) in inference. Returns: output: Output tensor of the same shape as hidden_states, containing the output of the MoE layer. @@ -116,91 +103,58 @@ def forward(self, hidden_states): Returned if self.is_test is True. """ - # Sequence parallelism is supported for training, but not for inference, so we need different branches for them. - # However, we may still want to run the MoE module in a particular SP mode, which requires the collective - # operations to be adjusted (compared to the base class). - if self.training: - output, router_logits, expert_index = self.forward_for_training(hidden_states) - else: - output, router_logits, expert_index = self.forward_for_inference(hidden_states) + # hidden_states: (S, B, H) or (S', B, H) in training, (B, S, H) in inference - return_op = (output,) - if self.expert_mlps.return_bias: - return_op += (None,) - if self.return_router_logits: - return_op += (router_logits,) - if self.is_test: - return_op += (expert_index,) + if not self.training: + # Sequence parallelism is only supported for training + assert self.sequence_parallel_enabled is False, "SP is not currently supported for inference" - return return_op - - def forward_for_training(self, hidden_states): - if self.sequence_parallel_mode in { - MoESequenceParallelMode.EXIT_SP_ON_ENTRY, - MoESequenceParallelMode.EXIT_SP_ON_ENTRY_DELAY_MLP_AR, - }: + if self.sequence_parallel_enabled: # All-Gather the hidden_states to exit sequence parallel - # hidden_states: (S', B, H) -> (S, B, H) - hidden_states = mappings.gather_from_sequence_parallel_region(hidden_states, to_model_parallel=False) + # full_hidden_states: (S', B, H) -> (S, B, H) + full_hidden_states = mappings.gather_from_sequence_parallel_region(hidden_states, to_model_parallel=False) + else: + full_hidden_states = hidden_states + + # full_hidden_states: (S, B, H) in training, (B, S, H) in inference + full_hidden_states_shape = full_hidden_states.shape + seq_len = full_hidden_states_shape[0] if self.training else full_hidden_states_shape[1] + + # full_hidden_states: (T, H) + full_hidden_states = full_hidden_states.view(-1, full_hidden_states.shape[-1]) # Get the router_logits, expert_affinities and expert_index from the router # router_logits: (T, E), expert_affinities: (T, E), expert_index: (T, top_k) - router_logits, expert_affinities, expert_index = self.router(hidden_states) + router_logits, expert_affinities, expert_index = self.router(full_hidden_states) + + # Get the output from the ExpertMLPs: (T, H) + output = self.expert_mlps( + hidden_states=full_hidden_states, + expert_affinities=expert_affinities, + expert_index=expert_index, + seq_len=seq_len, + ) - # Get the output from the ExpertMLPs: (S, B, H) - output = self.expert_mlps(hidden_states, expert_affinities, expert_index) + # output: (S, B, H) in training, (B, S, H) in inference + output = output.view(full_hidden_states_shape) - if self.sequence_parallel_mode == MoESequenceParallelMode.EXIT_SP_ON_ENTRY: - # Scatter back to sequence parallel (as the hidden_states were in sequence parallel) - # output: (S, B, H) -> (S', B, H) + if self.sequence_parallel_enabled and self.ep_enabled: + # Reduction is done earlier in the case of EP output = mappings.scatter_to_sequence_parallel_region(output) - - if self.sequence_parallel_mode == MoESequenceParallelMode.EXIT_SP_ON_ENTRY_DELAY_MLP_AR: + elif self.sequence_parallel_enabled: # Reduce-scatter back to sequence parallel (as the hidden_states were in sequence parallel) # output: (S, B, H) -> (S', B, H) output = mappings.reduce_scatter_to_sequence_parallel_region(output) - - return output, router_logits, expert_index - - def forward_for_inference(self, hidden_states): - """ - The collective ops for inference differ from training because the rest of the model does not support - sequence parallelism in inference. Moreover, the input in inference is (B, S, H) instead of (S, B, H), - which leads to differences in scatter/gather operations for SP. - The implementation differences are summarized as follows: - 1. EXIT_SP_ON_ENTRY is equivalent to NO_SP because there are no additional collectives (input is not in SP). - 2. The reduce-scatter used in training for EXIT_SP_ON_ENTRY_DELAY_MLP_AR is modified to an all-reduce. - 3. In OPTIMIZED_SP_MATMUL, - a. We run the router on the entire sequence (avoiding the all-gather of router logits). - b. We scatter/gather the sequence dimension to enter/exit SP before/after ExpertMLPs. - c. Note that OPTIMIZED_SP_MATMUL is not an SPMD workload, and is therefore not supported currently for inference. - - We run in SP mode only for context encoding, and not for token generation (since sequence length is 1). - """ - - assert self.sequence_parallel_mode != MoESequenceParallelMode.OPTIMIZED_SP_MATMUL, "OPTIMIZED_SP_MATMUL is unsupported for inference" - - seq_len, _, _ = hidden_states.shape - is_context_encoding = seq_len > 1 - - # Get the router_logits, expert_affinities and expert_index from the router - # router_logits: (T, E), expert_affinities: (T, E), expert_index: (T, top_k) - router_logits, expert_affinities, expert_index = self.router(hidden_states) - - if is_context_encoding and self.sequence_parallel_mode == MoESequenceParallelMode.OPTIMIZED_SP_MATMUL: - # Scatter the sequence dimension to enter SP - # hidden_states: (B, S, H) -> (B, S', H) - hidden_states = mappings.scatter_input_channels_to_tensor_model_parallel_region(hidden_states) - - # Get the output from the ExpertMLPs: (B, S, H) - output = self.expert_mlps(hidden_states, expert_affinities, expert_index) - - if self.sequence_parallel_mode == MoESequenceParallelMode.EXIT_SP_ON_ENTRY_DELAY_MLP_AR: - # Perform delayed all-reduce (required since the MLP is in TP) + elif not self.ep_enabled: + # output: (S, B, H) in training, (B, S, H) in inference output = mappings.reduce_from_tensor_model_parallel_region(output) - if is_context_encoding and self.sequence_parallel_mode == MoESequenceParallelMode.OPTIMIZED_SP_MATMUL: - # output: (B, S', H) -> (B, S', H) - output = mappings.gather_from_tensor_model_parallel_region_second_dim(output) + return_op = (output,) + if self.expert_mlps.return_bias: + return_op += (None,) + if self.return_router_logits: + return_op += (router_logits,) + if self.is_test: + return_op += (expert_index,) - return output, router_logits, expert_index + return return_op diff --git a/src/neuronx_distributed/modules/moe/model_utils.py b/src/neuronx_distributed/modules/moe/model_utils.py index 8e0ec75..df26f17 100644 --- a/src/neuronx_distributed/modules/moe/model_utils.py +++ b/src/neuronx_distributed/modules/moe/model_utils.py @@ -1,4 +1,3 @@ -import enum import torch import torch.nn.functional as F @@ -10,19 +9,3 @@ "silu": F.silu, "tanh": torch.tanh, } - - -class MoESequenceParallelMode(str, enum.Enum): - """Defines the modes of sequence parallelism used by in MoE.""" - - # No sequence parallel - NO_SP = "NO_SP" - - # Exit SP on entry to MoE layer, scatter before exiting - EXIT_SP_ON_ENTRY = "EXIT_SP_ON_ENTRY" - - # Exit SP on entry to MoE layer, don't do the all-reduce in down_proj MLP, reduce-scatter before exiting - EXIT_SP_ON_ENTRY_DELAY_MLP_AR = "EXIT_SP_ON_ENTRY_DELAY_MLP_AR" - - # Use SP optimizations for matmul-based permute/unpermute - OPTIMIZED_SP_MATMUL = "OPTIMIZED_SP_MATMUL" diff --git a/src/neuronx_distributed/modules/moe/moe_parallel_layers.py b/src/neuronx_distributed/modules/moe/moe_parallel_layers.py index fecb098..456d304 100644 --- a/src/neuronx_distributed/modules/moe/moe_parallel_layers.py +++ b/src/neuronx_distributed/modules/moe/moe_parallel_layers.py @@ -1,12 +1,13 @@ import math -from typing import Optional, Tuple +from typing import Optional, Tuple, Callable, Any import torch +import torch.nn as nn import torch.distributed from torch import Tensor -from neuronx_distributed.modules.moe.model_utils import MoESequenceParallelMode from neuronx_distributed.parallel_layers import layers, mappings, parallel_state, utils +from neuronx_distributed.parallel_layers.parallel_state import get_expert_model_parallel_size class ExpertFusedLinearWithAsyncCommunication(torch.autograd.Function): @@ -14,6 +15,23 @@ class ExpertFusedLinearWithAsyncCommunication(torch.autograd.Function): Mixture of Experts. The implementation largely mimics LinearWithAsyncCommunication, but is modified for the 3D weights. + + notation used for shapes: + e: number of experts + h: input dim + i: output dim + + shapes: + * input/input.grad (e, ..., h) + * output/output.grad (e, ..., i) + * weight (e, i, h) + + NOTE: that we have inner dimensions denoted by '...', which can be an arbitrary + number of dimensions. In general, the product of inner dimensions can be + thought of as the number of tokens. + Sometimes with MoE workloads it is convenient to have tokens laid out in multiple + dimensions to facilitate tracking when they are partitioned using multiple + parallelism dimensions. """ @staticmethod @@ -24,47 +42,55 @@ def forward( bias: Optional[Tensor], async_grad_allreduce: bool, sequence_parallel_enabled: bool, - save_for_backwards: bool = True, + save_for_backward: bool = True, ): if bias is not None: raise NotImplementedError("Bias is not currently supported for MoE") if sequence_parallel_enabled: - raise NotImplementedError("Since ExpertMLPs is executed only in TP mode, SP is not implemented") + raise NotImplementedError( + "sequence parallelism (SP) is not currently supported for expert " + "fused linear layers. If SP is in use for the model, then we " + "currently expect SP to be exited before linear layers are applied." + ) + if input.shape[0] != weight.shape[0] and input.shape[0] > 1: + raise RuntimeError( + f"input and weight disagree on number of experts (first dimension). " + f"input_shape={tuple(input.shape)}, weight_shape={tuple(weight.shape)}" + ) ctx.async_grad_allreduce = async_grad_allreduce - ctx.sequence_parallel_enabled = sequence_parallel_enabled ctx.compute_weight_gradient = weight.requires_grad - # E: num_experts, C: expert_capacity, H: input_size, I: intermediate/output_size - # input: (E, C, H), weight: (E, H, I) - - if save_for_backwards: + if save_for_backward: if ctx.compute_weight_gradient: ctx.save_for_backward(input, weight) else: ctx.save_for_backward(weight) - # output: (E, C, I) - output = torch.matmul(input, weight) + # E: num_experts, H: input_size, I: intermediate/output_size + # ... might refer to 1 or more dimensions, including C dimension (expert capacity) + # input: (E, ..., H), weight: (E, H, I) + output = torch.einsum("e...h,ehi->e...i", input, weight) + + # output: (E, ..., I) return output @staticmethod - def backward(ctx, grad_output: Tensor): - # grad_output: (E, C, I) - - # input: (E, C, H), weight: (E, H, I) + def backward(ctx, grad_output: Tensor) -> Tuple[Tensor]: + # grad_output: (E, ..., I) + # input: (E, ..., H), weight: (E, H, I) if ctx.compute_weight_gradient: input, weight = ctx.saved_tensors else: weight = ctx.saved_tensors[0] input = None - # grad_input: (E, C, H) - grad_input = grad_output.matmul(weight.transpose(-1, -2)) + # grad_input: (E, ..., H) + grad_input = torch.einsum("e...i,ehi->e...h", grad_output, weight) if ctx.async_grad_allreduce: # Asynchronous all-reduce - handle = torch.distributed.all_reduce( + torch.distributed.all_reduce( grad_input, group=parallel_state.get_tensor_model_parallel_group(), ) @@ -74,16 +100,45 @@ def backward(ctx, grad_output: Tensor): return grad_input, None, None, None, None, None, None # grad_weight: (E, H, I) - grad_weight = torch.matmul(input.transpose(-1, -2), grad_output) + grad_weight = torch.einsum("e...h,e...i->ehi", input, grad_output) return grad_input, grad_weight, None, None, None, None, None -class ExpertFusedColumnParallelLinear(layers.ColumnParallelLinear): +class ExpertFusedLinear(nn.Module): + def _mark_expert_parallel_weights(self, iterable=None): + """ Register expert parallel parameters """ + + if get_expert_model_parallel_size() > 1: + if iterable is None: + iterable = self.parameters() + + for p in iterable: + p.expert_model_parallel = True + + def _apply(self, fn, *args, **kwargs): + """ Moving parameters from cpu to device creates new parameters. to() method + internally calls the _apply method for all the submodules, which we override + here to make sure ep parameters are marked on device as well """ + + out = super()._apply(fn, *args, **kwargs) + self._mark_expert_parallel_weights() + return out + + def _save_to_state_dict(self, destination, *args, **kwargs): + initial_states = {id(v) for v in destination.values()} + out = super()._save_to_state_dict(destination, *args, **kwargs) + new_states = [v for v in destination.values() if id(v) not in initial_states] + self._mark_expert_parallel_weights(new_states) + return out + + +class ExpertFusedColumnParallelLinear(layers.ColumnParallelLinear, ExpertFusedLinear): """Specialized linear layer for MoE, supporting column parallelism for all experts simultaneously. This class inherits from ColumnParallelLinear, and over-rides certain attributes and functions needed to enable - column-parallel linear layer computation for 3D weights. + column-parallel linear layer computation for 3D weights. The forward pass of the parent class is over-ridden + to to support selective computations on a subset of experts. Bias is not currently supported for MoE. Sequence parallelism is handled independently of MLP computations in MoE, and therefore defaults to False. @@ -96,20 +151,20 @@ def __init__( num_experts: int, input_size: int, output_size: int, - gather_output: bool = True, dtype: torch.dtype = torch.float32, - device: torch.device = None, + device: Optional[torch.device] = None, stride: int = 1, - init_method: torch.nn.init = None, + init_method: Optional[Callable[..., Any]] = None, keep_master_weight: bool = False, - ): + ) -> None: self.num_experts = num_experts + self._n_local_experts = utils.divide(num_experts, parallel_state.get_expert_model_parallel_size()) super().__init__( input_size=input_size, output_size=output_size, bias=False, - gather_output=gather_output, + gather_output=False, dtype=dtype, device=device, stride=stride, @@ -118,30 +173,59 @@ def __init__( keep_master_weight=keep_master_weight, skip_bias_add=False, ) + self._mark_expert_parallel_weights() def set_weight_and_bias_config(self): # Define 3D weight tensor, one linear layer per expert - self.weight_shape = (self.num_experts, self.input_size, self.output_size_per_partition) + self.weight_shape = ( + self._n_local_experts, + self.input_size, + self.output_size_per_partition, + ) # Column parallel partitioning for each expert self.weight_partition_dim = 2 self.bias_shape = None def _init_weight(self, weight): # Initialize the linear layer of each expert separately - assert len(weight.shape) == 3 and weight.shape[0] == self.num_experts + assert len(weight.shape) == 3 for e in range(weight.shape[0]): if self.arg_init_method is None: torch.nn.init.kaiming_uniform_(weight[e], a=math.sqrt(5)) else: self.arg_init_method(weight[e]) + def forward( + self, input_: torch.Tensor, expert_indices: torch.Tensor = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """If expert_indices is provided, then the computations are performed only on the specified experts. + Otherwise, the input is passed through all experts in the layer.""" + + if self.async_tensor_model_parallel_allreduce or self.sequence_parallel_enabled: + input_parallel = input_ + else: + input_parallel = mappings.copy_to_tensor_model_parallel_region(input_) -class ExpertFusedRowParallelLinear(layers.RowParallelLinear): + # Matrix multiply. + weight = self.weight[expert_indices, :, :] if expert_indices is not None else self.weight + output = self._forward_impl( + input=input_parallel, + weight=weight, + bias=None, + async_grad_allreduce=self.async_tensor_model_parallel_allreduce, + sequence_parallel_enabled=self.sequence_parallel_enabled, + autograd_func_class=self.autograd_func_class, + ) + return output + + +class ExpertFusedRowParallelLinear(layers.RowParallelLinear, ExpertFusedLinear): """Specialized linear layer for MoE, supporting row parallelism for all experts simultaneously. This class inherits from RowParallelLinear, and over-rides certain attributes and functions needed to enable row-parallel linear layer computation for 3D weights. The forward pass of the parent class is over-ridden - to avoid the output all-reduce depending on the sequence parallel mode. + to optionally avoid the output all-reduce depending on the sequence parallel mode, and to support selective + computations on a subset of experts. Bias is not currently supported for MoE. Sequence parallelism is handled independently of MLP computations in MoE, and therefore defaults to False. @@ -154,24 +238,24 @@ def __init__( num_experts: int, input_size: int, output_size: int, - sequence_parallel_mode: MoESequenceParallelMode, - input_is_parallel: bool = False, + reduce_output: bool = True, dtype: torch.dtype = torch.float32, - device: torch.device = None, + device: Optional[torch.device] = None, stride: int = 1, - init_method: torch.nn.init = None, + init_method: Optional[Callable[..., Any]] = None, keep_master_weight: bool = False, - ): + ) -> None: self.num_experts = num_experts - if sequence_parallel_mode not in MoESequenceParallelMode: - raise TypeError(f"Unknown sequence_parallel_mode: {sequence_parallel_mode}") - self.sequence_parallel_mode = sequence_parallel_mode + self._n_local_experts = utils.divide(num_experts, parallel_state.get_expert_model_parallel_size()) + + # Whether to all-reduce the output across TP ranks or not + self.reduce_output = reduce_output super().__init__( input_size=input_size, output_size=output_size, bias=False, - input_is_parallel=input_is_parallel, + input_is_parallel=True, dtype=dtype, device=device, stride=stride, @@ -180,89 +264,94 @@ def __init__( keep_master_weight=keep_master_weight, skip_bias_add=False, ) + self._mark_expert_parallel_weights() def set_weight_and_bias_config(self): # Define 3D weight tensor, one linear layer per expert - self.weight_shape = (self.num_experts, self.input_size_per_partition, self.output_size) + self.weight_shape = ( + self._n_local_experts, + self.input_size_per_partition, + self.output_size, + ) # Row parallel partitioning for each expert self.weight_partition_dim = 1 self.bias_shape = None def _init_weight(self, weight): # Initialize the linear layer of each expert separately - assert len(weight.shape) == 3 and weight.shape[0] == self.num_experts + assert len(weight.shape) == 3 for e in range(weight.shape[0]): if self.arg_init_method is None: torch.nn.init.kaiming_uniform_(weight[e], a=math.sqrt(5)) else: self.arg_init_method(weight[e]) - def forward(self, input_: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # Set up backprop all-reduce. - if self.input_is_parallel: - input_parallel = input_ - else: - assert not self.sequence_parallel_enabled - input_parallel = mappings.scatter_to_tensor_model_parallel_region(input_) + def forward( + self, input_: torch.Tensor, expert_indices: torch.Tensor = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """If expert_indices is provided, then the computations are performed only on the specified experts. + Otherwise, the input is passed through all experts in the layer.""" # Matrix multiply. + weight = self.weight[expert_indices, :, :] if expert_indices is not None else self.weight output_parallel = self._forward_impl( - input=input_parallel, - weight=self.weight, + input=input_, + weight=weight, bias=None, async_grad_allreduce=False, sequence_parallel_enabled=False, autograd_func_class=self.autograd_func_class, ) - if self.sequence_parallel_mode == MoESequenceParallelMode.EXIT_SP_ON_ENTRY_DELAY_MLP_AR: - # Avoid the output all-reduce, in favor of a reduce-scatter at the end of the MoE layer instead - output = output_parallel - else: - # All-reduce across all the partitions. + if self.reduce_output: output = mappings.reduce_from_tensor_model_parallel_region(output_parallel) - - return output + return output + else: + # Return without output all-reduce, in favor of an all-reduce or reduce-scatter after the MoE output combine. + return output_parallel -class LinearWithParallelInput(torch.autograd.Function): - """Linear layer execution where the input is potentially parallel. - Implements an all-reduce of weight gradients in the backward pass if necessary. +class LinearWithWeightGradAR(torch.autograd.Function): + """Linear layer which implements an all-reduce of weight gradients in the backward pass. + Used for the MoE router, see LinearRouter for more details. """ @staticmethod def forward(ctx, input, weight, reduce_weight_grad): - # input: (S, B, H), weight: (E, H) - ctx.reduce_weight_grad = reduce_weight_grad + # input: (T, H), weight: (E, H) assert weight.requires_grad + ctx.reduce_weight_grad = reduce_weight_grad ctx.save_for_backward(input, weight) - # output: (S, B, E) + # output: (T, E) output = torch.matmul(input, weight.t()) return output @staticmethod def backward(ctx, grad_output): - # grad_output: (S, B, E) + # grad_output: (T, E) input, weight = ctx.saved_tensors - reduce_weight_grad = ctx.reduce_weight_grad - # grad_input: (S, B, H) + # grad_input: (T, H) grad_input = grad_output.matmul(weight) # grad_weight: (E, H) - grad_weight = torch.einsum("sbe,sbh->eh", grad_output, input) - if reduce_weight_grad and parallel_state.get_tensor_model_parallel_size() > 1: + grad_weight = torch.einsum("te,th->eh", grad_output, input) + if parallel_state.get_tensor_model_parallel_size() > 1 and ctx.reduce_weight_grad: # All-reduce the gradients of the weight torch.distributed.all_reduce(grad_weight, group=parallel_state.get_tensor_model_parallel_group()) return grad_input, grad_weight, None -class InputParallelLinear(torch.nn.Module): - """Linear layer where the input is potentially in parallel. - Used for defining the router in MoE when in certain SP modes. See routing.py for details. +class LinearRouter(torch.nn.Module): + """Specialized torch module for MoE Router, which implements an all-reduce of the weight gradients in the backward pass. + + Reason: + In the forward pass of ExpertMLPs, we avoid the all-reduce in the down projection step (see ExpertFusedRowParallelLinear). + Instead, we perform an all-reduce or reduce-scatter after combining outputs from different experts into the original SBH or BSH hidden_states. + Therefore, the router weight gradients (which flow backwards from the expert affinity scaling) are computed using only the partial output + at each rank. Therefore, we need to perform a small all-reduce collective (size HE) on the router gradients in the backward pass. Arguments: input_size: Dimensionality of the input to the linear layer. output_size: Dimensionality of the output of the linear layer. - reduce_weight_grad: Whether to all-reduce the gradients of the weights in the backward pass. dtype: Datatype for the layer weights. device: Device for the layer weights. """ @@ -271,7 +360,6 @@ def __init__( self, input_size: int, output_size: int, - reduce_weight_grad: bool, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu"), ): @@ -287,14 +375,15 @@ def __init__( if self.weight.device != torch.device("meta"): self.init_weight_cpu() - self.reduce_weight_grad = reduce_weight_grad + def forward(self, input_): + """Lightweight wrapper around the LinearRouterWithWeightGradAR autograd function.""" - def forward(self, input): - """Lightweight wrapper around the LinearWithParallelInput autograd function.""" - args = utils.cast_if_autocast_enabled(input, self.weight, self.reduce_weight_grad) + # if ep is enabled do not reduce weight grad because we don't do delayed allreduce + reduce_weight_grad = get_expert_model_parallel_size() == 1 + args = utils.cast_if_autocast_enabled(input_, self.weight, reduce_weight_grad) utils.verify_casted_dtype(args) with torch.cuda.amp.autocast(enabled=False): - return LinearWithParallelInput.apply(*args) + return LinearWithWeightGradAR.apply(*args) def init_weight_cpu(self): torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) diff --git a/src/neuronx_distributed/modules/moe/routing.py b/src/neuronx_distributed/modules/moe/routing.py index 60e0a20..44e7d21 100644 --- a/src/neuronx_distributed/modules/moe/routing.py +++ b/src/neuronx_distributed/modules/moe/routing.py @@ -1,68 +1,19 @@ from abc import ABC, abstractmethod -from typing import Union import torch import torch.nn.functional as F -from neuronx_distributed.modules.moe.model_utils import MoESequenceParallelMode -from neuronx_distributed.modules.moe.moe_parallel_layers import InputParallelLinear -from neuronx_distributed.parallel_layers import mappings - - -def get_linear_router( - hidden_size: int, - num_experts: int, - sequence_parallel_mode: MoESequenceParallelMode, - dtype: torch.dtype, - device: torch.device, -) -> InputParallelLinear: - """Helper function which returns a linear layer for the router, which implements any necessary collectives (depending on the - sequence_parallel_mode). - - - In EXIT_SP_ON_ENTRY_DELAY_MLP_AR mode, the gradients of the router weights need to be all-reduced across ranks, because - they have been computed using only the partial permuted_output at a given rank. This is a consequence of avoiding the all-reduce - in the down projection step (see ExpertFusedRowParallelLinear), and combining it with the scatter at the end of the MoE layer - (for SP). With this optimization, we replace a potentially large All-Gather (tensor size ECH) and medium-sized Scatter - (tensor size SBH), with a medium Reduce-Scatter (tensor size SBH) and a small All-Gather in the reverse pass (tensor size HE). - - - In OPTIMIZED_SP_MATMUL, we run the majority of the ExpertMLPs computation in sequence-parallel. Therefore, the router weight - gradients need to be all-reduced across ranks because they have been computed using only a chunk of the output at each rank. - - Arguments: - hidden_size: Hidden dimension. - num_experts: Total number of experts. - sequence_parallel_mode: SP mode being used for the MoE layer. - dtype: Datatype for the layer weights. - device: Device for the layer weights. - - Returns: - linear_router: InputParallelLinear which performs a linear projection from hidden_size to num_experts. - """ - - reduce_weight_grad = sequence_parallel_mode in { - MoESequenceParallelMode.EXIT_SP_ON_ENTRY_DELAY_MLP_AR, - MoESequenceParallelMode.OPTIMIZED_SP_MATMUL, - } - linear_router = InputParallelLinear( - input_size=hidden_size, - output_size=num_experts, - reduce_weight_grad=reduce_weight_grad, - dtype=dtype, - device=device, - ) - return linear_router - +from neuronx_distributed.modules.moe.moe_parallel_layers import LinearRouter +from neuronx_distributed.parallel_layers.parallel_state import get_expert_model_parallel_size class RouterBase(torch.nn.Module, ABC): """Base class for various routing strategies used in MoE. - This class is used to set common initialization parameters, and define the function signature of the forward pass of child classes. Arguments: num_experts: Total number of experts. top_k: Number of experts activated per token. Should be less than or equal to num_experts. hidden_size: Hidden dimension of the input sequence. act_fn: Activation used to obtain expert affinities from router logits. One of 'sigmoid' or 'softmax'. - sequence_parallel_mode: SP mode being used for the MoE layer. dtype: Datatype for the layer weights. device: Device for the layer weights. """ @@ -72,7 +23,6 @@ def __init__( num_experts: int, top_k: int, hidden_size: int, - sequence_parallel_mode: Union[str, MoESequenceParallelMode], act_fn: str, dtype: torch.dtype, device: torch.device, @@ -83,21 +33,27 @@ def __init__( raise ValueError(f"Invalid top_k={top_k} for num_experts={num_experts}") self.top_k = top_k self.hidden_size = hidden_size - if sequence_parallel_mode not in MoESequenceParallelMode.__members__: - raise TypeError(f"Unknown sequence_parallel_mode: {sequence_parallel_mode}") - self.sequence_parallel_mode = MoESequenceParallelMode[sequence_parallel_mode] if act_fn not in {"sigmoid", "softmax"}: raise ValueError("act_fn must be either 'sigmoid' or 'softmax'") self.act_fn = act_fn self.dtype = dtype self.device = device - def get_expert_affinities(self, router_logits): - """Applies the required activation function on the router logits to return the expert affinities.""" - - # router_logits: (T, E) - # Perform activation in fp64 to prevent auto-downcasting of operation to bf16, for numerical accuracy + # Create router + self.linear_router = LinearRouter( + input_size=hidden_size, + output_size=num_experts, + dtype=dtype, + device=device, + ) + + def get_router_logits_and_expert_affinities(self, hidden_states): + """Returns the router logits and expert affinities from the given hidden_states.""" + # router_logits: (T, H) @ (H, E) -> (T, E) + router_logits = self.linear_router(hidden_states) + + # Perform activation in fp64 to prevent auto-downcasting of operation to bf16, for numerical accuracy # expert_affinities: (T, E) if self.act_fn == "sigmoid": expert_affinities = torch.sigmoid(router_logits.to(dtype=torch.float64)) @@ -105,22 +61,21 @@ def get_expert_affinities(self, router_logits): expert_affinities = F.softmax(router_logits, dim=1, dtype=torch.float64) else: raise ValueError("act_fn must be either 'sigmoid' or 'softmax'") - return expert_affinities + + # Cast to required dtype + expert_affinities = expert_affinities.to(dtype=hidden_states.dtype) + return router_logits, expert_affinities @abstractmethod def forward(self, hidden_states): """Forward pass of the router. - If the input hidden_states are in sequence parallel, then the router implementation should perform the required - collectives to return the outputs for the complete sequence. - Common nomenclature: S: Sequence length, B: Batch size, H: Hidden Size - S': Sequence length (when the input is in SP) T: Tokens = S * B (token dimension obtained by flattening S and B) Arguments: - hidden_states: Input tensor of shape (S, B, H) or (S', B, H). + hidden_states: Input tensor of shape (T, H). Returns: router_logits: Tensor of shape (T, E) containing the router logits for each token for each expert. @@ -141,7 +96,6 @@ def __init__( num_experts: int, top_k: int, hidden_size: int, - sequence_parallel_mode: Union[str, MoESequenceParallelMode], dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu"), ): @@ -149,34 +103,19 @@ def __init__( num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, - sequence_parallel_mode=sequence_parallel_mode, act_fn="softmax", # Always use softmax activation for TopK router dtype=dtype, device=device, ) - # Create router - self.linear_router = get_linear_router(hidden_size, num_experts, self.sequence_parallel_mode, dtype, device) def forward(self, hidden_states): - # router_logits: (S, B, H) @ (H, E) -> (S, B, E) or (S', B, H) @ (H, E) -> (S', B, E) - router_logits = self.linear_router(hidden_states) - if self.training: - if self.sequence_parallel_mode == MoESequenceParallelMode.OPTIMIZED_SP_MATMUL: - # Gather the router_logits - # router_logits: (S', B, E) -> (S, B, E) - router_logits = mappings.gather_from_sequence_parallel_region(router_logits, to_model_parallel=False) - - # router_logits: (S, B, E) -> (T, E) - router_logits = router_logits.view(-1, self.num_experts) - - # Apply activation function to get expert_affinities - expert_affinities = self.get_expert_affinities(router_logits) + # Get router_logits and expert_affinities + router_logits, expert_affinities = self.get_router_logits_and_expert_affinities(hidden_states) # For each token, get the top_k experts # expert_index: (T, top_k) _, expert_index = torch.topk(router_logits, self.top_k) expert_index = expert_index.detach().to(dtype=torch.long) - expert_affinities = expert_affinities.to(dtype=hidden_states.dtype) return router_logits, expert_affinities, expert_index @@ -199,7 +138,6 @@ def __init__( num_experts: int, top_k: int, hidden_size: int, - sequence_parallel_mode: Union[str, MoESequenceParallelMode], act_fn: str = "sigmoid", dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu"), @@ -213,33 +151,20 @@ def __init__( num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, - sequence_parallel_mode=sequence_parallel_mode, act_fn=act_fn, dtype=dtype, device=device, ) # Create router - self.linear_router = get_linear_router(hidden_size, num_experts, self.sequence_parallel_mode, dtype, device) self.sinkhorn_iterations = ( sinkhorn_iterations if sinkhorn_iterations is not None else self.DEFAULT_SINKHORN_ITERS ) self.sinkhorn_tol = sinkhorn_tol def forward(self, hidden_states): - # router_logits: (S, B, H) @ (H, E) -> (S, B, E) or (S', B, H) @ (H, E) -> (S', B, E) - router_logits = self.linear_router(hidden_states) - if self.training: - if self.sequence_parallel_mode == MoESequenceParallelMode.OPTIMIZED_SP_MATMUL: - # Gather the router_logits - # router_logits: (S', B, E) -> (S, B, E) - router_logits = mappings.gather_from_sequence_parallel_region(router_logits, to_model_parallel=False) - - # router_logits: (S, B, E) -> (T, E) - router_logits = router_logits.view(-1, self.num_experts) - - # Apply activation function to get expert_affinities - expert_affinities = self.get_expert_affinities(router_logits) + # Get router_logits and expert_affinities + router_logits, expert_affinities = self.get_router_logits_and_expert_affinities(hidden_states) with torch.no_grad(): if self.training: @@ -256,7 +181,6 @@ def forward(self, hidden_states): expert_index = torch.argmax(sinkroute, dim=1, keepdim=True) expert_index = expert_index.to(dtype=torch.long) - expert_affinities = expert_affinities.to(dtype=hidden_states.dtype) return router_logits, expert_affinities, expert_index @staticmethod diff --git a/src/neuronx_distributed/modules/qkv_linear.py b/src/neuronx_distributed/modules/qkv_linear.py index c874c45..f27479c 100644 --- a/src/neuronx_distributed/modules/qkv_linear.py +++ b/src/neuronx_distributed/modules/qkv_linear.py @@ -1,6 +1,6 @@ import math import warnings -from typing import Optional, Tuple +from typing import Optional, Tuple, Callable, Any import torch import torch_xla.core.xla_model as xm @@ -128,14 +128,14 @@ def _initialize_affine_weight( return None -def _linear_forward(input, weight, bias): +def _linear_forward(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: output = torch.matmul(input, weight.t()) if bias is not None: output = output + bias return output -def _compute_gradients(input, weight, grad_output, use_bias): +def _compute_gradients(input: torch.Tensor, weight: torch.Tensor, grad_output: torch.Tensor, use_bias: bool) -> Tuple[torch.Tensor, torch.Tensor]: """ This method computes the gradients for the weight and bias, given the output gradient and input. @@ -150,6 +150,14 @@ def _compute_gradients(input, weight, grad_output, use_bias): return grad_weight, grad_bias +def check_requires_grad(weight_qkv, fuse_qkv, weight_q): + return (weight_qkv.requires_grad if fuse_qkv else weight_q.requires_grad) + + +def check_use_bias(weight_qkv, fuse_qkv, weight_q, bias_q, bias_qkv): + return (bias_qkv is not None if fuse_qkv else bias_q is not None) and check_requires_grad(weight_qkv, fuse_qkv, weight_q) + + class GQAQKVLinearWithAsyncCommunication(torch.autograd.Function): """Linear layer execution with asynchronous communication.""" @@ -157,26 +165,38 @@ class GQAQKVLinearWithAsyncCommunication(torch.autograd.Function): def forward( ctx, input: torch.Tensor, - weight_q: torch.Tensor, - weight_k: torch.Tensor, - weight_v: torch.Tensor, + weight_q: Optional[torch.Tensor], + weight_k: Optional[torch.Tensor], + weight_v: Optional[torch.Tensor], bias_q: Optional[torch.Tensor], bias_k: Optional[torch.Tensor], bias_v: Optional[torch.Tensor], async_grad_allreduce: bool, sequence_parallel_enabled: bool, kv_size_multiplier: int, - ): - ctx.use_bias = bias_q is not None and weight_q.requires_grad + weight_qkv: Optional[torch.Tensor] = None, + bias_qkv: Optional[torch.Tensor] = None, + fuse_qkv: bool = False, + output_size_q: int = None, + output_size_kv: int = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ctx.use_bias = check_use_bias(weight_qkv, fuse_qkv, weight_q, bias_q, bias_qkv) ctx.async_grad_allreduce = async_grad_allreduce ctx.sequence_parallel_enabled = sequence_parallel_enabled - ctx.compute_weight_gradient = weight_q.requires_grad + ctx.compute_weight_gradient = check_requires_grad(weight_qkv, fuse_qkv, weight_q) ctx.kv_size_multiplier = kv_size_multiplier + ctx.fuse_qkv = fuse_qkv if ctx.compute_weight_gradient: - ctx.save_for_backward(input, weight_q, weight_k, weight_v) + if ctx.fuse_qkv: + ctx.save_for_backward(input, weight_qkv) + else: + ctx.save_for_backward(input, weight_q, weight_k, weight_v) else: - ctx.save_for_backward(weight_q, weight_k, weight_v) + if ctx.fuse_qkv: + ctx.save_for_backward(weight_qkv) + else: + ctx.save_for_backward(weight_q, weight_k, weight_v) if ctx.sequence_parallel_enabled: # `input` is supposed to be 3D and its order of dimension is [sequence, batch, hidden] @@ -188,23 +208,35 @@ def forward( else: total_input = input - output_q = _linear_forward(total_input, weight_q, bias_q) - output_k = _linear_forward(total_input, weight_k, bias_k) - output_v = _linear_forward(total_input, weight_v, bias_v) + + if ctx.fuse_qkv: + output_qkv = _linear_forward(total_input, weight_qkv, bias_qkv) + # Split the outputs + output_dimensions = [output_size_q, output_size_kv, output_size_kv] + output_q, output_k, output_v = torch.split(output_qkv, output_dimensions, dim=-1) + else: + output_q = _linear_forward(total_input, weight_q, bias_q) + output_k = _linear_forward(total_input, weight_k, bias_k) + output_v = _linear_forward(total_input, weight_v, bias_v) return output_q, output_k, output_v @staticmethod def backward(ctx, grad_output_q, grad_output_k, grad_output_v): if ctx.compute_weight_gradient: - input, weight_q, weight_k, weight_v = ctx.saved_tensors + if ctx.fuse_qkv: + input, weight_qkv = ctx.saved_tensors + else: + input, weight_q, weight_k, weight_v = ctx.saved_tensors else: - weight_q, weight_k, weight_v = ctx.saved_tensors[:3] + if ctx.fuse_qkv: + weight_qkv = ctx.saved_tensors[:1] + else: + weight_q, weight_k, weight_v = ctx.saved_tensors[:3] input = None use_bias = ctx.use_bias - handle = None if ctx.compute_weight_gradient: if ctx.sequence_parallel_enabled: total_input = xm.all_gather( @@ -220,20 +252,29 @@ def backward(ctx, grad_output_q, grad_output_k, grad_output_v): # sum up the gradients from the repeated portions. get_kv_shared_group() # returns the ranks which have the same K and V heads, and hence allows us to # sum up from the distributed ranks. - handlek = torch.distributed.all_reduce(grad_output_k, group=get_kv_shared_group()) - handlev = torch.distributed.all_reduce(grad_output_v, group=get_kv_shared_group()) - - grad_input_q = grad_output_q.matmul(weight_q) - grad_input_k = grad_output_k.matmul(weight_k) - grad_input_v = grad_output_v.matmul(weight_v) - # Here we need to divide the grad_input_k and grad_input_v by a factor of kv_size_multipler, - # because after this step we are going to do an all-reduce over the entire tp group which - # would cause the K and V duplicate factor to be counted twice. - grad_input = grad_input_q + (grad_input_k + grad_input_v) / ctx.kv_size_multiplier + torch.distributed.all_reduce(grad_output_k, group=get_kv_shared_group()) + torch.distributed.all_reduce(grad_output_v, group=get_kv_shared_group()) + + if ctx.fuse_qkv: + # Divide grad_output_k and grad_output_v by the kv replication factor + # because after this step we are going to do an all-reduce over the entire tp group which + # would cause the K and V duplicate factor to be counted twice. + grad_output_k_scaled = grad_output_k / ctx.kv_size_multiplier + grad_output_v_scaled = grad_output_v / ctx.kv_size_multiplier + grad_output_qkv = torch.cat([grad_output_q, grad_output_k_scaled, grad_output_v_scaled], dim=-1) + grad_input = grad_output_qkv.matmul(weight_qkv) + else: + grad_input_q = grad_output_q.matmul(weight_q) + grad_input_k = grad_output_k.matmul(weight_k) + grad_input_v = grad_output_v.matmul(weight_v) + # Here we need to divide the grad_input_k and grad_input_v by a factor of kv_size_multipler, + # because after this step we are going to do an all-reduce over the entire tp group which + # would cause the K and V duplicate factor to be counted twice. + grad_input = grad_input_q + (grad_input_k + grad_input_v) / ctx.kv_size_multiplier if ctx.async_grad_allreduce: # Asynchronous all-reduce - handle = torch.distributed.all_reduce(grad_input, group=get_tensor_model_parallel_group()) + torch.distributed.all_reduce(grad_input, group=get_tensor_model_parallel_group()) # if no weight gradient, immediately return if not ctx.compute_weight_gradient: @@ -262,8 +303,9 @@ def backward(ctx, grad_output_q, grad_output_k, grad_output_v): pin_layout=False, ) - return sub_grad_input, None, None, None, None, None, None - return grad_input, None, None, None, None, None, None + return sub_grad_input, None, None, None, None, None, None, None, None, None, None, None, None, None, None + return grad_input, None, None, None, None, None, None, None, None, None, None, None, None, None, None + # Convert the tensor shapes to 2D for execution compatibility total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) @@ -283,9 +325,53 @@ def backward(ctx, grad_output_q, grad_output_k, grad_output_v): pin_layout=False, ) - grad_weight_q, grad_bias_q = _compute_gradients(total_input, weight_q, grad_output_q, use_bias) - grad_weight_k, grad_bias_k = _compute_gradients(total_input, weight_k, grad_output_k, use_bias) - grad_weight_v, grad_bias_v = _compute_gradients(total_input, weight_v, grad_output_v, use_bias) + if ctx.fuse_qkv: + # Use grad_output_qkv without scaling by the kv replication factor + grad_output_qkv_not_scaled = torch.cat([grad_output_q, grad_output_k, grad_output_v], dim=-1) + grad_weight_qkv, grad_bias_qkv = _compute_gradients(total_input, weight_qkv, grad_output_qkv_not_scaled, use_bias) + + + if ctx.sequence_parallel_enabled: + return ( + sub_grad_input, + None, + None, + None, + None, + None, + None, + None, + None, + None, + grad_weight_qkv, + grad_bias_qkv, + None, + None, + None, + ) + + return ( + grad_input, + None, + None, + None, + None, + None, + None, + None, + None, + None, + grad_weight_qkv, + grad_bias_qkv, + None, + None, + None, + ) + + else: + grad_weight_q, grad_bias_q = _compute_gradients(total_input, weight_q, grad_output_q, use_bias) + grad_weight_k, grad_bias_k = _compute_gradients(total_input, weight_k, grad_output_k, use_bias) + grad_weight_v, grad_bias_v = _compute_gradients(total_input, weight_v, grad_output_v, use_bias) if ctx.sequence_parallel_enabled: return ( @@ -300,6 +386,11 @@ def backward(ctx, grad_output_q, grad_output_k, grad_output_v): None, None, None, + None, + None, + None, + None, + None, ) return ( @@ -314,20 +405,29 @@ def backward(ctx, grad_output_q, grad_output_k, grad_output_v): None, None, None, + None, + None, + None, + None, + None, ) - def gqa_qkv_linear_with_async_allreduce( input: torch.Tensor, - weight_q: torch.Tensor, - weight_k: torch.Tensor, - weight_v: torch.Tensor, + weight_q: Optional[torch.Tensor], + weight_k: Optional[torch.Tensor], + weight_v: Optional[torch.Tensor], bias_q: Optional[torch.Tensor], bias_k: Optional[torch.Tensor], bias_v: Optional[torch.Tensor], async_grad_allreduce: bool, sequence_parallel_enabled: bool, kv_size_multiplier: int = 1, + weight_qkv: Optional[torch.Tensor] = None, + bias_qkv: Optional[torch.Tensor] = None, + fuse_qkv: bool=False, + output_size_q: int = None, + output_size_kv: int = None, ) -> torch.Tensor: args = cast_if_autocast_enabled( input, @@ -340,6 +440,11 @@ def gqa_qkv_linear_with_async_allreduce( async_grad_allreduce, sequence_parallel_enabled, kv_size_multiplier, + weight_qkv, + bias_qkv, + fuse_qkv, + output_size_q, + output_size_kv, ) verify_casted_dtype(args) with torch.cuda.amp.autocast(enabled=False): @@ -402,11 +507,12 @@ def __init__( bias: bool = True, gather_output: bool = True, dtype: torch.dtype = torch.float32, - device: torch.device = None, - init_method: torch.nn.init = None, + device: Optional[torch.device] = None, + init_method: Optional[Callable[..., Any]] = None, sequence_parallel_enabled: bool = False, keep_master_weight: bool = False, kv_size_multiplier: int = 1, + fuse_qkv: bool = False, ): super().__init__() @@ -428,6 +534,7 @@ def __init__( self.keep_master_weight = keep_master_weight self.device = device self.use_bias = bias + self.fuse_qkv = fuse_qkv self._create_weights_biases() self.initialize_weight_biases() @@ -445,41 +552,119 @@ def __init__( self._forward_impl = gqa_qkv_linear_with_async_allreduce def _create_weights_biases(self): - self.weight_q = Parameter( - torch.empty(self.q_output_size_per_partition, self.input_size, dtype=self.dtype, device=self.device) - ) - self.weight_k = Parameter( - torch.empty(self.kv_output_size_per_partition, self.input_size, dtype=self.dtype, device=self.device) - ) - self.weight_v = Parameter( - torch.empty(self.kv_output_size_per_partition, self.input_size, dtype=self.dtype, device=self.device) - ) - if self.use_bias: - bias_size = self.output_sizes[0] if self.gather_output else self.q_output_size_per_partition - self.bias_q = Parameter(torch.empty(bias_size, device=self.device, dtype=self.dtype)) - bias_size = self.output_sizes[1] if self.gather_output else self.kv_output_size_per_partition - self.bias_k = Parameter(torch.empty(bias_size, device=self.device, dtype=self.dtype)) - self.bias_v = Parameter(torch.empty(bias_size, device=self.device, dtype=self.dtype)) - else: + if self.fuse_qkv: + self.weight_qkv = Parameter( + torch.empty(self.q_output_size_per_partition + 2*self.kv_output_size_per_partition, self.input_size, dtype=self.dtype, device=self.device) + ) + if self.use_bias: + bias_size_q = self.output_sizes[0] if self.gather_output else self.q_output_size_per_partition + bias_size_kv = self.output_sizes[1] if self.gather_output else self.kv_output_size_per_partition + self.bias_qkv = Parameter(torch.empty(bias_size_q + 2*bias_size_kv, device=self.device, dtype=self.dtype)) + else: + self.register_parameter("bias_qkv", None) + self.register_parameter("bias_q", None) self.register_parameter("bias_k", None) self.register_parameter("bias_v", None) + else: + self.weight_q = Parameter( + torch.empty(self.q_output_size_per_partition, self.input_size, dtype=self.dtype, device=self.device) + ) + self.weight_k = Parameter( + torch.empty(self.kv_output_size_per_partition, self.input_size, dtype=self.dtype, device=self.device) + ) + self.weight_v = Parameter( + torch.empty(self.kv_output_size_per_partition, self.input_size, dtype=self.dtype, device=self.device) + ) + if self.use_bias: + bias_size = self.output_sizes[0] if self.gather_output else self.q_output_size_per_partition + self.bias_q = Parameter(torch.empty(bias_size, device=self.device, dtype=self.dtype)) + bias_size = self.output_sizes[1] if self.gather_output else self.kv_output_size_per_partition + self.bias_k = Parameter(torch.empty(bias_size, device=self.device, dtype=self.dtype)) + self.bias_v = Parameter(torch.empty(bias_size, device=self.device, dtype=self.dtype)) + else: + self.register_parameter("bias_q", None) + self.register_parameter("bias_k", None) + self.register_parameter("bias_v", None) def initialize_weight_biases(self): # Initialize weight. + if self.fuse_qkv: + # Split weight_qkv in to components for init + dimensions = [self.q_output_size_per_partition, self.kv_output_size_per_partition, self.kv_output_size_per_partition] + weight_q, weight_k, weight_v = torch.split(self.weight_qkv, dimensions, dim=0) + else: + weight_q = self.weight_q + weight_k = self.weight_k + weight_v = self.weight_v self.master_weight_q = self._init_per_layer_weight( - self.weight_q, self.output_sizes[0], self.q_output_size_per_partition, 1 + weight_q, self.output_sizes[0], self.q_output_size_per_partition, 1 ) self.master_weight_k = self._init_per_layer_weight( - self.weight_k, self.output_sizes[1], self.kv_output_size_per_partition, self.kv_size_multiplier + weight_k, self.output_sizes[1], self.kv_output_size_per_partition, self.kv_size_multiplier ) self.master_weight_v = self._init_per_layer_weight( - self.weight_v, self.output_sizes[1], self.kv_output_size_per_partition, self.kv_size_multiplier + weight_v, self.output_sizes[1], self.kv_output_size_per_partition, self.kv_size_multiplier ) + if self.fuse_qkv: + # Concat and update self.weight_qkv + with torch.no_grad(): + self.weight_qkv = torch.nn.Parameter(torch.cat([weight_q, weight_k, weight_v], dim=0)) + else: + self.weight_q = weight_q + self.weight_k = weight_k + self.weight_v = weight_v if self.use_bias: - self.master_bias_q = self._init_per_layer_bias(self.bias_q) - self.master_bias_k = self._init_per_layer_bias(self.bias_k) - self.master_bias_v = self._init_per_layer_bias(self.bias_v) + if self.fuse_qkv: + bias_size_q = self.output_sizes[0] if self.gather_output else self.q_output_size_per_partition + bias_size_kv = self.output_sizes[1] if self.gather_output else self.kv_output_size_per_partition + dimensions = [bias_size_q, bias_size_kv, bias_size_kv] + bias_q, bias_k, bias_v = torch.split(self.bias_qkv, dimensions, dim=0) + else: + bias_q = self.bias_q + bias_k = self.bias_k + bias_v = self.bias_v + self.master_bias_q = self._init_per_layer_bias( + bias_q, self.output_sizes[0], + torch.nn.init._calculate_fan_in_and_fan_out(weight_q), + self.kv_size_multiplier + ) + self.master_bias_k = self._init_per_layer_bias( + bias_k, self.output_sizes[1], + torch.nn.init._calculate_fan_in_and_fan_out(weight_k), + self.kv_size_multiplier + ) + self.master_bias_v = self._init_per_layer_bias( + bias_v, self.output_sizes[1], + torch.nn.init._calculate_fan_in_and_fan_out(weight_q), + self.kv_size_multiplier + ) + if self.fuse_qkv: + # Concat and update self.bias_qkv + with torch.no_grad(): + self.bias_qkv = torch.nn.Parameter(torch.cat([bias_q, bias_k, bias_v], dim=0)) + else: + self.bias_q = bias_q + self.bias_k = bias_k + self.bias_v = bias_v + + if self.fuse_qkv: + # Fuse weights and biases + if self.master_weight_q is not None: + self.master_weight_qkv = torch.cat([self.master_weight_q, self.master_weight_k, self.master_weight_v], dim=0) + else: + self.master_weight_qkv = None + self.master_weight_q = None + self.master_weight_k = None + self.master_weight_v = None + if self.use_bias: + if self.master_bias_q is not None: + self.master_bias_qkv = torch.cat([self.master_bias_q, self.master_bias_k, self.master_bias_v], dim=0) + else: + self.master_bias_qkv = None + self.master_bias_q = None + self.master_bias_k = None + self.master_bias_v = None def _init_per_layer_weight(self, weight, output_size, output_size_per_partition, kv_size_multiplier=1): master_weight = None @@ -498,7 +683,7 @@ def _init_per_layer_weight(self, weight, output_size, output_size_per_partition, return master_weight - def _init_per_layer_bias(self, bias): + def _init_per_layer_bias(self, bias, output_size, fan_in, kv_size_multiplier=1): master_bias = None if bias.device != torch.device("meta"): bound = 1 / math.sqrt(self.input_size) if fan_in > 0 else 0 @@ -526,18 +711,42 @@ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Ten input_parallel = copy_to_tensor_model_parallel_region(input) # Matrix multiply. - output_parallel_q, output_parallel_k, output_parallel_v = self._forward_impl( - input=input_parallel, - weight_q=self.weight_q, - weight_k=self.weight_k, - weight_v=self.weight_v, - bias_q=None, - bias_k=None, - bias_v=None, - async_grad_allreduce=self.async_tensor_model_parallel_allreduce, - sequence_parallel_enabled=self.sequence_parallel_enabled, - kv_size_multiplier=self.kv_size_multiplier, - ) + if self.fuse_qkv: + output_parallel_q, output_parallel_k, output_parallel_v = self._forward_impl( + input=input_parallel, + weight_q=None, + weight_k=None, + weight_v=None, + bias_q=None, + bias_k=None, + bias_v=None, + async_grad_allreduce=self.async_tensor_model_parallel_allreduce, + sequence_parallel_enabled=self.sequence_parallel_enabled, + kv_size_multiplier=self.kv_size_multiplier, + weight_qkv=self.weight_qkv, + bias_qkv=None, + fuse_qkv=self.fuse_qkv, + output_size_q=self.q_output_size_per_partition, + output_size_kv=self.kv_output_size_per_partition, + ) + else: + output_parallel_q, output_parallel_k, output_parallel_v = self._forward_impl( + input=input_parallel, + weight_q=self.weight_q, + weight_k=self.weight_k, + weight_v=self.weight_v, + bias_q=None, + bias_k=None, + bias_v=None, + async_grad_allreduce=self.async_tensor_model_parallel_allreduce, + sequence_parallel_enabled=self.sequence_parallel_enabled, + kv_size_multiplier=self.kv_size_multiplier, + weight_qkv=None, + bias_qkv=None, + fuse_qkv=self.fuse_qkv, + output_size_q=self.output_sizes[0] if self.gather_output else self.q_output_size_per_partition, + output_size_kv=self.output_sizes[1] if self.gather_output else self.kv_output_size_per_partition, + ) if self.gather_output: # All-gather across the partitions. assert not self.sequence_parallel_enabled @@ -546,7 +755,18 @@ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Ten output_v = gather_from_tensor_model_parallel_region(output_parallel_v) else: output_q, output_k, output_v = output_parallel_q, output_parallel_k, output_parallel_v - output_q = (output_q + self.bias_q) if self.bias_q is not None else output_q - output_k = (output_k + self.bias_k) if self.bias_k is not None else output_k - output_v = (output_v + self.bias_v) if self.bias_v is not None else output_v - return output_q, output_k, output_v + + if self.fuse_qkv: + if self.bias_qkv is not None: + bias_size_q = self.output_sizes[0] if self.gather_output else self.q_output_size_per_partition + bias_size_kv = self.output_sizes[1] if self.gather_output else self.kv_output_size_per_partition + dimensions = [bias_size_q, bias_size_kv, bias_size_kv] + bias_q, bias_k, bias_v = torch.split(self.bias_qkv, dimensions, dim=0) + output_q = (output_q + bias_q) + output_k = (output_k + bias_k) + output_v = (output_v + bias_v) + else: + output_q = (output_q + self.bias_q) if self.bias_q is not None else output_q + output_k = (output_k + self.bias_k) if self.bias_k is not None else output_k + output_v = (output_v + self.bias_v) if self.bias_v is not None else output_v + return output_q, output_k, output_v \ No newline at end of file diff --git a/src/neuronx_distributed/optimizer/__init__.py b/src/neuronx_distributed/optimizer/__init__.py index 850ff51..fd68ab1 100644 --- a/src/neuronx_distributed/optimizer/__init__.py +++ b/src/neuronx_distributed/optimizer/__init__.py @@ -1 +1 @@ -from .zero_redundancy_optimizer import NeuronZero1Optimizer +from .zero_redundancy_optimizer import NeuronZero1Optimizer, NeuronEPZero1Optimizer diff --git a/src/neuronx_distributed/optimizer/zero_dcp_utils.py b/src/neuronx_distributed/optimizer/zero_dcp_utils.py new file mode 100644 index 0000000..887661b --- /dev/null +++ b/src/neuronx_distributed/optimizer/zero_dcp_utils.py @@ -0,0 +1,519 @@ +import copy +import functools +import itertools +import logging +import os +import pickle +import re +import time +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dist_cp +import torch.nn.functional as F + +from torch.distributed.checkpoint.default_planner import DefaultSavePlanner, DefaultLoadPlanner +from torch.distributed.checkpoint.metadata import ( + BytesStorageMetadata, + Metadata, + TensorProperties as MetadataTensorProperties +) +from torch.distributed.checkpoint.planner import LoadPlan, SavePlan +from torch.distributed.checkpoint._nested_dict import ( + flatten_state_dict, + unflatten_state_dict, +) +from torch.distributed.fsdp._shard_utils import _get_remove_device_str +from torch.distributed._shard.sharding_spec import ShardMetadata +from torch.distributed._shard.sharded_tensor import ( + Shard, + ShardedTensor, + ShardedTensorMetadata, + TensorProperties, +) +import torch_xla.core.xla_model as xm + +from neuronx_distributed.parallel_layers.parallel_state import ( + get_data_parallel_group, + get_data_parallel_rank, + get_data_parallel_size, + get_pipeline_model_parallel_group, + get_pipeline_model_parallel_rank, + get_pipeline_model_parallel_size, + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_size, + rmsg, +) + +# avoid to log out `_dedup_tensors`, it's just 'step's and it's too long +logging.getLogger("torch.distributed.checkpoint._dedup_tensors").setLevel(logging.WARNING) + +MAX_RETRY = 100 + + +def _alloc_tensor(props: MetadataTensorProperties, size: Sequence[int]) -> torch.Tensor: + return torch.empty( + size=size, + dtype=props.dtype, + layout=props.layout, + requires_grad=props.requires_grad, + pin_memory=props.pin_memory, + device="cpu", + ) + + +def _get_optim_pid_to_params(optim: torch.optim.Optimizer) -> Dict[int, torch.nn.Parameter]: + ret = {pid: param for param_group in optim.param_groups for pid, param in enumerate(param_group["params"])} + return ret + + +def _get_param_to_param_names(model: torch.nn.Module) -> Dict[torch.nn.Parameter, str]: + ret = {param: name for name, param in model.named_parameters()} + return ret + + +def _get_optim_pid_to_param_names(model: torch.nn.Module, optim: torch.optim.Optimizer) -> Dict[int, str]: + optim_pid_to_params = _get_optim_pid_to_params(optim) + param_to_param_names = _get_param_to_param_names(model) + ret = {k: param_to_param_names[v] for k, v in optim_pid_to_params.items()} + return ret + + +def _tensor_to_sharded_tensor( + tensor: torch.Tensor, + param_shape: Sequence[int], + dp_rank: Optional[int] = None, +) -> ShardedTensor: + # quick path for scalars + if tensor.dim() == 0: + return tensor + + dp_size = get_data_parallel_size() + if dp_rank is None: + dp_rank = get_data_parallel_rank() + + shard_shape = list(tensor.shape) + padded_shape = shard_shape.copy() + padded_shape[0] = padded_shape[0] * dp_size + if dp_rank == dp_size - 1 and padded_shape[0] != param_shape[0]: + # unpad + tensor = tensor[: tensor.shape[0] - padded_shape[0] + param_shape[0]].clone() + + offsets = [0] * len(padded_shape) + offsets[0] = padded_shape[0] // dp_size * dp_rank + local_shards = [Shard.from_tensor_and_offsets(tensor, offsets, dp_rank)] + + # Create a ShardedTensor without invoking communication. + chunk_sizes = [] + for i in range(dp_size): + if i == dp_size - 1 and padded_shape[0] != param_shape[0]: + shape = shard_shape.copy() + shape[0] -= padded_shape[0] - param_shape[0] + chunk_sizes.append(shape) + else: + chunk_sizes.append(shard_shape) + dim0_offsets = [0] + list(itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes]))[:-1] + offsets = [0] * (len(chunk_sizes[0]) - 1) + chunk_offsets = [[d0] + offsets for d0 in dim0_offsets] + device_type = "cpu" + placements = [_get_remove_device_str(r, device_type, None) for r in range(len(chunk_sizes))] + assert len(chunk_sizes) == len(chunk_offsets) == len(placements) + shards_metadata = [ + ShardMetadata(offset, size, placement) + for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements) + ] + sharded_tensor_metadata = ShardedTensorMetadata( + shards_metadata=shards_metadata, + size=param_shape, + tensor_properties=TensorProperties( + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=False, + memory_format=torch.contiguous_format, + pin_memory=tensor.is_pinned(), + ), + ) + return ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=get_data_parallel_group() + ) + + +def _sharded_tensor_to_tensor( + tensor: Union[torch.Tensor, ShardedTensor], + param_shape: Sequence[int], +) -> torch.Tensor: + # quick path for scalars + if not isinstance(tensor, ShardedTensor): + return tensor + + dp_rank = get_data_parallel_rank() + dp_size = get_data_parallel_size() + + tensor = tensor.local_shards()[0].tensor + shard_shape = param_shape.copy() + shard_shape[0] = (param_shape[0] + dp_size - 1) // dp_size + padded_shape = shard_shape.copy() + padded_shape[0] = padded_shape[0] * dp_size + if dp_rank == dp_size - 1 and padded_shape[0] != param_shape[0]: + # pad + pad_size = padded_shape[0] - param_shape[0] + tensor = F.pad(tensor, [0, 0] * (tensor.dim() - 1) + [0, pad_size]) + return tensor + + +def _wrap_optim_state_dict( + state_dict: Dict[str, Any], + aux_infos: Dict[str, Any], + dedup: bool = False, + pp_rank: Optional[int] = None, + tp_rank: Optional[int] = None, + dp_rank: Optional[int] = None, +) -> Dict[str, Any]: + pp_rank = get_pipeline_model_parallel_rank() if pp_rank is None else pp_rank + tp_rank = get_tensor_model_parallel_rank() if tp_rank is None else tp_rank + + optim_pid_to_params = aux_infos["optim_pid_to_params"] + optim_pid_to_pnames = aux_infos["optim_pid_to_pnames"] + pnames_to_optim_pids = {v: k for k, v in optim_pid_to_pnames.items()} + + # replace pid with pname + new_state_dict = copy.copy(state_dict) + new_state_dict["base_state"] = {optim_pid_to_pnames[k]: v for k, v in state_dict["base_state"].items()} + shape_info = new_state_dict["shape_info"] + + # flatten state dict + new_state_dict, mappings = flatten_state_dict(new_state_dict) + for k, v in new_state_dict.items(): + if isinstance(v, torch.Tensor): + # tensor to sharded tensor + pname = mappings[k][1] + pid = pnames_to_optim_pids[pname] + param_shape = shape_info[pid] + new_state_dict[k] = _tensor_to_sharded_tensor(v, param_shape, dp_rank) + # add tp and pp info + if v.dim() > 0: # TODO: merge constant scalars + fqn = [] + fqn.append(mappings[k][-1]) + if get_pipeline_model_parallel_size() > 1: + # TODO: deduplicate shared params + fqn.append("|pp-{:04d}".format(pp_rank)) + if get_tensor_model_parallel_size() > 1: + param = optim_pid_to_params[pid] + if not dedup or hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel: + fqn.append("|tp-{:04d}".format(tp_rank)) + fqn = "".join(fqn) + new_path = list(mappings[k]) + new_path[-1] = fqn + mappings[k] = tuple(new_path) + + new_state_dict = unflatten_state_dict(new_state_dict, mappings) + return new_state_dict + + +def _unwrap_optim_state_dict( + state_dict: Dict[str, Any], + aux_infos: Dict[str, Any], +) -> Dict[str, Any]: + optim_pid_to_pnames = aux_infos["optim_pid_to_pnames"] + pnames_to_optim_pids = {v: k for k, v in optim_pid_to_pnames.items()} + + # flatten state dict + state_dict, mappings = flatten_state_dict(state_dict) + shape_info = state_dict["shape_info"] + for k, v in state_dict.items(): + if isinstance(v, torch.Tensor): + # sharded tensor to tensor + pname = mappings[k][1] + param_shape = shape_info[pnames_to_optim_pids[pname]] + state_dict[k] = _sharded_tensor_to_tensor(v, param_shape) + # remove tp and pp info + fqn = mappings[k][-1] + origin_key = fqn[: fqn.index("|")] if "|" in fqn else fqn + new_path = list(mappings[k]) + new_path[-1] = origin_key + mappings[k] = tuple(new_path) + state_dict = unflatten_state_dict(state_dict, mappings) + + # replace pname with pid + state_dict["base_state"] = {pnames_to_optim_pids[k]: v for k, v in state_dict["base_state"].items()} + return state_dict + + +def _prepare_optim_state_dict( + metadata: Metadata, + aux_infos: Dict[str, Any], + dedup: bool = False, + pp_rank: Optional[int] = None, + tp_rank: Optional[int] = None, + dp_rank: Optional[int] = None, +) -> Dict[str, Any]: + pp_rank = get_pipeline_model_parallel_rank() if pp_rank is None else pp_rank + tp_rank = get_tensor_model_parallel_rank() if tp_rank is None else tp_rank + + optim_pid_to_pnames = aux_infos["optim_pid_to_pnames"] + pnames_to_optim_pids = {v: k for k, v in optim_pid_to_pnames.items()} + + shape_info = aux_infos["shape_info"] + + new_state_dict = {} + for fqn, value in metadata.state_dict_metadata.items(): + if isinstance(value, BytesStorageMetadata): + new_state_dict[fqn] = "" + continue + # value: TensorStorageMetadata + if value.size.numel() == 1: + new_state_dict[fqn] = _alloc_tensor(value.properties, value.size) + else: + pp_rank_from_fqn = 0 + if get_pipeline_model_parallel_size() > 1: + pp_rank_from_fqn = int(re.search(r"\|pp-(\d{4})", fqn).group(1)) + tp_rank_from_fqn = 0 + if get_tensor_model_parallel_size() > 1: + tp_rank_from_fqn = int(re.search(r"\|tp-(\d{4})", fqn).group(1)) + if pp_rank_from_fqn == pp_rank and tp_rank_from_fqn == tp_rank: + origin_key = fqn[: fqn.index("|")] if "|" in fqn else fqn + assert origin_key.startswith("base_state.") + pname = origin_key[11:] # cut "base_state." + pname = pname[: pname.rindex(".")] + pid = pnames_to_optim_pids[pname] + param_shape = shape_info[pid] + shard_shape = list(param_shape) + shard_shape[0] = (shard_shape[0] + get_data_parallel_size() - 1) // get_data_parallel_size() + new_state_dict[fqn] = _tensor_to_sharded_tensor(_alloc_tensor(value.properties, shard_shape), param_shape, dp_rank=dp_rank) + + return new_state_dict + + +def _generate_all_local_save_plans( + state_dict: Dict[str, Any], + aux_infos: Dict[str, Any], + dedup: bool = False, +) -> List[SavePlan]: + def _generate_one_local_save_plan(global_rank): + # calc pp rank + pp_rank = None + for group in get_pipeline_model_parallel_group(as_list=True): + if global_rank in group: + if not isinstance(group, list): + group = list(group) + pp_rank = group.index(global_rank) + break + # calc tp rank + tp_rank = None + for group in get_tensor_model_parallel_group(as_list=True): + if global_rank in group: + if not isinstance(group, list): + group = list(group) + tp_rank = group.index(global_rank) + break + # calc dp rank + dp_rank = None + for group in get_data_parallel_group(as_list=True): + if global_rank in group: + if not isinstance(group, list): + group = list(group) + dp_rank = group.index(global_rank) + break + + wrapped_state_dict = _wrap_optim_state_dict(state_dict, aux_infos, dedup=dedup, pp_rank=pp_rank, tp_rank=tp_rank, dp_rank=dp_rank) + planner = DefaultSavePlanner() + planner.set_up_planner(wrapped_state_dict, global_rank == 0) + local_plan = planner.create_local_plan() + return local_plan + + all_plans = [_generate_one_local_save_plan(i) for i in range(dist.get_world_size())] + return all_plans + + +def _generate_all_local_load_plans( + state_dict: Dict[str, Any], + aux_infos: Dict[str, Any], + metadata: Metadata, + dedup: bool = False, +) -> List[LoadPlan]: + def _generate_one_local_load_plan(global_rank): + # calc pp rank + pp_rank = None + for group in get_pipeline_model_parallel_group(as_list=True): + if global_rank in group: + if not isinstance(group, list): + group = list(group) + pp_rank = group.index(global_rank) + break + # calc tp rank + tp_rank = None + for group in get_tensor_model_parallel_group(as_list=True): + if global_rank in group: + if not isinstance(group, list): + group = list(group) + tp_rank = group.index(global_rank) + break + # calc dp rank + dp_rank = None + for group in get_data_parallel_group(as_list=True): + if global_rank in group: + if not isinstance(group, list): + group = list(group) + dp_rank = group.index(global_rank) + break + + wrapped_state_dict = _prepare_optim_state_dict(metadata, aux_infos, dedup=dedup, pp_rank=pp_rank, tp_rank=tp_rank, dp_rank=dp_rank) + wrapped_state_dict = unflatten_state_dict(wrapped_state_dict, metadata.planner_data) + planner = DefaultLoadPlanner() + planner.set_up_planner(wrapped_state_dict, metadata, global_rank == 0) + local_plan = planner.create_local_plan() + return local_plan + + all_plans = [_generate_one_local_load_plan(i) for i in range(dist.get_world_size())] + return all_plans + + +@functools.lru_cache(maxsize=None) # equal to `@functools.cache` +def get_dcp_aux_infos( + model: torch.nn.Module, + optim: torch.optim.Optimizer, +) -> Dict[str, Any]: + return { + "optim_pid_to_params": _get_optim_pid_to_params(optim), + "optim_pid_to_pnames": _get_optim_pid_to_param_names(model, optim), + "shape_info": optim.state_dict()["shape_info"], + } + + +def save_optim_state_dict( + path: str, + state_dict: Dict[str, Any], + aux_infos: Dict[str, Any], + dedup: bool = False, +) -> Metadata: + """ + Method to wrap optimizer state dict, make it become a DCP friendly format. + + Parameters: + path (str): + save path. + state_dict (dict): + optimizer state dict. + aux_infos (dict): + auxiliary infomation extracted from model and optimizer. + dedup (bool): + if deduplicate tensor parallel parameters. + + Returns: + Metadata or None. + """ + wrapped_state_dict = _wrap_optim_state_dict(state_dict, aux_infos, dedup=dedup) + + storage_writer = dist_cp.FileSystemWriter(path) + is_coordinator = (dist.get_rank() == 0) + planner = DefaultSavePlanner() + + global_metatadata = None + + # get SavePlans + planner.set_up_planner(wrapped_state_dict, is_coordinator) + storage_writer.set_up_storage_writer(is_coordinator) + local_plan = planner.create_local_plan() + local_plan = storage_writer.prepare_local_plan(local_plan) + + all_local_plans = _generate_all_local_save_plans(state_dict, aux_infos, dedup=dedup) + + all_local_plans, global_metatadata = planner.create_global_plan( + all_local_plans + ) + all_local_plans = storage_writer.prepare_global_plan(all_local_plans) + central_plan = all_local_plans[dist.get_rank()] + + # write data and global metadata + final_local_plan = planner.finish_plan(central_plan) + all_writes = storage_writer.write_data(final_local_plan, planner) + all_writes.wait() + write_results = all_writes.value() + + # save + rename + with open(os.path.join(path, "write_results.tmp.tmp.{}".format(dist.get_rank())), "wb") as f: + pickle.dump(write_results, f) + os.fsync(f.fileno()) + os.rename(os.path.join(path, "write_results.tmp.tmp.{}".format(dist.get_rank())), os.path.join(path, "write_results.tmp.{}".format(dist.get_rank()))) + + if dist.get_rank() == 0: + file_paths = [os.path.join(path, "write_results.tmp.{}".format(i)) for i in range(dist.get_world_size())] + count = 0 + success = False + while count < MAX_RETRY and not success: + try: + if all(os.path.exists(file_path) for file_path in file_paths): + success = True + else: + time.sleep(1) # Wait for specified interval before retrying + count += 1 + except Exception as e: + logging.debug(f"An error occurred: {e}") + count += 1 + logging.debug("count: {}, rank: {}".format(count, dist.get_rank())) + if not success: + raise RuntimeError(rmsg("Failed to check write_results files exist.")) + + all_write_results = [] + for global_rank in range(dist.get_world_size()): + tmp_path = os.path.join(path, "write_results.tmp.{}".format(global_rank)) + with open(tmp_path, "rb") as f: + all_write_results.append(pickle.load(f)) + os.remove(tmp_path) + + storage_writer.finish(metadata=global_metatadata, results=all_write_results) + + return global_metatadata + + +def load_optim_state_dict( + path: str, + optimizer: torch.optim.Optimizer, + aux_infos: Dict[str, Any], + dedup: bool = False, +) -> None: + """ + Method to wrap optimizer state dict, make it become a DCP friendly format. + + Parameters: + path (str): + save path. + state_dict (dict): + optimizer state dict. + aux_infos (dict): + auxiliary infomation extracted from model and optimizer. + dedup (bool): + if deduplicate tensor parallel parameters. + + Returns: + None. + """ + storage_reader = dist_cp.FileSystemReader(path) + is_coordinator = (dist.get_rank() == 0) + planner = DefaultLoadPlanner() + + metadata = storage_reader.read_metadata() + + wrapped_state_dict = _prepare_optim_state_dict(metadata, aux_infos, dedup=dedup) + wrapped_state_dict = unflatten_state_dict(wrapped_state_dict, metadata.planner_data) + + planner.set_up_planner(wrapped_state_dict, metadata, is_coordinator) + storage_reader.set_up_storage_reader(metadata, is_coordinator) + local_plan = planner.create_local_plan() + local_plan = storage_reader.prepare_local_plan(local_plan) + + all_local_plans = _generate_all_local_load_plans(wrapped_state_dict, aux_infos, metadata, dedup=dedup) + + all_local_plans = planner.create_global_plan(all_local_plans) + all_local_plans = storage_reader.prepare_global_plan(all_local_plans) + central_plan = all_local_plans[dist.get_rank()] + + final_local_plan = planner.finish_plan(central_plan) + all_reads = storage_reader.read_data(final_local_plan, planner) + all_reads.wait() + xm.rendezvous("neuronx_distributed.optimizer.zero_dcp_utils.load_optim_state_dict") + + state_dict = _unwrap_optim_state_dict(wrapped_state_dict, aux_infos) + optimizer.load_state_dict(state_dict) diff --git a/src/neuronx_distributed/optimizer/zero_redundancy_optimizer.py b/src/neuronx_distributed/optimizer/zero_redundancy_optimizer.py index 26a43c2..d776c72 100644 --- a/src/neuronx_distributed/optimizer/zero_redundancy_optimizer.py +++ b/src/neuronx_distributed/optimizer/zero_redundancy_optimizer.py @@ -1,17 +1,22 @@ import gc import math import os -from typing import Union +from typing import Union, Optional, Callable, List, Any, Dict import torch import torch_xla.core.xla_model as xm from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer from ..parallel_layers.checkpointing import ensure_directory_exists -from ..parallel_layers.grads import clip_grad_norm +from ..parallel_layers.grads import get_grad_norm, clip_grads_with_norm +from ..utils.model_utils import recursive_filter from ..parallel_layers.parallel_state import ( get_data_parallel_group, get_data_parallel_rank, + get_expert_data_parallel_group, + get_expert_data_parallel_size, + get_expert_model_parallel_size, + get_expert_model_parallel_group, get_tensor_model_parallel_rank, model_parallel_is_initialized, ) @@ -30,16 +35,21 @@ def __init__(self, *args, **kwargs): if "sharding_groups" not in kwargs or kwargs["sharding_groups"] is None: kwargs["sharding_groups"] = get_data_parallel_group(as_list=True) - # If use dp groups for sharding, calculate the grad norm with world group - if kwargs["sharding_groups"] == get_data_parallel_group(as_list=True): - self._use_world_for_grad_norm = True - else: - self._use_world_for_grad_norm = False + # hard-code since use_world_for_grad_norm = True does not work with TP + self._use_world_for_grad_norm = False super().__init__(*args, **kwargs) + if kwargs.get("lazy_init"): + from neuronx_distributed.trainer import hooks + from neuronx_distributed.trainer.trainer import ( + filter_to_local_parameter_group, + ) + + hooks.register_post_partition_hook(filter_to_local_parameter_group, [self]) + hooks.register_post_partition_hook(self.init_zero) @property - def grad_norm(self): + def grad_norm(self) -> Optional[torch.Tensor]: return self._grad_norm def _shard_parameters(self): @@ -54,12 +64,7 @@ def _shard_parameters(self): else: raise e - @torch.no_grad() - def _clip_grad_norm( - self, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2.0, - ) -> torch.Tensor: + def _get_params_and_grad_norm(self, norm_type): all_parameters = [] for param_group, sharded_param_group in zip(self.param_groups, self.base_optimizer.param_groups): for param, shard in zip(param_group["params"], sharded_param_group["params"]): @@ -68,16 +73,30 @@ def _clip_grad_norm( shard.shared = param.shared if hasattr(param, "tensor_model_parallel"): shard.tensor_model_parallel = param.tensor_model_parallel + if hasattr(param, "expert_model_parallel"): + shard.expert_model_parallel = True all_parameters.append(shard) zero1_optimizer_groups = None if self._use_world_for_grad_norm else self._sharding_groups - self._grad_norm = clip_grad_norm( + + # Get norm + grad_norm = get_grad_norm( all_parameters, - max_norm=max_norm, norm_type=norm_type, zero1_optimizer=True, zero1_optimizer_groups=zero1_optimizer_groups, ) + return all_parameters, grad_norm + + @torch.no_grad() + def _clip_grad_norm( + self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2.0, + ) -> None: + + all_parameters, self._grad_norm = self._get_params_and_grad_norm(norm_type) + clip_grads_with_norm(all_parameters, self._grad_norm, max_norm) # [TODO] Remove this method def save_sharded_state_dict(self, output_dir: str, num_workers_per_step: int = 8) -> None: @@ -86,11 +105,7 @@ def save_sharded_state_dict(self, output_dir: str, num_workers_per_step: int = 8 "`NeuronZero1Optimizer.save_sharded_state_dict` is deprecated, please use `nxd.save_checkpoint` instead." ) - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == 0: - logger.debug("optimizer.saving checkpoint to {}".format(output_dir)) - else: - logger.debug("optimizer.saving checkpoint to {}".format(output_dir)) + logger.debug("optimizer.saving checkpoint to %s", output_dir) state_dict = self.state_dict() state_dict["dp_rank"] = get_data_parallel_rank() @@ -106,7 +121,7 @@ def save_sharded_state_dict(self, output_dir: str, num_workers_per_step: int = 8 local_rank = xm.get_local_ordinal() for worker in range(math.ceil(get_local_world_size() / num_workers_per_step)): if local_rank // num_workers_per_step == worker: - logger.debug(f"optimizer.worker {local_rank} saving checkpoint {chkpt_path}") + logger.debug("optimizer.worker %d saving checkpoint %s", local_rank, chkpt_path) cpu_data = move_all_tensor_to_cpu(state_dict) torch.save(cpu_data, chkpt_path) del cpu_data @@ -127,18 +142,221 @@ def load_sharded_state_dict(self, output_dir: str, num_workers_per_step: int = 8 "optim.dp_rank_{:02d}.tp_rank_{:02d}".format(get_data_parallel_rank(), get_tensor_model_parallel_rank()), ) - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == 0: - logger.debug(f"optimizer.loading checkpoint from {chkpt_path}") - else: - logger.debug(f"optimizer.loading checkpoint from {chkpt_path}") + logger.debug("optimizer.loading checkpoint from %s", chkpt_path) local_rank = xm.get_local_ordinal() for worker in range(math.ceil(get_local_world_size() / num_workers_per_step)): if local_rank // num_workers_per_step == worker: - logger.debug(f"optimizer.worker {local_rank} resuming from checkpoint {chkpt_path}") + logger.debug("optimizer.worker %d resuming from checkpoint %s", local_rank, chkpt_path) check_point = torch.load(chkpt_path, map_location="cpu") self.load_state_dict(check_point) del check_point gc.collect() xm.rendezvous("optimizer.load_checkpoint" + str(worker)) + + +class NeuronEPZero1Optimizer(NeuronZero1Optimizer): + def __init__(self, *args, **kwargs): + parameters = args[0] if len(args) > 0 else kwargs["params"] + ep_parameters = recursive_filter(parameters, self._is_ep_param) + non_ep_parameters = recursive_filter(parameters, lambda x: not self._is_ep_param(x)) + + # Default to use DP groups for sharding + if "sharding_groups" not in kwargs or kwargs["sharding_groups"] is None: + kwargs["sharding_groups"] = get_data_parallel_group(as_list=True) + elif kwargs["sharding_groups"] != get_data_parallel_group(as_list=True): + raise ValueError("Custom sharding group for Zero-1 with expert parallelism is not supported.") + + if "params" in kwargs: + kwargs.pop("params") + + self.non_ep_zero_optimizer = NeuronZero1Optimizer(non_ep_parameters, *args[1:], **kwargs) + + kwargs["sharding_groups"] = get_expert_data_parallel_group(as_list=True) + self.ep_zero_optimizer = NeuronZero1Optimizer(ep_parameters, *args[1:], **kwargs) + + # Avoid this optimizer to create hooks for grad accumulation + kwargs["use_grad_acc_hook"] = False + + super(NeuronEPZero1Optimizer, self).__init__(*args, **kwargs) + + def zero_grad(self, set_to_none: bool = False) -> None: + """ Reset the gradients of the two optimizers and hooked grad accumulators + when GPU-compatible precision is enabled. """ + self.ep_zero_optimizer.zero_grad(set_to_none=set_to_none) + self.non_ep_zero_optimizer.zero_grad(set_to_none=set_to_none) + + def _is_ep_param(self, param): + return hasattr(param, "expert_model_parallel") and param.expert_model_parallel + + def _filter_param_groups(self, groups: List[Dict[str, List[Any]]], predicate: Callable[..., List[Any]]) -> List[Dict[str, List[Any]]]: + filtered = [{k: v if k != "params" else [p for p in filter(predicate, v)]} for group in groups for k, v in group.items()] + return filtered + + @property + def sharding_groups(self): + return self.non_ep_zero_optimizer._sharding_groups + + @sharding_groups.setter + def sharding_groups(self, new_sharding_groups): + assert not self.inited, "already inited, cannot change sharding_groups" + self.non_ep_zero_optimizer._sharding_groups = new_sharding_groups + + def _combine_grad_norms(self, norms, norm_type): + if torch.isinf(torch.tensor(norm_type)): + return max(norms) + + norm_squares = [torch.pow(n, norm_type) for n in norms] + return torch.pow(sum(norm_squares), 1.0 / norm_type) + + @torch.no_grad() + def _clip_grad_norm( + self, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2.0, + ) -> None: + + non_ep_parameters, non_ep_grad_norm = self.non_ep_zero_optimizer._get_params_and_grad_norm(norm_type) + + # break the graph between the two grad norm calls to avoid runtime error + # TODO remove this + xm.mark_step() + + ep_parameters, ep_grad_norm = self.ep_zero_optimizer._get_params_and_grad_norm(norm_type) + + self._grad_norm = self._combine_grad_norms([non_ep_grad_norm, ep_grad_norm], norm_type) + + clip_grads_with_norm(non_ep_parameters + ep_parameters, self._grad_norm, max_norm) + + def _get_sharding_schemes(self) -> Dict[str, Any]: + # sequentially reduce over expert-data-parallel and expert-model-parallel groups + non_ep_sharding_scheme = [ + { + "sharding_group": get_expert_data_parallel_group(as_list=True), + "group_size": get_expert_data_parallel_size(), + "scale_factor": 1.0, + }, + { + "sharding_group": get_expert_model_parallel_group(as_list=True), + "group_size": get_expert_model_parallel_size(), + "scale_factor": 1.0, + }, + ] + + ep_sharding_scheme = [ + { + "sharding_group": get_expert_data_parallel_group(as_list=True), + "group_size": get_expert_data_parallel_size(), + # EP grads further need to be scaled down by EP degree + "scale_factor": 1.0 / float(get_expert_model_parallel_size()), + }, + ] + + return non_ep_sharding_scheme, ep_sharding_scheme + + def _reduce_gradients(self, **kwargs) -> None: + + non_ep_sharding_scheme, ep_sharding_scheme = self._get_sharding_schemes() + + # LR scheduler will modify some training parameters like learning rate + # of this optimizer, so we need to propagate these parameters to the + # base optimizers. + self._sync_param_groups(self.param_groups, self.non_ep_zero_optimizer.param_groups) + self.non_ep_zero_optimizer._reduce_gradients(sharding_scheme=non_ep_sharding_scheme, **kwargs) # noqa: W0212 + + self._sync_param_groups(self.param_groups, self.ep_zero_optimizer.param_groups) + self.ep_zero_optimizer._reduce_gradients(sharding_scheme=ep_sharding_scheme, **kwargs) # noqa: W0212 + + def _update_parameters(self, **kwargs) -> None: + non_ep_sharding_scheme, ep_sharding_scheme = self._get_sharding_schemes() + + self.non_ep_zero_optimizer._update_parameters(sharding_scheme=non_ep_sharding_scheme, **kwargs) # noqa: W0212 + self.ep_zero_optimizer._update_parameters(sharding_scheme=ep_sharding_scheme, **kwargs) # noqa: W0212 + + def _get_offset(self, dct: Dict[int, int]) -> int: + if len(dct) > 0: + return max(k for k in dct) + 1 + return 1 + + def state_dict(self) -> Dict[str, Any]: + """ Combine the state_dicts of the two base optimizers """ + + non_ep_state_dict = self.non_ep_zero_optimizer.state_dict() + ep_state_dict = self.ep_zero_optimizer.state_dict() + ep_param_id_offset = self._get_offset(non_ep_state_dict["state"]) + ep_base_state_offset = self._get_offset(non_ep_state_dict["base_state"]) + ep_shape_info_offset = self._get_offset(non_ep_state_dict["shape_info"]) + ep_param_group_offset = len(non_ep_state_dict["param_groups"]) + + state_dict = {"ep_param_id_offset": ep_param_id_offset, + "ep_param_group_offset": ep_param_group_offset, + "ep_base_state_offset": ep_base_state_offset, + "ep_shape_info_offset": ep_shape_info_offset, + } + + # combine param_groups + state_dict["param_groups"] = non_ep_state_dict["param_groups"] + ep_state_dict["param_groups"] + + # combine state + state_dict["state"] = non_ep_state_dict["state"] + state_dict["base_state"] = non_ep_state_dict["base_state"] + state_dict["shape_info"] = non_ep_state_dict["shape_info"] + for k, v in ep_state_dict["state"].items(): + state_dict["state"][ep_param_id_offset + k] = v + + for k, v in ep_state_dict["base_state"].items(): + state_dict["base_state"][ep_base_state_offset + k] = v + + for k, v in ep_state_dict["shape_info"].items(): + state_dict["shape_info"][ep_shape_info_offset + k] = v + + return state_dict + + + def load_state_dict(self, state_dict): + """ Split the state_dict for the two base optimizers and load them individually """ + + if ("ep_param_id_offset" not in state_dict or + "ep_param_group_offset" not in state_dict or + "ep_base_state_offset" not in state_dict or + "ep_shape_info_offset" not in state_dict): + raise ValueError("state_dict is not compatible with expert parallelism and Zero-1.") + + ep_param_id_offset = state_dict["ep_param_id_offset"] + ep_param_group_offset = state_dict["ep_param_group_offset"] + ep_base_state_offset = state_dict["ep_base_state_offset"] + ep_shape_info_offset = state_dict["ep_shape_info_offset"] + + # split param_groups + non_ep_param_groups = state_dict["param_groups"][:ep_param_group_offset] + ep_param_groups = state_dict["param_groups"][ep_param_group_offset:] + + def _split_states(state_key, offset): + non_ep_states, ep_states = {}, {} + for k, v in state_dict[state_key].items(): + if k < offset: + non_ep_states[k] = v + else: + ep_states[k - offset] = v + return non_ep_states, ep_states + + non_ep_state, ep_state = _split_states("state", ep_param_id_offset) + non_ep_base_state, ep_base_state = _split_states("base_state", ep_base_state_offset) + non_ep_shape_info, ep_shape_info = _split_states("shape_info", ep_shape_info_offset) + + non_ep_state_dict = { + "state": non_ep_state, + "base_state": non_ep_base_state, + "shape_info": non_ep_shape_info, + "param_groups": non_ep_param_groups, + } + + ep_state_dict = { + "state": ep_state, + "base_state": ep_base_state, + "shape_info": ep_shape_info, + "param_groups": ep_param_groups, + } + + self.non_ep_zero_optimizer.load_state_dict(non_ep_state_dict) + self.ep_zero_optimizer.load_state_dict(ep_state_dict) diff --git a/src/neuronx_distributed/parallel_layers/checkpointing.py b/src/neuronx_distributed/parallel_layers/checkpointing.py index 2668381..1212582 100644 --- a/src/neuronx_distributed/parallel_layers/checkpointing.py +++ b/src/neuronx_distributed/parallel_layers/checkpointing.py @@ -1,6 +1,6 @@ import gc import os -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Dict import torch import torch_xla.core.xla_model as xm @@ -32,8 +32,8 @@ def ensure_directory_exists(filename: str) -> None: PreShardHookFn = Callable[[torch.nn.Module, dict, str], bool] -def _invoke_preshard_hook(module: torch.nn.Module, model_state_dict: dict, prefix: str = "") -> dict: - if module == None: +def _invoke_preshard_hook(module: torch.nn.Module, model_state_dict: Dict[str, Any], prefix: str = "") -> None: + if module is None: return # This is temporary until we formailze the preshard_hook in src @@ -45,7 +45,7 @@ def _invoke_preshard_hook(module: torch.nn.Module, model_state_dict: dict, prefi _invoke_preshard_hook(child, model_state_dict, prefix + name + ".") -def get_sharded_model_dict(model: torch.nn.Module, model_state_dict: dict) -> dict: +def get_sharded_model_dict(model: torch.nn.Module, model_state_dict: Dict[str, Any]) -> Dict[str, Any]: from ..pipeline.model import NxDPPModel tp_size = get_tensor_model_parallel_size() @@ -89,18 +89,11 @@ def save( - checkpoint.pt """ - if torch.distributed.is_initialized(): - rank = torch.distributed.get_rank() - if rank == 0: - logger.info("saving checkpoint to {}".format(output_dir)) - else: - logger.info("saving checkpoint to {}".format(output_dir)) - rank = 0 - - chkpt_path = output_dir + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + logger.info("saving checkpoint to %s", output_dir) chkpt_path = os.path.join( - chkpt_path, + output_dir, "tp_rank_{:02d}_pp_rank_{:02d}".format(get_tensor_model_parallel_rank(), get_pipeline_model_parallel_rank()), ) if not master_dp_only: @@ -151,13 +144,14 @@ def save( def load( chkpt_path: str, - model: torch.nn.Module = None, + model: Optional[torch.nn.Module] = None, model_or_optimizer: Any = None, model_key: Optional[str] = "model", load_xser: bool = False, sharded: bool = True, strict: bool = True, master_dp_only: bool = True, + weights_only: bool = False, ) -> dict: """Load a checkpoint (model or optimizer) and return. In case the model/optimizer object is provided, it will load the model weights/optimizer stats. For large models/optimizers, to avoid @@ -175,13 +169,7 @@ def load( model_or_optimizer is not None ), "When checkpoint is not shareded, kwarg `model_or_optimizer` needs to be passed" # noqa: E501 - if torch.distributed.is_initialized(): - rank = torch.distributed.get_rank() - if rank == 0: - logger.info("loading checkpoint from {}".format(chkpt_path)) - else: - logger.info("loading checkpoint from {}".format(chkpt_path)) - rank = 0 + logger.info("loading checkpoint from %s", chkpt_path) skip_rendezvous = os.environ.get(NXD_SKIP_RENDEZVOUS, None) == "1" @@ -190,7 +178,7 @@ def load( if isinstance(model_or_optimizer, NxDPPModel) and not load_xser: logger.warning( - f"[Warning] It's recommended to use save_xser \ + "[Warning] It's recommended to use save_xser \ to save NxDPPModel to reduce saving time and redundant graphss" ) @@ -203,7 +191,7 @@ def load( if isinstance(model_or_optimizer, torch.nn.Module) and not model_moved_to_device and load_xser: logger.warning( - f"[Warning] For save_xser case it is recommended to call load \ + "[Warning] For save_xser case it is recommended to call load \ after moving model to device to reduce redundant graphs." ) @@ -222,8 +210,7 @@ def load( else: checkpoint_name = chkpt_path - if rank == 0: - logger.debug(f" loading checkpoint from {chkpt_path}") + logger.debug(" loading checkpoint from %s", chkpt_path) tp_size = get_tensor_model_parallel_size() tp_rank = get_tensor_model_parallel_rank() @@ -232,9 +219,9 @@ def load( if tp_rank == worker_start: logger.debug(f"Worker {tp_rank} resuming from checkpoint {checkpoint_name}") if load_xser: - check_point = _xser_load(checkpoint_name) + check_point = _xser_load(checkpoint_name, weights_only=weights_only) else: - check_point = torch.load(checkpoint_name, map_location="cpu") + check_point = torch.load(checkpoint_name, weights_only=weights_only, map_location="cpu") if model_or_optimizer: if model_key is not None: model_state_dict = check_point[model_key] @@ -262,22 +249,21 @@ def load( return check_point -def _xser_load(path): +def _xser_load(path: torch.serialization.FILE_LIKE, weights_only: bool = False) -> Any: """ Modify from xla serialization load https://github.com/pytorch/xla/blob/master/torch_xla/utils/serialization.py#L79-L100, with casting tensors to xla device to prevent OOM """ - ref_data = torch.load(path) + ref_data = torch.load(path, weights_only=weights_only) def convert_fn(tensors): rewritten_tensors = [] for t in tensors: - rewritten_tensors.append( - torch.load(os.path.join(path + ".tensors", "tensor_{}.pt".format(t.tid))).to(device=xm.xla_device()) - ) + tensor_path = os.path.join(path + ".tensors", "tensor_{}.pt".format(t.tid)) + rewritten_tensors.append(torch.load(tensor_path, weights_only=weights_only).to(device=xm.xla_device())) return rewritten_tensors def select_fn(v): - return type(v) == xser.TensorReference + return isinstance(v, xser.TensorReference) return xm.ToXlaTensorArena(convert_fn, select_fn).transform(ref_data) diff --git a/src/neuronx_distributed/parallel_layers/grads.py b/src/neuronx_distributed/parallel_layers/grads.py index 4ddf3cb..c1a7749 100644 --- a/src/neuronx_distributed/parallel_layers/grads.py +++ b/src/neuronx_distributed/parallel_layers/grads.py @@ -9,6 +9,10 @@ from .parallel_state import ( get_data_parallel_group, get_data_parallel_size, + get_expert_data_parallel_group, + get_expert_data_parallel_size, + get_expert_model_parallel_group, + get_expert_model_parallel_size, get_pipeline_model_parallel_group, get_pipeline_model_parallel_size, get_tensor_model_parallel_group, @@ -52,17 +56,33 @@ def get_grad_norm(parameters, norm_type=2, zero1_optimizer=False, zero1_optimize if not zero1_optimizer and zero1_optimizer_groups is not None: raise ValueError( - f"Getting zero1_optimizer_groups while zero1_optimizer is False. When using zero-1 optimizer grad clipping is handled by optimizer." + "Getting zero1_optimizer_groups while zero1_optimizer is False. When using zero-1 optimizer grad clipping is handled by optimizer." ) # noqa - def _allreduce_norm_across_parallel_groups(total_norm, reduce_op): + def _allreduce_norm_across_parallel_groups(total_norm, ep_total_norm, reduce_op): """ - zero1 without groups: allreduce across world groups - otherwise allreduce across each parallel group """ if zero1_optimizer and zero1_optimizer_groups is None: - torch.distributed.all_reduce(total_norm, op=reduce_op) + torch.distributed.all_reduce(total_norm, op=reduce_op, group=get_expert_data_parallel_group()) + if get_expert_model_parallel_size() > 1: + torch.distributed.all_reduce(total_norm, op=reduce_op, group=get_expert_model_parallel_group()) + torch.distributed.all_reduce(ep_total_norm, op=reduce_op, group=get_expert_data_parallel_group()) + total_norm += ep_total_norm + else: + if get_expert_model_parallel_size() > 1: + torch.distributed.all_reduce( + ep_total_norm, + op=reduce_op, + group=get_expert_model_parallel_group(), + ) + if reduce_op == torch.distributed.ReduceOp.MAX: + total_norm = max(total_norm, ep_total_norm) + else: + # reduce_op will be SUM for Lp-norms + total_norm += ep_total_norm if get_tensor_model_parallel_size() > 1: torch.distributed.all_reduce( total_norm, @@ -82,11 +102,12 @@ def _allreduce_norm_across_parallel_groups(total_norm, reduce_op): groups=zero1_optimizer_groups, pin_layout=True, ) + return total_norm if isinstance(parameters, torch.Tensor): parameters = [parameters] - device = xm.xla_device() + device = parameters[0].device dtype = parameters[0].dtype grads = [] @@ -97,10 +118,13 @@ def _allreduce_norm_across_parallel_groups(total_norm, reduce_op): is_not_shared = param_is_not_shared(param) is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) is_tp_param = hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel + is_ep_param = _is_ep_param(param) if grad_not_none: grad = param.grad.detach() grads.append(grad) if grad_not_none and is_not_shared: + if is_ep_param: + grad.expert_model_parallel = True if is_tp_param or (is_not_tp_duplicate and not force_spmd): # TP parallelized parameters # (not force_spmd) only tp rank 0 will add non-tp paramaters @@ -126,14 +150,21 @@ def _allreduce_norm_across_parallel_groups(total_norm, reduce_op): # sum the non-tp grad norm and scale by the tp_size for grad in grads_for_norm_tp_duplicated: grad_norm = torch.norm(grad, norm_type) - total_norm += grad_norm**norm_type + if _is_ep_grad(grad): + ep_total_norm += grad_norm**norm_type + else: + total_norm += grad_norm**norm_type total_norm /= get_tensor_model_parallel_size() + ep_total_norm /= get_tensor_model_parallel_size() for grad in grads_for_norm: grad_norm = torch.norm(grad, norm_type) - total_norm += grad_norm**norm_type + if _is_ep_grad(grad): + ep_total_norm += grad_norm**norm_type + else: + total_norm += grad_norm**norm_type - _allreduce_norm_across_parallel_groups(total_norm, torch.distributed.ReduceOp.SUM) + total_norm = _allreduce_norm_across_parallel_groups(total_norm, ep_total_norm, torch.distributed.ReduceOp.SUM) total_norm = torch.pow(total_norm, 1.0 / norm_type) return total_norm @@ -177,8 +208,13 @@ def clip_grad_norm( force_spmd=force_spmd, ) - # Scale. - device = xm.xla_device() + clip_grads_with_norm(parameters, total_norm, max_norm) + return total_norm + + +def clip_grads_with_norm(parameters, total_norm, max_norm): + assert len(parameters) > 0, "Parameters should be a non-empty list for gradient clipping" + device = parameters[0].device grads = [] for param in parameters: if param.grad is not None: @@ -197,7 +233,7 @@ def clip_grad_norm( return total_norm -def bucket_allreduce_gradients(grads_list): +def bucket_allreduce_gradients(grads_list, reduce_over_ep_group=False): """ All reduce bucket gradients for data parallelism. Referred from https://github.com/aws-neuron/neuronx-nemo-megatron/blob/main/nemo/nemo/collections/nlp/models/language_modeling/megatron_base_model.py#L58 # noqa: E501 @@ -221,10 +257,24 @@ def bucket_allreduce_gradients(grads_list): gradients = reversed(grads) total = 0 tensor_bucket = [] - groups = get_data_parallel_group(as_list=True) + + # if reduce_over_ep_group == False, the assumption is that we are allreducing + # all gradients over the expert-data-parallel group. if there is no ep, this + # is the only allreduce that we do, since data-parallel-group == expert-data-parallel-group. + # otherwise, non-expert-parallel gradients will go through an additional allreduce + # with reduce_over_ep_group == True, so that they are reduced over the full dp group. + if reduce_over_ep_group: + groups = get_expert_model_parallel_group(as_list=True) + else: + groups = get_expert_data_parallel_group(as_list=True) + + # the assumption is that if we are reducing over ep group, then we + # will also separately reduce over expert data parallel group, and the + # normalization has already happened. + size = get_data_parallel_size() if not reduce_over_ep_group else 1.0 for grad in gradients: - grad.data /= get_data_parallel_size() + grad.data /= size grad_bytes = grad.numel() * grad.element_size() # Gradient is larger than bucket_cap, don't bucketize diff --git a/src/neuronx_distributed/parallel_layers/layers.py b/src/neuronx_distributed/parallel_layers/layers.py index ed7c40b..fd33994 100644 --- a/src/neuronx_distributed/parallel_layers/layers.py +++ b/src/neuronx_distributed/parallel_layers/layers.py @@ -1,6 +1,6 @@ import math import warnings -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Any, Callable, List, Dict import torch import torch.nn.functional as F @@ -12,7 +12,7 @@ from .mappings import ( copy_to_tensor_model_parallel_region, gather_from_tensor_model_parallel_region, - gather_from_tensor_model_parallel_region_second_dim, + gather_from_tensor_model_parallel_region_with_dim, reduce_from_tensor_model_parallel_region, reduce_scatter_to_sequence_parallel_region, scatter_input_channels_to_tensor_model_parallel_region, @@ -29,10 +29,10 @@ cast_if_autocast_enabled, divide, get_padding_length, - param_is_not_tensor_parallel_duplicate, set_tensor_model_parallel_attributes, verify_casted_dtype, ) +from .utils import param_is_not_tensor_parallel_duplicate # noqa: F401 # pylint: disable=W0611 if "reduce_scatter_tensor" not in dir(torch.distributed): torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base @@ -40,7 +40,7 @@ torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base -def _initialize_affine_weight_neuron(weight, init_method, partition_dim, stride=1): +def _initialize_affine_weight_neuron(weight: torch.Tensor, init_method: Callable[[torch.Tensor], None], partition_dim: int, stride: int = 1) -> None: """Initialize affine weight for model parallel on Neuron device. Args: @@ -55,7 +55,7 @@ def _initialize_affine_weight_neuron(weight, init_method, partition_dim, stride= init_method(weight) -def create_local_weight(full_weight, partition_dim, per_partition_size, stride, out_weight=None): +def create_local_weight(full_weight: torch.Tensor, partition_dim: int, per_partition_size: Union[int, List[int]], stride: int, out_weight: Optional[torch.Tensor] = None) -> torch.Tensor: per_partition_per_stride_size = divide(per_partition_size, stride) weight_list = torch.split(full_weight, per_partition_per_stride_size, dim=partition_dim) rank = get_tensor_model_parallel_rank() @@ -69,14 +69,14 @@ def create_local_weight(full_weight, partition_dim, per_partition_size, stride, # Initialize a parameter with a given init_method on CPU # Optionally return the un-partitioned parameter def _initialize_parameter_cpu( - param, # shape should already be partitioned - partition_dim, - init_method, + param: torch.Tensor, # shape should already be partitioned + partition_dim: int, + init_method: Callable[[torch.Tensor], None], return_master_param=False, *, param_dtype=torch.float32, - stride=1, -): + stride: int = 1, +) -> Optional[torch.Tensor]: """Initialize a parameter for tensor parallelism Build a master copy of the parameter on all processes and scatter to the @@ -115,13 +115,13 @@ def __init__( self, num_embeddings: int, embedding_dim: int, - padding_idx: int = None, - max_norm: float = None, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, norm_type: float = 2.0, scale_grad_by_freq: bool = False, sparse: bool = False, - init_method: torch.nn.init = init.normal_, - device: torch.device = None, + init_method: Callable[..., torch.Tensor] = init.normal_, + device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, shard_across_embedding: bool = False, pad: bool = False, @@ -204,7 +204,7 @@ def __init__( ) _initialize_affine_weight_neuron(self.weight, init_method, partition_dim=self.weight_partition_dim) - def init_weight_cpu(self): + def init_weight_cpu(self) -> None: _initialize_parameter_cpu( param=self.weight, partition_dim=self.weight_partition_dim, @@ -212,7 +212,7 @@ def init_weight_cpu(self): param_dtype=self.dtype, ) - def _forward_shard_across_vocab(self, input_): + def _forward_shard_across_vocab(self, input_: torch.Tensor) -> Any: if self.tensor_model_parallel_size > 1: input_mask = (input_ >= self.start_index) & (input_ < self.end_index) # Mask the input. @@ -237,7 +237,7 @@ def _forward_shard_across_vocab(self, input_): return reduce_from_tensor_model_parallel_region(output_parallel) - def _forward_shard_across_embed(self, input_): + def _forward_shard_across_embed(self, input_: torch.Tensor) -> torch.Tensor: output_parallel = F.embedding( input_.long(), self.weight, @@ -249,7 +249,7 @@ def _forward_shard_across_embed(self, input_): ) return gather_from_tensor_model_parallel_region(output_parallel) - def forward(self, input_): + def forward(self, input_: torch.Tensor) -> torch.Tensor: if self.pad and self.training: raise RuntimeError("`pad=True` is only supported for inference. Set model.eval()") @@ -262,7 +262,7 @@ def forward(self, input_): output = self._forward_shard_across_vocab(input_) return output - def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool: + def preshard_hook(self, model_state_dict: Dict[str, Any], prefix: str) -> bool: if not self.pad or self.pad_size == 0: return @@ -297,7 +297,7 @@ def forward( async_grad_allreduce: bool, sequence_parallel_enabled: bool, save_for_backward: bool = True, - ): + ) -> torch.Tensor: ctx.use_bias = bias is not None and weight.requires_grad ctx.async_grad_allreduce = async_grad_allreduce ctx.sequence_parallel_enabled = sequence_parallel_enabled @@ -325,7 +325,7 @@ def forward( return output @staticmethod - def backward(ctx, grad_output): + def backward(ctx: Any, grad_output: Any) -> Any: if ctx.compute_weight_gradient: input, weight = ctx.saved_tensors else: @@ -445,13 +445,13 @@ class BaseParallelLinear(torch.nn.Module): def __init__(self): super().__init__() - def _init_weight(self, weight): + def _init_weight(self, weight: torch.Tensor) -> None: if self.arg_init_method is None: torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) else: self.arg_init_method(weight) - def _init_bias(self): + def _init_bias(self) -> None: fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 torch.nn.init.uniform_(self.bias, -bound, bound) @@ -485,9 +485,9 @@ def __init__( bias: bool = True, gather_output: bool = True, dtype: torch.dtype = torch.float32, - device: torch.device = None, + device: Optional[torch.device] = None, stride: int = 1, - init_method: torch.nn.init = None, + init_method: Optional[Callable[..., torch.Tensor]] = None, sequence_parallel_enabled: bool = False, keep_master_weight: bool = False, skip_bias_add: bool = False, @@ -530,7 +530,7 @@ def __init__( self._forward_impl = linear_with_async_allreduce - def set_weight_and_bias_config(self): + def set_weight_and_bias_config(self) -> None: # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. self.weight_shape = (self.output_size_per_partition, self.input_size) @@ -575,7 +575,7 @@ def initialize_weight_and_bias(self): else: self.register_parameter("bias", None) - def init_weight_cpu(self): + def init_weight_cpu(self) -> None: self.master_weight = _initialize_parameter_cpu( param=self.weight, partition_dim=self.weight_partition_dim, @@ -624,7 +624,7 @@ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Ten output = (output + self.bias) if self.bias is not None else output return output - def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool: + def preshard_hook(self, model_state_dict: Dict[str, Any], prefix: str) -> bool: if not self.pad or self.pad_size == 0: return if self.output_size != model_state_dict[prefix].shape[0] + self.pad_size: @@ -669,9 +669,9 @@ def __init__( bias: bool = True, input_is_parallel: bool = False, dtype: torch.dtype = torch.float32, - device: torch.device = None, + device: Optional[torch.device] = None, stride: int = 1, - init_method: torch.nn.init = None, + init_method: Optional[Callable[..., Any]] = None, sequence_parallel_enabled: bool = False, keep_master_weight: bool = False, skip_bias_add: bool = False, @@ -705,7 +705,7 @@ def __init__( self._forward_impl = linear_with_async_allreduce - def set_weight_and_bias_config(self): + def set_weight_and_bias_config(self) -> None: # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. self.weight_shape = (self.output_size, self.input_size_per_partition) @@ -716,7 +716,7 @@ def set_weight_and_bias_config(self): else: self.bias_shape = None - def initialize_weight_and_bias(self): + def initialize_weight_and_bias(self) -> None: self.set_weight_and_bias_config() init_device = self.device @@ -747,7 +747,7 @@ def initialize_weight_and_bias(self): else: self.register_parameter("bias", None) - def init_weight_cpu(self): + def init_weight_cpu(self) -> None: self.master_weight = _initialize_parameter_cpu( param=self.weight, partition_dim=self.weight_partition_dim, @@ -757,7 +757,7 @@ def init_weight_cpu(self): return_master_param=self.keep_master_weight, ) - def _init_bias(self): + def _init_bias(self) -> None: bound = 1 / math.sqrt(self.input_size_per_partition) if self.input_size_per_partition > 0 else 0 torch.nn.init.uniform_(self.bias, -bound, bound) @@ -838,7 +838,7 @@ def forward( return output @staticmethod - def backward(ctx, grad_output): + def backward(ctx: Any, grad_output: Any) -> Any: # Adapted from https://stackoverflow.com/questions/74949892/implementing-a-conv2d-backward-in-pytorch input, weight = ctx.saved_tensors @@ -916,7 +916,7 @@ def __init__( partition_dim: int, dtype: torch.dtype, device: torch.device, - init_method: torch.nn.init, + init_method: Callable[..., torch.Tensor], keep_master_params: bool, ): if not all(d == 1 for d in dilation): @@ -1022,10 +1022,10 @@ def _init_bias(self, bias): # Convolutions can take an int or tuple for most of their __init__ args # This function broadcasts the given arg to a tuple if it's not a tuple already -def _convert_conv_arg_to_tuple_if_needed(arg: Union[int, Tuple[int, ...]], dimensions: int): - if type(arg) == tuple: +def _convert_conv_arg_to_tuple_if_needed(arg: Union[int, Tuple[int, ...]], dimensions: int) -> Tuple[int, ...]: + if isinstance(arg, tuple): return arg - if type(arg) == int: + if isinstance(arg, int): return tuple(arg for _ in range(dimensions)) raise TypeError(f"Arg should be int or tuple of int, but received {type(arg)}") @@ -1067,8 +1067,8 @@ def __init__( padding_mode: str = "zeros", gather_output: bool = True, dtype: torch.dtype = torch.float32, - device: torch.device = None, - init_method: torch.nn.init = None, + device: Optional[torch.device] = None, + init_method: Optional[Callable[..., Any]] = None, keep_master_weight: bool = False, ): # Base class expects these all to be tuples so it can support N-dimensional convs @@ -1125,7 +1125,7 @@ def forward(self, in_tensor: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch # This way, each worker only has to do 1/world_size of the bias add if self.gather_output: # All-gather across the partitions - output = gather_from_tensor_model_parallel_region_second_dim(output_parallel) + output = gather_from_tensor_model_parallel_region_with_dim(output_parallel, gather_dim=1) else: output = output_parallel return output @@ -1168,8 +1168,8 @@ def __init__( padding_mode: str = "zeros", input_is_parallel=False, dtype: torch.dtype = torch.float32, - device: torch.device = None, - init_method: torch.nn.init = None, + device: Optional[torch.device] = None, + init_method: Optional[Callable[..., Any]] = None, keep_master_weight: bool = False, ): # Base class expects these all to be tuples so it can support N-dimensional convs diff --git a/src/neuronx_distributed/parallel_layers/mappings.py b/src/neuronx_distributed/parallel_layers/mappings.py index 3ea7f4d..6951dd4 100644 --- a/src/neuronx_distributed/parallel_layers/mappings.py +++ b/src/neuronx_distributed/parallel_layers/mappings.py @@ -6,7 +6,7 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_size, ) -from .utils import split_tensor_along_last_dim, split_tensor_along_second_dim +from .utils import split_tensor_along_last_dim, split_tensor_along_dim if "all_gather_into_tensor" not in dir(torch.distributed): torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base @@ -14,6 +14,24 @@ torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base +def nonzero_partition_dim_swap( + func: Callable[[Tensor, int], Tensor], +) -> Callable[[Tensor, int], Tensor]: + """ Decorator that internally swaps the partition/gather dim with 0-dimension. To the + outside the partition_dim appears to be the (arbitrary) partition dimension. Internally, + partition/split dimension is always 0 - which is achieved by pre- and post-transpose. """ + + @functools.wraps(func) + def wrapped_fn(x: Tensor, partition_dim: int) -> Tensor: + x_t = x.transpose(0, partition_dim) if partition_dim != 0 else x + y_t: Tensor = func(x_t, partition_dim=0) + y = y_t.transpose(0, partition_dim) if partition_dim != 0 else y_t + + return y + + return wrapped_fn + + def _reduce(input_: torch.Tensor) -> torch.Tensor: """All-reduce the input tensor across model parallel group.""" @@ -31,13 +49,25 @@ def _split_along_last_dim(input_: torch.Tensor) -> torch.Tensor: """Split the tensor along its last dimension and keep the corresponding slice.""" + return _split_along_dim(input_, len(input_.shape)-1) + +def _split_along_first_dim(input_: Tensor) -> Tensor: + """Split the tensor along its first dimension and keep the + corresponding slice.""" + + return _split_along_dim(input_, 0) + +def _split_along_dim(input_: Tensor, partition_dim: int) -> Tensor: + """Split the tensor along its first dimension and keep the + corresponding slice.""" + world_size = get_tensor_model_parallel_size() # Bypass the function if we are using only 1 device. if world_size == 1: return input_ - # Split along last dimension. - input_list = split_tensor_along_last_dim(input_, world_size) + # Split along partition dimension. + input_list = split_tensor_along_dim(input_, partition_dim, world_size) # Note: torch.split does not create contiguous tensors by default. rank = get_tensor_model_parallel_rank() @@ -46,27 +76,40 @@ def _split_along_last_dim(input_: torch.Tensor) -> torch.Tensor: return output -def _split_along_second_dim(input_: torch.Tensor) -> torch.Tensor: - """Split the tensor along its second dimension and keep the - corresponding slice.""" +@nonzero_partition_dim_swap +def _gather_along_dim(x: Tensor, partition_dim: int) -> Tensor: + """Given a tensor partitioned across the specified dimension, + gather and concatenate along partition dimension (using TP/SP group). + """ + tp_group = get_tensor_model_parallel_group() - world_size = get_tensor_model_parallel_size() - # Bypass the function if we are using only 1 device. - if world_size == 1: - return input_ + # bpyass the function if we only have 1 TP rank. + if tp_group.size() == 1: + return x - # Split along second dimension, numbered starting from 1 - input_list = split_tensor_along_second_dim(input_, world_size) + output = xm.all_gather( + x, + dim=partition_dim, + groups=get_tensor_model_parallel_group(as_list=True), + pin_layout=False, + ) - # Note: torch.split does not create contiguous tensors by default. - rank = get_tensor_model_parallel_rank() - output = input_list[rank].contiguous() + return output.contiguous() - return output +def _gather_along_last_dim(x: Tensor) -> Tensor: + return _gather_along_dim(x, partition_dim=len(x.shape)-1) -def _gather_along_second_dim(input_: torch.Tensor) -> torch.Tensor: - """Gather tensors and concatenate along the last dimension.""" +def _reduce_scatter_along_first_dim(x: Tensor) -> Tensor: + return _reduce_scatter_along_dim(x, 0) + +def _reduce_scatter_along_last_dim(x: Tensor) -> Tensor: + return _reduce_scatter_along_dim(x, len(x.shape)-1) + +@nonzero_partition_dim_swap +def _reduce_scatter_along_dim(x: Tensor, partition_dim: int) -> Tensor: + """Reduce-scatter the input tensor across model parallel group.""" + tp_group = get_tensor_model_parallel_group() world_size = get_tensor_model_parallel_size() # Bypass the function if we are using only 1 device. @@ -76,9 +119,16 @@ def _gather_along_second_dim(input_: torch.Tensor) -> torch.Tensor: # Size and dimension. rank = get_tensor_model_parallel_rank() - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group()) + xm.reduce_scatter( + xm.REDUCE_SUM, + x.contiguous(), + scatter_dim=partition_dim, + shard_count=tp_group.size(), + scale=1, + output=output, + groups=get_tensor_model_parallel_group(as_list=True), + pin_layout=False, + ) # Note: torch.cat already creates a contiguous tensor. output = torch.cat(tensor_list, dim=1).contiguous() @@ -112,7 +162,6 @@ def _gather_along_last_dim(input_: torch.Tensor) -> torch.Tensor: # Size and dimension. last_dim = input_.dim() - 1 - rank = get_tensor_model_parallel_rank() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group()) @@ -132,30 +181,12 @@ def _gather_along_first_dim(input_: torch.Tensor) -> torch.Tensor: output = xm.all_gather(input_, groups=get_tensor_model_parallel_group(as_list=True), pin_layout=False) - return output - - -def _reduce_scatter_along_first_dim(input_: torch.Tensor) -> torch.Tensor: - """Reduce-scatter the input tensor across model parallel group.""" - world_size = get_tensor_model_parallel_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - shape = list(input_.shape) - assert shape[0] % world_size == 0 - shape[0] //= world_size - output = torch.empty(shape, dtype=input_.dtype, device=input_.device) - groups = get_tensor_model_parallel_group(as_list=True) - - xm.reduce_scatter( - xm.REDUCE_SUM, - input_.contiguous(), - scatter_dim=0, - shard_count=len(groups[0]), - scale=1, - output=output, - groups=groups, + return xm.all_to_all( + x, + split_dimension=split_dim, + concat_dimension=concat_dim, + split_count=ep_group.size(), + groups=get_expert_model_parallel_group(as_list=True), pin_layout=False, ) @@ -234,8 +265,27 @@ def backward(ctx, grad_output): return _split_along_last_dim(grad_output) -class _GatherFromModelParallelRegionSecondDim(torch.autograd.Function): - """Gather the input from tensor model parallel region along the 2nd dim of the tensor and concatenate.""" +class _ScatterToSequenceParallelRegion(torch.autograd.Function): + """Split the input into TP/SP partitions along specified sequence dimension, + only keep the corresponding chunk for the current TP rank.""" + + @staticmethod + def symbolic(graph, input_: Tensor, partition_dim: int) -> Tensor: + return _split_along_dim(input_, partition_dim=partition_dim) + + @staticmethod + def forward(ctx, input_: Tensor, partition_dim: int) -> Tensor: + ctx.partition_dim = partition_dim + return _split_along_dim(input_, partition_dim=partition_dim) + + @staticmethod + def backward(ctx, grad_output: Tensor) -> Tuple[Tensor, None]: + return _gather_along_dim(grad_output, partition_dim=ctx.partition_dim), None + + +class _GatherFromSequenceParallelRegion(torch.autograd.Function): + """Gather input partitions across TP/SP group and concatenate along specified + sequence dimension.""" # FIXME(mkozuki): Definition of static symbolic methods don't look correct according to # https://pytorch.org/docs/stable/onnx.html#static-symbolic-method @@ -324,6 +374,22 @@ def backward(ctx, grad_output): return _gather_along_second_dim(grad_output) +class _ScatterInputChannelsToModelParallelRegion(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank.""" + + @staticmethod + def symbolic(graph, input_): + return _split_along_dim(input_, 1) + + @staticmethod + def forward(ctx, input_): + return _split_along_dim(input_, 1) + + @staticmethod + def backward(ctx, grad_output): + return _gather_along_dim(grad_output, 1) + + # ----------------- # Helper functions. # ----------------- @@ -349,10 +415,6 @@ def gather_from_tensor_model_parallel_region(input_): return _GatherFromModelParallelRegion.apply(input_) -def gather_from_tensor_model_parallel_region_second_dim(input_): - return _GatherFromModelParallelRegionSecondDim.apply(input_) - - def scatter_to_sequence_parallel_region(input_: torch.Tensor) -> torch.Tensor: return _ScatterToSequenceParallelRegion.apply(input_) @@ -361,5 +423,100 @@ def gather_from_sequence_parallel_region(input_, to_model_parallel=True): return _GatherFromSequenceParallelRegion.apply(input_, to_model_parallel) -def reduce_scatter_to_sequence_parallel_region(input_): - return _ReduceScatterToSequenceParallelRegion.apply(input_) +def reduce_scatter_to_sequence_parallel_region(input_: Tensor) -> Tensor: + return _ReduceScatterToSequenceParallelRegion.apply(input_, 0) + + +def reduce_scatter_to_tensor_model_parallel_region_with_dim( + input_: Tensor, + partition_dim: int, +) -> Tensor: + """performs a reduce-scatter within TP group, with the scatter happening across + the user-specified dimension.""" + return _ReduceScatterToSequenceParallelRegion.apply(input_, partition_dim) + + +def gather_from_tensor_model_parallel_region_with_dim( + input_: Tensor, + gather_dim: int, +) -> Tensor: + """performs a all-gather within TP group, with the gather happening across + the user-specified dimension.""" + return _GatherFromSequenceParallelRegion.apply(input_, gather_dim, False) + + +def enter_expert_parallel_region(x: Tensor, scatter_gather: bool) -> Tensor: + """used to enter expert-parallel (EP) region. + + parallelism dimensions: + * before (non-expert region): [PP, DP, TP] + * after (expert region) : [PP, DPEXP, EP, TP] + * satisfy DP == DPEXP * EP + + Args: + x: (e, c, h) or (e, c/sp, h) if SP. routed activations, where index along + the e dimension determines which expert the activation needs to go to. \ + contains a subset of tokens to be handled by each expert. + scatter_gather: whether to apply scatter-gather optimization to reduce + communication volume. currently this should be set to True when sequence + length is divisible by tp degree. + + Returns: + x: (e/ep, ep, c, h): contains a subset of tokens (partitioned w/ DP_EXP) \ + only for the experts that are associated with this EP rank. + + """ + e, c, h = x.shape + + # add dimension to make it easier to track tokens + x = x.view(e, 1, c, h) + + # DROP DUPLICATE_TOKENS: (e, 1, c, h) -> (e, 1, c/sp, h) + if scatter_gather: + x = _ScatterToSequenceParallelRegion.apply(x, 2) + + # SWAP PARTITION DIMENSION, ENTER EP: (e, 1, c/sp, h) -> (e/ep, ep, c/sp, h) + x = _AllToAllInExpertParallelRegion.apply(x, 0, 1) + + # REGATHER DUPLICATE TOKENS: (e/ep, ep, c/sp, h) -> (e/ep, ep, c, h) + if scatter_gather: + x = _GatherFromSequenceParallelRegion.apply(x, 2, False) + + return x + + +def exit_expert_parallel_region(x: Tensor, scatter_gather: bool) -> Tensor: + """used to exit expert-parallel (EP) region. + + parallelism dimensions: + * before (expert region) : [PP, DPEXP, EP, TP] + * after (non-expert region): [PP, DP, TP] + * and satisfy DP == DPEXP * EP + + Args: + x: (e/ep, ep, c, h): contains a subset of tokens \ + that are assigned to the subset of experts that are associated with \ + this EP rank. + scatter_gather: whether to apply scatter-gather optimization to reduce + communication volume. currently this should be set to True when sequence + + Returns: + x: (e, c, h) + """ + e, p, c, h = x.shape + + # DROP DUPLICATE_TOKENS: (e/ep, ep, c, h) -> (e/ep, ep, c/sp, h) + if scatter_gather: + x = _ScatterToSequenceParallelRegion.apply(x, 2) + + # SWAP PARTITION DIMENSION, EXIT EP: (e/ep, ep, c/sp, h) -> (e, 1, c/sp, h) + x = _AllToAllInExpertParallelRegion.apply(x, 1, 0) + + # REGATHER DUPLICATE TOKENS: (e, 1, c/sp, h) -> (e, 1, c, h) + if scatter_gather: + x = _GatherFromSequenceParallelRegion.apply(x, 2, False) + + # drop the extra dimension: (e, c, h) + x = x.squeeze(1) + + return x diff --git a/src/neuronx_distributed/parallel_layers/parallel_state.py b/src/neuronx_distributed/parallel_layers/parallel_state.py index 1a2a838..1707938 100644 --- a/src/neuronx_distributed/parallel_layers/parallel_state.py +++ b/src/neuronx_distributed/parallel_layers/parallel_state.py @@ -1,6 +1,14 @@ import os +import itertools +from typing import Any, List, Optional, TYPE_CHECKING import torch +import torch.distributed +from torch.distributed import ProcessGroup +from ..utils.logger import get_logger + +if TYPE_CHECKING: + from torch._C._distributed_c10d import Store try: # Method exists at least from PT 1.13-2.1 @@ -10,34 +18,43 @@ except ImportError: TCP_STORE_AVAILABLE = False -from ..utils.logger import get_logger - logger = get_logger() # Intra-layer model parallel group that the current rank belongs to. -_TENSOR_MODEL_PARALLEL_GROUP = None -_TENSOR_MODEL_PARALLEL_GROUP_SPMD = None +_TENSOR_MODEL_PARALLEL_GROUP: Optional[ProcessGroup] = None +_TENSOR_MODEL_PARALLEL_GROUP_SPMD: Optional[ProcessGroup] = None + +# Expert model parallel group that the current rank belongs to. +_EXPERT_MODEL_PARALLEL_GROUP: Optional[ProcessGroup] = None +_EXPERT_MODEL_PARALLEL_GROUP_SPMD: Optional[ProcessGroup] = None # Inter-layer model parallel group that the current rank belongs to. -_PIPELINE_MODEL_PARALLEL_GROUP = None -_PIPELINE_GLOBAL_RANKS = None -_PIPELINE_MODEL_PARALLEL_GROUP_SPMD = None -_NEXT_RANK_GROUP_SPMD = None -_PREV_RANK_GROUP_SPMD = None -_NEXT_RANK_GROUP = None -_PREV_RANK_GROUP = None +_PIPELINE_MODEL_PARALLEL_GROUP: Optional[ProcessGroup] = None +_PIPELINE_GLOBAL_RANKS: Optional[List[int]] = None +_PIPELINE_MODEL_PARALLEL_GROUP_SPMD: Optional[ProcessGroup] = None +_NEXT_RANK_GROUP_SPMD: Optional[ProcessGroup] = None +_PREV_RANK_GROUP_SPMD: Optional[ProcessGroup] = None +_NEXT_RANK_GROUP: Optional[ProcessGroup] = None +_PREV_RANK_GROUP: Optional[ProcessGroup] = None # Data parallel group that the current rank belongs to. -_DATA_PARALLEL_GROUP = None -_DATA_PARALLEL_GROUP_SPMD = None +_DATA_PARALLEL_GROUP: Optional[ProcessGroup] = None +_DATA_PARALLEL_GROUP_SPMD: Optional[ProcessGroup] = None + +# Expert data parallel group that the current rank belongs to. +_EXP_DATA_PARALLEL_GROUP: Optional[ProcessGroup] = None +_EXP_DATA_PARALLEL_GROUP_SPMD: Optional[ProcessGroup] = None # These values enable us to change the mpu sizes on the fly. -_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None -_MPU_TENSOR_MODEL_PARALLEL_RANK = None +_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE: Optional[int] = None +_MPU_TENSOR_MODEL_PARALLEL_RANK: Optional[int] = None + +_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE: Optional[int] = None +_MPU_EXPERT_MODEL_PARALLEL_RANK: Optional[int] = None # A CPU group that contains ranks from current rank's PP group\ # Used for PP metadata transmission -PP_GROUP_PG_GLOO = None +PP_GROUP_PG_GLOO: Optional[ProcessGroup] = None def initialize_model_parallel(tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1) -> None: @@ -48,20 +65,119 @@ def initialize_model_parallel(tensor_model_parallel_size: int = 1, pipeline_mode pipeline_model_parallel_size: number of Neuron devices used to parallelize model layer. tensor_model_parallel_size: number of Neuron devices used to parallelize model tensor. - Let's say we have a total of 32 Neuron devices denoted by n0 ... n32 and we - use 8 Neuron devices to do tensor parallelism and 4 for pipeline parallelism. The present - function will create 32 data-parallel groups (meaning no data parallelism), 8 pipeline - model-parallel groups and 4 tensor model-parallel groups as: - 32 data_parallel groups: - [n0], [n1], [n2], [n3], [n4], [n5], [n6], [n7], [n8], [n9], [n10], [n11], - [n12], [n13], [n14], [n15], [n16], [n17], [n18], [n19], [n20], [n21], [n22], - [n23], [n24], [n25], [n26], [n27], [n28], [n29], [n30], [n31] - 4 tensor model-parallel groups: - [n0, n1, n2, n3, n4, n5, n6, n7], [n8, n9, n10, n11, n12, n13, n14, n15], - [n16, n17, n18, n19, n20, n21, n22, n23], [n24, n25, n26, n27, n28, n29, n30, n31] - 8 pipeline model-parallel groups: - [n0, n8, n16, n24], [n1, n9, n17, n25], [n2, n10, n18, n26], [n3, n11, n19, n27], - [n4, n12, n20, n28], [n5, n13, n21, n29], [n6, n14, n22, n30], [n7, n15, n23, n31] + expert_model_parallel_size: number of Neuron devices used to parallelize MoE experts. + + mental model: + WITHOUT EXPERT PARALLELISM (EP) + imagine an array filled with worker global ranks [1, 2, .., PP*DP*TP], + reshaped to a (contiguous, row-major) tensor of shape [PP, DP, TP] + (for now we are ignoring EP by using EP = 1, this will be discussed later). + indices along the final dimension (TP) have stride of 1 (contiguous) + NOTE: this is important because it ensures as much TP communication as + possible is intra-node, as workers in the same node have contiguous + global ranks. + indices along the 2nd to last dimension (DP) have stride of TP + indices along the 3rd to last dimension (PP) have stride of DP * TP + + WITH EXPERT PARALLELISM (EP): + the tensor from before can have two shapes + [PP, DP_exp, EP, TP] - in expert regions (MLP) + [PP, DP_nonexp , TP] - everywhere else. + since DP_exp * EP == DP_nonexp, we can view switches between expert and nonexpert + regions as a reshaping of this tensor, and regardless of which mode we're in: + * the stride of earlier dimensions (in this case only PP) remains DP_exp * EP * TP. + * the stride of later dimensions (in this case only TP) remains 1. + importantly, this means that when switching between nonexpert and expert regions, + any given worker will retain the same PP and TP ranks. + + EXAMPLE 1 (NO EP) + ---------------------------------------------------------------------------------- + Let's say: + * we have a total of 32 Neuron devices denoted by n0 ... n32 + * user specifies TP=8, PP=4 + From this we can derive that DP = N / (TP * PP) = 1 + + The function will create: + * 8 pipeline model-parallel groups of size PP=4. + Stride is 8, since the product of all subsequent parallelism dimensions is 8. + [ + [n00, n08, n16, n24], # (DP=0, TP=0) + [n01, n09, n17, n25], # (DP=0, TP=1) + ... + [n06, n14, n22, n30], # (DP=0, TP=6) + [n07, n15, n23, n31] # (DP=0, TP=7) + ] + * 32 data-parallel groups of size DP=1 (meaning no data parallelism). + [ + [n00], # (PP=0, TP=0) + [n01], # (PP=0, TP=1) + ... + [n30], # (PP=3, TP=6) + [n31] # (PP=3, TP=7) + ] + * 4 tensor model-parallel groups of size TP=8 + Stride is 1 since this is the final parallelism dimension. + [ + [n00, n01, n02, n03, n04, n05, n06, n07], # (PP=0, DP=0) + [n08, n09, n10, n11, n12, n13, n14, n15], # (PP=1, DP=0) + [n16, n17, n18, n19, n20, n21, n22, n23], # (PP=2, DP=0) + [n24, n25, n26, n27, n28, n29, n30, n31], # (PP=3, DP=0) + ] + + EXAMPLE 2 (WITH EP) + ---------------------------------------------------------------------------------- + Lets say: + * we have a total of 128 neuron devices denoted by n0 ... n128 + * user specifies TP=8, PP=4, EP=2 + From this we can derive that DP_nonexp = 4, and DP_exp = 2 + + The function will create: + * 32 pipeline model parallel groups of size PP=4 each. + stride is 32, because product of all subsequent parallelism dimensions is 32. + [ + [n000, n032, n064, n096], # (DP=0, TP=0) or (DP_EXP=0, EP=0, TP=0) + [n001, n033, n065, n097], # (DP=0, TP=1) or (DP_EXP=0, EP=0, TP=1) + ... + [n030, n062, n094, n126], # (DP=3, TP=6) or (DP_EXP=1, EP=1, TP=6) + [n031, n063, n095, n127] # (DP=3, TP=7) or (DP_EXP=1, EP=1, TP=7) + ] + * 32 DP_nonexp groups of size DP_nonexp=4 each. + stride is 8 (TP) + [ + [n000, n008, n016, n024], # (PP=0, TP=0) + [n001, n009, n017, n025], # (PP=0, TP=1) + ... + [n102, n110, n118, n126], # (PP=3, TP=6) + [n103, n111, n119, n127], # (PP=3, TP=7) + ] + * 64 DP_exp groups of size DP_exp=2 each. + stride is 16 (EP * TP) + [ + [n000, n016], # (PP=0, EP=0, TP=0) + [n001, n017], # (PP=0, EP=0, TP=1) + ... + [n110, n126], # (PP=3, EP=1, TP=6) + [n111, n127] # (PP=3, EP=1, TP=7) + ] + * 64 expert model parallel groups of size EP=2 each. + stride is 8 (TP) + [ + [n000, n008], # (PP=0, DP_EXP=0, TP=0) + [n001, n009], # (PP=0, DP_EXP=0, TP=1) + ... + [n118, n126], # (PP=3, DP_EXP=1, TP=6) + [n119, n127] # (PP=3, DP_EXP=1, TP=7) + ] + * 16 TP groups of size TP=8 each. + stride is 1, contiguousness prioritizes TP communication happening within + ranks on same node. + [ + [n000, n001, n002, n003, n004, n005, n006, n007], # (PP=0, DP=0) or (PP=0, DP_EXP=0, EP=0) + [n008, n009, n010, n011, n012, n013, n014, n015], # (PP=0, DP=1) or (PP=0, DP_EXP=0, EP=1) + ... + [n112, n113, n114, n115, n116, n117, n118, n119], # (PP=3, DP=2) or (PP=3, DP_EXP=1, EP=0) + [n120, n121, n122, n123, n124, n125, n126, n127] # (PP=3, DP=3) or (PP=3, DP_EXP=1, EP=1) + ] """ # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() @@ -77,112 +193,166 @@ def initialize_model_parallel(tensor_model_parallel_size: int = 1, pipeline_mode ) ) data_parallel_size: int = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size) - if torch.distributed.get_rank() == 0: - print("> initializing tensor model parallel with size {}".format(tensor_model_parallel_size)) - print("> initializing pipeline model parallel with size {}".format(pipeline_model_parallel_size)) - print("> initializing data parallel with size {}".format(data_parallel_size)) - + + if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size * expert_model_parallel_size) != 0: + raise RuntimeError( + f"invalid implied expert data parallel degree: " + f"`world_size` ({world_size}) is not divisible by " + f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + f"pipeline_model_parallel_size ({pipeline_model_parallel_size}) x " + f"expert_model_parallel_size ({expert_model_parallel_size})" + ) + exp_data_parallel_size: int = world_size // ( + tensor_model_parallel_size * pipeline_model_parallel_size * expert_model_parallel_size + ) + + if tensor_model_parallel_size == 4: + # On trn1, TP=4 is a special case where each TP group consists of locally connected, + # non-contiguous ranks grouped within each node to avoid cross-node TP. + # Ex: for TP=4 PP=1 on 2 trn1.32xl nodes (64 NeuronCores): + # 16 TP groups: [ [0, 8, 16, 24], [1, 9, 17, 25], [2, 10, 18, 26], ... [7, 15, 23, 31], + # [32, 40, 48, 56], [33, 41, 49, 57], [34, 42, 50, 58], ... [39, 47, 55, 63] ] + # 4 DP groups: [ [0, 1, 2, 3, 4, 5, 6, 7, 32, 33, 34, 35, 36, 37, 38, 39] + # [8, 9, 10, 11, 12, 13, 14, 15, 40, 41, 42, 43, 44, 45, 46, 47] + # [16, 17, 18, 19, 20, 21, 22, 23, 48, 49, 50, 51, 52, 53, 54, 55] + # [24, 25, 26, 27, 28, 29, 30, 31, 56, 57, 58, 59, 60, 61, 62, 63] ] + # 64 PP groups: [ [0], [1], [2] .. [63] ] (No pipeline parallelism) + if expert_model_parallel_size > 1: + raise NotImplementedError("TP=4 case not yet implemented for expert parallelism") + + cluster_ranks = torch.arange(0, world_size) + cluster_ranks_exp = ( + cluster_ranks.reshape([pipeline_model_parallel_size, data_parallel_size // 8, 4, 8]) + .transpose(-1, -2) + .reshape( + pipeline_model_parallel_size, data_parallel_size, expert_model_parallel_size, tensor_model_parallel_size + ) + ) + cluster_ranks_nonexp = ( + cluster_ranks.reshape([pipeline_model_parallel_size, data_parallel_size // 8, 4, 8]) + .transpose(-1, -2) + .reshape(pipeline_model_parallel_size, data_parallel_size, tensor_model_parallel_size) + ) + else: + cluster_ranks = torch.arange(0, world_size) + cluster_ranks_exp = cluster_ranks.reshape( + [ + pipeline_model_parallel_size, + exp_data_parallel_size, + expert_model_parallel_size, + tensor_model_parallel_size, # important: contiguous parallelism dimension + ] + ) + cluster_ranks_nonexp = cluster_ranks.reshape( + [ + pipeline_model_parallel_size, + data_parallel_size, + tensor_model_parallel_size, # important: contiguous parallelism dimension + ] + ) + + logger.info("> initializing tensor model parallel with size %d", tensor_model_parallel_size) + logger.info("> initializing pipeline model parallel with size %d", pipeline_model_parallel_size) + logger.info("> initializing data parallel with size %d", data_parallel_size) + logger.info("> initializing world size to %d", world_size) + if expert_model_parallel_size > 1: + logger.info("> initializing expert model parallel with size %d", expert_model_parallel_size) + logger.info("> initializing data parallel (exp) with size %d", exp_data_parallel_size) + # We create a dummy neff and execute it across all workers in the world. # This is done to initialize the collectives. Collectives initialization # requires all workers in the world to participate and this soometimes - # may not be guranteed. Hence as a workaround, we run this dummy neff, and + # may not be guranteed. Hence as a workaround, we run this dummy neff, and # get the collectives initialized. - temp = torch.rand([1], device='xla') + temp = torch.rand([1], device="xla") torch.distributed.all_reduce(temp, group=torch.distributed.group.WORLD) import torch_xla.core.xla_model as xm - xm.mark_step() - num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size - num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + xm.mark_step() rank = torch.distributed.get_rank() compress_rg = int(os.getenv("NEURON_EXPERIMENTAL_COMPRESS_RG", "0")) # Build the data-parallel groups. - global _DATA_PARALLEL_GROUP - global _DATA_PARALLEL_GROUP_SPMD - assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" - all_data_parallel_group_ranks = [] - - # On trn1, TP=4 is a special case where each TP group consists of locally connected, non-contiguous - # ranks grouped within each node to avoid cross-node TP. - # Ex: for TP=4 PP=1 on 2 trn1.32xl nodes (64 NeuronCores): - # 16 TP groups: [ [0, 8, 16, 24], [1, 9, 17, 25], [2, 10, 18, 26], ... [7, 15, 23, 31], - # [32, 40, 48, 56], [33, 41, 49, 57], [34, 42, 50, 58], ... [39, 47, 55, 63] ] - # 4 DP groups: [ [0, 1, 2, 3, 4, 5, 6, 7, 32, 33, 34, 35, 36, 37, 38, 39] - # [8, 9, 10, 11, 12, 13, 14, 15, 40, 41, 42, 43, 44, 45, 46, 47] - # [16, 17, 18, 19, 20, 21, 22, 23, 48, 49, 50, 51, 52, 53, 54, 55] - # [24, 25, 26, 27, 28, 29, 30, 31, 56, 57, 58, 59, 60, 61, 62, 63] ] - # 64 PP groups: [ [0], [1], [2] .. [63] ] (No pipeline parallelism) - if tensor_model_parallel_size == 4: - for p in range(pipeline_model_parallel_size): - start_rank = p * num_pipeline_model_parallel_groups - end_rank = (p + 1) * num_pipeline_model_parallel_groups - for i in range(tensor_model_parallel_size): - ranks = [] - for j in range(start_rank + i * 8, end_rank, 32): - ranks += range(j, j + 8) - all_data_parallel_group_ranks.append(list(ranks)) - else: - for i in range(pipeline_model_parallel_size): - start_rank = i * num_pipeline_model_parallel_groups - end_rank = (i + 1) * num_pipeline_model_parallel_groups - for j in range(tensor_model_parallel_size): - ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) - all_data_parallel_group_ranks.append(list(ranks)) - - _DATA_PARALLEL_GROUP_SPMD = all_data_parallel_group_ranks - for ranks in all_data_parallel_group_ranks: - pg_options = {"xla_pg_options": {"mesh": _DATA_PARALLEL_GROUP_SPMD}} - group = torch.distributed.new_group(ranks, pg_options=pg_options) - if rank in ranks: - _DATA_PARALLEL_GROUP = group + all_data_parallel_group_ranks = [ + cluster_ranks_nonexp[pp_rank, :, tp_rank].tolist() + for pp_rank, tp_rank in itertools.product( + range(pipeline_model_parallel_size), + range(tensor_model_parallel_size), + ) + ] + _build_and_assign_groups( + group_name="_DATA_PARALLEL_GROUP", + spmd_group_name="_DATA_PARALLEL_GROUP_SPMD", + mesh=all_data_parallel_group_ranks, + compress_rg=False, + ) + + # Build the expert data-parallel groups. + all_exp_data_parallel_group_ranks = [ + cluster_ranks_exp[pp_rank, :, ep_rank, tp_rank].tolist() + for pp_rank, ep_rank, tp_rank in itertools.product( + range(pipeline_model_parallel_size), + range(expert_model_parallel_size), + range(tensor_model_parallel_size), + ) + ] + _build_and_assign_groups( + group_name="_EXP_DATA_PARALLEL_GROUP", + spmd_group_name="_EXP_DATA_PARALLEL_GROUP_SPMD", + mesh=all_exp_data_parallel_group_ranks, + compress_rg=False, + ) # Build the tensor model-parallel groups. - global _TENSOR_MODEL_PARALLEL_GROUP - global _TENSOR_MODEL_PARALLEL_GROUP_SPMD - all_tensor_parallel_group_ranks = [] - assert _TENSOR_MODEL_PARALLEL_GROUP is None, "tensor model parallel group is already initialized" - # See note re: TP=4, above - if tensor_model_parallel_size == 4: - for i in range(0, world_size, 32): - for j in range(8): - ranks = range(i + j, i + 32, 8) - all_tensor_parallel_group_ranks.append(list(ranks)) - else: - for i in range(num_tensor_model_parallel_groups): - ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) - all_tensor_parallel_group_ranks.append(list(ranks)) - - _TENSOR_MODEL_PARALLEL_GROUP_SPMD = all_tensor_parallel_group_ranks - if compress_rg: - # When scaling to large number of nodes, the size of the replica groups becomes huge. - # This increases the overall HLO hashing time which in turn causes framework overhead. - # This can be reduced by passing the first tp replica only. All the other ranks would - # infer their groups depending on the size of the replica group and the start and end ranks - # Note: this works only for cases where the ranks are continuous. It won't work for TP=4 case. - _TENSOR_MODEL_PARALLEL_GROUP_SPMD = [all_tensor_parallel_group_ranks[0]] - for ranks in all_tensor_parallel_group_ranks: - pg_options = {"xla_pg_options": {"mesh": _TENSOR_MODEL_PARALLEL_GROUP_SPMD}} - group = torch.distributed.new_group(ranks, pg_options=pg_options) - if rank in ranks: - _TENSOR_MODEL_PARALLEL_GROUP = group + all_tensor_parallel_group_ranks = [ + cluster_ranks_nonexp[pp_rank, dp_rank, :].tolist() + for pp_rank, dp_rank in itertools.product( + range(pipeline_model_parallel_size), + range(data_parallel_size), + ) + ] + _build_and_assign_groups( + group_name="_TENSOR_MODEL_PARALLEL_GROUP", + spmd_group_name="_TENSOR_MODEL_PARALLEL_GROUP_SPMD", + mesh=all_tensor_parallel_group_ranks, + compress_rg=compress_rg, + ) + + # Build the expert model-parallel groups + all_expert_parallel_group_ranks = [ + cluster_ranks_exp[pp_rank, dp_exp_rank, :, tp_rank].tolist() + for pp_rank, dp_exp_rank, tp_rank in itertools.product( + range(pipeline_model_parallel_size), + range(exp_data_parallel_size), + range(tensor_model_parallel_size), + ) + ] + _build_and_assign_groups( + group_name="_EXPERT_MODEL_PARALLEL_GROUP", + spmd_group_name="_EXPERT_MODEL_PARALLEL_GROUP_SPMD", + mesh=all_expert_parallel_group_ranks, + compress_rg=False, + ) # Build the pipeline model-parallel groups. - global _PIPELINE_MODEL_PARALLEL_GROUP - global _PIPELINE_GLOBAL_RANKS - global _PIPELINE_MODEL_PARALLEL_GROUP_SPMD - assert _PIPELINE_MODEL_PARALLEL_GROUP is None, "pipeline model parallel group is already initialized" - all_pipeline_parallel_group_ranks = [] - for i in range(num_pipeline_model_parallel_groups): - ranks = range(i, world_size, num_pipeline_model_parallel_groups) - all_pipeline_parallel_group_ranks.append(list(ranks)) - _PIPELINE_MODEL_PARALLEL_GROUP_SPMD = all_pipeline_parallel_group_ranks + all_pipeline_parallel_group_ranks = [ + cluster_ranks_nonexp[:, dp_rank, tp_rank].tolist() + for dp_rank, tp_rank in itertools.product( + range(data_parallel_size), + range(tensor_model_parallel_size), + ) + ] + _build_and_assign_groups( + group_name="_PIPELINE_MODEL_PARALLEL_GROUP", + spmd_group_name="_PIPELINE_MODEL_PARALLEL_GROUP_SPMD", + mesh=all_pipeline_parallel_group_ranks, + compress_rg=False, + ) + for ranks in _PIPELINE_MODEL_PARALLEL_GROUP_SPMD: - pg_options = {"xla_pg_options": {"mesh": _PIPELINE_MODEL_PARALLEL_GROUP_SPMD}} if rank in ranks: - group = torch.distributed.new_group(ranks, pg_options=pg_options) - _PIPELINE_MODEL_PARALLEL_GROUP = group + global _PIPELINE_GLOBAL_RANKS _PIPELINE_GLOBAL_RANKS = ranks # Only create pre/next groups if PP is enabled @@ -204,14 +374,46 @@ def initialize_model_parallel(tensor_model_parallel_size: int = 1, pipeline_mode if rank in ranks: group = torch.distributed.new_group(ranks, pg_options=pg_options) _PREV_RANK_GROUP = group - if torch.distributed.get_rank() == 0: - logger.debug(rmsg(f"_PIPELINE_MODEL_PARALLEL_GROUP_SPMD {_PIPELINE_MODEL_PARALLEL_GROUP_SPMD}")) - logger.debug(rmsg(f"_TENSOR_MODEL_PARALLEL_GROUP_SPMD {_TENSOR_MODEL_PARALLEL_GROUP_SPMD}")) - logger.debug(rmsg(f"_DATA_PARALLEL_GROUP_SPMD {_DATA_PARALLEL_GROUP_SPMD}")) + + +def _build_and_assign_groups( + group_name: str, + spmd_group_name: str, + mesh: List[List[int]], + compress_rg: bool, +) -> None: + def __set_global_var(key: str, val: Any) -> None: + if key not in globals(): + raise RuntimeError(f"expected {key} to be in globals but was undefined") + # if globals()[key] is not None: + # raise RuntimeError(f"expected {key} to be uninitialized but was set to {globals()[key]}") + + globals()[key] = val + + if compress_rg: + # When scaling to large number of nodes, the size of the replica groups becomes huge. + # This increases the overall HLO hashing time which in turn causes framework overhead. + # This can be reduced by passing the first tp replica only. All the other ranks would + # infer their groups depending on the size of the replica group and the start and end ranks + # Note: this works only for cases where the ranks are continuous. It won't work for TP=4 case. + mesh = [mesh[0]] + + __set_global_var(key=spmd_group_name, val=mesh) + for group_ranks in mesh: + group = torch.distributed.new_group( + group_ranks, + pg_options={"xla_pg_options": {"mesh": mesh}}, + ) + if torch.distributed.get_rank() in group_ranks: + __set_global_var(key=group_name, val=group) + + if globals()[group_name] is None: + raise RuntimeError(f"expected {group_name} to be initialized but was not. mesh: {mesh}") + try_set_nki_parallel_state() -def try_set_nki_parallel_state(): +def try_set_nki_parallel_state() -> None: """ Inject parallel state information into NkiKernel, if compatible torch_neuronx exists. """ @@ -223,31 +425,47 @@ def try_set_nki_parallel_state(): rank=get_tensor_model_parallel_rank(), world_size=get_tensor_model_parallel_size(), ) - logger.debug(rmsg(f"Successfully initialized NKI parallel state.")) + logger.debug(rmsg("Successfully initialized NKI parallel state.")) except Exception as e: - logger.warning(rmsg(f"Failed to initialize NKI parallel state with exception {e}." - "Proceeding without distributed NKI support.")) + logger.warning( + rmsg( + f"Failed to initialize NKI parallel state with exception {e}." + "Proceeding without distributed NKI support." + ) + ) -def model_parallel_is_initialized(): + +def model_parallel_is_initialized() -> bool: """Check if model and data parallel groups are initialized.""" if _TENSOR_MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: return False return True -def get_tensor_model_parallel_group(as_list=False): +def get_tensor_model_parallel_group(as_list: bool = False) -> ProcessGroup: """Get the tensor model parallel group the caller rank belongs to.""" assert _TENSOR_MODEL_PARALLEL_GROUP is not None, "intra_layer_model parallel group is not initialized" return _TENSOR_MODEL_PARALLEL_GROUP._mesh if as_list else _TENSOR_MODEL_PARALLEL_GROUP -def get_data_parallel_group(as_list=False): +def get_expert_model_parallel_group(as_list: bool = False) -> ProcessGroup: + assert _EXPERT_MODEL_PARALLEL_GROUP is not None, "expert model parallel group is not initialized" + return _EXPERT_MODEL_PARALLEL_GROUP._mesh if as_list else _EXPERT_MODEL_PARALLEL_GROUP + + +def get_data_parallel_group(as_list: bool = False) -> ProcessGroup: """Get the data parallel group the caller rank belongs to.""" assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" return _DATA_PARALLEL_GROUP._mesh if as_list else _DATA_PARALLEL_GROUP -def get_tensor_model_parallel_size(): +def get_expert_data_parallel_group(as_list: bool = False) -> ProcessGroup: + """Get the expert data parallel group the caller rank belongs to.""" + assert _EXP_DATA_PARALLEL_GROUP is not None, "expert data parallel group is not initialized" + return _EXP_DATA_PARALLEL_GROUP._mesh if as_list else _EXP_DATA_PARALLEL_GROUP + + +def get_tensor_model_parallel_size() -> int: """Return world size for the tensor model parallel group.""" global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: @@ -255,19 +473,19 @@ def get_tensor_model_parallel_size(): return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) -def set_tensor_model_parallel_size(world_size): +def set_tensor_model_parallel_size(world_size: int) -> None: """Set the tensor model parallel size""" global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size -def set_tensor_model_parallel_rank(rank): +def set_tensor_model_parallel_rank(rank: int) -> None: """Set tensor model parallel rank.""" global _MPU_TENSOR_MODEL_PARALLEL_RANK _MPU_TENSOR_MODEL_PARALLEL_RANK = rank -def get_tensor_model_parallel_rank(): +def get_tensor_model_parallel_rank() -> int: """Return my rank for the tensor model parallel group.""" global _MPU_TENSOR_MODEL_PARALLEL_RANK if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None: @@ -275,7 +493,7 @@ def get_tensor_model_parallel_rank(): return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) -def get_tensor_model_parallel_src_rank(): +def get_tensor_model_parallel_src_rank() -> int: """Calculate the global rank corresponding to the first local rank in the tensor model parallel group.""" global_rank = torch.distributed.get_rank() @@ -283,7 +501,35 @@ def get_tensor_model_parallel_src_rank(): return (global_rank // local_world_size) * local_world_size -def get_data_parallel_src_rank(): +def set_expert_model_parallel_size(world_size: int) -> None: + """Set the expert model parallel size.""" + global _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE + _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = world_size + + +def get_expert_model_parallel_size() -> int: + """Return world size for the expert model parallel group.""" + global _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE + if _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE is not None: + return _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE + return torch.distributed.get_world_size(group=get_expert_model_parallel_group()) + + +def set_expert_model_parallel_rank(rank: int) -> None: + """Set the expert model parallel rank.""" + global _MPU_EXPERT_MODEL_PARALLEL_RANK + _MPU_EXPERT_MODEL_PARALLEL_RANK = rank + + +def get_expert_model_parallel_rank() -> int: + """Return my rank for the expert model parallel group.""" + global _MPU_EXPERT_MODEL_PARALLEL_RANK + if _MPU_EXPERT_MODEL_PARALLEL_RANK is not None: + return _MPU_EXPERT_MODEL_PARALLEL_RANK + return torch.distributed.get_rank(group=get_expert_model_parallel_group()) + + +def get_data_parallel_src_rank() -> int: """Calculate the global rank corresponding to the first local rank in the data parallel group.""" global_rank = torch.distributed.get_rank() data_parallel_size: int = get_data_parallel_size() @@ -291,28 +537,38 @@ def get_data_parallel_src_rank(): return global_rank % num_data_parallel_groups -def get_data_parallel_size(): +def get_data_parallel_size() -> int: """Return world size for the data parallel group.""" return torch.distributed.get_world_size(group=get_data_parallel_group()) -def get_data_parallel_rank(): +def get_data_parallel_rank() -> int: """Return my rank for the data parallel group.""" return torch.distributed.get_rank(group=get_data_parallel_group()) -def get_pipeline_model_parallel_group(as_list=False): +def get_expert_data_parallel_size() -> int: + """Return world size for the expert data parallel group.""" + return torch.distributed.get_world_size(group=get_expert_data_parallel_group()) + + +def get_expert_data_parallel_rank() -> int: + """Return my rank for the expert data parallel group.""" + return torch.distributed.get_rank(group=get_expert_data_parallel_group()) + + +def get_pipeline_model_parallel_group(as_list: bool = False) -> ProcessGroup: """Get the pipeline model parallel group the caller rank belongs to.""" assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, "pipeline_model parallel group is not initialized" return _PIPELINE_MODEL_PARALLEL_GROUP._mesh if as_list else _PIPELINE_MODEL_PARALLEL_GROUP -def get_pipeline_model_parallel_rank(): +def get_pipeline_model_parallel_rank() -> int: """Return my rank for the pipeline model parallel group.""" return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) -def get_pipeline_model_parallel_sr_group(parity: bool): +def get_pipeline_model_parallel_sr_group(parity: bool) -> ProcessGroup: assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" world_size = get_pipeline_model_parallel_size() @@ -326,38 +582,38 @@ def subgroup(r, ranks): return group -def get_pipeline_model_parallel_size(): +def get_pipeline_model_parallel_size() -> int: """Return world size for the pipeline model parallel group.""" return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group()) -def get_next_rank_group(as_list=False): +def get_next_rank_group(as_list: bool = False) -> ProcessGroup: """Get the next tensor model parallel group the caller rank belongs to.""" assert _NEXT_RANK_GROUP is not None, "intra_layer_model parallel group is not initialized" return _NEXT_RANK_GROUP._mesh if as_list else _NEXT_RANK_GROUP -def get_prev_rank_group(as_list=False): +def get_prev_rank_group(as_list: bool = False) -> ProcessGroup: """Get the previous tensor model parallel group the caller rank belongs to.""" assert _PREV_RANK_GROUP is not None, "intra_layer_model parallel group is not initialized" return _PREV_RANK_GROUP._mesh if as_list else _PREV_RANK_GROUP -def get_pipeline_model_parallel_next_rank(): +def get_pipeline_model_parallel_next_rank() -> int: assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_size() return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] -def get_pipeline_model_parallel_prev_rank(): +def get_pipeline_model_parallel_prev_rank() -> int: assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized" rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_size() return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] -def destroy_model_parallel(): +def destroy_model_parallel() -> None: """Set the groups to none.""" global _TENSOR_MODEL_PARALLEL_GROUP _TENSOR_MODEL_PARALLEL_GROUP = None @@ -379,22 +635,22 @@ def destroy_model_parallel(): _PREV_RANK_GROUP_SPMD = None -def is_tcp_store_available(): +def is_tcp_store_available() -> bool: return TCP_STORE_AVAILABLE -def get_tcp_store(): +def get_tcp_store() -> "Store": """ Getting the default tcp_store from the global group initialization """ - assert is_tcp_store_available(), f"Can not import _get_default_store from distributed_c10d" + assert is_tcp_store_available(), "Can not import _get_default_store from distributed_c10d" return _get_default_store() -def initialize_pp_gloo_groups(): +def initialize_pp_gloo_groups() -> None: global PP_GROUP_PG_GLOO assert PP_GROUP_PG_GLOO is None, "pp gloo groups are already initialized!" - logger.info(f"initialize_pp_gloo_groups...") + logger.info("initialize_pp_gloo_groups...") pp_group_spmd = get_pipeline_model_parallel_group(as_list=True) rank = torch.distributed.get_rank() for pp_group in pp_group_spmd: @@ -403,19 +659,21 @@ def initialize_pp_gloo_groups(): PP_GROUP_PG_GLOO = pg -def get_pp_gloo_group(): +def get_pp_gloo_group() -> ProcessGroup: global PP_GROUP_PG_GLOO assert PP_GROUP_PG_GLOO is not None, "pp gloo groups are not initialized!" return PP_GROUP_PG_GLOO -def is_global_rank_zero(): + +def is_global_rank_zero() -> bool: # TODO: Change this to torch.distributed.get_rank when PTL fix of init_process # before nxd_config is added. import torch_xla.core.xla_model as xm + return xm.get_ordinal() == 0 -def create_pg_with_ranks(ranks): +def create_pg_with_ranks(ranks: List[int]) -> ProcessGroup: """ Create a SPMD process group based on input pp ranks. This can be used to create process group to average grads for shared weights betweenn PP ranks @@ -450,7 +708,7 @@ def create_pg_with_ranks(ranks): return group -def gather_python_object(obj, group): +def gather_python_object(obj: Any, group: ProcessGroup) -> List[Any]: """ Eagerly gather python object for a group Usually used to collect timeline events @@ -462,7 +720,7 @@ def gather_python_object(obj, group): return object_gather_list -def rmsg(msg): +def rmsg(msg: str) -> str: """ Return a message with parallel ranking information """ @@ -481,3 +739,11 @@ def rmsg(msg): global_rank = xm.get_ordinal() return f"[rank_{global_rank}_pp{pp_rank}_tp{tp_rank}_dp{dp_rank}] {msg}" + + +def rmsg_ep(msg: str) -> str: + pp_rank = get_pipeline_model_parallel_rank() + ep_rank = get_expert_model_parallel_rank() + tp_rank = get_tensor_model_parallel_rank() + dp_rank = get_data_parallel_rank() + return f"[pp{pp_rank}|ep{ep_rank}|tp{tp_rank}|dp{dp_rank}] {msg}" diff --git a/src/neuronx_distributed/parallel_layers/utils.py b/src/neuronx_distributed/parallel_layers/utils.py index 901cb89..e667cd1 100644 --- a/src/neuronx_distributed/parallel_layers/utils.py +++ b/src/neuronx_distributed/parallel_layers/utils.py @@ -1,6 +1,6 @@ import collections import os -from typing import List, Sequence +from typing import List, Sequence, Any import numpy as np import torch @@ -75,19 +75,44 @@ def maybe_copy(attribute): maybe_copy(attribute) -def ensure_divisibility(numerator, denominator): +def ensure_divisibility(numerator: int, denominator: int) -> int: """Ensure that numerator is divisible by the denominator.""" assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) -def divide(numerator, denominator): +def divide(numerator: int, denominator: int) -> int: """Ensure that numerator is divisible by the denominator and return the division value.""" ensure_divisibility(numerator, denominator) return numerator // denominator -def get_padding_length(numerator, denominator): +def split_tensor_along_dim( + tensor: torch.Tensor, + dim: int, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """Split a tensor along the dimension 'dim'. + Arguments: + tensor: input tensor. + dim: the dimension to split over. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + dim_size = divide(tensor.size()[dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, dim_size, dim=dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +def get_padding_length(numerator: int, denominator: int) -> int: """Value to pad numerator with to make it evenly divisible by the denominator""" if numerator % denominator != 0: mod = numerator % denominator @@ -108,17 +133,7 @@ def split_tensor_along_last_dim( contiguous_split_chunks: If True, make each chunk contiguous in memory. """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = divide(tensor.size()[last_dim], num_partitions) - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - + return split_tensor_along_dim(tensor, tensor.dim() - 1, num_partitions, contiguous_split_chunks) def split_tensor_along_second_dim( tensor: torch.Tensor, @@ -132,34 +147,25 @@ def split_tensor_along_second_dim( contiguous_split_chunks: If True, make each chunk contiguous in memory. """ - # Get the size and dimension. - last_dim_size = divide(tensor.size()[1], num_partitions) - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=1) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - + return split_tensor_along_dim(tensor, 1, num_partitions, contiguous_split_chunks) -def is_torch_version_greater_than_2(): +def is_torch_version_greater_than_2() -> bool: return torch.__version__.startswith("2") -def is_pjrt_device(): +def is_pjrt_device() -> bool: return os.environ.get("PJRT_DEVICE", None) == "NEURON" -def requires_init_pg_override(): +def requires_init_pg_override() -> bool: return torch.__version__.startswith("2.0") -def cast_tensor(tensor, from_dtype=torch.float32, to_dtype=torch.bfloat16): +def cast_tensor(tensor: torch.Tensor, from_dtype: torch.dtype = torch.float32, to_dtype: torch.dtype = torch.bfloat16) -> Any: return tensor.to(dtype=to_dtype) if tensor.dtype == from_dtype else tensor -def cast_all(state, from_dtype=torch.float32, to_dtype=torch.bfloat16): +def cast_all(state: Any, from_dtype: torch.dtype = torch.float32, to_dtype: torch.dtype = torch.bfloat16) -> Any: if isinstance(state, torch.Tensor): return cast_tensor(state, from_dtype=from_dtype, to_dtype=to_dtype) else: @@ -176,7 +182,7 @@ def cast_all(state, from_dtype=torch.float32, to_dtype=torch.bfloat16): # Refering to https://github.com/NVIDIA/apex/blob/master/apex/_autocast_utils.py#L22 -def cast_if_autocast_enabled(*args): +def cast_if_autocast_enabled(*args: Any) -> Any: if not torch.is_autocast_enabled(): return args else: @@ -184,7 +190,7 @@ def cast_if_autocast_enabled(*args): # Modifying from https://github.com/pytorch/pytorch/blob/main/torch/cuda/amp/autocast_mode.py#L57, removed check for cuda device -def _cast(value, dtype): +def _cast(value: Any, dtype: Any) -> Any: if isinstance(value, torch.Tensor): is_eligible = value.is_floating_point() and (value.dtype is not torch.float64) return value.to(dtype) if is_eligible else value @@ -204,7 +210,7 @@ def _cast(value, dtype): return value -def move_all_tensor_to_cpu(data, convert=True): +def move_all_tensor_to_cpu(data: Any, convert: bool = True) -> Any: def is_xla_tensor(tensor): return tensor.device.type == "xla" @@ -215,12 +221,12 @@ def convert_fn(tensors): return [tensor.to("cpu") for tensor in tensors] def select_fn(v): - return type(v) == torch.Tensor and is_xla_tensor(v) + return isinstance(v, torch.Tensor) and is_xla_tensor(v) return xm.ToXlaTensorArena(convert_fn, select_fn).transform(data) -def get_local_world_size(): +def get_local_world_size() -> int: if is_torch_version_greater_than_2(): # With pjrt this only works after init_process_group() import torch_xla.experimental.pjrt as pjrt @@ -240,7 +246,7 @@ def move_model_to_device(model: torch.nn.Module, device: torch.device) -> None: model_utils.move_model_to_device(model, device) -def verify_casted_dtype(value): +def verify_casted_dtype(value: Any) -> None: """Veryfy whether the input values have been casted correctly""" if not torch.is_autocast_enabled(): return diff --git a/src/neuronx_distributed/pipeline/comm.py b/src/neuronx_distributed/pipeline/comm.py index 9309a5d..c57f182 100644 --- a/src/neuronx_distributed/pipeline/comm.py +++ b/src/neuronx_distributed/pipeline/comm.py @@ -78,10 +78,8 @@ def recv_from(tensor_meta, recv_prev=True, tracing=False, all_reduce_send_recv=F # Use all_gather instead of all_reduce for send/recv rank = xm.get_ordinal() split_index = 0 - gr = [] for group in groups: if rank in group: - gr = group split_index = sorted(group).index(rank) break split_index = 1 - split_index @@ -167,7 +165,7 @@ def _recv_with_gloo_group(src_rank): data = tensor.cpu().numpy().tobytes() # get original length length = int.from_bytes(data[:4], "big") - data = data[4 : length + 4] + data = data[4: length + 4] return pickle.loads(data) @@ -189,7 +187,7 @@ def _recv_with_tcp_store(src_rank): # Need to decode since tcp_store.get returns byte types obj_str = tcp_store.get(key).decode("utf-8") success = True - except: + except Exception: count += 1 if not success: raise RuntimeError(rmsg(f"Failed to receive object with key {key}")) diff --git a/src/neuronx_distributed/pipeline/model.py b/src/neuronx_distributed/pipeline/model.py index 52f97ee..2d425bd 100644 --- a/src/neuronx_distributed/pipeline/model.py +++ b/src/neuronx_distributed/pipeline/model.py @@ -40,6 +40,7 @@ maybe_materalize_model, move_model_to_device, reinit_model, + get_delay_tracing, ) from neuronx_distributed.utils.serialization import ( SerializationManager, @@ -75,6 +76,8 @@ def __init__( return_loss_on_cpu: Optional[bool] = True, deallocate_pipeline_outputs: bool = False, auto_partition: Optional[bool] = False, + fuse_microbatches: Optional[bool] = False, + _delay_tracing: Optional[bool] = False, _turn_off_odd_even_scheduler: Optional[bool] = False, _all_reduce_send_recv: Optional[bool] = False, _fused_send_recv: Optional[bool] = False, @@ -203,7 +206,8 @@ def __init__( super(NxDPPModel, self).__init__() if not parallel_state.model_parallel_is_initialized() and not _debug_mode: raise RuntimeError( - "Model parallelism needs to be initialzed before applying NxDPPModel wrapper. Please call neuronx_distributed.parallel_layers.initialize_model_parallel(pipeline_model_parallel_size, tensor_model_parallel_size)" # noqa: E501 + "Model parallelism needs to be initialzed before applying NxDPPModel wrapper. Please call neuronx_distributed.parallel_layers.initialize_model_parallel(pipeline_model_parallel_size, tensor_model_parallel_size)" + # noqa: E501 ) if transformer_layer_cls is None: raise ValueError("NxDPPModel requires transformer_layer_cls as input") @@ -223,6 +227,11 @@ def __init__( self.virtual_pipeline_size = virtual_pipeline_size self.param_init_fn = param_init_fn self._all_reduce_send_recv = _all_reduce_send_recv + self.fuse_microbatches = fuse_microbatches + if self.fuse_microbatches: + if self.return_loss_on_cpu: + logger.warning("return_loss_on_cpu will be set to False when fuse_microbatches is set to True.") + self.return_loss_on_cpu = False # Non-public config self._mark_step_before_pp_runtime = _mark_step_before_pp_runtime @@ -244,7 +253,7 @@ def __init__( self._metadata_comm_type = "gloo" if _use_gloo_for_metadata_comm else "tcp" if not parallel_state.is_tcp_store_available() and self._metadata_comm_type == "tcp": - logger.warning(f"Can not get default tcp_store, fall back to use gloo for metadata communication") + logger.warning("Can not get default tcp_store, fall back to use gloo for metadata communication") self._metadata_comm_type = "gloo" # Use Odd/Even schedule when num_microbatches==pipeline_parallel_size @@ -278,6 +287,14 @@ def __init__( self.local_name_to_original_name = {} self.original_name_to_local_name = {} self.local_stage_modules = [] + self._delay_tracing = _delay_tracing + self.meta_device_parameter_map = None + self.leaf_module_cls = leaf_module_cls + self.autowrap_functions = autowrap_functions + self.autowrap_modules = autowrap_modules + self.autowrap_obj_methods = autowrap_obj_methods + self.tracer_cls = tracer_cls + self.meta_named_parameters = {} self.clear_minibatch_state() if not _debug_mode: @@ -295,16 +312,20 @@ def __init__( model_layers = self.get_model_layers(self.original_torch_module, self.transformer_layer_cls) if len(model_layers) == 0: raise ValueError(f"No modules of type {self.transformer_layer_cls} found in the model.") - if torch.distributed.get_rank() == 0: - logger.info("Model transformer layers are: \n{}".format(model_layers)) + logger.info("Model transformer layers are: \n%s", model_layers) num_partitions = self.pipeline_parallel_size * self.virtual_pipeline_size pipeline_cuts = create_partitions(num_partitions, model_layers) - if torch.distributed.get_rank() == 0: - logger.info("Pipeline cuts are: \n{}".format(pipeline_cuts)) + logger.info("Pipeline cuts are: \n%s", pipeline_cuts) - # If pipeline cuts are set, directly run tracing and partition - if pipeline_cuts is not None and len(pipeline_cuts) > 0: + # If pipeline cuts are set and input names are provided, directly run tracing and partition + if pipeline_cuts is not None and len(pipeline_cuts) > 0 and not self._delay_tracing: + if input_names is None: + raise ValueError( + "Input names are required for tracing when the NxD optimizer and model wrapper are not used" + ) self.trace( + args=None, + kwargs=None, input_names=input_names, leaf_modules=leaf_module_cls, autowrap_functions=autowrap_functions, @@ -371,6 +392,8 @@ def _empty_send_tensor_buffer(self): def trace( self, + args: Optional[List[str]] = None, + kwargs: Optional[Dict[Any, Any]] = None, input_names: Optional[List[str]] = None, leaf_modules: Optional[List[Any]] = None, autowrap_functions: Optional[List[Callable]] = None, @@ -406,6 +429,8 @@ def trace( leaf_modules.append(self.transformer_layer_cls.__name__) self.traced_model = trace_model( self.original_torch_module, + args=args, + kwargs=kwargs, input_names=input_names, leaf_modules=leaf_modules, autowrap_functions=autowrap_functions, @@ -432,7 +457,8 @@ def partition(self): if (self.pipeline_parallel_size * self.virtual_pipeline_size) != num_stages: raise ValueError( - f"User cut stages {num_stages} mismatch the initialized pipeline parallel size {self.pipeline_parallel_size}" # noqa: E501 + f"User cut stages {num_stages} mismatch the initialized pipeline parallel size {self.pipeline_parallel_size}" + # noqa: E501 ) assert ( self.traced_model is not None @@ -472,6 +498,18 @@ def partition(self): self._post_partition(qualname_map) + # Execute steps to move model to device in case of delayed tracing flow + if get_delay_tracing(self): + self._create_meta_parameter_map() + self.maybe_materialize_local_module() + self.move_model_to_device() + self._get_meta_device_parameter_map() + + # Execute post partition hooks + from neuronx_distributed.trainer import hooks + + hooks.execute_all_hooks(self) + def _post_partition(self, qualname_map): # Create name mapping between original parameter to partitioned parameter self._build_parameter_buffer_name_mapping(qualname_map) @@ -602,6 +640,23 @@ def _sync_shared_weights(self): torch.distributed.all_reduce(p.data, group=pg) break + def _create_meta_parameter_map(self): + for n, p in self.named_parameters(): + self.meta_named_parameters[n] = p + return self.meta_named_parameters + + def _get_meta_device_parameter_map(self): + if self.meta_device_parameter_map is not None: + return self.meta_device_parameter_map + else: + self.meta_device_parameter_map = {} + for name, param in self.named_parameters(): + if param.device.type == "xla": + self.meta_device_parameter_map[self.meta_named_parameters[name]] = param + else: + self.meta_device_parameter_map[self.meta_named_parameters[name]] = self.meta_named_parameters[name] + return self.meta_device_parameter_map + def _mark_pipeline_cuts(self, cut_point): # Internal API to mark the cut in the graph for node in self.traced_model.graph.nodes: @@ -632,7 +687,8 @@ def _verify_inputs(self, kwargs): if self.input_names is not None: if set(kwargs.keys()) != set(self.input_names): raise RuntimeError( - f"train/eval inputs ({set(kwargs.keys())}) must be same as the tracing input names {set(self.input_names)}" # noqa: E501 + f"train/eval inputs ({set(kwargs.keys())}) must be same as the tracing input names {set(self.input_names)}" + # noqa: E501 ) def _disable_grad_for_nonlocal(self): @@ -679,23 +735,50 @@ def _prepare_inputs_and_infer_shape(self, kwargs): self.should_graph_break = True self.shape_traced = True end = time.time() - logger.info(rmsg(f"Tensor shapes inference finished, total consumed time {end-start}s")) + logger.info(rmsg(f"Tensor shapes inference finished, total consumed time {end - start}s")) for i in range(self.virtual_pipeline_size): stage_id = self.get_current_stage(i) logger.debug( rmsg( - f"After tracing model chunk {i}'s stage_id_to_IO_input_names {self.stage_id_to_IO_input_names[stage_id]}" # noqa: E501 + f"After tracing model chunk {i}'s stage_id_to_IO_input_names {self.stage_id_to_IO_input_names[stage_id]}" + # noqa: E501 ) ) # Need to create input iters again since the old one is garbage collected self._create_model_inputs_iter(kwargs) - def run_train(self, **kwargs): + def perform_delayed_tracing_and_partition(self, args, kwargs): + # Perform tracing + model_layers = self.get_model_layers(self.original_torch_module, self.transformer_layer_cls) + num_partitions = self.pipeline_parallel_size * self.virtual_pipeline_size + pipeline_cuts = create_partitions(num_partitions, model_layers) + if pipeline_cuts is not None and len(pipeline_cuts) > 0: + self.trace( + args=args, + kwargs=kwargs, + leaf_modules=self.leaf_module_cls, + autowrap_functions=self.autowrap_functions, + autowrap_modules=self.autowrap_modules, + autowrap_obj_methods=self.autowrap_obj_methods, + tracer_cls=self.tracer_cls, + ) + for pp_cut in pipeline_cuts: + self.cut_pipeline_stage(pp_cut) + + # Perform partition + if not self.partitioned: + self.partition() + + def run_train(self, *args, **kwargs): if self._mark_step_before_pp_runtime: self._mark_step() self._autocast_enabled = torch.is_autocast_enabled() self._autocast_dtype = torch.get_autocast_gpu_dtype() + + if not self.partitioned: + self.perform_delayed_tracing_and_partition(args, kwargs) + with torch.cuda.amp.autocast(enabled=False): loss = self._run_train(**kwargs) self._autocast_enabled = False @@ -709,6 +792,10 @@ def run_eval(self, **kwargs): self._mark_step() self._autocast_enabled = torch.is_autocast_enabled() self._autocast_dtype = torch.get_autocast_gpu_dtype() + + if not self.partitioned: + self.perform_delayed_tracing_and_partition({}, kwargs) + with torch.cuda.amp.autocast(enabled=False): loss = self._run_eval(**kwargs) self._autocast_enabled = False @@ -825,7 +912,8 @@ def _handle_stage_outputs(self, outputs, stage): else: if len(outputs) != self.stage_id_to_output_count[stage]: raise RuntimeError( - f"Stage {stage} number outputs ({len(outputs)}) mismatches with compiled result ({self.stage_id_to_output_count[stage]})" # noqa: E501 + f"Stage {stage} number outputs ({len(outputs)}) mismatches with compiled result ({self.stage_id_to_output_count[stage]})" + # noqa: E501 ) return outputs @@ -897,13 +985,15 @@ def _fwd_step_task(self): if self.current_mb_stage_input[1] != self.current_mb: raise RuntimeError( rmsg( - f"Running ForwardStepTask for mb {self.current_mb} but current_mb_stage_input contains mb {self.current_mb_stage_input[1]}" # noqa: E501 + f"Running ForwardStepTask for mb {self.current_mb} but current_mb_stage_input contains mb {self.current_mb_stage_input[1]}" + # noqa: E501 ) ) if self.current_mb_stage_input[2] != self.current_model_chunk: raise RuntimeError( rmsg( - f"Running ForwardStepTask for model chunk {self.current_model_chunk} but current_mb_stage_input contains model chunk {self.current_mb_stage_input[2]}" # noqa: E501 + f"Running ForwardStepTask for model chunk {self.current_model_chunk} but current_mb_stage_input contains model chunk {self.current_mb_stage_input[2]}" + # noqa: E501 ) ) if self.current_mb_stage_output is not None: @@ -959,11 +1049,13 @@ def _fwd_step_task(self): # [TODO] Add support, requires for cross attention if pass_along_io and t.requires_grad: raise RuntimeError( - f"Does not support tensors that require grads to pass along! IO name {name} current stage {next_stage - 1}" # noqa: E501 + f"Does not support tensors that require grads to pass along! IO name {name} current stage {next_stage - 1}" + # noqa: E501 ) logger.debug( rmsg( - f"fwd mb {self.current_mb} model chunk {self.current_model_chunk} collect {name}'s {idx}th tensor meta {tensor_meta[idx]} for bwd" # noqa: E501 + f"fwd mb {self.current_mb} model chunk {self.current_model_chunk} collect {name}'s {idx}th tensor meta {tensor_meta[idx]} for bwd" + # noqa: E501 ) ) # Collect the outputs that require grad for bwd @@ -995,11 +1087,13 @@ def _bwd_step_task(self): raise RuntimeError("Running BackwardStepTask but current_mb_grads is None") if self.current_mb_grads[1] != self.current_mb: raise RuntimeError( - f"Running BackwardStepTask for mb {self.current_mb} but current_mb_grads contains mb {self.current_mb_grads[1]}" # noqa: E501 + f"Running BackwardStepTask for mb {self.current_mb} but current_mb_grads contains mb {self.current_mb_grads[1]}" + # noqa: E501 ) if self.current_mb_grads[2] != self.current_model_chunk: raise RuntimeError( - f"Running BackwardStepTask for model chunk {self.current_model_chunk} but current_mb_grads contains model chunk {self.current_mb_grads[2]}" # noqa: E501 + f"Running BackwardStepTask for model chunk {self.current_model_chunk} but current_mb_grads contains model chunk {self.current_mb_grads[2]}" + # noqa: E501 ) if len(self.mb_to_outputs_for_grads[self.current_model_chunk][self.current_mb]) == 0: raise RuntimeError( @@ -1037,7 +1131,8 @@ def _fwd_preprocess_task(self): if len(self.mb_to_inputs_for_grads[self.current_model_chunk][self.current_mb]) != 0: raise RuntimeError( rmsg( - "Running ForwardPreprocessTask but mb_to_inputs_for_grads already contains inputs for current mb" # noqa: E501 + "Running ForwardPreprocessTask but mb_to_inputs_for_grads already contains inputs for current mb" + # noqa: E501 ) ) @@ -1109,13 +1204,15 @@ def _fwd_postprocess_task(self): if self.current_mb_stage_output[1] != self.current_mb: raise RuntimeError( rmsg( - f"Running ForwardPostprocessTask for mb {self.current_mb} but current_mb_stage_output contains mb {self.current_mb_stage_output[1]}" # noqa: E501 + f"Running ForwardPostprocessTask for mb {self.current_mb} but current_mb_stage_output contains mb {self.current_mb_stage_output[1]}" + # noqa: E501 ) ) if self.current_mb_stage_output[2] != self.current_model_chunk: raise RuntimeError( rmsg( - f"Running ForwardPostprocessTask for model chunk {self.current_model_chunk} but current_mb_stage_output contains model chunk {self.current_mb_stage_output[2]}" # noqa: E501 + f"Running ForwardPostprocessTask for model chunk {self.current_model_chunk} but current_mb_stage_output contains model chunk {self.current_mb_stage_output[2]}" + # noqa: E501 ) ) outputs = self.current_mb_stage_output[0] @@ -1132,7 +1229,8 @@ def _fwd_postprocess_task(self): # Tensors that needs to pass along if name not in self.mb_pass_along_io[self.current_mb]: raise RuntimeError( - f"Pass along io {name} is missing, current_mb_pass_along_io {self.mb_pass_along_io[self.current_mb].keys()}" # noqa: E501 + f"Pass along io {name} is missing, current_mb_pass_along_io {self.mb_pass_along_io[self.current_mb].keys()}" + # noqa: E501 ) current_output, model_chunk = self.mb_pass_along_io[self.current_mb].pop(name) if model_chunk != self.current_model_chunk: @@ -1292,7 +1390,7 @@ def _exec_schedule(self, pipe_schedule): logger.debug(rmsg(f"Run task {task}")) self.current_mb = task.mb self.current_model_chunk = task.model_chunk - self.should_graph_break = task.graph_break + self.should_graph_break = task.graph_break and (not self.fuse_microbatches) # Equivalent to: self._fwd_step_task() self._exec_instr = MethodType(self._INSTRUCTION_MAP[type(task)], self) self._exec_instr() @@ -1490,7 +1588,7 @@ def _create_pg_with_ranks(self, ranks): return pg def _get_microbatch_dataloader(self, all_batches): - return MpDeviceLoader(all_batches, xm.xla_device(), batches_per_execution=self.num_microbatches) + return MpDeviceLoader(all_batches, xm.xla_device(), batches_per_execution=self.num_microbatches + 1) def maybe_materialize_local_module(self): """ diff --git a/src/neuronx_distributed/pipeline/partition.py b/src/neuronx_distributed/pipeline/partition.py index 89b4ca2..d8e99ee 100644 --- a/src/neuronx_distributed/pipeline/partition.py +++ b/src/neuronx_distributed/pipeline/partition.py @@ -285,6 +285,8 @@ def create_partitions(pipeline_parallel_size, model_layer_names): """ num_hidden_layers = len(model_layer_names) num_layer_per_partition = num_hidden_layers // pipeline_parallel_size + assert num_layer_per_partition >= 1, f"The partition cannot be created; num_hidden_layers should be atleast equal to \ + pipeline_parallel_size, but found num_hidden_layers = {num_hidden_layers} and pipeline_parallel_size = {pipeline_parallel_size}" layers_per_partition = [num_layer_per_partition for x in range(pipeline_parallel_size)] remainder = num_hidden_layers % pipeline_parallel_size if remainder > 0: diff --git a/src/neuronx_distributed/pipeline/scheduler.py b/src/neuronx_distributed/pipeline/scheduler.py index df8f559..3ae14f4 100644 --- a/src/neuronx_distributed/pipeline/scheduler.py +++ b/src/neuronx_distributed/pipeline/scheduler.py @@ -9,7 +9,7 @@ def __init__(self, mb, model_chunk=0, graph_break=True): def __eq__(self, other) -> bool: return ( - type(self) == type(other) + type(self) is type(other) and self.mb == other.mb and self.model_chunk == other.model_chunk and self.graph_break == other.graph_break @@ -64,10 +64,10 @@ def __init__(self, graph_break=True): class ReduceGradsTask(PostProcessTask): def __repr__(self): - return f"ReduceGradsTask" + return "ReduceGradsTask" def __eq__(self, other) -> bool: - return type(self) == type(other) + return type(self) is type(other) class PipeSchedule(ABC): diff --git a/src/neuronx_distributed/pipeline/trace.py b/src/neuronx_distributed/pipeline/trace.py index 0891c23..3c62be4 100644 --- a/src/neuronx_distributed/pipeline/trace.py +++ b/src/neuronx_distributed/pipeline/trace.py @@ -74,8 +74,24 @@ def __init__(self, **config) -> None: self.name = "pytorch" -def get_concrete_args(model: nn.Module, input_names: List[str]): +def get_concrete_args( + model: nn.Module, + input_names: Optional[List[str]] = None, + args: Optional[List[Any]] = None, + kwargs: Optional[Dict[Any, Any]] = None +): sig = inspect.signature(model.forward) + if input_names is None and (kwargs is not None or args is not None): + input_names = [] + # Handle args given without keywords + for i in range(len(args)): + param_name = list(sig.parameters.keys())[i] + input_names.append(param_name) + + if kwargs is not None: + for k, v in kwargs.items(): + input_names.append(k) + # Get the names of all provided args from customer and pass those to input_names if not (set(input_names) <= set(sig.parameters.keys())): formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names) @@ -134,6 +150,8 @@ def patch_obj_method(autowrap_obj_methods): def trace_model( model: nn.Module, + args: Optional[List[Any]] = None, + kwargs: Optional[Dict[Any, Any]] = None, input_names: Optional[List[str]] = None, tracer_cls: Union[Any, str] = None, leaf_modules: Optional[List[Any]] = None, @@ -143,20 +161,12 @@ def trace_model( ): if _create_wrapped_func is None and autowrap_obj_methods is not None: logger.warning( - f"Can not import _create_wrapped_func from torch.fx.__symbolic_trace, autowrap_obj_method will be ignored" + "Can not import _create_wrapped_func from torch.fx.__symbolic_trace, autowrap_obj_method will be ignored" ) tracer_cls = get_tracer_class(model, tracer_cls=tracer_cls) - if input_names is None: - logger.warning(f"Getting input_names None. It is recommending to set up input names for tracing.") - if is_hf_pretrained_model(model): - input_names = model.dummy_inputs.keys() - else: - input_names = [] - - input_names = list(input_names) - concrete_args = get_concrete_args(model, input_names) + concrete_args = get_concrete_args(model, input_names, args, kwargs) # User specified leaf modules to skip if leaf_modules is None: diff --git a/src/neuronx_distributed/quantization/dequantize.py b/src/neuronx_distributed/quantization/dequantize.py index 8db0afd..bb196d2 100644 --- a/src/neuronx_distributed/quantization/dequantize.py +++ b/src/neuronx_distributed/quantization/dequantize.py @@ -1,7 +1,20 @@ import torch +def direct_cast_dequantize(tensor: torch.Tensor, upcast_dtype: torch.dtype) -> torch.Tensor: + """ + A utility function to dequantize a tensor from lower dtype to upcast dtype without any scaling factor + + Args: + tensor (torch.Tensor): tensor to be dequantized + upcast_dtype (torch.dtype): upcast dtype + + Returns: + torch.Tensor: upcasted tensor + """ + upcast_tensor = tensor.to(upcast_dtype) + return upcast_tensor -def dequantize(tensor: torch.Tensor, scale: torch.Tensor, upcast_dtype: torch.dtype) -> torch.Tensor: +def scale_dequantize(tensor: torch.Tensor, scale: torch.Tensor, upcast_dtype: torch.dtype) -> torch.Tensor: """ A utility function to dequantize a tensor from lower dtype to upcast dtype based on its corresponding scale Note: It will not convert back the tensor to its existing dtype @@ -11,9 +24,9 @@ def dequantize(tensor: torch.Tensor, scale: torch.Tensor, upcast_dtype: torch.dt scale (torch.Tensor): scale to be used for dequantization Returns: - torch.Tensor: upcasted tensor with the same dtype as the input tensor - torch.Tensor: the scale used to dequantize the input tensor + torch.Tensor: upcasted tensor multiplied with scale """ - upcast_tensor = tensor.to(upcast_dtype) + upcast_tensor = tensor.to(torch.float32) upcast_tensor *= scale + upcast_tensor = upcast_tensor.to(upcast_dtype) return upcast_tensor diff --git a/src/neuronx_distributed/quantization/observer.py b/src/neuronx_distributed/quantization/observer.py new file mode 100644 index 0000000..9ae7ae6 --- /dev/null +++ b/src/neuronx_distributed/quantization/observer.py @@ -0,0 +1,166 @@ +""" +This module implements observers which are used to collect statistics about +the values observed during calibration (PTQ) or training (QAT). +""" + +from typing import Any, Dict, List, Tuple + +import torch +from torch.ao.quantization.observer import UniformQuantizationObserverBase + + +class PerChannelAbsMaxObserver(UniformQuantizationObserverBase): + r"""Observer module for computing the quantization parameters based on the + running per channel abs max values. + + This observer uses the tensor abs max statistics to compute the per channel + quantization parameters. The module records the running abs maximum + of incoming tensors, and uses this statistic to compute the quantization + parameters. + + Args: + ch_axis: Channel axis + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used + reduce_range: Reduces the range of the quantized data type by 1 bit + quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`. + + The quantization parameters are computed the same way as in + :class:`~torch.ao.quantization.observer.MinMaxObserver`, with the difference + that the running abs max values are stored per channel. + Scales and zero points are thus computed per channel as well. + + """ + max_val: torch.Tensor + + def __init__( + self, + ch_axis=0, + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, + ) -> None: + + # Currently only suppport torch.quint8 + assert dtype == torch.qint8, "Only torch.qint8 is supported" + + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + ) + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.ch_axis = ch_axis + self.register_buffer("max_val", torch.tensor([], **factory_kwargs)) + + def forward(self, x_orig): + return self._forward(x_orig) + + def _forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + max_val = self.max_val + x_dim = x.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(self.max_val.dtype) + y = torch.flatten(y, start_dim=1) + + if max_val.numel() == 0: + max_val = torch.amax(torch.abs(y), dim=1, keepdim=True) + else: + max_val_cur = torch.amax(torch.abs(y), dim=1, keepdim=True) + max_val = torch.max(max_val_cur, max_val) + self.max_val.resize_(max_val.shape) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Calculates the quantization parameters, given max + value tensors + + Returns: + scales: Scales tensor of shape (#channels,) + zero_points: Zero points tensor of shape (#channels,) + """ + + quant_max = self.quant_max + # After profiling remove the assertion if taking too much time + assert torch.all(self.max_val >= 0) + max_val_pos = self.max_val # self.max_val already is the absolute maximum + + device = max_val_pos.device + scale = torch.ones(max_val_pos.size(), dtype=torch.float32, device=device) + zero_point = torch.zeros(max_val_pos.size(), dtype=torch.int64, device=device) + + if self.qscheme == torch.per_tensor_symmetric or self.qscheme == torch.per_channel_symmetric: + scale = max_val_pos / (float(quant_max)) + scale = torch.max(scale, self.eps) + else: + raise ValueError(f"Only support {torch.per_tensor_symmetric} and {torch.per_channel_symmetric}") + + # For scalar values, cast them to Tensors of size 1 to keep the shape + # consistent with default values in FakeQuantize. + if len(scale.shape) == 0: + # TODO: switch to scale.item() after adding JIT support + scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device) + if len(zero_point.shape) == 0: + # TODO: switch to zero_point.item() after adding JIT support + zero_point = torch.tensor([int(zero_point)], dtype=zero_point.dtype, device=device) + + return scale.squeeze(1), zero_point.squeeze(1) + + def extra_repr(self): + return f"abs_max_val={self.max_val}" + + def _load_from_state_dict( + self, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Dict[str, torch.Tensor], + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ): + raise NotImplementedError() + + def _load_from_state_dict_script( + self, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Dict[str, torch.Tensor], + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ): + raise NotImplementedError() + + @torch.jit.export + def reset_min_max_vals(self): + """Resets the min/max values.""" + # This used to be torch.ones but that does not work because + # JIT compiler can optimize it via common subexpression elimination + # in which case both min_val and max_val point to the same tensor. + self.max_val = torch.rand( + 0, + ) diff --git a/src/neuronx_distributed/quantization/quantization_config.py b/src/neuronx_distributed/quantization/quantization_config.py new file mode 100644 index 0000000..67e9ba8 --- /dev/null +++ b/src/neuronx_distributed/quantization/quantization_config.py @@ -0,0 +1,56 @@ +### Create Enum to define the type of quantization possible +import enum +from enum import Enum +from typing import TypedDict + +import torch + + +class MyEnumMeta(enum.EnumMeta): + def __contains__(cls, item): + try: + cls(item) + except ValueError: + return False + else: + return True + + +class QuantizationType(Enum, metaclass=MyEnumMeta): + PER_TENSOR_SYMMETRIC = "per_tensor_symmetric" + PER_CHANNEL_SYMMETRIC = "per_channel_symmetric" + + +class QuantizedDtype(Enum, metaclass=MyEnumMeta): + INT8 = torch.int8 + + +class BASE_QCONFIG_DICT_TYPE(TypedDict): + quantization_type: QuantizationType + quantized_dtype: QuantizedDtype + + +class PER_CHANNEL_QCONFIG_DICT_TYPE(BASE_QCONFIG_DICT_TYPE): + quantization_per_channel_axis: int + + +_DEFAULT_CUSTOM_QCONFIG_DICT: BASE_QCONFIG_DICT_TYPE = { + "quantization_type": QuantizationType.PER_TENSOR_SYMMETRIC, + "quantized_dtype": QuantizedDtype.INT8, +} + +_DEFAULT_PER_CHANNEL_QCONFIG_DICT: PER_CHANNEL_QCONFIG_DICT_TYPE = { + "quantization_type": QuantizationType.PER_CHANNEL_SYMMETRIC, + "quantized_dtype": QuantizedDtype.INT8, + "quantization_per_channel_axis": 0, +} + + +def get_default_custom_qconfig_dict() -> BASE_QCONFIG_DICT_TYPE: + r"""Defines the default custom config dict.""" + return dict(_DEFAULT_CUSTOM_QCONFIG_DICT) + + +def get_default_per_channel_custom_qconfig_dict() -> PER_CHANNEL_QCONFIG_DICT_TYPE: + """Defines the default custom per channel config dict""" + return dict(_DEFAULT_PER_CHANNEL_QCONFIG_DICT) diff --git a/src/neuronx_distributed/quantization/quantization_layers.py b/src/neuronx_distributed/quantization/quantization_layers.py index e09fbec..0a08816 100644 --- a/src/neuronx_distributed/quantization/quantization_layers.py +++ b/src/neuronx_distributed/quantization/quantization_layers.py @@ -12,16 +12,19 @@ We would be creating neuronx_distributed.functional API for this purpose. NAPP-2202 """ -import enum + import warnings from abc import ABCMeta, abstractmethod -from enum import Enum from typing import Optional, Tuple, Union import torch from torch.nn.parameter import Parameter +from neuronx_distributed.modules.moe.moe_parallel_layers import ( + ExpertFusedLinearWithAsyncCommunication, ExpertFusedLinear +) from neuronx_distributed.parallel_layers.layers import ( + LinearWithAsyncCommunication, _initialize_affine_weight_neuron, _initialize_parameter_cpu, linear_with_async_allreduce, @@ -34,41 +37,32 @@ scatter_to_tensor_model_parallel_region, ) from neuronx_distributed.parallel_layers.parallel_state import ( - get_tensor_model_parallel_size, + get_tensor_model_parallel_size, get_expert_model_parallel_size ) from neuronx_distributed.parallel_layers.utils import ( divide, set_tensor_model_parallel_attributes, ) -from neuronx_distributed.quantization.dequantize import dequantize +from neuronx_distributed.quantization.dequantize import direct_cast_dequantize, scale_dequantize +from neuronx_distributed.quantization.quantization_config import ( + _DEFAULT_CUSTOM_QCONFIG_DICT, + BASE_QCONFIG_DICT_TYPE, + PER_CHANNEL_QCONFIG_DICT_TYPE, + QuantizationType, + QuantizedDtype, +) +from neuronx_distributed.quantization.quantization_utils import extract_q_scale from neuronx_distributed.utils.logger import get_logger logger = get_logger() -### Create Enum to define the type of quantization possible -class MyEnumMeta(enum.EnumMeta): - def __contains__(cls, item): - try: - cls(item) - except ValueError: - return False - else: - return True - - -class QuantizationType(Enum, metaclass=MyEnumMeta): - SCALAR = "scalar" - - -class QuantizedDtype(Enum, metaclass=MyEnumMeta): - INT8 = torch.int8 - - class BaseQuantizeParallelLinear(torch.nn.Module, metaclass=ABCMeta): + autograd_func_class = LinearWithAsyncCommunication + def __init__( self, - quantization_type: Union[QuantizationType, str] = "scalar", + quantization_type: Union[QuantizationType, str] = "per_tensor_symmetric", dequantized_dtype: torch.dtype = torch.bfloat16, quantized_dtype: torch.dtype = torch.int8, device: torch.device = None, @@ -76,7 +70,7 @@ def __init__( """_summary_ Args: - quantization_type (Union[QuantizationType, str], optional): Quantization type. Defaults to "scalar". + quantization_type (Union[QuantizationType, torch.qscheme], optional): Quantization type. Defaults to per_tensor_symmetric. dequantized_dtype (torch.dtype, optional): Detype to dequantize the weight to. Defaults to torch.bfloat16. quantized_dtype (torch.dtype, optional): Dtype to qunatize the weight to. Defaults to torch.int8. device (torch.device, optional): Device to which initialize the Parameters. Defaults to None. @@ -94,6 +88,13 @@ def __init__( self.device = device self.register_parameter("scale", None) + self.keep_master_weight = None + + self.weight_shape = None + self.weight_partition_dim = None + self.stride = None + self.bias_shape = None + def _init_weight(self, weight: torch.Tensor): """Init the weight in Quantized Parallel layers with zeroes. @@ -112,14 +113,135 @@ def _init_bias(self, bias: torch.Tensor): """ torch.nn.init._no_grad_fill_(bias, 0.0) + def _setup_for_weight(self): + init_device = self.device + weight = torch.empty(*self.weight_shape, dtype=self.quantized_dtype, device=init_device) + self.weight = Parameter(weight, requires_grad=False) + self.device = self.weight.device + + if self.device.type == "cpu": + self.master_weight = _initialize_parameter_cpu( + param=self.weight, + partition_dim=self.weight_partition_dim, + init_method=self._init_weight, + param_dtype=self.quantized_dtype, + stride=self.stride, + return_master_param=self.keep_master_weight, + ) + elif self.device.type == "meta": + set_tensor_model_parallel_attributes( + tensor=self.weight, + is_parallel=True, + dim=self.weight_partition_dim, + stride=self.stride, + ) + else: + _initialize_affine_weight_neuron( + weight=self.weight, + init_method=self._init_weight, + partition_dim=self.weight_partition_dim, + stride=self.stride, + ) + + setattr(self.weight, "get_tensor_from_state_dict", self.get_weight_from_state_dict) + setattr(self.weight, "set_tensor_to_state_dict", self.set_weight_to_state_dict) + + def _base_setup_for_bias(self, bias: bool): + if bias: + if self.device is None or self.device.type == "cpu": + self.bias = Parameter(torch.empty(*self.bias_shape, dtype=self.dequantized_dtype), requires_grad=False) + else: + self.bias = Parameter( + torch.empty(*self.bias_shape, device=self.device, dtype=self.dequantized_dtype), requires_grad=False + ) + if self.bias.device != torch.device("meta"): + self._init_bias(self.bias) + + setattr(self.bias, "get_tensor_from_state_dict", self.get_bias_from_state_dict) + setattr(self.bias, "set_tensor_to_state_dict", self.set_bias_to_state_dict) + else: + self.register_parameter("bias", None) + + def _setup_for_scale( + self, + weight_shape: tuple, + quantization_type: QuantizationType, + weight_partition_dim: Optional[int] = None, + per_channel_axis: Optional[int] = None, + ): + """Setup required for scale + + Args: + weight_shape (tuple): Weight shape + quantization_type (QuantizationType): Quantization Type + weight_partition_dim (Optional[int], optional): Weight partition dimension. Defaults to None. + This is required if per channel quantization is used. + per_channel_axis (Optional[int], optional): Scale dimension. Defaults to None. + This is required if per channel quantization is used. + + Raises: + ValueError: If quantization_type is not within QuantizationType.PER_TENSOR_SYMMETRIC and QuantizationType.PER_CHANNEL_SYMMETRIC + + NOTE: Currently we are setting the attribute for tensor model parallel even for per tensor symmetric case. + This is to make it uniform. After KVCache quantization is implemented(as K and V have different quantization schemes) and if the uniformity is not required, remove it. + """ + if quantization_type == QuantizationType.PER_TENSOR_SYMMETRIC: + self.scale = Parameter(torch.tensor([1.0]), requires_grad=False) + set_tensor_model_parallel_attributes(tensor=self.scale, is_parallel=False, dim=None, stride=None) + elif quantization_type == QuantizationType.PER_CHANNEL_SYMMETRIC: + assert ( + per_channel_axis is not None + ), "per_channel_axis cannot be None for per_channel_symmetric quantization" + scale_shape = [1] * len(weight_shape) + scale_shape[per_channel_axis] = weight_shape[per_channel_axis] + self.scale = Parameter(torch.ones(scale_shape, device=self.weight.device), requires_grad=False) + + # Only when the weight partition dim is the per_channel_axis, we need to partition the scale + if weight_partition_dim == per_channel_axis: + # we need to partition scale as well + set_tensor_model_parallel_attributes( + tensor=self.scale, is_parallel=True, dim=weight_partition_dim, stride=self.stride + ) + else: + set_tensor_model_parallel_attributes(tensor=self.scale, is_parallel=False, dim=None, stride=None) + else: + raise ValueError(f"scale for quantization_type: {quantization_type} not supported") + + setattr(self.scale, "get_tensor_from_state_dict", BaseQuantizeParallelLinear.get_scale_from_state_dict) + + @staticmethod + def get_weight_from_state_dict(prefix: str, state_dict: dict) -> torch.Tensor: + return QuantizedParallelLinearLayerStateDictAdaptor.get_weight_from_state_dict( + prefix=prefix, state_dict=state_dict + ) + + @staticmethod + def set_weight_to_state_dict(prefix: str, tensor: torch.Tensor, state_dict: dict) -> None: + return QuantizedParallelLinearLayerStateDictAdaptor.set_weight_to_state_dict( + prefix=prefix, tensor=tensor, state_dict=state_dict + ) + + @staticmethod + def get_bias_from_state_dict(prefix: str, state_dict: dict) -> torch.Tensor: + return QuantizedParallelLinearLayerStateDictAdaptor.get_bias_from_state_dict( + prefix=prefix, state_dict=state_dict + ) + + @staticmethod + def set_bias_to_state_dict(prefix: str, tensor: torch.Tensor, state_dict: dict) -> torch.Tensor: + return QuantizedParallelLinearLayerStateDictAdaptor.set_bias_to_state_dict( + prefix=prefix, tensor=tensor, state_dict=state_dict + ) + + @staticmethod + def get_scale_from_state_dict(prefix: str, state_dict): + return QuantizedParallelLinearLayerStateDictAdaptor.get_scale_from_state_dict( + prefix=prefix, state_dict=state_dict + ) + @classmethod @abstractmethod - def from_float( - cls, - mod, - quantization_type: Union[QuantizationType, str] = QuantizationType.SCALAR, - quantized_dtype: Union[QuantizedDtype, torch.device] = QuantizedDtype.INT8, - ): + def from_float(cls, mod, q_config: BASE_QCONFIG_DICT_TYPE = _DEFAULT_CUSTOM_QCONFIG_DICT): """Create Quantized class from non quantized version Args: @@ -127,11 +249,6 @@ def from_float( """ -class TensorParallelDim(enum.Enum): - QUANTIZED_COLUMN_PARALLEL = 0 - QUANTIZED_ROW_PARALLEL = 1 - - class QuantizedParallelLinearLayerStateDictAdaptor(object): """ A utility class that modifies the state dict to the form required by the Quantized Linear Layers @@ -214,16 +331,10 @@ def get_scale_from_state_dict(prefix: str, state_dict) -> torch.Tensor: if (prefix + "_packed_params.dtype") in state_dict and state_dict[prefix + "_packed_params._packed_params"][ 0 ].dtype == torch.qint8: - return torch.tensor([state_dict[prefix + "_packed_params._packed_params"][0].q_scale()]) + return extract_q_scale(state_dict[prefix + "_packed_params._packed_params"][0]) elif (prefix + "scale") in state_dict: scale: torch.Tensor = state_dict[prefix + "scale"] - # If dict already contains the scale in the form of torch tensor of dimension 1 - if scale.shape == (1,): - return scale - elif scale.shape == (): - return scale.unsqueeze(0) - else: - raise RuntimeError(f"Scale shape is not valid {(prefix + 'scale')}: {scale}") + return scale else: raise RuntimeError(f"Cannot find {(prefix + 'scale')} in state_dict") @@ -254,7 +365,7 @@ def __init__( input_size: int, output_size: int, bias: bool = True, - quantization_type: Union[QuantizationType, str] = "scalar", + quantization_type: Union[QuantizationType, str] = "per_tensor_symmetric", gather_output: bool = True, dtype: torch.dtype = torch.float32, quantized_dtype: Union[QuantizedDtype, torch.dtype] = QuantizedDtype.INT8, @@ -262,14 +373,12 @@ def __init__( stride: int = 1, sequence_parallel_enabled: bool = False, keep_master_weight: bool = False, + quantization_per_channel_axis: Optional[int] = None, ): super().__init__( quantization_type=quantization_type, dequantized_dtype=dtype, quantized_dtype=quantized_dtype, device=device ) - if self.quantization_type == QuantizationType.SCALAR: - self.scale = Parameter(torch.tensor([1.0]), requires_grad=False) - # Keep input parameters self.input_size = input_size self.output_size = output_size @@ -284,76 +393,43 @@ def __init__( # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. + ###### Weight config and bias config setup ##### + self._setup_for_weight_and_bias_config(bias=bias) ###### Weight setup ##### self._setup_for_weight() ###### Bias setup ##### self._setup_for_bias(bias=bias) + ###### Scale setup ##### + self._setup_for_scale( + weight_shape=self.weight_shape, + quantization_type=self.quantization_type, + weight_partition_dim=self.weight_partition_dim, + per_channel_axis=quantization_per_channel_axis, + ) ##### Parallelism setup ##### self._setup_for_parallelism(world_size=world_size) - ###### Quantization Scale setup ##### - setattr(self.scale, "get_tensor_from_state_dict", QuantizedColumnParallel.get_scale_from_state_dict) - set_tensor_model_parallel_attributes(tensor=self.scale, is_parallel=False, dim=None, stride=None) - self._forward_impl = linear_with_async_allreduce - def _setup_for_weight(self): - init_device = self.device - - weight = torch.empty( - self.output_size_per_partition, self.input_size, dtype=self.quantized_dtype, device=init_device + def _setup_for_weight_and_bias_config(self, bias: bool): + self.weight_shape = ( + self.output_size_per_partition, + self.input_size, ) - self.weight = Parameter(weight, requires_grad=False) + self.weight_partition_dim = 0 - self.device = self.weight.device - - if self.device.type == "cpu": - self.master_weight = _initialize_parameter_cpu( - param=self.weight, - partition_dim=TensorParallelDim.QUANTIZED_COLUMN_PARALLEL.value, - init_method=self._init_weight, - param_dtype=self.quantized_dtype, - stride=self.stride, - return_master_param=self.keep_master_weight, - ) - elif self.device.type == "meta": - set_tensor_model_parallel_attributes( - tensor=self.weight, - is_parallel=True, - dim=TensorParallelDim.QUANTIZED_COLUMN_PARALLEL.value, - stride=self.stride, - ) + if bias: + self.bias_size = self.output_size if self.gather_output else self.output_size_per_partition + self.bias_shape = (self.bias_size,) else: - _initialize_affine_weight_neuron( - weight=self.weight, - init_method=self._init_weight, - partition_dim=TensorParallelDim.QUANTIZED_COLUMN_PARALLEL.value, - stride=self.stride, - ) - - setattr(self.weight, "get_tensor_from_state_dict", QuantizedColumnParallel.get_weight_from_state_dict) - setattr(self.weight, "set_tensor_to_state_dict", QuantizedColumnParallel.set_weight_to_state_dict) + self.bias_shape = None def _setup_for_bias(self, bias: bool): + self._base_setup_for_bias(bias=bias) if bias: - self.bias_size = self.output_size if self.gather_output else self.output_size_per_partition - if self.device is None or self.device.type == "cpu": - self.bias = Parameter(torch.empty(self.bias_size, dtype=self.dequantized_dtype), requires_grad=False) - else: - self.bias = Parameter( - torch.empty(self.bias_size, device=self.device, dtype=self.dequantized_dtype), requires_grad=False - ) - if self.bias.device != torch.device("meta"): - self._init_bias(bias=self.bias) - if not self.gather_output: set_tensor_model_parallel_attributes(self.bias, True, 0, stride=self.stride) - setattr(self.bias, "get_tensor_from_state_dict", QuantizedColumnParallel.get_scale_from_state_dict) - setattr(self.bias, "set_tensor_to_state_dict", QuantizedColumnParallel.set_bias_to_state_dict) - else: - self.register_parameter("bias", None) - def _setup_for_parallelism(self, world_size: int): self.async_tensor_model_parallel_allreduce = not self.sequence_parallel_enabled and world_size > 1 if self.sequence_parallel_enabled: @@ -381,15 +457,17 @@ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Ten input_parallel = copy_to_tensor_model_parallel_region(input) # Matrix multiply. - weight_for_matmul = dequantize(self.weight, self.scale, input_parallel.dtype) + weight_for_matmul = direct_cast_dequantize(tensor=self.weight, upcast_dtype=input_parallel.dtype) output_parallel = self._forward_impl( input=input_parallel, weight=weight_for_matmul, bias=None, async_grad_allreduce=self.async_tensor_model_parallel_allreduce, sequence_parallel_enabled=self.sequence_parallel_enabled, + autograd_func_class=self.autograd_func_class, save_for_backward=False, ) + output_parallel = scale_dequantize(tensor=output_parallel, scale=self.scale.T, upcast_dtype=output_parallel.dtype) if self.gather_output: # All-gather across the partitions. assert not self.sequence_parallel_enabled @@ -399,57 +477,29 @@ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Ten output = (output + self.bias) if self.bias is not None else output return output - @staticmethod - def get_weight_from_state_dict(prefix: str, state_dict: dict) -> torch.Tensor: - return QuantizedParallelLinearLayerStateDictAdaptor.get_weight_from_state_dict( - prefix=prefix, state_dict=state_dict - ) - - @staticmethod - def set_weight_to_state_dict(prefix: str, tensor: torch.Tensor, state_dict: dict) -> None: - return QuantizedParallelLinearLayerStateDictAdaptor.set_weight_to_state_dict( - prefix=prefix, tensor=tensor, state_dict=state_dict - ) - - @staticmethod - def get_bias_from_state_dict(prefix: str, state_dict: dict) -> torch.Tensor: - return QuantizedParallelLinearLayerStateDictAdaptor.get_bias_from_state_dict( - prefix=prefix, state_dict=state_dict - ) - - @staticmethod - def set_bias_to_state_dict(prefix: str, tensor: torch.Tensor, state_dict: dict) -> torch.Tensor: - return QuantizedParallelLinearLayerStateDictAdaptor.set_bias_to_state_dict( - prefix=prefix, tensor=tensor, state_dict=state_dict - ) - - @staticmethod - def get_scale_from_state_dict(prefix: str, state_dict): - return QuantizedParallelLinearLayerStateDictAdaptor.get_scale_from_state_dict( - prefix=prefix, state_dict=state_dict - ) - @classmethod def from_float( - cls, - mod, - quantization_type: Union[QuantizationType, str] = QuantizationType.SCALAR, - quantized_dtype: Union[QuantizedDtype, torch.device] = QuantizedDtype.INT8, + cls, mod, q_config: Union[BASE_QCONFIG_DICT_TYPE, PER_CHANNEL_QCONFIG_DICT_TYPE] = _DEFAULT_CUSTOM_QCONFIG_DICT ): """Create a QuantizedColumnParallel from a float module.""" assert mod.__class__.__name__ == "ColumnParallelLinear", "ColumnParallelLinear expected" + if q_config["quantization_type"] == QuantizationType.PER_CHANNEL_SYMMETRIC: + assert q_config["quantization_per_channel_axis"] is not None + else: + q_config["quantization_per_channel_axis"] = None new_mod = QuantizedColumnParallel( input_size=mod.input_size, output_size=mod.output_size, - quantization_type=quantization_type, + quantization_type=q_config["quantization_type"], bias=mod.bias is not None, - quantized_dtype=quantized_dtype, + quantized_dtype=q_config["quantized_dtype"], gather_output=mod.gather_output, dtype=mod.dtype, device=mod.weight.device, stride=mod.stride, sequence_parallel_enabled=mod.sequence_parallel_enabled, keep_master_weight=mod.keep_master_weight, + quantization_per_channel_axis=q_config["quantization_per_channel_axis"], ) return new_mod @@ -481,7 +531,7 @@ def __init__( input_size: int, output_size: int, bias: bool = True, - quantization_type: Union[QuantizationType, str] = "scalar", + quantization_type: Union[QuantizationType, str] = "per_tensor_symmetric", input_is_parallel: bool = False, dtype: torch.dtype = torch.float32, quantized_dtype: Union[QuantizedDtype, torch.dtype] = QuantizedDtype.INT8, @@ -489,12 +539,11 @@ def __init__( stride: int = 1, sequence_parallel_enabled: bool = False, keep_master_weight: bool = False, + quantization_per_channel_axis: Optional[int] = None, ): super().__init__( quantization_type=quantization_type, dequantized_dtype=dtype, quantized_dtype=quantized_dtype, device=device ) - if self.quantization_type == QuantizationType.SCALAR: - self.scale = Parameter(torch.tensor([1.0]), requires_grad=False) # Keep input parameters self.input_size = input_size @@ -512,78 +561,40 @@ def __init__( # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. + ###### Weight config and bias config setup ##### + self._setup_for_weight_and_bias_config(bias=bias) ###### Weight setup ##### self._setup_for_weight() ###### Bias setup ##### self._setup_for_bias(bias=bias) ###### Quantization Scale setup ##### - setattr(self.scale, "get_tensor_from_state_dict", QuantizedRowParallel.get_scale_from_state_dict) - set_tensor_model_parallel_attributes(tensor=self.scale, is_parallel=False, dim=None, stride=None) + self._setup_for_scale( + weight_shape=self.weight_shape, + quantization_type=self.quantization_type, + weight_partition_dim=self.weight_partition_dim, + per_channel_axis=quantization_per_channel_axis, + ) self._forward_impl = linear_with_async_allreduce - def _setup_for_weight(self): - init_device = self.device - weight = torch.empty( - self.output_size, self.input_size_per_partition, device=init_device, dtype=self.quantized_dtype + def _setup_for_weight_and_bias_config(self, bias: bool): + self.weight_shape = ( + self.output_size, + self.input_size_per_partition, ) - self.weight = Parameter(weight, requires_grad=False) - self.device = self.weight.device + self.weight_partition_dim = 1 - if self.device.type == "cpu": - self.master_weight = _initialize_parameter_cpu( - param=self.weight, - partition_dim=TensorParallelDim.QUANTIZED_ROW_PARALLEL.value, - init_method=self._init_weight, - param_dtype=self.quantized_dtype, - stride=self.stride, - return_master_param=self.keep_master_weight, - ) - elif self.device.type == "meta": - set_tensor_model_parallel_attributes( - tensor=self.weight, - is_parallel=True, - dim=TensorParallelDim.QUANTIZED_ROW_PARALLEL.value, - stride=self.stride, - ) + if bias: + self.bias_size = self.output_size + self.bias_shape = (self.bias_size,) else: - _initialize_affine_weight_neuron( - weight=self.weight, - init_method=self._init_weight, - partition_dim=TensorParallelDim.QUANTIZED_ROW_PARALLEL.value, - stride=self.stride, - ) - - setattr(self.weight, "get_tensor_from_state_dict", QuantizedRowParallel.get_weight_from_state_dict) - setattr(self.weight, "set_tensor_to_state_dict", QuantizedRowParallel.set_weight_to_state_dict) + self.bias_shape = None def _setup_for_bias(self, bias: bool): + self._base_setup_for_bias(bias=bias) if bias: - if self.device is None or self.device.type == "cpu": - self.bias = Parameter( - torch.empty( - self.output_size, - dtype=self.dequantized_dtype, - ), - requires_grad=False, - ) - else: - self.bias = Parameter( - torch.empty( - self.output_size, - device=self.device, - dtype=self.dequantized_dtype, - ), - requires_grad=False, - ) - if self.bias.device != torch.device("meta"): - self._init_bias(self.bias) setattr(self.bias, "sequence_parallel_enabled", self.sequence_parallel_enabled) - setattr(self.bias, "get_tensor_from_state_dict", QuantizedRowParallel.get_scale_from_state_dict) - setattr(self.bias, "set_tensor_to_state_dict", QuantizedRowParallel.set_bias_to_state_dict) - else: - self.register_parameter("bias", None) def forward(self, input_: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Forward of QuantizedRowParallel @@ -601,15 +612,17 @@ def forward(self, input_: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Te assert not self.sequence_parallel_enabled input_parallel = scatter_to_tensor_model_parallel_region(input_) # Matrix multiply. - weight_for_matmul = dequantize(self.weight, self.scale, input_parallel.dtype) + weight_for_matmul = direct_cast_dequantize(tensor=self.weight, upcast_dtype=input_parallel.dtype) output_parallel = self._forward_impl( input=input_parallel, weight=weight_for_matmul, bias=None, async_grad_allreduce=False, sequence_parallel_enabled=False, + autograd_func_class=self.autograd_func_class, save_for_backward=False, ) + output_parallel = scale_dequantize(tensor=output_parallel, scale=self.scale.T, upcast_dtype=output_parallel.dtype) # All-reduce across all the partitions. if self.sequence_parallel_enabled: output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) @@ -618,59 +631,260 @@ def forward(self, input_: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Te output = (output_ + self.bias) if self.bias is not None else output_ return output - @staticmethod - def get_weight_from_state_dict(prefix: str, state_dict: dict) -> torch.Tensor: - return QuantizedParallelLinearLayerStateDictAdaptor.get_weight_from_state_dict( - prefix=prefix, state_dict=state_dict + @classmethod + def from_float( + cls, + mod, + q_config: Union[BASE_QCONFIG_DICT_TYPE, PER_CHANNEL_QCONFIG_DICT_TYPE] = _DEFAULT_CUSTOM_QCONFIG_DICT, + ): + """Create a QuantizedRowParallel from a float module + + Args: + mod: float module + """ + assert mod.__class__.__name__ == "RowParallelLinear", "RowParallelLinear expected" + + if q_config["quantization_type"] == QuantizationType.PER_CHANNEL_SYMMETRIC: + assert q_config["quantization_per_channel_axis"] is not None + else: + q_config["quantization_per_channel_axis"] = None + + return QuantizedRowParallel( + input_size=mod.input_size, + output_size=mod.output_size, + bias=mod.bias is not None, + quantization_type=q_config["quantization_type"], + input_is_parallel=mod.input_is_parallel, + dtype=mod.dtype, + quantized_dtype=q_config["quantized_dtype"], + device=mod.weight.device, + stride=mod.stride, + sequence_parallel_enabled=mod.sequence_parallel_enabled, + keep_master_weight=mod.keep_master_weight, + quantization_per_channel_axis=q_config["quantization_per_channel_axis"], ) - @staticmethod - def set_weight_to_state_dict(prefix: str, tensor: torch.Tensor, state_dict: dict) -> None: - return QuantizedParallelLinearLayerStateDictAdaptor.set_weight_into_state_dict( - prefix=prefix, tensor=tensor, state_dict=state_dict + +class QuantizedExpertFusedColumnParallel(QuantizedColumnParallel, ExpertFusedLinear): + """ + Quantized version of the ExpertFusedColumnParallelLinear class + """ + + autograd_func_class = ExpertFusedLinearWithAsyncCommunication + + def __init__( + self, + num_experts: int, + input_size: int, + output_size: int, + quantization_type: Union[QuantizationType, str] = "per_tensor_symmetric", + dtype: torch.dtype = torch.float32, + quantized_dtype: Union[QuantizedDtype, torch.dtype] = QuantizedDtype.INT8, + device: torch.device = None, + stride: int = 1, + keep_master_weight: bool = False, + quantization_per_channel_axis: Optional[int] = None, + ): + self.num_experts = num_experts + self._n_local_experts = divide(num_experts, get_expert_model_parallel_size()) + + if quantization_per_channel_axis is not None: + assert ( + quantization_per_channel_axis != 0 + ), "For QuantizedExpertFusedColumnParallel, quantization_per_channel_axis cannot be the dimension 0, which is expert dimension" + + super().__init__( + input_size=input_size, + output_size=output_size, + bias=False, + quantization_type=quantization_type, + gather_output=False, + dtype=dtype, + quantized_dtype=quantized_dtype, + device=device, + stride=stride, + sequence_parallel_enabled=False, + keep_master_weight=keep_master_weight, + quantization_per_channel_axis=quantization_per_channel_axis, ) - @staticmethod - def get_bias_from_state_dict(prefix: str, state_dict: dict) -> torch.Tensor: - return QuantizedParallelLinearLayerStateDictAdaptor.get_bias_from_state_dict( - prefix=prefix, state_dict=state_dict + def _setup_for_weight_and_bias_config(self, bias: bool): + """ + Same as the ExpertFusedColumnParallelLinear.set_weight_and_bias_config() + + TODO: modularize both quantization layers and moe layers + """ + # Define 3D weight tensor, one linear layer per expert + self.weight_shape = (self._n_local_experts, self.input_size, self.output_size_per_partition) + # Column parallel partitioning for each expert + self.weight_partition_dim = 2 + self.bias_shape = None + + def forward( + self, input_: torch.Tensor, expert_indices: torch.Tensor = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Same as the forward of ExpertFusedColumnParallelLinear, except with weight dequantization, + and save_for_backward=False.""" + + if self.async_tensor_model_parallel_allreduce or self.sequence_parallel_enabled: + input_parallel = input_ + else: + input_parallel = copy_to_tensor_model_parallel_region(input_) + + # Matrix multiply. + weight = self.weight[expert_indices, :, :] if expert_indices is not None else self.weight + weight_for_matmul = scale_dequantize(weight, self.scale, input_parallel.dtype) + output = self._forward_impl( + input=input_parallel, + weight=weight_for_matmul, + bias=None, + async_grad_allreduce=self.async_tensor_model_parallel_allreduce, + sequence_parallel_enabled=self.sequence_parallel_enabled, + autograd_func_class=self.autograd_func_class, + save_for_backward=False, ) + return output - @staticmethod - def set_bias_to_state_dict(prefix: str, tensor: torch.Tensor, state_dict: dict) -> torch.Tensor: - return QuantizedParallelLinearLayerStateDictAdaptor.set_bias_to_state_dict( - prefix=prefix, tensor=tensor, state_dict=state_dict + @classmethod + def from_float( + cls, + mod, + q_config: Union[BASE_QCONFIG_DICT_TYPE, PER_CHANNEL_QCONFIG_DICT_TYPE] = _DEFAULT_CUSTOM_QCONFIG_DICT, + ): + """Create a QuantizedExpertFusedColumnParallel from a float module.""" + assert mod.__class__.__name__ == "ExpertFusedColumnParallelLinear", "ExpertFusedColumnParallelLinear expected" + + if q_config["quantization_type"] == QuantizationType.PER_CHANNEL_SYMMETRIC: + assert q_config["quantization_per_channel_axis"] is not None + else: + q_config["quantization_per_channel_axis"] = None + + new_mod = QuantizedExpertFusedColumnParallel( + num_experts=mod.num_experts, + input_size=mod.input_size, + output_size=mod.output_size, + quantization_type=q_config["quantization_type"], + quantized_dtype=q_config["quantized_dtype"], + dtype=mod.dtype, + device=mod.weight.device, + stride=mod.stride, + keep_master_weight=mod.keep_master_weight, + quantization_per_channel_axis=q_config["quantization_per_channel_axis"], ) + return new_mod - @staticmethod - def get_scale_from_state_dict(prefix: str, state_dict): - return QuantizedParallelLinearLayerStateDictAdaptor.get_scale_from_state_dict( - prefix=prefix, state_dict=state_dict + +class QuantizedExpertFusedRowParallel(QuantizedRowParallel, ExpertFusedLinear): + """ + Quantized version of the ExpertFusedRowParallelLinear class + """ + + autograd_func_class = ExpertFusedLinearWithAsyncCommunication + + def __init__( + self, + num_experts: int, + input_size: int, + output_size: int, + reduce_output: bool = False, + quantization_type: Union[QuantizationType, str] = "per_tensor_symmetric", + quantized_dtype: Union[QuantizedDtype, torch.dtype] = QuantizedDtype.INT8, + dtype: torch.dtype = torch.float32, + device: torch.device = None, + stride: int = 1, + keep_master_weight: bool = False, + quantization_per_channel_axis: Optional[int] = None, + ): + self.num_experts = num_experts + self._n_local_experts = divide(num_experts, get_expert_model_parallel_size()) + + # Whether to all-reduce the output across TP ranks or not + self.reduce_output = reduce_output + + if quantization_per_channel_axis is not None: + assert ( + quantization_per_channel_axis != 0 + ), "For QuantizedExpertFusedRowParallel, quantization_per_channel_axis cannot be the dimension 0, which is expert dimension" + + super().__init__( + input_size=input_size, + output_size=output_size, + bias=False, + quantization_type=quantization_type, + input_is_parallel=True, + dtype=dtype, + quantized_dtype=quantized_dtype, + device=device, + stride=stride, + sequence_parallel_enabled=False, + keep_master_weight=keep_master_weight, + quantization_per_channel_axis=quantization_per_channel_axis, ) + def _setup_for_weight_and_bias_config(self, bias: bool): + """ + Same as the ExpertFusedRowParallelLinear.set_weight_and_bias_config() + + TODO: modularize both quantization layers and moe layers + """ + # Define 3D weight tensor, one linear layer per expert + self.weight_shape = (self._n_local_experts, self.input_size_per_partition, self.output_size) + # Row parallel partitioning for each expert + self.weight_partition_dim = 1 + self.bias_shape = None + + def forward( + self, input_: torch.Tensor, expert_indices: torch.Tensor = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Same as the forward of ExpertFusedRowParallelLinear, except with weight dequantization, + and save_for_backward=False.""" + + # Matrix multiply. + weight = self.weight[expert_indices, :, :] if expert_indices is not None else self.weight + weight_for_matmul = scale_dequantize(weight, self.scale, input_.dtype) + output_parallel = self._forward_impl( + input=input_, + weight=weight_for_matmul, + bias=None, + async_grad_allreduce=False, + sequence_parallel_enabled=False, + autograd_func_class=self.autograd_func_class, + save_for_backward=False, + ) + + if self.reduce_output: + output = reduce_from_tensor_model_parallel_region(output_parallel) + return output + else: + # Return without output all-reduce, in favor of an all-reduce or reduce-scatter after the MoE output combine. + return output_parallel + @classmethod def from_float( cls, mod, - quantization_type: Union[QuantizationType, str] = QuantizationType.SCALAR, - quantized_dtype: Union[QuantizedDtype, torch.device] = QuantizedDtype.INT8, + q_config: Union[BASE_QCONFIG_DICT_TYPE, PER_CHANNEL_QCONFIG_DICT_TYPE] = _DEFAULT_CUSTOM_QCONFIG_DICT, ): - """Create a QuantizedRowParallel from a float module - - Args: - mod: float module """ - assert mod.__class__.__name__ == "RowParallelLinear", "RowParallelLinear expected" - return QuantizedRowParallel( + Create a QuantizedExpertFusedRowParallel from a float module + """ + assert mod.__class__.__name__ == "ExpertFusedRowParallelLinear", "ExpertFusedRowParallelLinear expected" + + if q_config["quantization_type"] == QuantizationType.PER_CHANNEL_SYMMETRIC: + assert q_config["quantization_per_channel_axis"] is not None + else: + q_config["quantization_per_channel_axis"] = None + + return QuantizedExpertFusedRowParallel( + num_experts=mod.num_experts, input_size=mod.input_size, output_size=mod.output_size, - bias=mod.bias is not None, - quantization_type=quantization_type, - input_is_parallel=mod.input_is_parallel, + reduce_output=mod.reduce_output, + quantization_type=q_config["quantization_type"], dtype=mod.dtype, - quantized_dtype=quantized_dtype, + quantized_dtype=q_config["quantized_dtype"], device=mod.weight.device, stride=mod.stride, - sequence_parallel_enabled=mod.sequence_parallel_enabled, keep_master_weight=mod.keep_master_weight, + quantization_per_channel_axis=q_config["quantization_per_channel_axis"], ) diff --git a/src/neuronx_distributed/quantization/quantization_mappings.py b/src/neuronx_distributed/quantization/quantization_mappings.py index de8d61f..bae130a 100644 --- a/src/neuronx_distributed/quantization/quantization_mappings.py +++ b/src/neuronx_distributed/quantization/quantization_mappings.py @@ -3,13 +3,16 @@ """ from typing import Any, Callable, Dict +import neuronx_distributed.modules.moe.moe_parallel_layers as moe_parallel_layers import neuronx_distributed.parallel_layers.layers as parallel_layers -import neuronx_distributed.quantization.quantization_layers as quantization_parallel_layers +import neuronx_distributed.quantization.quantization_layers as q_layers # Default map for swapping dynamic modules DEFAULT_QUANT_MODULE_MAPPINGS: Dict[Callable, Any] = { - parallel_layers.ColumnParallelLinear: quantization_parallel_layers.QuantizedColumnParallel, - parallel_layers.RowParallelLinear: quantization_parallel_layers.QuantizedRowParallel, + parallel_layers.ColumnParallelLinear: q_layers.QuantizedColumnParallel, + parallel_layers.RowParallelLinear: q_layers.QuantizedRowParallel, + moe_parallel_layers.ExpertFusedColumnParallelLinear: q_layers.QuantizedExpertFusedColumnParallel, + moe_parallel_layers.ExpertFusedRowParallelLinear: q_layers.QuantizedExpertFusedRowParallel, } diff --git a/src/neuronx_distributed/quantization/quantization_utils.py b/src/neuronx_distributed/quantization/quantization_utils.py index 3cee266..3ba4b94 100644 --- a/src/neuronx_distributed/quantization/quantization_utils.py +++ b/src/neuronx_distributed/quantization/quantization_utils.py @@ -1,7 +1,55 @@ import torch import torch.ao.nn.quantized.dynamic as nnqd import torch.nn as nn -from torch.ao.quantization.qconfig import default_dynamic_qconfig +from torch.ao.nn.quantized.dynamic.modules.linear import _quantize_weight +from torch.ao.quantization.qconfig import QConfig, default_dynamic_qconfig +from torch.quantization import MinMaxObserver, default_observer + +from neuronx_distributed.quantization.observer import PerChannelAbsMaxObserver + + +def extract_q_scale_per_tensor(q_tensor: torch.Tensor) -> torch.Tensor: + """Extract scales per tensor. + + Args: + q_tensor (torch.Tensor): Input torch.qint8/torch.quint8 + + Returns: + torch.Tensor: returns the tensor of shape torch.Size([1]) + """ + assert q_tensor.qscheme() == torch.per_tensor_affine + return torch.tensor([q_tensor.q_scale()]) + + +def extract_q_scale_per_channel(q_tensor: torch.Tensor) -> torch.Tensor: + """Extract the scale for per channel quantization + + Ideally scales would be a 1D tensor. But we want to multiply the scales with the weight to dequantize. + So we simply view the scale so that its broadcastable to the weight shape. + + Args: + q_tensor (torch.Tensor): Input quantized tensor + + Returns: + torch.Tensor: scale with all the shape broadcastable wrt weight + """ + assert q_tensor.qscheme() == torch.per_channel_affine + per_channel_axis = q_tensor.q_per_channel_axis() + q_tensor_shape = q_tensor.shape + # The shape here would be [1, 1, ....] + scale_shape = [1] * len(q_tensor_shape) + # The shape now would be [1, C, 1, .....], so makes it broadcastble + scale_shape[per_channel_axis] = q_tensor_shape[per_channel_axis] + return q_tensor.q_per_channel_scales().to(torch.float32).view(scale_shape) + + +def extract_q_scale(q_tensor: torch.Tensor): + if q_tensor.qscheme() == torch.per_tensor_affine: + return extract_q_scale_per_tensor(q_tensor) + elif q_tensor.qscheme() == torch.per_channel_affine: + return extract_q_scale_per_channel(q_tensor) + else: + raise (f"qscheme: {q_tensor.qscheme()} is not supported") def convert_qint8_to_int8_state_dict(state_dict: dict) -> dict: @@ -24,7 +72,7 @@ def convert_qint8_to_int8_state_dict(state_dict: dict) -> dict: for prefix in prefixes: state_dict[prefix + "weight"] = torch.int_repr(state_dict[prefix + "_packed_params._packed_params"][0]) - state_dict[prefix + "scale"] = torch.tensor([state_dict[prefix + "_packed_params._packed_params"][0].q_scale()]) + state_dict[prefix + "scale"] = extract_q_scale(state_dict[prefix + "_packed_params._packed_params"][0]) if len(state_dict[prefix + "_packed_params._packed_params"]) == 2: # Bias is included @@ -38,10 +86,41 @@ def convert_qint8_to_int8_state_dict(state_dict: dict) -> dict: state_dict.pop(prefix + "zero_point") -def convert_float_model_to_pytorch_int8_model(float_model: torch.nn.Module, inplace=False) -> torch.nn.Module: +def quantize_pytorch_model_per_tensor_symmetric(float_model: torch.nn.Module, inplace=False) -> torch.nn.Module: qconfig_spec = {torch.nn.Linear: default_dynamic_qconfig} mapping = {nn.Linear: nnqd.Linear} quant_model = torch.quantization.quantize_dynamic( float_model, qconfig_spec=qconfig_spec, mapping=mapping, dtype=torch.qint8, inplace=inplace ) return quant_model + + +def quantize_pytorch_model_per_channel_symmetric(float_model: torch.nn.Module, inplace=False) -> torch.nn.Module: + q_config = QConfig( + activation=default_observer, + weight=PerChannelAbsMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric, ch_axis=0), + ) + + qconfig_spec = {nn.Linear: q_config} + mapping = {nn.Linear: nnqd.Linear} + + quant_model = torch.quantization.quantize_dynamic( + float_model, qconfig_spec=qconfig_spec, mapping=mapping, dtype=torch.qint8, inplace=inplace + ) + return quant_model + + +def quantize_per_tensor_symmetric(tensor): + tensor_observer = MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)() + tensor_observer(tensor) + q_tensor = _quantize_weight(tensor, tensor_observer) + return q_tensor + + +def quantize_per_channel_symmetric(tensor: torch.Tensor, channel_axis: int): + tensor_observer = PerChannelAbsMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_channel_symmetric, ch_axis=channel_axis + )() + tensor_observer(tensor) + q_tensor = _quantize_weight(tensor, tensor_observer) + return q_tensor diff --git a/src/neuronx_distributed/quantization/quantize.py b/src/neuronx_distributed/quantization/quantize.py index 68b3c04..9da5c68 100644 --- a/src/neuronx_distributed/quantization/quantize.py +++ b/src/neuronx_distributed/quantization/quantize.py @@ -1,33 +1,17 @@ import copy -from typing import Any, Callable, Dict, TypedDict +from typing import Any, Callable, Dict -from neuronx_distributed.quantization.quantization_layers import ( - QuantizationType, - QuantizedDtype, +from neuronx_distributed.quantization.quantization_config import ( + BASE_QCONFIG_DICT_TYPE, + get_default_custom_qconfig_dict, ) from neuronx_distributed.quantization.quantization_mappings import ( get_default_quant_module_mappings, ) -class CONFIG_DICT_TYPE(TypedDict): - quantization_type: QuantizationType - quantized_dtype: QuantizedDtype - - -_DEFAULT_CUSTOM_CONFIG_DICT: CONFIG_DICT_TYPE = { - "quantization_type": QuantizationType.SCALAR, - "quantized_dtype": QuantizedDtype.INT8, -} - - -def get_default_custom_config_dict() -> CONFIG_DICT_TYPE: - r"""Defines the default custom config dict.""" - return _DEFAULT_CUSTOM_CONFIG_DICT - - def convert( - module: Any, q_config: CONFIG_DICT_TYPE = None, inplace: bool = False, mapping: Dict[Callable, Any] = None + module: Any, q_config: BASE_QCONFIG_DICT_TYPE = None, inplace: bool = False, mapping: Dict[Callable, Any] = None ) -> Any: """Funtion to convert a Non quantized module to its quantized version based on the q_config @@ -43,7 +27,7 @@ def convert( if not inplace: module = copy.deepcopy(module) if q_config is None: - q_config = get_default_custom_config_dict() + q_config = get_default_custom_qconfig_dict() if mapping is None: mapping = get_default_quant_module_mappings() @@ -60,14 +44,13 @@ def _convert_initialized_float_to_initialized_quantized(module, q_config, mappin reassign = {} for name, mod in module.named_children(): - if not type(mod) in mapping: + if type(mod) not in mapping: _convert_initialized_float_to_initialized_quantized(module=mod, q_config=q_config, mapping=mapping) if type(mod) in mapping: quantized_class = mapping[type(mod)] reassign[name] = quantized_class.from_float( mod=mod, - quantization_type=q_config.get("quantization_type"), - quantized_dtype=q_config.get("quantized_dtype"), + q_config=q_config, ) # Currently there is a bug in quantize.convert function where even though # Parallel embedding has set for tensor_model_parallel attribute, it does not show diff --git a/src/neuronx_distributed/scripts/checkpoint_converter.py b/src/neuronx_distributed/scripts/checkpoint_converter.py new file mode 100644 index 0000000..5217071 --- /dev/null +++ b/src/neuronx_distributed/scripts/checkpoint_converter.py @@ -0,0 +1,739 @@ +# Note : This file location may change in future. +import argparse +import json +import os +import re + +from numpy import format_float_scientific + +import torch +import torch_xla.utils.serialization as xser + +from neuronx_distributed.pipeline.partition import ( + create_partitions, + stage_to_pipeline_parallel_rank, +) +from neuronx_distributed.trainer.checkpoint import _xser_load_data +from neuronx_distributed.trainer.checkpoint_storage import BaseCheckpointStorage, create_checkpoint_storage + + +class CheckpointConverterBase: + + # ParallelEmbedding + embedding_partition_dim = 0 + # ColumnParallelLinear or GQAQKVColumnParallelLinear + qkv_partition_dim = 0 + # ColumnParallelLinear + gate_up_proj_partition_dim = 0 + # RowParallelLinear + down_proj_partition_dim = 1 + # RowParallelLinear + o_proj_partition_dim = 1 + + def get_partition_dim(self, name): + if "embed_tokens" in name or "lm_head" in name: + partition_dim = self.embedding_partition_dim + elif self.is_qkv_weight(name): + partition_dim = self.qkv_partition_dim + elif "gate_proj" in name or "up_proj" in name or "gate_up_proj" in name: + partition_dim = self.gate_up_proj_partition_dim + elif "down_proj" in name: + partition_dim = self.down_proj_partition_dim + elif "o_proj" in name: + partition_dim = self.o_proj_partition_dim + else: + raise AssertionError(f"Unknown partition_dim for {name}") + return partition_dim + + # QKV Helper functions + def get_hf_to_nxd_model_keys(self, qkv_linear=True, is_gqa=True): + if qkv_linear: + keys_hf_to_nxd = { + "q_proj.weight": "qkv_proj.weight_q", + "k_proj.weight": "qkv_proj.weight_k", + "v_proj.weight": "qkv_proj.weight_v", + } + elif is_gqa: # shouldnt hit this case as it qkv linear is used for gqa + keys_hf_to_nxd = { + "q_proj.weight": "q_proj.weight", + "k_proj.weight": "k_proj.weight", + "v_proj.weight": "v_proj.weight", + } + else: + keys_hf_to_nxd = { + "q_proj.weight": "qkv_proj.weight", + "k_proj.weight": "qkv_proj.weight", + "v_proj.weight": "qkv_proj.weight", + } + keys_nxd_to_hf = {v: k for k, v in keys_hf_to_nxd.items()} + return keys_hf_to_nxd, keys_nxd_to_hf + + + def get_fused_qkv_key(self): + return "qkv_proj.weight_qkv" + + def is_qkv_weight(self, name): + return "q_proj" in name or "k_proj" in name or "v_proj" in name or "qkv_proj" in name or "query_key_value" in name + + def coalesce_qkv(self, state_dict, config, tp_degree): + for i in range(config["num_hidden_layers"]): + q = state_dict.pop(f"model.layers.{i}.self_attn.q_proj.weight") + k = state_dict.pop(f"model.layers.{i}.self_attn.k_proj.weight") + v = state_dict.pop(f"model.layers.{i}.self_attn.v_proj.weight") + partition_size = config["hidden_size"] // tp_degree + tp_partititons = [] + for tp_rank in range(tp_degree): + q_split = q.narrow(0, tp_rank * partition_size, partition_size).detach().clone() + k_split = k.narrow(0, tp_rank * partition_size, partition_size).detach().clone() + v_split = v.narrow(0, tp_rank * partition_size, partition_size).detach().clone() + tp_partititons.append(torch.cat([q_split, k_split, v_split], dim=self.qkv_partition_dim)) + + state_dict[f"model.layers.{i}.self_attn.qkv_proj.weight"] = torch.cat(tp_partititons, dim=self.qkv_partition_dim) + + return state_dict + + def get_weight_key(self, keys_hf_to_nxd, keys_nxd_to_hf, name, hf_to_nxd): + if not self.is_qkv_weight(name): + return name + + keys = keys_hf_to_nxd if hf_to_nxd else keys_nxd_to_hf + return ".".join(name.split(".")[:-2]) + "." + keys[".".join(name.split(".")[-2:])] + + def rename_keys_for_megatron(self, key, model_style, hf_to_nxdt=False): + if model_style != 'megatron': + return key + + megatron_name_to_hf_name = { + 'language_model.embedding.word_embeddings.weight' : 'model.embed_tokens.weight', + 'language_model.encoder.final_layernorm.weight' : 'model.norm.weight', + 'language_model.encoder.' : 'model.', + 'self_attention.' : 'self_attn.', + 'core_attention.rotary_emb.inv_freq' : 'rotary_emb.inv_freq', + 'dense.weight' : 'o_proj.weight', + 'dense_h_to_4h.weight' : 'gate_up_proj.weight', + 'dense_4h_to_h.weight' : 'down_proj.weight', + 'language_model.output_layer.weight' : 'lm_head.weight', + 'query_key_value.weight' : 'qkv_proj.weight' + } + + def check_replace_complete(strings, key, meg_str, hf_str): + for string in strings: + if string in key: + if string in meg_str+hf_str : # check whther its present atleast in one of them + return True + return False + + for meg_str,hf_str in megatron_name_to_hf_name.items(): + if not hf_to_nxdt: + key = key.replace(meg_str,hf_str) + else: + key = key.replace(hf_str,meg_str) + if check_replace_complete(['embed','final_layernorm','model.norm'], key, meg_str, hf_str): + break + + return key + + def modify_qkv_for_megatron(self,partial_state,args): + if args.model_style != 'megatron': + return + if not args.qkv_linear: + if args.convert_from_full_state: + # merge k_proj,q_proj and v_proj to query_key_value.weight + pkeys = list(partial_state.keys()) + for key in pkeys: + if 'q_proj' in key: + q = partial_state[key] + k = partial_state[key.replace('q_proj','k_proj')] + v = partial_state[key.replace('q_proj','v_proj')] + partial_state[key.replace('q_proj','query_key_value')] = torch.cat((q,k,v),dim=0).detach().clone() + del partial_state[key], partial_state[key.replace('q_proj','k_proj')], partial_state[key.replace('q_proj','v_proj')] + else: + if args.convert_from_full_state: + # Opposite of :: query.weight and key_value.weight to qkv_proj.weight_q, qkv_proj.weight_k, qkv_proj.weight_v + pkeys = list(partial_state.keys()) + for key in pkeys: + # Reverse weight projection renaming + original_key = None + if 'query_key_value.weight_q' in key: + original_key = key.replace('query_key_value.weight_q', 'query.weight') + partial_state[original_key] = partial_state[key].detach().clone() + del partial_state[key] + # Reverse weight tensor splitting + elif ('query_key_value.weight_k' in key or 'query_key_value.weight_v' in key) and key in partial_state: # partial state will not have removed keys + if key.endswith('query_key_value.weight_k'): + weight_k_key = key + weight_v_key = key.replace('query_key_value.weight_k', 'query_key_value.weight_v') + original_key = key.replace('query_key_value.weight_k', 'key_value.weight') + else: + weight_k_key = key.replace('query_key_value.weight_v', 'query_key_value.weight_k') + weight_v_key = key + original_key = key.replace('query_key_value.weight_v', 'key_value.weight') + combined_tensor = torch.cat([partial_state[weight_k_key], partial_state[weight_v_key]], dim=0) + partial_state[original_key] = combined_tensor.detach().clone() + del partial_state[weight_k_key] + del partial_state[weight_v_key] + print(f"{original_key=},{key=}") + else: + # query.weight and key_value.weight to qkv_proj.weight_q, qkv_proj.weight_k, qkv_proj.weight_v + pkeys = list(partial_state.keys()) + for key in pkeys: + if 'query.weight' in key: + partial_state[key.replace('query.weight','qkv_proj.weight_q')] = partial_state[key].detach().clone() + del partial_state[key] + elif 'key_value.weight' in key: + split_size = partial_state[key].size(0) // 2 + tensor1, tensor2 = torch.split(partial_state[key], split_size, dim=0) + partial_state[key.replace('key_value.weight','qkv_proj.weight_k')] = tensor1.detach().clone() + partial_state[key.replace('key_value.weight','qkv_proj.weight_v')] = tensor2.detach().clone() + del partial_state[key] + + def is_q_or_o_for_megatron(self, args, name): + if args.model_style != 'megatron': + return False + if 'q' in name or 'o_proj' in name: # since is_qkv_weight is already checked a simple 'q' is good enough. + # Since GQA doesnt support replication we will return true here and then do direct torch.cat without worrying about shuffling + return True + return False + + + def find_size(self, size_total): + # Divide the total size by 3 to get the q,k,v component sizes, which are of equal size. + size = size_total / 3 + return int(size) + + # Find the fused qkv weight in the partial state and split it into q,k,v components. + # Update the partial state accordingly + def convert_partial_state_to_non_fused_qkv(self, partial_state, keys_nxd_to_hf, kv_size_multiplier, num_hidden_layers): + for i in range(num_hidden_layers): + qkv_key = self.get_fused_qkv_key() + qkv = partial_state.pop(f"model.layers.{i}.self_attn.{qkv_key}") + size = self.find_size(qkv.size(0)) + q, k, v = torch.split(qkv, (size, size, size), dim=0) + + # Set q,k,v in partial state + q_key = next(key for key in keys_nxd_to_hf.keys() if "weight_q" in key) + k_key = next(key for key in keys_nxd_to_hf.keys() if "weight_k" in key) + v_key = next(key for key in keys_nxd_to_hf.keys() if "weight_v" in key) + partial_state[f"model.layers.{i}.self_attn.{q_key}"] = q + partial_state[f"model.layers.{i}.self_attn.{k_key}"] = k + partial_state[f"model.layers.{i}.self_attn.{v_key}"] = v + return partial_state + + # Take the individual q,k,v components and concat them + # Update the partial state accordingly + def convert_partial_state_to_fused_qkv(self, partial_state, keys_nxd_to_hf, num_hidden_layers): + for i in range(num_hidden_layers): + q_key = next(key for key in keys_nxd_to_hf.keys() if "weight_q" in key) + k_key = next(key for key in keys_nxd_to_hf.keys() if "weight_k" in key) + v_key = next(key for key in keys_nxd_to_hf.keys() if "weight_v" in key) + q = partial_state.pop(f"model.layers.{i}.self_attn.{q_key}") + k = partial_state.pop(f"model.layers.{i}.self_attn.{k_key}") + v = partial_state.pop(f"model.layers.{i}.self_attn.{v_key}") + qkv = torch.cat([q, k, v], dim=0) + qkv_key = self.get_fused_qkv_key() + partial_state[f"model.layers.{i}.self_attn.{qkv_key}"] = qkv + return partial_state + + # Helper function for convert_to_full_state() + def merge_tp_checkpoints(self, args): + full_state = {} + with open(args.config, "r") as f: + config = json.load(f) + q_heads = config["num_attention_heads"] + kv_heads = config["num_key_value_heads"] + head_dim = config["hidden_size"] // q_heads + is_gqa = q_heads != kv_heads + keys_hf_to_nxd, keys_nxd_to_hf = self.get_hf_to_nxd_model_keys(args.qkv_linear, is_gqa) + + for tp_rank in range(args.tp_size): + for pp_rank in range(args.pp_size): + for ep_rank in range(args.ep_size): + if args.load_xser: + partial_state = self.load_partial_xser(args, tp_rank, pp_rank, ep_rank) + else: + partial_state = self.load_partial_no_xser(args, tp_rank, pp_rank, ep_rank) + pkeys = list(partial_state.keys()) + for key in pkeys: + partial_state[self.rename_keys_for_megatron(key, args.model_style, hf_to_nxdt=False)] = partial_state[key].cpu() + if args.model_style=='megatron': + del partial_state[key] + self.modify_qkv_for_megatron(partial_state, args) # dict so gets auto modified. + if args.model_key is not None and args.model_key in partial_state: + partial_state = partial_state[args.model_key] + if args.fuse_qkv: + partial_state = self.convert_partial_state_to_non_fused_qkv(partial_state, keys_nxd_to_hf, args.kv_size_multiplier, args.n_layers) + for name, param in partial_state.items(): + if (self.is_qkv_weight(name) or "o_proj" in name) and args.qkv_linear: + # qkv_proj would be a key if we are using the QKVLinear layer + partition_dim = self.get_partition_dim(name) + name = self.get_weight_key(keys_hf_to_nxd, keys_nxd_to_hf, name, False) + + if name not in full_state: + full_state[name] = [] + + full_state[name].append(param) + if tp_rank != (args.tp_size - 1): + continue + + full_weight = torch.cat(full_state[name], dim=partition_dim) + if "k" in name or "v" in name or self.is_q_or_o_for_megatron(args,name): # no kv replication in megatron so q needs to be appended directly + # If kv_multiplier is set, the kv heads are repeated. So we need to + # take only the first chunk + full_state[name] = torch.chunk(full_weight, args.kv_size_multiplier)[0].detach().clone() + else: + # Since we do the replication of KV heads, the Q heads are placed as: + # Q0Q1Q8Q9...Q2Q3Q10Q11... + # Hence when creating the merged checkpoint, we need to bring the Q heads and o_proj in order. + if "o_proj" in name: + # The shuffling is same for both o_proj and q, but o_proj is sharded on column. + # Hence to reuse the same shuffling code, we just transpose, do the shuffling and + # transpose back + full_weight = torch.transpose(full_weight, 0, 1) + weights = full_weight.reshape(q_heads, head_dim, -1) + weights_shape = weights.size() + weights = weights.reshape( + -1, q_heads // (kv_heads * args.kv_size_multiplier), head_dim, weights_shape[-1] + ) + weight_splits = [] + indicies = torch.arange(0, args.tp_size // kv_heads) * kv_heads + for i in range(kv_heads): + weight_splits.append(weights[indicies + i].reshape(-1, weights_shape[-1])) + full_weight = torch.cat(weight_splits, dim=self.qkv_partition_dim) + full_state[name] = ( + torch.transpose(full_weight, 0, 1).detach().clone() + if "o_proj" in name + else full_weight.detach().clone() + ) + elif "qkv_proj" in name and not is_gqa: + partition_dim = self.get_partition_dim(name) + partition_size = config["hidden_size"] // args.tp_size + q, k, v = torch.split(param, partition_size, dim=partition_dim) + q_name = name.replace("qkv", "q") + k_name = name.replace("qkv", "k") + v_name = name.replace("qkv", "v") + for name, weight in zip([q_name, k_name, v_name], [q, k, v]): + if name not in full_state: + full_state[name] = [] + full_state[name].append(weight) + if tp_rank == (args.tp_size - 1): + full_weight = torch.cat(full_state[name], dim=partition_dim) + full_state[name] = full_weight.detach().clone() + elif ( + "embed_tokens" in name + or self.is_qkv_weight(name) + or "o_proj" in name + or "lm_head" in name + ): + partition_dim = self.get_partition_dim(name) + name = self.get_weight_key(keys_hf_to_nxd, keys_nxd_to_hf, name, False) + if name not in full_state: + full_state[name] = [] + full_state[name].append(param) + if tp_rank == (args.tp_size - 1): + full_weight = torch.cat(full_state[name], dim=partition_dim) + full_state[name] = full_weight.detach().clone() + elif "down_proj" in name: + partition_dim = self.get_partition_dim(name) + expert_partition_dim = 0 + name = self.get_weight_key(keys_hf_to_nxd, keys_nxd_to_hf, name, False) + if name not in full_state: + full_state[name] = [[]] + full_state[name][tp_rank].append(param) + if ep_rank == (args.ep_size - 1): + full_weight = torch.cat(full_state[name][tp_rank], dim=expert_partition_dim) + full_state[name][tp_rank] = full_weight.detach().clone() + if tp_rank != (args.tp_size - 1): + full_state[name].append([]) + else: + full_weight = torch.cat(full_state[name], dim=partition_dim) + full_state[name] = full_weight + elif "gate_up_proj" in name: + partition_dim = self.get_partition_dim(name) + expert_partition_dim = 0 + dim_size = param.size()[partition_dim] // 2 + gate_proj_name = name.replace("gate_up_proj", "gate_proj") + up_proj_name = name.replace("gate_up_proj", "up_proj") + gate_proj_weight = param.narrow(partition_dim, 0, dim_size).detach().clone() + up_proj_weight = param.narrow(partition_dim, dim_size, dim_size).detach().clone() + if gate_proj_name not in full_state: + full_state[gate_proj_name] = [[]] + if up_proj_name not in full_state: + full_state[up_proj_name] = [[]] + full_state[gate_proj_name][tp_rank].append(gate_proj_weight) + full_state[up_proj_name][tp_rank].append(up_proj_weight) + if ep_rank == (args.ep_size - 1): + full_gate_proj_weight = torch.cat(full_state[gate_proj_name][tp_rank], dim=expert_partition_dim) + full_up_proj_weight = torch.cat(full_state[up_proj_name][tp_rank], dim=expert_partition_dim) + full_state[gate_proj_name][tp_rank] = full_gate_proj_weight + full_state[up_proj_name][tp_rank] = full_up_proj_weight + if tp_rank != args.tp_size - 1: + full_state[gate_proj_name].append([]) + full_state[up_proj_name].append([]) + else: + full_gate_proj_weight = torch.cat(full_state[gate_proj_name], dim=partition_dim) + full_up_proj_weight = torch.cat(full_state[up_proj_name], dim=partition_dim) + full_state[gate_proj_name] = full_gate_proj_weight + full_state[up_proj_name] = full_up_proj_weight + + elif "expert_mlps" in name: + if name not in full_state: + full_state[name] = [] + full_state[name].append(param) + if ep_rank == args.ep_size - 1: + expert_dim = 0 + full_state[name] = torch.cat(full_state[name], dim=expert_dim) + else: + if name not in full_state: + full_state[name] = param + + full_state = self.post_process_full_state_after_tp_conversion(full_state, args) + return full_state + + # Helper function for convert_from_full_state() + def convert_full_state_to_tp(self, full_state, args, tp_rank, pp_rank, ep_rank, partitions, config): + tp_size = args.tp_size + pp_size = args.pp_size + ep_size = args.ep_size + kv_size_multiplier = args.kv_size_multiplier + + partial_state = {} + q_heads = config["num_attention_heads"] + kv_heads = config["num_key_value_heads"] + head_dim = config["hidden_size"] // q_heads + + is_gqa = q_heads != kv_heads + keys_hf_to_nxd, keys_nxd_to_hf = self.get_hf_to_nxd_model_keys(args.qkv_linear, is_gqa) + + for name, full_p in full_state.items(): + ##################### PP Slice ######################################### + # Embedding only in first PP + if pp_rank != 0 and "embed_tokens" in name: + continue + + # Non-expert parameters only in EP rank 0 + if ep_rank != 0 and "expert_mlps" not in name: + continue + + # LMhead and final layer norm only in last PP rank + if pp_rank != pp_size - 1 and ("lm_head" in name or "model.norm.weight" in name): + continue + if "layers" in name: + layer_idx = int(name.split(".")[2]) + current_stage = len(partitions) + # iterate through the pp cuts and find the current stage + for stage, pp_cut in enumerate(partitions): + cut_layer_idx = int(pp_cut.split(".")[2]) + if layer_idx <= cut_layer_idx: + current_stage = stage + break + current_pp_rank = stage_to_pipeline_parallel_rank(current_stage, pp_size=pp_size) + if current_pp_rank != pp_rank: + continue + + ##################### EP Slice ######################################### + if "expert_mlps" in name: + expert_dim = 0 + expert_dim_size = full_p.shape[expert_dim] + if expert_dim_size % ep_size != 0: + raise ValueError(f"Expert dimension ({expert_dim_size}) is not divisible by expert parallelism degree ({ep_size}).") + num_local_experts = expert_dim_size // ep_size + with torch.no_grad(): + weight_slice = full_p.narrow(expert_dim, num_local_experts * ep_rank, num_local_experts) + partial_state[name] = weight_slice + + ##################### TP Slice ######################################### + if (self.is_qkv_weight(name) or "o_proj" in name) and args.qkv_linear: + name = self.get_weight_key(keys_hf_to_nxd, keys_nxd_to_hf, name, True) + if "weight_k" in name or "weight_v" in name or self.is_q_or_o_for_megatron(args,name): + repeated_kv = full_p.repeat(kv_size_multiplier, 1) + + dim_size = repeated_kv.size()[0] + assert dim_size % tp_size == 0, "0th dim after KV replication is not divisible by tp_size" + partition_size = dim_size // tp_size + with torch.no_grad(): + partition_dim = 0 + if "o_proj" in name: # only in megatron case we come here + partition_dim = self.get_partition_dim(name) + to_load = repeated_kv.narrow(partition_dim, tp_rank * partition_size, partition_size).detach().clone() + # Cloning the tensor is really important, since we have performed slice and reshape operations. + # These operations are just views and if we don't clone, we would end up saving the entire tensor + + partial_state[name] = to_load.detach().clone() + else: + # When GQAQKV linear with kv_multiplier is used, we need to reshuffle the order of Q heads + # so they interact with the right KV heads. Now since the heads are shuffled, we have to + # shuffle the o_proj rows since that translates the heads to hidden dim + if "o_proj" in name: + # The shuffling is same for both o_proj and q, but o_proj is sharded on column. + # Hence to reuse the same shuffling code, we just transpose, do the shuffling and + # transpose back + full_p = torch.transpose(full_p, 0, 1) + weights = full_p.reshape(q_heads, head_dim, -1) + weights_shape = weights.size() + weights = weights.reshape(-1, q_heads // (kv_heads * kv_size_multiplier), head_dim, weights_shape[-1]) + weight_splits = [] + indicies = torch.arange(0, kv_heads) * tp_size // kv_heads + for i in range(tp_size // kv_heads): + weight_splits.append(weights[indicies + i]) + weights = torch.cat(weight_splits, dim=self.qkv_partition_dim) + with torch.no_grad(): + to_load = weights[tp_rank].reshape(-1, weights_shape[-1]) + if "o_proj" in name: + to_load = torch.transpose(to_load, 0, 1) + # Cloning the tensor is really important, since we have performed slice and reshape operations. + # These operations are just views and if we don't clone, we would end up saving the entire tensor + partial_state[name] = to_load.detach().clone() + elif ( + "embed_tokens" in name + or self.is_qkv_weight(name) + or "o_proj" in name + or "down_proj" in name + or "lm_head" in name + ): + partition_dim = self.get_partition_dim(name) + dim_size = full_p.size()[partition_dim] + assert dim_size % tp_size == 0, "vocab size is not divisiable" + partition_size = dim_size // tp_size + with torch.no_grad(): + to_load = full_p.narrow(partition_dim, tp_rank * partition_size, partition_size) + partial_state[name] = to_load.detach().clone() + elif "gate_proj" in name or "up_proj" in name: + partition_dim = self.get_partition_dim(name) + dim_size = full_p.size()[partition_dim] + assert dim_size % tp_size == 0, "vocab size is not divisiable" + partition_size = dim_size // tp_size + with torch.no_grad(): + to_load = full_p.narrow(partition_dim, tp_rank * partition_size, partition_size).detach().clone() + token = "gate_proj" if "gate_proj" in name else "up_proj" + updated_name = name.replace(token, "gate_up_proj") + if updated_name in partial_state: + if token == "gate_proj": + partial_state[updated_name] = ( + torch.cat([to_load, partial_state[updated_name]], dim=partition_dim).detach().clone() + ) + else: + partial_state[updated_name] = ( + torch.cat([partial_state[updated_name], to_load], dim=partition_dim).detach().clone() + ) + else: + partial_state[updated_name] = to_load.detach().clone() + else: + # no TP + partial_state[name] = full_p + pkeys = list(partial_state.keys()) + for key in pkeys: + partial_state[self.rename_keys_for_megatron(key, args.model_style, hf_to_nxdt = True)] = partial_state[key] + if args.model_style=='megatron': + del partial_state[key] + self.modify_qkv_for_megatron(partial_state,args) + if args.fuse_qkv: + partial_state = self.convert_partial_state_to_fused_qkv(partial_state, keys_nxd_to_hf, args.n_layers) + return partial_state + + # Placeholder functions for additional processing of full_state + def pre_process_full_state_before_tp_conversion(self, full_state, args): + """Child classes can override this function to implement custom logic.""" + return full_state + + def post_process_full_state_after_tp_conversion(self, full_state, args): + """Child classes can override this function to implement custom logic.""" + return full_state + + # Helper functions for save/load + def load_full_state(self, args): + full_state = torch.load(args.input_dir) + return full_state + + def get_input_filename(self, args, tp_rank, pp_rank, ep_rank, xser): + if xser: + v1_api_filename = os.path.join(args.input_dir, "tp_rank_{:02d}_pp_rank_{:02d}".format(tp_rank, pp_rank)) + else: + v1_api_filename = os.path.join( + args.input_dir, "tp_rank_{:02d}_pp_rank_{:02d}".format(tp_rank, pp_rank), "checkpoint.pt" + ) + + v2_api_filename = os.path.join( + args.input_dir, "dp_rank_00_tp_rank_{:02d}_pp_rank_{:02d}.pt".format(tp_rank, pp_rank) + ) + + v3_api_filename = os.path.join( + args.input_dir, "dp_rank_00_ep_rank_{:02d}_tp_rank_{:02d}_pp_rank_{:02d}.pt".format(ep_rank, tp_rank, pp_rank) + ) + + if os.path.exists(v1_api_filename): + return v1_api_filename + + if os.path.exists(v2_api_filename): + return v2_api_filename + + if os.path.exists(v3_api_filename): + return v3_api_filename + + raise RuntimeError(f"Error: neither {v1_api_filename}, nor {v2_api_filename}, nor {v3_api_filename} exist") + + def get_output_filename(self, args, tp_rank, pp_rank, ep_rank, xser): + if args.ep_size > 1: + return os.path.join( + args.output_dir, "model", "dp_rank_00_ep_rank_{:02d}_tp_rank_{:02d}_pp_rank_{:02d}.pt".format(ep_rank, tp_rank, pp_rank) + ) + else: + return os.path.join( + args.output_dir, "model", "dp_rank_00_tp_rank_{:02d}_pp_rank_{:02d}.pt".format(tp_rank, pp_rank) + ) + + def load_partial_xser(self, args, tp_rank, pp_rank, ep_rank): + filename = self.get_input_filename(args, tp_rank, pp_rank, ep_rank, 1) + dir_name = os.path.join(*(filename.split("/")[:-3])) + checkpoint_dir = create_checkpoint_storage(dir_name) + partial_state = _xser_load_data(checkpoint_dir, filename, None, ep_only=ep_rank > 0) + self.prune_state(partial_state, ep_rank) + return partial_state + + def prune_state(self, state, ep_rank): + if ep_rank > 0: + to_remove = [] + for k, v in state.items(): + if v is None: + to_remove.append(k) + for k in to_remove: + state.pop(k) + + def load_partial_no_xser(self, args, tp_rank, pp_rank, ep_rank): + filename = self.get_input_filename(args, tp_rank, pp_rank, ep_rank, 0) + partial_state = torch.load(filename) + return partial_state + + def save_full(self, args, full_state): + save_path = args.output_dir + os.makedirs(save_path, exist_ok=True) + if os.path.isdir(save_path): + save_path = os.path.join(save_path, "checkpoint.pt") + print(f"Saving full checkpoint to {save_path}") + torch.save(full_state, save_path) + + def save_partial_xser(self, args, partial_state, tp_rank, pp_rank, ep_rank): + filename = self.get_output_filename(args, tp_rank, pp_rank, ep_rank, 1) + os.makedirs(args.output_dir + "/model", exist_ok=True) + print(f"Saving to {filename}") + xser.save(partial_state, filename) + + def save_partial_no_xser(self, args, partial_state, tp_rank, pp_rank, ep_rank): + filename = self.get_output_filename(args, tp_rank, pp_rank, ep_rank, 0) + os.makedirs(args.output_dir + "/model", exist_ok=True) + print(f"Saving to {filename}") + torch.save(partial_state, filename) + + # Main functions to run checkpoint conversion + def convert_from_xser(self, args): + for tp_rank in range(args.tp_size): + for pp_rank in range(args.pp_size): + partial_state = self.load_partial_xser(args, tp_rank, pp_rank) + self.save_partial_no_xser(args, partial_state, tp_rank, pp_rank) + + def convert_to_xser(self, args): + for tp_rank in range(args.tp_size): + for pp_rank in range(args.pp_size): + partial_state = self.load_partial_no_xser(args, tp_rank, pp_rank) + self.save_partial_xser(args, partial_state, tp_rank, pp_rank) + + def convert_from_full_state(self, args): + full_state = self.load_full_state(args) + layer_name_pattern = r"^(model\.layers\.\d+)" + model_layer_names = sorted( + list( + set( + [ + re.match(layer_name_pattern, key).group(1) + for key in full_state.keys() + if re.match(layer_name_pattern, key) + ] + ) + ), + key=lambda x: int(re.search(r"\d+", x).group()), + ) + partitions = create_partitions(args.pp_size * args.virtual_pp_size, model_layer_names) + print(f"pipeline_cuts {partitions}") + with open(args.config, "r") as f: + config = json.load(f) + if args.coalesce_qkv: + full_state = self.coalesce_qkv(full_state, config, args.tp_size) + + full_state = self.pre_process_full_state_before_tp_conversion(full_state, args) + + for tp_rank in range(args.tp_size): + for pp_rank in range(args.pp_size): + for ep_rank in range(args.ep_size): + partial_state = self.convert_full_state_to_tp( + full_state, + args, + tp_rank, + pp_rank, + ep_rank, + partitions, + config, + ) + if args.save_xser: + self.save_partial_xser(args, partial_state, tp_rank, pp_rank, ep_rank) + else: + self.save_partial_no_xser(args, partial_state, tp_rank, pp_rank, ep_rank) + + def convert_to_full_state(self, args): + full_state = self.merge_tp_checkpoints(args) + self.save_full(args, full_state) + + # Argument parsing and execution + def get_arg_parser(self): + """Child classes can override this to add new arguments.""" + + parser = argparse.ArgumentParser() + parser.add_argument("--input_dir", type=str, required=True, help="Path to input model/weights") + parser.add_argument("--output_dir", type=str, required=True, help="Path to save converted model/weights") + parser.add_argument("--config", type=str, help="Config.json") + parser.add_argument( + "--model_key", type=str, default="model", help="Key of the model state dict in the checkpoint object" + ) + parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallel degree for the model") + parser.add_argument("--pp_size", type=int, default=1, help="Pipeline Parallel degree for the model") + parser.add_argument("--ep_size", type=int, default=1, help="Expert Parallel degree for the model") + parser.add_argument("--virtual_pp_size", type=int, default=1, help="Virtual Pipeline Parallel degree for the model") + parser.add_argument("--n_layers", type=int, default=0, help="Number of Layers") + parser.add_argument("--coalesce_qkv", type=bool, default=False, help="whether to coalesce qkv") + parser.add_argument( + "--kv_size_multiplier", type=int, default=1, help="Factor by which the KV heads were replicated" + ) + parser.add_argument( + "--qkv_linear", type=bool, default=False, help="Factor by which the KV heads were replicated" + ) + parser.add_argument( + "--fuse_qkv", type=bool, default=False, help="Whether to fuse qkv" + ) + parser.add_argument("--load_xser", type=bool, default=False, help="Load from xser saved checkpoints") + parser.add_argument("--save_xser", type=bool, default=False, help="Save with xser") + parser.add_argument( + "--convert_from_xser", action="store_true", help="Convert xser saved checkpoint to normal torch checkpoint" + ) + parser.add_argument( + "--convert_to_xser", action="store_true", help="Convert normal torch checkpoint to xser checkpoint" + ) + parser.add_argument("--convert_from_full_state", action="store_true", help="Convert full model to sharded model") + parser.add_argument("--convert_to_full_state", action="store_true", help="Convert sharded model to full model") + parser.add_argument('--model_style', type=str, choices=['hf', 'megatron'], default='hf', help='The source style.') + + return parser + + def run(self, args): + """Main function used to run checkpoint conversion.""" + + assert sum( + int(getattr(args, flag)) + for flag in ["convert_from_full_state", "convert_to_full_state", "convert_from_xser", "convert_to_xser"] + ) == 1, "Exactly one '--convert_*' flag must be specified" + + if args.convert_from_full_state: + self.convert_from_full_state(args) + elif args.convert_to_full_state: + self.convert_to_full_state(args) + elif args.convert_from_xser: + self.convert_from_xser(args) + elif args.convert_to_xser: + self.convert_to_xser(args) diff --git a/src/neuronx_distributed/trace/__init__.py b/src/neuronx_distributed/trace/__init__.py index d13a085..7f70e96 100644 --- a/src/neuronx_distributed/trace/__init__.py +++ b/src/neuronx_distributed/trace/__init__.py @@ -1 +1,7 @@ +from .model_builder import ModelBuilder +from .spmd import ( + NxDModel, + SPMDBucketModel, + SPMDBucketModelScript +) from .trace import parallel_model_load, parallel_model_save, parallel_model_trace diff --git a/src/neuronx_distributed/trace/hlo_utils.py b/src/neuronx_distributed/trace/hlo_utils.py new file mode 100644 index 0000000..f1a6083 --- /dev/null +++ b/src/neuronx_distributed/trace/hlo_utils.py @@ -0,0 +1,283 @@ +from typing import Dict + +from torch_neuronx.proto import metaneff_pb2 +from torch_neuronx.pyhlo import hlo_pb2, xla_data_pb2 +from torch_neuronx.xla_impl.trace import ( + hlo_entry_computation, + get_hlo_computation_by_id, + get_hlo_root_instruction, + HloArtifacts, + XLA_DTYPE_TO_METANEFF_DTYPE, +) + +TRANSPOSABLE_WEIGHT_IDX = "transposable_weight_idx" +REQUIRE_TRANSPOSE_WEIGHT_IDX = "require_transpose_weight_idx" +REQUIRE_TRANSPOSE_CUSTOM_CALL = "require_transpose_custom_call" + + +def read_hlo(hlo_path: str): + """Read a HLOModuleProto from given path""" + hlo = hlo_pb2.HloModuleProto() + with open(hlo_path, "rb") as f: + hlo.ParseFromString(f.read()) + return hlo + + +def add_weight_idx_attr_to_hlo(hlo: hlo_pb2.HloModuleProto, weight_name_to_idx: Dict[str, int]): + """ + Add frontend attributes on weight indices for weights + """ + weight_idx = sorted(weight_name_to_idx.values()) + weight_idx_list_str = ",".join([str(idx) for idx in weight_idx]) + hlo.frontend_attributes.map[TRANSPOSABLE_WEIGHT_IDX] = weight_idx_list_str + return hlo + + +def get_layout_transform_map(hlo_stub: hlo_pb2.HloModuleProto, weight_name_to_idx: Dict[str, int]): + """ + Return a map of weight layout transformation from the HLO stub, if the weight + is transformed in the HLO stub + {"weight_name": hlo_computation_proto} + + This map might not contain all the weights from weight_name_to_idx, because + some of them could already be in the optimal layout, so there won't be + a transformation for them in the hlo_stub. + """ + weight_idx_to_name = {} + for weight_name, idx in weight_name_to_idx.items(): + weight_idx_to_name[idx] = weight_name + + weight_name_to_transform_cpt = {} + entry_cpt = hlo_entry_computation(hlo_stub) + for instr in entry_cpt.instructions: + if TRANSPOSABLE_WEIGHT_IDX in instr.frontend_attributes.map: + priority_weight_idx = int(instr.frontend_attributes.map[TRANSPOSABLE_WEIGHT_IDX]) + # Compiler will always wrap the layout transformation into just one + # custom-call, so getting the first called computation id is enough + cpt = get_hlo_computation_by_id(hlo_stub, instr.called_computation_ids[0]) + weight_name = weight_idx_to_name[priority_weight_idx] + weight_name_to_transform_cpt[weight_name] = cpt + return weight_name_to_transform_cpt + + +def update_computation_id_and_name(src_cpt: hlo_pb2.HloComputationProto, start_id: int, name_prefix: str): + """ + Update the id and name inside a computation. + + It will increase all ids inside the computation by `start_id`, and add a + prefix of `name_prefix` for all var names inside the computation. + """ + # Create a new one to avoid polluting the existing one + cpt = hlo_pb2.HloComputationProto() + cpt.CopyFrom(src_cpt) + + # update the id + cpt.id += start_id + cpt.root_id += start_id + + for instr in cpt.instructions: + instr.id += start_id + if len(instr.operand_ids) == 0: + continue + for idx in range(len(instr.operand_ids)): + instr.operand_ids[idx] += start_id + + # update the name + cpt.name = name_prefix + cpt.name + for idx in range(len(cpt.program_shape.parameter_names)): + cpt.program_shape.parameter_names[idx] = name_prefix + cpt.program_shape.parameter_names[idx] + + for instr in cpt.instructions: + instr.name = name_prefix + instr.name + + return cpt + + +def append_layout_computation_to_hlo( + hlo_artifact: HloArtifacts, + weight_name_to_transform_cpt: Dict[str, hlo_pb2.HloComputationProto], + ): + """ + For each weight mentioned in the hlo_artifact.ho_module, if there is a + computation corresponds to that weight in the map of + `weight_name_to_transform_cpt`, append that computation to the end of + `hlo_artifact.ho_module`. + """ + hlo = hlo_artifact.hlo_module + + weight_idx_to_append = [] + layout_transform_cpt_to_append = [] + for weight_name, weight_idx in hlo_artifact.weight_name_to_idx.items(): + # We skip the weights if it is not in `weight_name_to_transform_cpt` + # because they are already in optimal layout + if weight_name in weight_name_to_transform_cpt: + weight_idx_to_append.append(weight_idx) + cpt = weight_name_to_transform_cpt[weight_name] + # Need to update the id and name for the computation, to avoid + # duplicate name or duplicate id in the whole hlo + cpt = update_computation_id_and_name(cpt, start_id=hlo.id+1, name_prefix="wlt_") + layout_transform_cpt_to_append.append(cpt) + weight_idx_str = ",".join([str(idx) for idx in weight_idx_to_append]) + layout_transform_cpt_str = ",".join([cpt.name for cpt in layout_transform_cpt_to_append]) + + hlo.frontend_attributes.map[REQUIRE_TRANSPOSE_WEIGHT_IDX] = weight_idx_str + hlo.frontend_attributes.map[REQUIRE_TRANSPOSE_CUSTOM_CALL] = layout_transform_cpt_str + + # Compiler will be responsible to insert the layout transformation custom-calls into the HLO compute graph. + hlo.computations.extend(layout_transform_cpt_to_append) # extend() will copy the value + return hlo_artifact + + +def extract_weight_layout_transform_hlo( + hlo_stub: hlo_pb2.HloModuleProto, + weight_name_to_idx: Dict[str, int], + ): + """ + Build a new HLO for weight layout transformation following the suggestion + from the `hlo_stub`. + + After the transformation, the layout of some weights will change, but + the other could stay the same, because they are already in the optimal + layout. + + The resulting HLO will take all the weights in original layout as input, + and output them in optimal layout. The number and order of the inputs and + outputs are the same. The order of inputs (weights) is decided by the its + index from the `weight_name_to_idx`, and it is in ascending order. + + This is because during the transformation, we will provide all the weights + as input and expect to get all the weights in the same order from the output. + """ + wlt_hlo = hlo_pb2.HloModuleProto() # weight layout transformation HLO + wlt_hlo.CopyFrom(hlo_stub) + + entry_cpt = hlo_entry_computation(wlt_hlo) + weight_idx_to_info = {} + + # Step 1. Update the output of the root instruction in the entry computation + # Find the weights whose layout will change, as part of the output + for instr in entry_cpt.instructions: + is_changed_weight = TRANSPOSABLE_WEIGHT_IDX in instr.frontend_attributes.map + if is_changed_weight: + idx = int(instr.frontend_attributes.map[TRANSPOSABLE_WEIGHT_IDX]) + weight_idx_to_info[idx] = (instr.id, instr.shape) + + # Find the weights whose layout won't change, as part of the output + all_weight_idx = set(weight_name_to_idx.values()) + changed_weight_idx = set(weight_idx_to_info.keys()) + unchanged_weight_idx = list(all_weight_idx - changed_weight_idx) + for instr in entry_cpt.instructions: + is_unchanged_weight = instr.opcode == "parameter" and instr.parameter_number in unchanged_weight_idx + if is_unchanged_weight: + idx = instr.parameter_number + weight_idx_to_info[idx] = (instr.id, instr.shape) + + weight_idx_list = sorted(weight_idx_to_info.keys()) + var_id_list = [weight_idx_to_info[w_idx][0] for w_idx in weight_idx_list] + var_shape_list = [weight_idx_to_info[w_idx][1] for w_idx in weight_idx_list] + + # Update the root instrution to return all the weights + root_instr = get_hlo_root_instruction(entry_cpt) + root_instr.name = "last" + root_instr.opcode = "tuple" + root_instr.ClearField("shape") + root_instr.shape.element_type = xla_data_pb2.PrimitiveType.TUPLE + root_instr.shape.tuple_shapes.extend(var_shape_list) + root_instr.ClearField("operand_ids") + root_instr.operand_ids.extend(var_id_list) + + output_shape = root_instr.shape + + # Clear irrelevant fields + # TODO: create a new instr instead of updating the old one + root_instr.ClearField("custom_call_target") + root_instr.ClearField("backend_config") + root_instr.ClearField("constrain_layout") + root_instr.ClearField("operand_shapes_with_layout") + root_instr.ClearField("frontend_attributes") + root_instr.ClearField("custom_call_api_version") + root_instr.ClearField("precision_config") + root_instr.ClearField("feature_group_count") + root_instr.ClearField("batch_group_count") + root_instr.ClearField("statistics_viz") + + # Step 2: Update output shape of the entry computation + program_shape = entry_cpt.program_shape + program_shape.result.CopyFrom(output_shape) + + # Step 3: Clean up instructions inside the entry computation that are not + # for weight layout transformation + reduced_instrs = [] + for instr in entry_cpt.instructions: + is_not_weight = instr.opcode == "parameter" and instr.parameter_number not in all_weight_idx + if is_not_weight: + continue + reduced_instrs.append(instr) + entry_cpt.ClearField("instructions") + entry_cpt.instructions.extend(reduced_instrs) + + # Step 4: Clean up inputs that are not used in the entry computation + program_shape = entry_cpt.program_shape + reduced_param_names = [] + reduced_param_shapes = [] + input_id_mapping = {} + input_id_in_updated_hlo = 0 + for id, (param_name, param_shape) in enumerate(zip(program_shape.parameter_names, program_shape.parameters)): + if id not in weight_idx_list: + continue + reduced_param_names.append(param_name) + reduced_param_shapes.append(param_shape) + input_id_mapping[id] = input_id_in_updated_hlo + input_id_in_updated_hlo += 1 + + program_shape.ClearField("parameter_names") + program_shape.parameter_names.extend(reduced_param_names) + program_shape.ClearField("parameters") + program_shape.parameters.extend(reduced_param_shapes) + + for instr in entry_cpt.instructions: + if instr.opcode == "parameter": + instr.parameter_number = input_id_mapping[instr.parameter_number] + + # Step 5: Update the input and output of the HLO module + host_program_shape = wlt_hlo.host_program_shape + host_program_shape.result.CopyFrom(output_shape) + host_program_shape.ClearField("parameters") + host_program_shape.parameters.extend(reduced_param_shapes) + host_program_shape.ClearField("parameter_names") + host_program_shape.parameter_names.extend(reduced_param_names) + + return wlt_hlo + + +def prepare_metaneff_for_wlt_hlo( + wlt_hlo: hlo_pb2.HloModuleProto, + weight_name_to_idx: Dict[str, int], + ): + """ + Generate a metaneff for weight layout transformation HLO. + """ + metaneff = metaneff_pb2.MetaNeff() + weight_name_sorted_by_idx = sorted([(idx, name) for name, idx in weight_name_to_idx.items()]) + weight_name_sorted_by_idx = [name.replace("->", ".") for (idx, name) in weight_name_sorted_by_idx] + + entry_cpt = hlo_entry_computation(wlt_hlo) + # Prepare meta for input_tensors + for index, param_meta in enumerate(entry_cpt.program_shape.parameters): + input_tensor = metaneff.input_tensors.add() + # Needs to be `input#` to avoid a `ddrs_create_lookup_key` error + input_tensor.name = f"input{index}".encode("utf8") + input_tensor.shape[:] = list(param_meta.dimensions) + input_tensor.data_type = XLA_DTYPE_TO_METANEFF_DTYPE[param_meta.element_type] + input_tensor.type = metaneff_pb2.MetaTensor.Type.INPUT_WEIGHT + input_tensor.checkpoint_key = weight_name_sorted_by_idx[index].encode("utf8") + + # Prepare meta for output_tensors + for index, output_meta in enumerate(entry_cpt.program_shape.result.tuple_shapes): + output_tensor = metaneff.output_tensors.add() + output_tensor.name = f"output{index}".encode("utf8") + output_tensor.shape[:] = list(output_meta.dimensions) + output_tensor.data_type = XLA_DTYPE_TO_METANEFF_DTYPE[output_meta.element_type] + output_tensor.checkpoint_key = weight_name_sorted_by_idx[index].encode("utf8") + + return metaneff diff --git a/src/neuronx_distributed/trace/model_builder.py b/src/neuronx_distributed/trace/model_builder.py new file mode 100644 index 0000000..434f5de --- /dev/null +++ b/src/neuronx_distributed/trace/model_builder.py @@ -0,0 +1,586 @@ +import concurrent.futures +import multiprocessing +import os +import shutil +import time +import logging +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch_neuronx +import torch_neuronx.xla_impl +import torch_neuronx.xla_impl.trace +import torch_xla.distributed.xla_multiprocessing as xmp +from torch_neuronx import BucketModelConfig +from torch_neuronx.proto import metaneff_pb2 +from torch_neuronx.xla_impl.trace import get_torch_dtype, HloArtifacts + +from safetensors.torch import save_file, load_file + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.trace.spmd import ( + NxDModel, + NxDModelExecutor, + SPMDBucketModel, + SPMDBucketModelScript, + default_bucket_kernel, + StateInitializer) +from neuronx_distributed.trace.trace import _mock_parallel_state, get_sharded_checkpoint +from neuronx_distributed.utils.model_utils import init_on_device +import neuronx_distributed.trace.hlo_utils as hlo_utils + +ModelInputType = List[Union[Tuple[Union[torch.Tensor, List[torch.Tensor]]], torch.Tensor]] +logger = logging.getLogger("Neuron") + + +# TODO write a generic class which can accept a function as well +class BaseModelInstance: + def __init__(self, module_cls, input_output_aliases): + self.module_cls = module_cls + self.module = None + self.input_output_aliases = [input_output_aliases] + + def load_module(self): + self.module = self.module_cls() + + def get(self, bucket_rank, **kwargs): + return self.module, self.input_output_aliases[0] + + +class ModelContainer: + def __init__(self, model_instance, example_inputs, compiler_args, bucket_config, priority_model_idx): + self.model_instance: BaseModelInstance = model_instance + self.example_inputs = example_inputs + self.compiler_args = compiler_args + self.bucket_config: BucketModelConfig = bucket_config + self.priority_model_idx = priority_model_idx + self.hlo_artifact_collection = None + self.neff_artifact_collection = None + + # these are determined later through the trace function + self.num_params = None + self.num_user_inputs = None # accounts for excluded inputs + self.num_states = None + self.num_weights = None + +class JITWrapper(torch.nn.Module): + """ + Makes a python object like Flattener and Packer JIT traceable. + """ + def __init__(self, func, is_flattener): + super().__init__() + self.func = func + self.is_flattener = is_flattener + + def forward(self, inputs: List[torch.Tensor]): + # flattener expects a tuple while packer expects a list + if (self.is_flattener): + return self.func(tuple(inputs)) + else: + return self.func(inputs) + +class ModelBuilder: + def __init__( + self, + router, + tp_degree, + checkpoint_loader, + compiler_workdir=None, + master_proc_env_vars=None, + ): + if not torch_neuronx.__version__.startswith("2"): + raise AssertionError( + f"ModelBuilder requires torch-neuronx>=2.* but found torch-neuronx=={torch_neuronx.__version__}." + ) + + self.router = router + self.tp_degree = tp_degree + self.checkpoint_loader = checkpoint_loader + self.compiler_workdir = compiler_workdir if compiler_workdir else "/tmp/nxd_model/" + + self.model_collection: Dict[str, ModelContainer] = {} + self.master_proc_env_vars: Optional[Dict[str, str]] = master_proc_env_vars + + def add( + self, + key: str, + model_instance: BaseModelInstance, + example_inputs: ModelInputType, + compiler_args: Union[str, List[str]] = None, + bucket_config: BucketModelConfig = None, + priority_model_idx: int = None, + ) -> None: + """ + Adds a model to the model collection to be traced. + """ + if compiler_args is None: + compiler_args = "--enable-saturate-infinity --auto-cast=none --model-type=transformer -O1" + + # This does not validate if the HLOs are same across all ranks. + # _validate_traceable(model_instance.module, self.tp_degree, force_custom_init_on_device=True) + + if bucket_config: + bucket_config.store_example_inputs(example_inputs) + + self.model_collection[key] = ModelContainer( + model_instance, example_inputs, compiler_args, bucket_config, priority_model_idx + ) + return self + + def trace( + self, + tp_degree=None, + initialize_model_weights=True + ): + if tp_degree is None: + tp_degree = self.tp_degree + else: + self.tp_degree = tp_degree + + ctx = multiprocessing.get_context("spawn") + manager = ctx.Manager() + mp_q = manager.Queue() + + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "2022" + os.environ["NEURONCORE_NUM_DEVICES"] = str(tp_degree) # for pjrt + os.environ["WORLD_SIZE"] = str(tp_degree) + prev_sharing_strategy = torch.multiprocessing.get_sharing_strategy() + torch.multiprocessing.set_sharing_strategy("file_system") + + if self.master_proc_env_vars: + for env_var, val in self.master_proc_env_vars.items(): + os.environ[env_var] = val + + # Clean compiler working dir + if os.path.exists(self.compiler_workdir): + shutil.rmtree(self.compiler_workdir) + + num_hlos = 0 + logger.info(f"Generating HLOs for the following models: {list(self.model_collection.keys())}") + for key in self.model_collection: + model_artifacts = self.model_collection[key] + bucket_degree = 1 if not model_artifacts.bucket_config else model_artifacts.bucket_config.bucket_degree + num_hlos += bucket_degree + logger.info(f"Generating {bucket_degree} hlos for key: {key}") + xmp.spawn( + self._generate_hlo, + args=( + key, + mp_q, + ), + start_method="spawn", + nprocs=self.tp_degree, + ) + + hlo_artifact_collection = mp_q.get() + model_artifacts.hlo_artifact_collection = hlo_artifact_collection + hm = hlo_artifact_collection[0].hlo_module + id_to_computation = {cpt.id: cpt for cpt in hm.computations} + entry_computation = id_to_computation[hm.entry_computation_id] + model_artifacts.num_params = len([i for i in entry_computation.instructions if i.opcode == "parameter"]) + + self._mark_weight_in_priority_hlo() + + def submit_compilation_job(key, bucket_rank, args): + return key, bucket_rank, torch_neuronx.xla_impl.trace.generate_neff(*args) + + logger.info("Started compilation for all HLOs") + for key, model_artifacts in self.model_collection.items(): + # init placeholder for all hlo + model_artifacts.neff_artifact_collection = [None] * len(model_artifacts.hlo_artifact_collection) + + if model_artifacts.priority_model_idx is not None: + bucket_rank = model_artifacts.priority_model_idx + hlo_artifacts = model_artifacts.hlo_artifact_collection[bucket_rank] + + neff_artifacts = torch_neuronx.xla_impl.trace.generate_neff( + hlo_artifacts, + os.path.join(self.compiler_workdir, key, f"_tp0_bk{bucket_rank}"), + # TODO: improve these compiler flags if possiable + model_artifacts.compiler_args + " --enable-internal-neff-wrapper", + False, + ) + # The neff is still valid for this SPMD model + self.model_collection[key].neff_artifact_collection[bucket_rank] = neff_artifacts + logger.info("Done compilation for the priority HLO") + + self._add_layout_optimization_to_remaining_hlo() + + executor = concurrent.futures.ThreadPoolExecutor() + jobs = [] + for key, model_artifacts in self.model_collection.items(): + for bucket_rank, hlo_artifacts in enumerate(model_artifacts.hlo_artifact_collection): + if bucket_rank == model_artifacts.priority_model_idx: + # no need to compile the priority model again + continue + jobs.append( + executor.submit( + submit_compilation_job, + key, + bucket_rank, + ( + hlo_artifacts, + os.path.join(self.compiler_workdir, key, f"_tp0_bk{bucket_rank}"), + model_artifacts.compiler_args, + False, + ), + ) + ) + + for future in concurrent.futures.as_completed(jobs): + key, bucket_rank, neff_artifacts = future.result() + self.model_collection[key].neff_artifact_collection[bucket_rank] = neff_artifacts + + # Save metaneff + for key, model_artifacts in self.model_collection.items(): + for bucket_rank, hlo_artifacts in enumerate(model_artifacts.hlo_artifact_collection): + path = os.path.join(self.compiler_workdir, key, f"_tp0_bk{bucket_rank}", "metaneff.pb") + with open(path, 'wb') as f: + f.write(hlo_artifacts.metaneff) + + logger.info("Finished Compilation for all HLOs") + + logger.info("Finished Compilation for all HLOs") + + nxd_model_executor = self.build_nxd_model() + + if (initialize_model_weights): + self.shard_checkpoint(self.compiler_workdir) + + weights = [] + for rank in range(self.tp_degree): + ckpt = load_file(os.path.join(self.compiler_workdir, f"tp{rank}_sharded_checkpoint.safetensors")) + weights.append(ckpt) + + nxd_model_executor.nxd_model.initialize(weights) + logger.info("NxD Model Initialized") + + torch.multiprocessing.set_sharing_strategy(prev_sharing_strategy) + + return nxd_model_executor + + def shard_checkpoint(self, serialize_path): + if not os.path.exists(serialize_path): + os.makedirs(serialize_path) + + source_model_key = list(self.model_collection.keys())[0] + logger.info("Sharding Weights") + for rank in range(self.tp_degree): + self.shard_weights(rank, self.model_collection[source_model_key], serialize_path) + logger.info("Done Sharding weights") + + + def _generate_hlo( + self, + rank, + key, + mp_q, + ): + os.environ["RANK"] = str(rank) + torch.distributed.init_process_group("xla", init_method="pjrt://") + parallel_state.initialize_model_parallel(tensor_model_parallel_size=self.tp_degree) + + if rank == 0: + model_input_container = self.model_collection[key] + logger.info(f"Started loading module {key}") + start_time = time.time() + model_input_container.model_instance.load_module() + logger.info(f"Finished loading module {key} in {time.time() - start_time} seconds") + example_input_collection = model_input_container.example_inputs + bucket_config = model_input_container.bucket_config + + bucket_degree = 1 + if bucket_config is not None: + bucket_degree = bucket_config.bucket_degree + + hlo_artifact_collection = [] + for bucket_rank in range(bucket_degree): + example_inputs = example_input_collection[bucket_rank] + func_kwargs = ( + {} if bucket_config is None else bucket_config.get_func_kwargs_for_bucket_rank(bucket_rank) + ) + if "bucket_rank" in func_kwargs: + func_kwargs.pop("bucket_rank") # to avoid multiple definition of bucket_rank + func, input_output_aliases = model_input_container.model_instance.get(bucket_rank, **func_kwargs) + + hlo_artifacts = torch_neuronx.xla_impl.trace.generate_hlo( + func, example_inputs, input_output_aliases, False, False, False + ) + hlo_artifacts.metaneff = hlo_artifacts.metaneff.SerializeToString() + hlo_artifact_collection.append(hlo_artifacts) + + mp_q.put(hlo_artifact_collection) + + def shard_weights(self, rank, model_container: ModelContainer, serialize_path: str): + checkpoint = self.checkpoint_loader() + _mock_parallel_state(self.tp_degree, rank) + with init_on_device(torch.device("meta"), force_custom_init_on_device=True): + model_container.model_instance.load_module() + func_kwargs = ( + {} + if model_container.bucket_config is None + else model_container.bucket_config.get_func_kwargs_for_bucket_rank(0) + ) + if "bucket_rank" in func_kwargs: + func_kwargs.pop("bucket_rank") # to avoid multiple definition of bucket_rank + model, io_aliases = model_container.model_instance.get(0, **func_kwargs) + + get_sharded_checkpoint(checkpoint, model, rank, self.tp_degree) + + save_file(checkpoint, os.path.join(serialize_path, f"tp{rank}_sharded_checkpoint.safetensors")) + + def build_state_initializer(self): + shapes = {} + dtypes = {} + + # Take any metaneff + source_model_key = list(self.model_collection.keys())[0] + metaneff = metaneff_pb2.MetaNeff() + metaneff_str = self.model_collection[source_model_key].hlo_artifact_collection[0].metaneff + metaneff.ParseFromString(metaneff_str) + for tensor in metaneff.input_tensors: + if tensor.type is metaneff_pb2.MetaTensor.Type.INPUT_STATE: + # proto keys are bytes not strings, and casting as a string causes it to be "b'key'" + checkpoint_key = str(tensor.checkpoint_key).replace("b'","").replace("'","") + shapes[checkpoint_key] = list(tensor.shape) + dtypes[checkpoint_key] = get_torch_dtype(tensor.data_type) + if len(shapes): + return torch.jit.script(StateInitializer(shapes=shapes, dtypes=dtypes, tp_degree=self.tp_degree)) + else: + return None + + def build_flattener_map(self): + flattener_map = [] + for key, model_container in self.model_collection.items(): + flattener = JITWrapper(func=model_container.hlo_artifact_collection[0].flattener,is_flattener=True) + example_inputs = model_container.example_inputs + flattener_script = torch.jit.trace(flattener, ([*example_inputs[0]],), strict=False) + flattener_map.append((key, flattener_script)) + return torch.nn.ModuleDict(flattener_map) + + + def build_packer(self, packer): + # Take any metaneff + source_model_key = list(self.model_collection.keys())[0] + metaneff = metaneff_pb2.MetaNeff() + metaneff_str = self.model_collection[source_model_key].hlo_artifact_collection[0].metaneff + metaneff.ParseFromString(metaneff_str) + + # create example outputs from metaneff + example_outputs = [] + for i,meta_tensor in enumerate(metaneff.output_tensors): + if i not in metaneff.output_aliases_to: + example_outputs.append(torch.zeros(list(meta_tensor.shape),dtype=get_torch_dtype(meta_tensor.data_type))) + + # return jit traced packer + jit_wrapped_packer = JITWrapper(packer,False) + return torch.jit.trace(jit_wrapped_packer, (example_outputs,), strict=False) + + def build_nxd_model(self): + model_map_input = [] + for key, model_container in self.model_collection.items(): + + models = [ + (self._read_neff_from_path(neff_artifacts.neff_filename), hlo_artifacts.metaneff) + for hlo_artifacts, neff_artifacts in zip( + model_container.hlo_artifact_collection, model_container.neff_artifact_collection + ) + ] + + buckets = [torch.classes.neuron.SPMDModel(neff, metaneff, self.tp_degree) for neff, metaneff in models] + + spmd_bucket_model_executor = SPMDBucketModelScript(compiled_models=buckets) + with torch_neuronx.contexts.disable_nrt_load(): + spmd_bucket_model_executor = torch.jit.script(spmd_bucket_model_executor) + if model_container.bucket_config is None: + bucket_kernel = torch.jit.script(default_bucket_kernel) + bucket_kernel_constant_args = () + else: + bucket_kernel = model_container.bucket_config.bucket_kernel() + bucket_kernel_constant_args = model_container.bucket_config.bucket_kernel_constant_args + spmd_bucket_model = SPMDBucketModel( + bucket_kernel, + bucket_kernel_constant_args, + spmd_bucket_model_executor + ) + with torch_neuronx.contexts.disable_nrt_load(): + spmd_bucket_model = torch.jit.script(spmd_bucket_model) + model_map_input.append((key, spmd_bucket_model)) + + state_initializer = self.build_state_initializer() + + model_map = torch.nn.ModuleDict(model_map_input) + + flattener_map = self.build_flattener_map() + + input_shape_map = {} + # use to jit trace NxDModelExecutor + example_inputs = None + for key, model_container in self.model_collection.items(): + # example_inputs is of type List[Tuple[Tensor, Tensor, ...]] + example_inputs = model_container.example_inputs + for example_input in example_inputs: + # torch.Size type is not a concept in a jit model, it's just List[int] + input_shape_map[str([list(tensor.shape) for tensor in example_input])] = key + + packer = next(iter(self.model_collection.values())).hlo_artifact_collection[0].packer + traced_packer = self.build_packer(packer) + + # Get weight layout transformation model + wlt_model = self._prepare_weight_layout_transform_model() + + with torch_neuronx.contexts.disable_nrt_load(): + nxd_model = NxDModel( + models=model_map, + tp_degree=self.tp_degree, + flattener_map=flattener_map, + input_shape_map=input_shape_map, + packer=traced_packer, + state_initializer=state_initializer, + weight_loader=wlt_model, + ) + with torch_neuronx.contexts.disable_nrt_load(): + nxd_model = torch.jit.script(nxd_model) + + # mock model as initialized so jit trace doesn't fail + nxd_model.mock_initialization(True) + nxd_model_executor = torch.jit.trace(NxDModelExecutor(nxd_model),example_inputs[0],strict=False) + nxd_model_executor.nxd_model.mock_initialization(False) + + return nxd_model_executor + + def _read_neff_from_path(self, neff_path: str): + with open(neff_path, "rb") as f: + return f.read() + + def _get_priority_hlo_artifact(self) -> HloArtifacts: + for model_artifacts in self.model_collection.values(): + if model_artifacts.priority_model_idx is not None: + return model_artifacts.hlo_artifact_collection[model_artifacts.priority_model_idx] + return None + + def _should_optimize_layout(self): + return self._get_priority_hlo_artifact() is not None + + def _mark_weight_in_priority_hlo(self): + """ + Mark weights in the priority HLO, so compiler will suggest optimal + layout for the weights. + """ + if not self._should_optimize_layout(): + logger.info("Can't find a priority model, skip marking weights") + return + priority_hlo_artifacts = self._get_priority_hlo_artifact() + + hlo_utils.add_weight_idx_attr_to_hlo( + hlo=priority_hlo_artifacts.hlo_module, + weight_name_to_idx=priority_hlo_artifacts.weight_name_to_idx, + ) + + def _get_hlo_stub(self): + """ + Read the HLO stub if it is there, otherwise return None + """ + neff_artifacts = None + for model_artifacts in self.model_collection.values(): + if model_artifacts.priority_model_idx is not None: + neff_artifacts = model_artifacts.neff_artifact_collection[model_artifacts.priority_model_idx] + assert neff_artifacts.neff_filename is not None, "Can't find the path for the NEFF from the priority model" + hlo_stub_filepath = neff_artifacts.neff_filename.replace("graph.neff", "wrapped_neff.hlo") + + if os.path.exists(hlo_stub_filepath): + return hlo_utils.read_hlo(hlo_stub_filepath) + else: + return None + + def _add_layout_optimization_to_remaining_hlo(self): + """ + Apply the layout transformation suggestion from the priority HLO to + other HLOs, so they all can benefit. + + This is a no-op if there is no suggestion on weight layout. + """ + if not self._should_optimize_layout(): + logger.info("Can't find a priority model, skip optimizing weight layout for other HLOs") + return + + hlo_stub = self._get_hlo_stub() + if hlo_stub is None: + logger.info("No changes on weight layout, skip updating weight layout for other HLOs") + return + + priority_hlo_artifacts = self._get_priority_hlo_artifact() + weight_name_to_transform_cpt = hlo_utils.get_layout_transform_map( + hlo_stub=hlo_stub, + weight_name_to_idx=priority_hlo_artifacts.weight_name_to_idx + ) + + for model_artifacts in self.model_collection.values(): + for bucket_rank, hlo_artifacts in enumerate(model_artifacts.hlo_artifact_collection): + if bucket_rank == model_artifacts.priority_model_idx: + continue + hlo_utils.append_layout_computation_to_hlo(hlo_artifacts, weight_name_to_transform_cpt) + logger.info("Done optimizing weight layout for all HLOs") + + def _prepare_weight_layout_transform_model(self): + """ + Generate a NEFF for weight layout transformation, which will be run on + device before actual inference. + + This will return None if there is no changes on weight layout. + """ + if not self._should_optimize_layout(): + logger.info("Can't find a priority model, falling back to the existing weight layout") + return + + hlo_stub = self._get_hlo_stub() + if hlo_stub is None: + logger.info("No changes on weight layout, falling back to the existing weight layout") + return + + # Clear existing dir + layout_dir = os.path.join(self.compiler_workdir, "layout_opt") + if os.path.exists(layout_dir): + shutil.rmtree(layout_dir) + os.makedirs(layout_dir) + + # Prepare HLO + weight_name_to_idx = self._get_priority_hlo_artifact().weight_name_to_idx + wlt_hlo = hlo_utils.extract_weight_layout_transform_hlo( + hlo_stub=hlo_stub, + weight_name_to_idx=weight_name_to_idx, + ) + + metaneff = hlo_utils.prepare_metaneff_for_wlt_hlo( + wlt_hlo=wlt_hlo, + weight_name_to_idx=weight_name_to_idx, + ) + metaneff_str = metaneff.SerializeToString() + metaneff_path = os.path.join(layout_dir, "metaneff") + with open(metaneff_path, "wb") as f: + f.write(metaneff_str) + + wlt_hlo_artifact = HloArtifacts( + hlo_module=wlt_hlo, + flattener=None, + packer=None, + metaneff=metaneff, + weights=None, + constant_parameter_tensors=None, + weight_name_to_idx=weight_name_to_idx, + ) + + # Generate NEFF + wlt_neff_artifact = torch_neuronx.xla_impl.trace.generate_neff( + wlt_hlo_artifact, + compiler_workdir=layout_dir, + compiler_args="--model-type=transformer -O1", + inline_weights_to_neff=False, + ) + wlt_neff = self._read_neff_from_path(wlt_neff_artifact.neff_filename) + + # Build the model on runtime + wlt_model = torch.classes.neuron.LayoutTransformation(wlt_neff, metaneff_str, self.tp_degree) + logger.info("Done preparing weight layout transformation") + return wlt_model diff --git a/src/neuronx_distributed/trace/spmd.py b/src/neuronx_distributed/trace/spmd.py new file mode 100644 index 0000000..e5952b8 --- /dev/null +++ b/src/neuronx_distributed/trace/spmd.py @@ -0,0 +1,187 @@ +from typing import List, Dict + +import torch +from torch_neuronx.xla_impl import structure + +def default_bucket_kernel(inputs: List[torch.Tensor]): + return inputs, torch.tensor(0).to(torch.int) + +class SPMDBucketModelScript(torch.nn.Module): + """ + BucketModelScript mostly remains the same. Just that he data needs to be passed in. + """ + + def __init__(self, compiled_models: List[torch.classes.neuron.SPMDModel]): + super().__init__() + self.models = compiled_models + + def forward( + self, + inputs: List[torch.Tensor], + bucket_idx_tensor: torch.Tensor, + ): + bucket_idx = torch.ops.aten.Int(bucket_idx_tensor) + initialized = self.models[bucket_idx].is_initialized() + + if initialized: + output = self.models[bucket_idx].forward(inputs) + return output + else: + raise ValueError("This model is not initialized, please call traced_model.nxd_model.initialize(sharded_checkpoint) or traced_model.nxd_model.initialize_with_saved_weights()") + +class SPMDBucketModel(torch.nn.Module): + """ + This classes implements bucketing with the SPMDModel runtime class. + The major difference from torch_neuronx's BucketModel class is that the + weights are not registered in this class but passed in as parameter in the forward. + """ + + def __init__( + self, + bucket_kernel, + bucket_kernel_constant_args, + bucket_model_executor: SPMDBucketModelScript, + ): + super().__init__() + # bucket kernel & preprocessors goes here + # weights and states are passed in + self.bucket_kernel = bucket_kernel + self.bucket_kernel_constant_args = bucket_kernel_constant_args + + self.bucket_model_executor = bucket_model_executor + + def forward( + self, + inputs: List[torch.Tensor], + ): + preprocessed_inputs, bucket_idx_tensor = self.bucket_kernel( + inputs, *self.bucket_kernel_constant_args + ) + + return self.bucket_model_executor(preprocessed_inputs, bucket_idx_tensor) + +class StateInitializer(torch.nn.Module): + # torchscript cannot script dict of with values of different types + # so we store shapes and dtypes in separate dicts + def __init__(self, shapes, dtypes, tp_degree): + super().__init__() + self.shapes = shapes + self.dtypes = dtypes + self.tp_degree = tp_degree + + def forward(self): + results : List[Dict[str, torch.Tensor]]= [] + for rank in range(0, self.tp_degree): + states = {} + for key in self.shapes.keys(): + states[key] = torch.zeros(self.shapes[key], dtype=self.dtypes[key], device=f"privateuseone:{rank}") + results.append(states) + + return results + +class NxDModel(torch.nn.Module): + """ + NxDModel runs houses multiple SPMD bucket models that share the same weights. + Note: The weights must be sharded the same way across all SPMD bucket models. + """ + + def __init__( + self, + models: torch.nn.ModuleDict, + tp_degree: int, + flattener_map: torch.nn.ModuleDict, + input_shape_map: Dict[str, str], + packer: structure.Packer, + state_initializer: StateInitializer, + weight_loader: torch.classes.neuron.LayoutTransformation + ): + super().__init__() + self.models = models + self.flattener_map = flattener_map + self.packer = packer + self.input_shape_map =input_shape_map + self.state_initializer = state_initializer + self.tp_degree = tp_degree + self.weight_loader = weight_loader + + # default values only used for scripting to prevent it complaining about empty lists + self.weights: List[Dict[str, torch.Tensor]] = [{'__neuronprivatetensor__':torch.tensor(0)}] + self.state: List[Dict[str, torch.Tensor]] = [{'__neuronprivatetensor__':torch.tensor(0)}] + + def initialize_spmd_models( + self, + states: List[Dict[str, torch.Tensor]], + weights: List[Dict[str, torch.Tensor]] + ): + for bucket_model in self.models.values(): + for model in bucket_model.bucket_model_executor.models: + model.initialize(states, weights) + + @torch.jit.export + def initialize(self, checkpoint: List[Dict[str, torch.Tensor]]): + if self.weight_loader is not None: + self.weights = self.weight_loader.forward(checkpoint, False) + else: + self.weights = torch.ops.neuron._parallel_load(checkpoint) + + if (self.state_initializer is not None): + self.state = self.state_initializer() + self.initialize_spmd_models(self.state,self.weights) + else: + self.initialize_spmd_models([],self.weights) + + @torch.jit.export + def initialize_with_saved_weights(self): + if (self.state_initializer is not None): + self.state = self.state_initializer() + self.initialize_spmd_models(self.state,self.weights) + else: + self.initialize_spmd_models([],self.weights) + + @torch.jit.unused + def mock_initialization(self, mock: bool): + """ + This function is only used once for jit tracing, + and won't be serialized on jit.save + """ + for script_model in self.models.modules(): + if script_model.original_name == SPMDBucketModel.__name__: + for model in script_model.bucket_model_executor.models: + model.set_mock_initialized(mock) + + def router(self, inputs: List[torch.Tensor]) -> str: + actual_shape = str([tensor.shape for tensor in inputs]) + return self.input_shape_map[actual_shape] + + def forward(self, inputs: List[torch.Tensor]): + """ """ + model_name = self.router(inputs) + + # Initialize empty tensor to ensure jit.script gets the write type + flattened_inputs : List[torch.Tensor] = [torch.zeros(0)] + + # torch.jit.script does not allow indexing of ModuleDict + # so we work around by looping and conditionally executing it + for name, flattener in self.flattener_map.items(): + if name == model_name: + flattened_inputs = flattener(inputs) + + result: List[torch.Tensor] = [torch.zeros(0)] + for name, model in self.models.items(): + if name == model_name: + result = model.forward(flattened_inputs) + + result = self.packer(result) + return result + +class NxDModelExecutor(torch.nn.Module): + """ + Wraps over jit scripted NxDModel class + so traced model can be executed like traced_model(*inputs) + """ + def __init__(self,nxd_model): + super().__init__() + self.nxd_model = nxd_model + + def forward(self, *inputs): + return self.nxd_model(list(inputs)) diff --git a/src/neuronx_distributed/trace/trace.py b/src/neuronx_distributed/trace/trace.py index fba5d3d..312378d 100644 --- a/src/neuronx_distributed/trace/trace.py +++ b/src/neuronx_distributed/trace/trace.py @@ -6,7 +6,9 @@ import pathlib import shutil from collections import defaultdict -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Union, Tuple +from typing import cast + import torch import torch_neuronx @@ -95,7 +97,10 @@ def forward(self, *tensors): if not self.load: self._load() results = [] - futures = [self.executor.submit(model, *tensors) for model in self.models] + if self.executor is not None: + futures = [self.executor.submit(model, *tensors) for model in self.models] + else: + raise RuntimeError("executor is None although it has to be properly initialized") for future in concurrent.futures.as_completed(futures): results.append(future.result()) # Here we are making the assumption that we are operating in SPMD mode. @@ -169,7 +174,7 @@ def _trace( inline_weights_to_neff: bool = True, bucket_config: Optional[BucketModelConfig] = None, tp_degree: int = 1, - max_parallel_compilations: int = None, + max_parallel_compilations: Optional[int] = None, ) -> None: os.environ["RANK"] = str(rank) if requires_init_pg_override(): @@ -189,7 +194,6 @@ def _trace( example_inputs = [example_inputs] for tp_rank in range(tp_degree): - artifacts_collection = [] if rank == tp_rank: for bucket_rank in range(bucket_degree): # Set flag to stop parallel_layes.load() from waiting on all @@ -234,11 +238,11 @@ def parallel_model_trace( inline_weights_to_neff: bool = True, bucket_config: Optional[BucketModelConfig] = None, tp_degree: int = 1, - max_parallel_compilations: int = None, + max_parallel_compilations: Optional[int] = None, spmd_mode: bool = False, checkpoint_loader_callable: Optional[Callable] = None, force_custom_init_on_device: bool = False, - serialization_path: str = None, + serialization_path: Optional[str] = None, ) -> ParallelModel: """ Trace a distributed module/function to produce a compiled Neuron ScriptModule. @@ -267,13 +271,13 @@ def parallel_model_trace( the rank specific weights to generate the other ranks. checkpoint_loader_callable: A callable method to load the model's checkpoint. When using spmd_mode, checkpoint_loader_callable is a required argument. - force_custom_init_on_device: Bool to indidcate whether to force use custom init_on_device functionality + force_custom_init_on_device: Bool to indicate whether to force use custom init_on_device functionality NOTE: If you are trying to use it for Quantized api, make sure this bool is set to True serialization_path: A path to store the serialized traced model if provided. Currently only works for SPMD mode. Returns: A wrapper Module which wraps individual HLO computation which is a - fused neuron::foward operation. + fused neuron::forward operation. """ if bucket_config is not None and inline_weights_to_neff: @@ -315,7 +319,7 @@ def parallel_model_trace( models = _spmd_trace( func, example_inputs, - checkpoint_loader_callable, + cast(Callable, checkpoint_loader_callable), tp_degree, states, compiler_workdir, @@ -326,7 +330,7 @@ def parallel_model_trace( ) else: logging.warn( - f"Using non SPMD mode. Set spmd_mode=True if the worlkload is SPMD for a faster trace. Tracing in non SPMD mode for large models can run into OOM errors as we compile all ranks" + "Using non SPMD mode. Set spmd_mode=True if the worlkload is SPMD for a faster trace. Tracing in non SPMD mode for large models can run into OOM errors as we compile all ranks" ) xmp.spawn( _trace, @@ -340,7 +344,7 @@ def parallel_model_trace( inline_weights_to_neff, bucket_config, tp_degree, - max_parallel_compilations if max_parallel_compilations != None else tp_degree, + max_parallel_compilations if max_parallel_compilations is not None else tp_degree, ), start_method="spawn", nprocs=tp_degree, @@ -367,13 +371,13 @@ def parallel_model_save(model: ParallelModel, save_dir: str) -> None: def find_unique_dtypes(model): state_dict = model.state_dict() - dtype_map = defaultdict(int) + dtype_map: defaultdict = defaultdict(int) for _, value in state_dict.items(): dtype_map[value.dtype] += 1 return dict(dtype_map) -def _load_script_modules(model_dir: str) -> List[torch.ScriptModule]: +def _load_script_modules(model_dir: str) -> Tuple[List[Any], List[str]]: models = [] with torch_neuronx.contexts.disable_nrt_load(): model_rank_files = sorted( @@ -420,7 +424,7 @@ def _spmd_trace( compiler_args: Optional[Union[List[str], str]] = None, bucket_config: Optional[BucketModelConfig] = None, force_custom_init_on_device: bool = False, - serialization_path: str = None, + serialization_path: Optional[str] = None, ): """ Xla trace a signle rank and compile it with neuronx-cc. @@ -506,8 +510,12 @@ def _validate_traceable(func: Callable, tp_degree: int, force_custom_init_on_dev with init_on_device(torch.device("meta"), force_custom_init_on_device=force_custom_init_on_device): model, _ = func() + assert isinstance( + model, torch.nn.Module + ), "The first return value of func is expected to be of type torch.nn.Module" + def _validate_children(module: torch.nn.Module): - if module == None: + if module is None: return # Sharding across vocab dimension requires rank level constants for intput masking. @@ -598,6 +606,16 @@ def _load_weights( with init_on_device(torch.device("meta"), force_custom_init_on_device=force_custom_init_on_device): model, _ = func() + get_sharded_checkpoint(checkpoint, model, rank, tp_degree) + + with torch_neuronx.contexts.disable_nrt_load(): + rank_0_path = os.path.join(compiler_workdir, "tp_0.pt") + traced_model = torch.jit.load(rank_0_path) + replace_weights(traced_model, checkpoint) + return traced_model + + +def get_sharded_checkpoint(checkpoint, model, rank, tp_degree): invoke_preshard_hook(model, checkpoint, "") dtype = None @@ -607,12 +625,20 @@ def _load_weights( # Shards the checkpoint to the right weight for the rank shard_children(model, checkpoint, "", dtype, rank, tp_degree) - with torch_neuronx.contexts.disable_nrt_load(): - rank_0_path = os.path.join(compiler_workdir, "tp_0.pt") - traced_model = torch.jit.load(rank_0_path) - replace_weights(traced_model, checkpoint) - return traced_model +def create_local_weight_qkv(rank, world_size, full_weight, partition_dim, q_len, kv_len, out_weight=None): + # Shard q,k,v weights separately and then fuse them for each rank + q_weight, k_weight, v_weight = torch.split(full_weight, [q_len, kv_len, kv_len], dim=partition_dim) + q_weight_list = torch.split(q_weight, divide(q_len, world_size), dim=partition_dim)[rank::world_size] + k_weight_list = torch.split(k_weight, divide(kv_len, world_size), dim=partition_dim)[rank::world_size] + v_weight_list = torch.split(v_weight, divide(kv_len, world_size), dim=partition_dim)[rank::world_size] + + with torch.no_grad(): + return torch.cat(( + torch.cat(q_weight_list, dim=partition_dim), + torch.cat(k_weight_list, dim=partition_dim), + torch.cat(v_weight_list, dim=partition_dim), + ), dim=partition_dim, out=out_weight) def create_local_weight(rank, world_size, full_weight, partition_dim, per_partition_size, stride, out_weight=None): per_partition_per_stride_size = divide(per_partition_size, stride) @@ -635,11 +661,12 @@ def __init__(self, world_size): self.world_size = world_size def size(self): - return tp_degree + return self.world_size parallel_state._TENSOR_MODEL_PARALLEL_GROUP = Mock(tp_degree) parallel_state._DATA_PARALLEL_GROUP = Mock(1) parallel_state._MPU_TENSOR_MODEL_PARALLEL_RANK = rank + parallel_state._EXPERT_MODEL_PARALLEL_GROUP = Mock(1) def invoke_preshard_hook(module, checkpoint, prefix): @@ -647,7 +674,7 @@ def invoke_preshard_hook(module, checkpoint, prefix): Preshard hooks are hooks to manipulate checkpoints Checkpoint manipulation for GQA replication is one usecase. """ - if module == None: + if module is None: return # This is temporary until we formailze the preshard_hook in src @@ -665,7 +692,7 @@ def shard_children(module, checkpoint, prefix, dtype, rank, tp_degree): Checkpoint weights are sharded based on rank and tp_degree """ - if module == None: + if module is None: return for name, child in module._modules.items(): @@ -675,7 +702,6 @@ def shard_children(module, checkpoint, prefix, dtype, rank, tp_degree): if not isinstance(module, __SUPPORTED_SHARDED_MODULES): return - module_parameter: torch.nn.Parameter = None for module_parameter_name, module_parameter in module.named_parameters(): parameter_name = prefix + module_parameter_name @@ -696,8 +722,15 @@ def shard_children(module, checkpoint, prefix, dtype, rank, tp_degree): stride = module_parameter.partition_stride per_partition_size = tensor.shape[partition_dim] // tp_degree - checkpoint[parameter_name] = create_local_weight( - rank, tp_degree, tensor, partition_dim, per_partition_size, stride - ) + if hasattr(module_parameter, "fused_qkv"): + query_len = module_parameter.num_attention_heads * module_parameter.head_dim + kv_len = module_parameter.num_key_value_heads * module_parameter.head_dim + checkpoint[parameter_name] = create_local_weight_qkv( + rank, tp_degree, tensor, partition_dim, query_len, kv_len + ) + else: + checkpoint[parameter_name] = create_local_weight( + rank, tp_degree, tensor, partition_dim, per_partition_size, stride + ) else: checkpoint[parameter_name] = tensor diff --git a/src/neuronx_distributed/trainer/__init__.py b/src/neuronx_distributed/trainer/__init__.py index a1537e2..5329c27 100644 --- a/src/neuronx_distributed/trainer/__init__.py +++ b/src/neuronx_distributed/trainer/__init__.py @@ -1,6 +1,9 @@ -from .checkpoint import load_checkpoint, save_checkpoint -from .trainer import ( - initialize_parallel_model, - initialize_parallel_optimizer, - neuronx_distributed_config, -) +from .checkpoint import load_checkpoint, save_checkpoint # noqa: F401 +from .post_partition_hooks import PostPartitionHooks + +hooks = PostPartitionHooks() +from .trainer import ( # noqa: E402, F401 + initialize_parallel_model, # noqa: E402 + initialize_parallel_optimizer, # noqa: E402 + neuronx_distributed_config, # noqa: E402 +) # noqa: E402 diff --git a/src/neuronx_distributed/trainer/checkpoint.py b/src/neuronx_distributed/trainer/checkpoint.py index 15ab136..3b8678e 100644 --- a/src/neuronx_distributed/trainer/checkpoint.py +++ b/src/neuronx_distributed/trainer/checkpoint.py @@ -2,8 +2,10 @@ import gc import math import os +import re from datetime import datetime -from typing import List, Tuple +from packaging import version +from typing import List, Tuple, Optional, Any, Dict import torch import torch_xla @@ -11,11 +13,26 @@ import torch_xla.utils.serialization as xser from neuronx_distributed.optimizer import NeuronZero1Optimizer + +if version.parse(torch.__version__) >= version.parse("2.1"): + HAVE_DCP_SUPPORT = True + import neuronx_distributed.optimizer.zero_dcp_utils as dcp_utils +else: + HAVE_DCP_SUPPORT = False + dcp_utils = None + from neuronx_distributed.parallel_layers.parallel_state import ( get_data_parallel_group, get_data_parallel_rank, + get_expert_data_parallel_rank, + get_expert_model_parallel_rank, + get_expert_data_parallel_size, + get_expert_model_parallel_size, + get_expert_data_parallel_group, + get_expert_model_parallel_group, get_pipeline_model_parallel_rank, get_tensor_model_parallel_rank, + model_parallel_is_initialized, ) from neuronx_distributed.parallel_layers.utils import ( get_local_world_size, @@ -30,18 +47,19 @@ logger = get_logger() -def _get_path(prefix, tp=True, pp=True, dp=False): +def _get_path(prefix: str, tp: bool = True, pp: bool = True, dp: bool = False, ep: bool = False) -> str: path = "" path += "_dp_rank_{:02d}".format(get_data_parallel_rank() if dp else 0) + path += "_ep_rank_{:02d}".format(get_expert_model_parallel_rank()) if ep else "" path += "_tp_rank_{:02d}".format(get_tensor_model_parallel_rank() if tp else 0) path += "_pp_rank_{:02d}".format(get_pipeline_model_parallel_rank() if pp else 0) if path != "": path = path[1:] path += ".pt" - return "{}/{}".format(prefix, path) + return f"{prefix}/{path}" -def _determine_remove_tags(checkpoint_dir: BaseCheckpointStorage, num_kept: int): +def _determine_remove_tags(checkpoint_dir: BaseCheckpointStorage, num_kept: int) -> List[str]: """ deteremine checkpoint tags to be removed to satisfy num_kept return value: a list of tags @@ -70,8 +88,10 @@ def _determine_remove_tags(checkpoint_dir: BaseCheckpointStorage, num_kept: int) return remove_tags +# Global ThreadPoolExecutor to avoid reinitialization +_executor = None -def _bulk_save(checkpoint_dir: BaseCheckpointStorage, save_items: List[Tuple[object, str]]): +def _bulk_save(checkpoint_dir: BaseCheckpointStorage, save_items: List[Tuple[object, str]]) -> None: for obj, filename in save_items: checkpoint_dir.save_object(obj, filename) @@ -91,28 +111,33 @@ def __init__(self, async_save: bool = False): if self._async_save: self._checkpoint_dir = None - self._executor = concurrent.futures.ProcessPoolExecutor(max_workers=1) - self._save_items: list[(torch.Tensor, src)] = list() - self._save_task: concurrent.futurex = None - self._remove_tags: list[str] = None - self._remove_task: concurrent.future = None - - def begin(self, checkpoint_dir: BaseCheckpointStorage, tag: str): + global _executor + _executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + self._save_items: List[Tuple[torch.Tensor, str]] = [] + self._save_task: Optional[concurrent.futures.Future] = None + self._dcp_save_items: List[Tuple[Any]] = [] + self._dcp_save_task: Optional[concurrent.futures.Future] = None + self._remove_tags: Optional[List[str]] = None + self._remove_task: Optional[concurrent.futures.Future] = None + self._num_kept: Optional[int] + + def begin(self, checkpoint_dir: BaseCheckpointStorage, tag: str) -> None: self._checkpoint_dir = checkpoint_dir if self._async_save and self._current_tag is not None: self.wait_save(async_remove=True) self._current_tag = tag + xm.rendezvous("create ckpt dir after every neuron-core has filepath") if torch.distributed.get_rank() == 0: method = "async" if self._async_save else "synced" - logger.info(f"{method} saving of checkpoint {tag} began") + logger.info("%s saving of checkpoint %s began", method, tag) self._checkpoint_dir.create_dir(self._current_tag) # create a "checkpoint" tag to mark the directory as checkpoint directory # this is to distinguish checkpoint from users' own data directory under output directory self._checkpoint_dir.save_text("1", os.path.join(self._current_tag, "checkpoint")) - def add_save_task(self, obj: object, filename: str): + def add_save_task(self, obj: Any, filename: str) -> None: assert filename.startswith(self._current_tag + "/") relative_filename = filename[len(self._current_tag) + 1 :] self._relative_filenames.add(relative_filename) @@ -122,60 +147,88 @@ def add_save_task(self, obj: object, filename: str): assert self._checkpoint_dir self._checkpoint_dir.save_object(obj, filename) - def end(self, num_kept: int): + def add_dcp_save_task(self, checkpoint_dir: BaseCheckpointStorage, state_dict: dict, optimizer, model, ckpt_path): + path = os.path.join(checkpoint_dir.dirname(), ckpt_path, "optim") + aux_infos = dcp_utils.get_dcp_aux_infos(model, optimizer) + state_dict_cpu = move_all_tensor_to_cpu(state_dict) + if self._async_save: + self._dcp_save_items.append((path, state_dict_cpu, aux_infos)) + else: + dcp_utils.save_optim_state_dict(path, state_dict_cpu, aux_infos) + + def _dealloc_tensor_host_memory_callback(self, future): + """Future callback to asynchronous deallocate the tensor host memory + """ + self._save_items = [] + gc.collect() + + def end(self, num_kept: int) -> None: if self._async_save: self._num_kept = num_kept if len(self._save_items) > 0: - self._save_task = self._executor.submit(_bulk_save, self._checkpoint_dir, self._save_items) - if torch.distributed.get_rank() == 0: - logger.info(f"async saving of checkpoint {self._current_tag} requested") + self._save_task = _executor.submit(_bulk_save, self._checkpoint_dir, self._save_items) + # After save async thread is finished, use callback to async dealloc the tensor host memory + self._save_task.add_done_callback(self._dealloc_tensor_host_memory_callback) + if len(self._dcp_save_items) > 0: + self._dcp_save_task = _executor.submit(dcp_utils.save_optim_state_dict, *(self._dcp_save_items[0])) + # After save async thread is finished, use callback to async dealloc the tensor host memory + self._dcp_save_task.add_done_callback(self._dealloc_tensor_host_memory_callback) + logger.info("async saving of checkpoint %s requested", self._current_tag) else: xm.rendezvous("saving checkpoint done") if torch.distributed.get_rank() == 0: - logger.info(f"synced saving of checkpoint {self._current_tag} completed") + logger.info("synced saving of checkpoint %s completed", self._current_tag) self._checkpoint_dir.save_text("1", os.path.join(self._current_tag, "done")) xm.rendezvous("mark checkpoint as done") self.submit_remove(num_kept, async_remove=False) - def wait_save(self, async_remove): + def wait_save(self, async_remove: bool) -> None: if not self._async_save: return # first wait for save to finish - tasks = [] if self._save_task: done, _ = concurrent.futures.wait([self._save_task]) for f in done: if f.exception(): raise f.exception() + if self._dcp_save_task: + done, _ = concurrent.futures.wait([self._dcp_save_task]) + for f in done: + if f.exception(): + raise f.exception() - xm.rendezvous(f"async saving checkpoint done") + xm.rendezvous("async saving checkpoint done") if self._save_task: self._save_task = None + # This is already asynchronously invoked in _dealloc_tensor_host_memory_callback + # However, this needs to be kept for the sync checkpointing case self._save_items = [] - if torch.distributed.get_rank() == 0: - self._checkpoint_dir.save_text("1", os.path.join(self._current_tag, "done")) + if self._dcp_save_task: + self._dcp_save_task = None + self._dcp_save_items = [] + if torch.distributed.get_rank() == 0: + self._checkpoint_dir.save_text("1", os.path.join(self._current_tag, "done")) - xm.rendezvous(f"mark checkpoint as done") + xm.rendezvous("mark checkpoint as done") - if torch.distributed.get_rank() == 0: - logger.info(f"async saving of checkpoint {self._current_tag} completed") + logger.info("async saving of checkpoint %s completed", self._current_tag) # remove checkpoint if necessary. self.wait_remove() self.submit_remove(self._num_kept, async_remove=async_remove) - def submit_remove(self, num_kept: int, async_remove: bool, remove_tags: List[str] = []): + def submit_remove(self, num_kept: int, async_remove: bool, remove_tags: Optional[List[str]] = None) -> None: + remove_tags = remove_tags or [] remove_tags = remove_tags if len(remove_tags) else _determine_remove_tags(self._checkpoint_dir, num_kept) xm.rendezvous("determine remove tags done") if len(remove_tags) == 0: - if torch.distributed.get_rank() == 0: - logger.info(f"no checkpoints to remove.") + logger.info("no checkpoints to remove.") return if torch.distributed.get_rank() == 0: - logger.info(f"removing previous checkpoint in {remove_tags}") + logger.info("removing previous checkpoint in %s", remove_tags) # remove the done file first to avoid the situation # the deletion was interrupted by faults, leaving # a corrupted checkpoint_dir with the "done" tag. @@ -186,7 +239,7 @@ def submit_remove(self, num_kept: int, async_remove: bool, remove_tags: List[str completed_tags.append(remove_tag) self._checkpoint_dir.remove_file(done_file) - logger.info(f"done tags in {completed_tags} cleared") + logger.info("done tags in %s cleared", completed_tags) remove_filenames = [] for remove_tag in remove_tags: @@ -195,18 +248,17 @@ def submit_remove(self, num_kept: int, async_remove: bool, remove_tags: List[str if async_remove: self._remove_tags = remove_tags - self._remove_task = self._executor.submit(self._checkpoint_dir.remove_files, remove_filenames) - if torch.distributed.get_rank() == 0: - logger.info(f"async removal of {self._remove_tags} requested.") + self._remove_task = _executor.submit(self._checkpoint_dir.remove_files, remove_filenames) + logger.info("async removal of %s requested.", self._remove_tags) else: self._checkpoint_dir.remove_files(remove_filenames) xm.rendezvous("remove files done") # wait until everyone deleted the files they wrote, then rank 0 delete what were left if torch.distributed.get_rank() == 0: self._checkpoint_dir.remove_dirs(remove_tags) - logger.info(f"previous checkpoint in {remove_tags} successfully removed") + logger.info("previous checkpoint in %s successfully removed", remove_tags) - def wait_remove(self): + def wait_remove(self) -> None: if self._remove_task: done, _ = concurrent.futures.wait([self._remove_task]) for f in done: @@ -216,24 +268,24 @@ def wait_remove(self): xm.rendezvous("remove files done") if torch.distributed.get_rank() == 0: self._checkpoint_dir.remove_dirs(self._remove_tags) - logger.info(f"async removal of {self._remove_tags} completed") + logger.info("async removal of %s completed", self._remove_tags) self._remove_tags = None self._remove_task = None # This rendevous is necessary, since it avoids the race condition that # can occur when each worker is trying to find the next set of files to # delete. Its a corner case, where worker 0 is still deleting, and worker - # 1 has moved on to the submit_remove task. It tries to find_files, and + # 1 has moved on to the submit_remove task. It tries to find_files, and # in the process runs into a race condition, resulting in file not found # error. xm.rendezvous("Wait for all workers to come from deletion") - def wait_all(self): - # when this function is called, ProcessPool may have been shutdown. + def wait_all(self) -> None: + # when this function is called, ThreadPool may have been shutdown. # there fore we must use synced remove self.wait_save(async_remove=False) -def _get_my_group_info(groups: List[List[int]]): +def _get_my_group_info(groups: List[List[int]]) -> Tuple[int, int]: global_rank = torch.distributed.get_rank() for group in groups: if global_rank in group: @@ -242,7 +294,18 @@ def _get_my_group_info(groups: List[List[int]]): raise RuntimeError(f"Error: global rank {global_rank} is not in groups") -def _xser_load_data(checkpoint_dir: BaseCheckpointStorage, path: str, groups: List[List[int]] = None): +def _get_ep_group_info() -> Tuple[int, int, int, int, int, int]: + emp_rank = get_expert_model_parallel_rank() + edp_rank = get_expert_data_parallel_rank() + emp_size = get_expert_model_parallel_size() + edp_size = get_expert_data_parallel_size() + emp_group = get_expert_model_parallel_group(as_list=True) + edp_group = get_expert_data_parallel_group(as_list=True) + + return emp_rank, edp_rank, emp_size, edp_size, emp_group, edp_group + + +def _xser_load_data(checkpoint_dir: BaseCheckpointStorage, path: str, groups: Optional[List[List[int]]] = None, ep_only: bool = False): """ load tensors saved in path into a state_dict. Parameters: @@ -254,56 +317,80 @@ def _xser_load_data(checkpoint_dir: BaseCheckpointStorage, path: str, groups: Li # checkpoint generated using older version still need to be supported. ref_info = checkpoint_dir.load_object(path + ".info.pt") if checkpoint_dir.file_exists(path + ".info.pt") else None - tensor_folder = path + ".tensors" + ep_tensor_folder = path + ".tensors" if groups is not None: my_rank_in_group, my_group_size = _get_my_group_info(groups) + # for non-ep tensors all ranks load from ep_rank_00 folder, round-robin + non_ep_tensor_folder = re.sub("ep_rank_\d{2}", "ep_rank_00", path) + ".tensors" - def convert_fn(tensors): - rewritten_tensors = [] + if model_parallel_is_initialized(): + emp_rank, edp_rank, emp_size, edp_size, emp_group, edp_group = _get_ep_group_info() - for t in tensors: - tensor_file = os.path.join(tensor_folder, "tensor_{}.pt".format(t.tid)) - if (ref_info is not None) and (groups is not None): - # When there is redundency (groups is not None) and we know the tensor's shape and dtype (ref_info is not None) - # we use the following optimization: - # among workers that has same tensor (in same group), only 1 worker read tensor from disk - # other workers will get the tensor from network broadcasting - # - # we used round robin to select which worker will read from disk to evenly - # distribute the load tasks. - if (t.tid % my_group_size) == my_rank_in_group: - loaded = checkpoint_dir.load_object(tensor_file).to(xm.xla_device()) + def convert_fn(tensors): + rewritten_tensors = [None for t in tensors] + + def _is_ep(dct): + return "expert_model_parallel" in dct and dct["expert_model_parallel"] + + ep_tensors = [(i, t) for i, t in enumerate(tensors) if _is_ep(ref_info[t.tid])] + non_ep_tensors = [(i, t) for i, t in enumerate(tensors) if not _is_ep(ref_info[t.tid])] + + def _load_tensors(tensor_list, _group, _rank, _group_size, _tensor_folder): + for idx, (original_idx, t) in enumerate(tensor_list): + tensor_file = os.path.join(_tensor_folder, "tensor_{}.pt".format(t.tid)) + + if (ref_info is not None) and (groups is not None): + # When there is redundency (groups is not None) and we know the tensor's shape and dtype (ref_info is not None) + # we use the following optimization: + # among workers that has same tensor (in same group), only 1 worker read tensor from disk + # other workers will get the tensor from network broadcasting + # + # we used round robin to select which worker will read from disk to evenly + # distribute the load tasks. + + if (idx % _group_size) == _rank: + loaded = checkpoint_dir.load_object(tensor_file).to(xm.xla_device()) + else: + dtype = ref_info[t.tid]["dtype"] + shape = ref_info[t.tid]["shape"] + loaded = torch.zeros(shape, dtype=dtype, device=xm.xla_device()) + + # we use all_reduce to implement broadcast because xla does not have native broadcast support. + xm.all_reduce(xm.REDUCE_SUM, [loaded], groups=_group) else: - dtype = ref_info[t.tid]["dtype"] - shape = ref_info[t.tid]["shape"] - loaded = torch.zeros(shape, dtype=dtype, device=xm.xla_device()) - # we use all_reduce to implement broadcast because xla does not have native broadcast support. - xm.all_reduce(xm.REDUCE_SUM, [loaded], groups=groups) - else: - # when dtype and shape are not available or there is no redundency, all workers load tensor from disk - loaded = checkpoint_dir.load_object(tensor_file).to(xm.xla_device()) + # when dtype and shape are not available or there is no redundency, all workers load tensor from disk + loaded = checkpoint_dir.load_object(tensor_file).to(xm.xla_device()) - rewritten_tensors.append(loaded) + rewritten_tensors[original_idx] = loaded + + if groups is not None: + _load_tensors(ep_tensors, edp_group, edp_rank, edp_size, ep_tensor_folder) + _load_tensors(non_ep_tensors, groups, my_rank_in_group, my_group_size, non_ep_tensor_folder) + elif ep_only: + _load_tensors(ep_tensors, None, None, None, ep_tensor_folder) + else: + _load_tensors(ep_tensors + non_ep_tensors, None, None, None, ep_tensor_folder) if groups is not None: xm.mark_step() return rewritten_tensors def select_fn(v): - return type(v) == xser.TensorReference + return isinstance(v, xser.TensorReference) return xm.ToXlaTensorArena(convert_fn, select_fn).transform(ref_data) class _InternalTensorReference: - def __init__(self, tid, shape, dtype): + def __init__(self, tid, shape, dtype, expert_model_parallel): self.tid = tid self.shape = shape self.dtype = dtype + self.expert_model_parallel = expert_model_parallel -def _assign_tensors_to_bins(tensors, bin_count) -> List[List[int]]: +def _assign_tensors_to_bins(tensors: List[torch.Tensor], bin_count: int) -> List[List[int]]: """ assign a list of tensors into multiple bins, such that each bin's total tensor size are similar. @@ -337,8 +424,8 @@ def _assign_tensors_to_bins(tensors, bin_count) -> List[List[int]]: def _xser_save_data( - checkpoint_dir: BaseCheckpointStorage, path: str, state_dict, iostate, groups: List[List[int]] = None -): + checkpoint_dir: BaseCheckpointStorage, path: str, state_dict, iostate: CheckpointIOState, groups: Optional[List[List[int]]] = None +) -> Any: """ This function save the tensors in a state_dict into a directory. Each tensor will be saved as a separate file. @@ -352,56 +439,61 @@ def _xser_save_data( if groups is not None: my_rank_in_group, my_group_size = _get_my_group_info(groups) - def convert_fn(tensors): + emp_rank, edp_rank, emp_size, edp_size, emp_group, edp_group = _get_ep_group_info() + + def convert_fn(tensors: List[torch.Tensor]) -> List[torch.Tensor]: torch_xla._XLAC._xla_sync_multi(tensors, devices=[], wait=True, sync_xla_data=True) if groups is None: my_tensors = None else: - my_tensors = _assign_tensors_to_bins(tensors, my_group_size)[my_rank_in_group] + my_tensors = _assign_tensors_to_bins(tensors, edp_size)[edp_rank] rewritten_tensors = [] for i, t in enumerate(tensors): + is_expert_parallel = hasattr(t, "expert_model_parallel") and t.expert_model_parallel if (my_tensors is None) or (i in my_tensors): t0 = datetime.now() cpu_data = t.cpu() t1 = datetime.now() - iostate.add_save_task(cpu_data, xser._get_tensor_file(path, i)) - if torch.distributed.get_rank() == 0: - logger.debug(f" transfer tensor {i} to cpu elapsed: {(t1 - t0).total_seconds()} seconds") - rewritten_tensors.append(_InternalTensorReference(i, t.shape, t.dtype)) + # if the below condition is not satisfied, someone else will store the same data + if my_tensors is None or (is_expert_parallel or emp_rank == 0): + iostate.add_save_task(cpu_data, xser._get_tensor_file(path, i)) + logger.debug(" transfer tensor %d to cpu elapsed: %d seconds", i, (t1 - t0).total_seconds()) + rewritten_tensors.append(_InternalTensorReference(i, t.shape, t.dtype, is_expert_parallel)) return rewritten_tensors def select_fn(v): - return type(v) == torch.Tensor and xm.is_xla_tensor(v) + return isinstance(v, torch.Tensor) # and xm.is_xla_tensor(v) checkpoint_dir.create_shared_dir(path) return xm.ToXlaTensorArena(convert_fn, select_fn).transform(state_dict) -def _extract_tensor_info_and_update_state_dict(state_dict: dict, tensor_info: dict): +def _extract_tensor_info_and_update_state_dict(state_dict: Dict[str, Any], tensor_info: Dict[int, Dict[str, Any]]) -> None: """ for a given state_dict, replace _InternalTensorReference with XserTensorReference, and put the dtype and shape in a separate accout. """ for k, v in state_dict.items(): - if type(v) == _InternalTensorReference: - tensor_info[v.tid] = {"dtype": v.dtype, "shape": v.shape} + if isinstance(v, _InternalTensorReference): + tensor_info[v.tid] = {"dtype": v.dtype, "shape": v.shape, "expert_model_parallel": v.expert_model_parallel} state_dict[k] = xser.TensorReference(v.tid) - if type(v) == dict: + if isinstance(v, dict): _extract_tensor_info_and_update_state_dict(v, tensor_info) def _save( - ckpt, checkpoint_dir: BaseCheckpointStorage, path: str, groups=None, num_workers=8, use_xser=False, iostate=None -): + ckpt: Any, checkpoint_dir: BaseCheckpointStorage, path: str, groups: Optional[List[List[int]]] = None, num_workers: int = 8, use_xser: bool = False, iostate: Optional[CheckpointIOState] = None, optimizer: bool = False +) -> None: if groups is not None: my_rank_in_group, my_group_size = _get_my_group_info(groups) + emp_rank, edp_rank, emp_size, edp_size, emp_group, edp_group = _get_ep_group_info() # quick path when use xser if use_xser: state_dict = _xser_save_data(checkpoint_dir, xser._get_tensors_folder(path), ckpt, iostate, groups) - if (groups is None) or (my_rank_in_group == 0): + if (groups is None) or (edp_rank == 0): tensor_info = {} # to make sure path can be loaded using xser.load(), we must update # state_dict such that it does not have _InternalTensorReference @@ -420,7 +512,7 @@ def _save( iostate.add_save_task(cpu_data, path) -def _load_obj_from_state_dict(obj, state_dict, strict): +def _load_obj_from_state_dict(obj: Any, state_dict: Dict[str, Any], strict: bool) -> None: if isinstance(obj, torch.nn.Module): obj.load_state_dict(state_dict, strict=strict) elif isinstance(obj, dict): @@ -433,14 +525,14 @@ def _load_obj_from_state_dict(obj, state_dict, strict): def _load( - obj, + obj: Any, checkpoint_dir: BaseCheckpointStorage, path: str, - groups: List[List[int]] = None, + groups: Optional[List[List[int]]] = None, num_workers: int = 8, strict: bool = True, use_xser: bool = False, -): +) -> None: """ Load object the save as path. @@ -454,6 +546,7 @@ def _load( # quick path when use xser if use_xser: ckpt = _xser_load_data(checkpoint_dir, path, groups) + ckpt = _move_step_to_cpu(ckpt) _load_obj_from_state_dict(obj, ckpt, strict) return @@ -467,7 +560,7 @@ def _load( xm.rendezvous("load checkpoint done") -def has_checkpoint(checkpoint_dir_str: str): +def has_checkpoint(checkpoint_dir_str: str) -> bool: checkpoint_dir = create_checkpoint_storage(checkpoint_dir_str) return len(checkpoint_dir.list_completed_checkpoint_tags()) > 0 @@ -476,10 +569,10 @@ def has_checkpoint(checkpoint_dir_str: str): def save_checkpoint( - checkpoint_dir_str, - tag, - model=None, - optimizer=None, + checkpoint_dir_str: str, + tag: str, + model: Any = None, + optimizer: Any = None, scheduler=None, user_content=None, num_workers=8, @@ -487,7 +580,8 @@ def save_checkpoint( num_kept_ckpts=None, async_save=False, zero1_optimizer=False, -): + use_zero1_dcp=False, +) -> None: """ Method to save checkpoint, return ``None``. @@ -512,18 +606,18 @@ def save_checkpoint( - newest Parameters: - path (str): + checkpoint_dir_str (str): path to save the checkpoints. tag (str): tag to save the checkpoints. model (torch.nn.Module or dict): - model to save, optinal. + model to save, optional. optimizer (torch.optim.Optimizer or dict): - optimizer to save, optinal. + optimizer to save, optional. scheduler: - scheduler to save, optinal. + scheduler to save, optional. user_content: - user contents to save, optinal. + user contents to save, optional. num_workers (int): num of workers to save the checkpoints on the same time, range: 1-32. use_xser (bool): @@ -532,9 +626,11 @@ def save_checkpoint( num_kept_ckpts (int): number of checkpoints to keep on disk, optional. Default: ``None``. async_save (bool): - whether to use asynchronous saving method + whether to use asynchronous saving method. zero1_optimizer (bool): - whether the optimizer state is from a zero1 optimizer, used when optimizer is a dict + whether the optimizer state is from a zero1 optimizer, used when optimizer is a dict. + use_zero1_dcp (bool): + whether to use Distributed Checkpoint for ZeRO-1 optimizer. """ # TODO: Use distributed checkpoint assert torch.distributed.is_initialized(), "Only support distributed training mode." @@ -553,18 +649,22 @@ def save_checkpoint( g_iostate.begin(checkpoint_dir, tag) ckpt_path = str(tag) + ep_enabled = get_expert_model_parallel_size() > 1 + # save model if model is not None: if torch.distributed.get_rank() == 0: checkpoint_dir.create_dir(os.path.join(ckpt_path, "model"), exist_ok=True) - model_path = os.path.join(ckpt_path, _get_path("model")) + model_path = os.path.join(ckpt_path, _get_path("model", ep=ep_enabled)) + if isinstance(model, NxDPPModel): ckpt = model.local_state_dict() elif isinstance(model, dict): ckpt = model else: ckpt = model.state_dict() - groups = get_data_parallel_group(as_list=True) + + groups = get_expert_data_parallel_group(as_list=True) _save( ckpt, checkpoint_dir, @@ -589,17 +689,29 @@ def save_checkpoint( zero1_enabled = isinstance(optimizer, NeuronZero1Optimizer) optimizer_state_dict = optimizer.state_dict() - optimizer_path = os.path.join(ckpt_path, _get_path("optim", dp=zero1_enabled)) + if zero1_enabled: + par_args = {"dp": True} + elif ep_enabled: + par_args = {"ep": True} + else: + par_args = {} + + optimizer_path = os.path.join(ckpt_path, _get_path("optim", **par_args)) groups = None if zero1_enabled else get_data_parallel_group(as_list=True) - _save( - optimizer_state_dict, - checkpoint_dir, - optimizer_path, - groups=groups, - num_workers=num_workers, - use_xser=use_xser, - iostate=g_iostate, - ) + if zero1_enabled and use_zero1_dcp and HAVE_DCP_SUPPORT: + assert async_save, "Now we only use DCP for async checkpoint saving." + g_iostate.add_dcp_save_task(checkpoint_dir, optimizer_state_dict, optimizer, model, ckpt_path) + else: + _save( + optimizer_state_dict, + checkpoint_dir, + optimizer_path, + groups=groups, + num_workers=num_workers, + use_xser=use_xser, + iostate=g_iostate, + optimizer=True, + ) # save scheduler if scheduler is not None: @@ -614,15 +726,26 @@ def save_checkpoint( g_iostate.end(num_kept_ckpts) +def _move_step_to_cpu(ckpt: Any, key: Optional[str] = None) -> Any: + if key == "step": + return ckpt.cpu() + + if isinstance(ckpt, dict): + return {k: _move_step_to_cpu(v, k) for k, v in ckpt.items()} + + return ckpt + + def load_checkpoint( - path, - tag=None, - model=None, - optimizer=None, - scheduler=None, - num_workers=8, - strict=True, -): + path: str, + tag: Optional[str] = None, + model: Optional[torch.nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Any = None, + num_workers: int = 8, + strict: bool = True, + use_zero1_dcp: bool = False, +) -> Any: """ Method to load checkpoint, return user contents if exists otherwise ``None``. If ``tag`` not provided, will try to use the newest tag tracked by ``save_checkpoint``. @@ -631,21 +754,27 @@ def load_checkpoint( path (str): path to load the checkpoints. tag (str): - tag to load the checkpoints. + tag to load the checkpoints, optional. model (torch.nn.Module): - model to load, optinal. + model to load, optional. optimizer (torch.optim.Optimizer): - optimizer to load, optinal. + optimizer to load, optional. scheduler: - scheduler to load, optinal. + scheduler to load, optional. num_workers (int): num of workers to load the checkpoints on the same time, range: 1-32. strict (bool): whether to use strict mode when loading model checkpoint. Default: ``True``. + use_zero1_dcp (bool): + whether to use Distributed Checkpoint for ZeRO-1 optimizer. """ assert torch.distributed.is_initialized(), "Only support distributed training mode." checkpoint_dir = create_checkpoint_storage(path) + ep_enabled = get_expert_model_parallel_size() > 1 + + if ep_enabled and strict: + print("Strict checkpoint loading is not supported with expert parallelism.") if tag is None: tags = checkpoint_dir.list_completed_checkpoint_tags() @@ -657,12 +786,11 @@ def load_checkpoint( use_xser = checkpoint_dir.is_checkpoint_xser(tag) - if torch.distributed.get_rank() == 0: - logger.info("loading checkpoint from {}".format(ckpt_path)) + logger.info("loading checkpoint from %s", ckpt_path) # load model if model is not None: - model_path = os.path.join(ckpt_path, _get_path("model")) + model_path = os.path.join(ckpt_path, _get_path("model", ep=ep_enabled)) groups = get_data_parallel_group(as_list=True) _load( model, checkpoint_dir, model_path, groups=groups, num_workers=num_workers, strict=strict, use_xser=use_xser @@ -680,11 +808,28 @@ def load_checkpoint( elif isinstance(optimizer, NeuronZero1Optimizer): zero1_enabled = True else: - raise RuntimeError(f"Error: invalid type for the argument optimizer for load_checkpoint, expecting a dict or NxDOptimizer, or NeuronZero1Optimizer, got {type(optimizer)}") + raise RuntimeError( + f"Error: invalid type for the argument optimizer for load_checkpoint, expecting a dict or NxDOptimizer, or NeuronZero1Optimizer, got {type(optimizer)}" + ) + + ep_enabled = get_expert_model_parallel_size() > 1 + + if zero1_enabled: + par_args = {"dp": True} + groups = None + elif ep_enabled: + par_args = {"ep": True} + groups = get_expert_data_parallel_group(as_list=True) + else: + par_args = {} + groups = get_data_parallel_group(as_list=True) - groups = None if zero1_enabled else get_data_parallel_group(as_list=True) - optimizer_path = os.path.join(ckpt_path, _get_path("optim", dp=zero1_enabled)) - _load(optimizer, checkpoint_dir, optimizer_path, groups=groups, num_workers=num_workers, use_xser=use_xser) + if zero1_enabled and use_zero1_dcp and HAVE_DCP_SUPPORT: + aux_infos = dcp_utils.get_dcp_aux_infos(model, optimizer) + dcp_utils.load_optim_state_dict(os.path.join(checkpoint_dir.dirname(), ckpt_path, "optim"), optimizer, aux_infos) + else: + optimizer_path = os.path.join(ckpt_path, _get_path("optim", **par_args)) + _load(optimizer, checkpoint_dir, optimizer_path, groups=groups, num_workers=num_workers, use_xser=use_xser) # load scheduler if scheduler is not None: @@ -697,13 +842,12 @@ def load_checkpoint( if checkpoint_dir.file_exists(user_content_path): user_content = checkpoint_dir.load_object(user_content_path, map_location="cpu") - if torch.distributed.get_rank() == 0: - logger.info("loading checkpoint done") + logger.info("loading checkpoint done") xm.rendezvous("load all checkpoints done") return user_content -def finalize_checkpoint(): +def finalize_checkpoint() -> None: if g_iostate: g_iostate.wait_all() diff --git a/src/neuronx_distributed/trainer/checkpoint_storage.py b/src/neuronx_distributed/trainer/checkpoint_storage.py index 487b075..2cd8fa9 100644 --- a/src/neuronx_distributed/trainer/checkpoint_storage.py +++ b/src/neuronx_distributed/trainer/checkpoint_storage.py @@ -7,7 +7,7 @@ import time from abc import abstractmethod from io import BytesIO -from typing import List, Tuple +from typing import List, Dict, Any, Optional, Tuple, TYPE_CHECKING import boto3 import botocore @@ -21,34 +21,37 @@ except ImportError: use_crt = False +if TYPE_CHECKING: + from mypy_boto3_s3 import S3ServiceResource, S3Client + class BaseCheckpointStorage: def __init__(self, dirname: str): self._dirname = dirname - def dirname(self): + def dirname(self) -> str: return self._dirname - def is_checkpoint_xser(self, dirname: str): + def is_checkpoint_xser(self, dirname: str) -> bool: dirs = self.find_subdirs_contain_path( pattern="*.tensors", search_depth=2, search_root=dirname, max_count=1, sort_by_mdate=False ) return len(dirs) > 0 - def list_checkpoint_tags(self): + def list_checkpoint_tags(self) -> List[str]: return self.find_subdirs_contain_path(pattern="checkpoint", search_depth=1, sort_by_mdate=True) - def list_completed_checkpoint_tags(self): + def list_completed_checkpoint_tags(self) -> List[str]: return self.find_subdirs_contain_path(pattern="done", search_depth=1, sort_by_mdate=True) def find_subdirs_contain_path( self, pattern: str, search_depth: int, - search_root: str = None, - max_count: int = None, + search_root: Optional[str] = None, + max_count: Optional[int] = None, sort_by_mdate: bool = False, - ): + ) -> List[str]: files = self.find_files(pattern, search_depth + 1, search_root, max_count, sort_by_mdate) subdirs = [] for file in files: @@ -56,11 +59,11 @@ def find_subdirs_contain_path( return subdirs @abstractmethod - def dir_exists(self, dirname: str): + def dir_exists(self, dirname: str) -> bool: raise NotImplementedError @abstractmethod - def file_exists(self, filename: str): + def file_exists(self, filename: str) -> bool: raise NotImplementedError @abstractmethod @@ -68,47 +71,47 @@ def find_files( self, pattern: str, search_depth: int, - search_root: str = None, - max_count: int = None, + search_root: Optional[str] = None, + max_count: Optional[int] = None, sort_by_mdate: bool = False, ): raise NotImplementedError @abstractmethod - def save_text(self, text: str, filename: str): + def save_text(self, text: str, filename: str) -> None: raise NotImplementedError @abstractmethod - def save_object(self, obj: object, filename: str): + def save_object(self, obj: object, filename: str) -> None: raise NotImplementedError @abstractmethod - def load_object(self, filename: str, map_location=None) -> object: + def load_object(self, filename: str, map_location: torch.serialization.MAP_LOCATION = None) -> Any: raise NotImplementedError @abstractmethod - def create_dir(self, dirname: str, exist_ok: bool = True): + def create_dir(self, dirname: str, exist_ok: bool = True) -> None: raise NotImplementedError @abstractmethod def create_shared_dir( - self, dirname: str, exist_ok: bool = True, process_group: torch.distributed.ProcessGroup = None - ): + self, dirname: str, exist_ok: bool = True, process_group: Optional[torch.distributed.ProcessGroup] = None + ) -> None: raise NotImplementedError @abstractmethod - def remove_dir(self, dirname: str): + def remove_dir(self, dirname: str) -> None: raise NotImplementedError @abstractmethod - def remove_file(self, filename: str): + def remove_file(self, filename: str) -> None: raise NotImplementedError - def remove_dirs(self, dirnames: List[str]): + def remove_dirs(self, dirnames: List[str]) -> None: for dirname in dirnames: self.remove_dir(dirname) - def remove_files(self, filenames: List[str]): + def remove_files(self, filenames: List[str]) -> None: for filename in filenames: if self.file_exists(filename): self.remove_file(filename) @@ -118,15 +121,15 @@ class FilesysCheckpointStorage(BaseCheckpointStorage): def __init__(self, dirname: str): super().__init__(dirname) - def dir_exists(self, dirname: str): + def dir_exists(self, dirname: str) -> bool: dirname = os.path.join(self._dirname, dirname) return os.path.exists(dirname) and os.path.isdir(dirname) - def file_exists(self, filename: str): + def file_exists(self, filename: str) -> bool: filename = os.path.join(self._dirname, filename) return os.path.exists(filename) and os.path.isfile(filename) - def is_checkpoint_xser(self, ckpt_path: str): + def is_checkpoint_xser(self, ckpt_path: str) -> bool: ckpt_path = os.path.join(self._dirname, ckpt_path) for x in os.listdir(ckpt_path): inner_path = os.path.join(ckpt_path, x) @@ -140,10 +143,10 @@ def find_files( self, pattern: str, search_depth: int, - search_root: str = None, - max_count: int = None, + search_root: Optional[str] = None, + max_count: Optional[int] = None, sort_by_mdate: bool = False, - ): + ) -> List[str]: if not os.path.exists(self._dirname): return [] @@ -156,7 +159,7 @@ def find_files( if sort_by_mdate: paths.sort(key=os.path.getmtime) - if type(max_count) == int and max_count > 0 and len(paths) > max_count: + if isinstance(max_count, int) and max_count > 0 and len(paths) > max_count: paths = paths[0:max_count] files = [] @@ -164,36 +167,36 @@ def find_files( files.append(os.path.relpath(path, self._dirname)) return files - def save_text(self, text: str, filename: str): + def save_text(self, text: str, filename: str) -> None: filename = os.path.join(self._dirname, filename) with open(filename, "w") as f: f.write(text) - def save_object(self, obj: object, filename: str): + def save_object(self, obj: Any, filename: str) -> None: filename = os.path.join(self._dirname, filename) torch.save(obj, filename) - def load_object(self, filename: str, map_location=None): + def load_object(self, filename: str, map_location: torch.serialization.MAP_LOCATION = None) -> Any: filename = os.path.join(self._dirname, filename) return torch.load(filename, map_location=map_location) - def remove_dir(self, dirname: str): + def remove_dir(self, dirname: str) -> None: dirname = os.path.join(self._dirname, dirname) if os.path.exists(dirname): shutil.rmtree(dirname) - def remove_file(self, filename: str): + def remove_file(self, filename: str) -> None: filename = os.path.join(self._dirname, filename) if os.path.exists(filename): os.unlink(filename) - def create_dir(self, dirname: str, exist_ok: bool = True): + def create_dir(self, dirname: str, exist_ok: bool = True) -> None: dirname = os.path.join(self._dirname, dirname) os.makedirs(dirname, exist_ok=exist_ok) def create_shared_dir( - self, dirname: str, exist_ok: bool = True, process_group: torch.distributed.ProcessGroup = None - ): + self, dirname: str, exist_ok: bool = True, process_group: Optional[torch.distributed.ProcessGroup] = None + ) -> None: if process_group is None: return self.create_dir(dirname, exist_ok) @@ -227,14 +230,14 @@ def __init__(self, dirname: str): boto3.set_stream_logger(name="botocore.credentials", level=logging.ERROR) - def dir_exists(self, dirname: str): + def dir_exists(self, dirname: str) -> bool: """ s3 allow create files with common prefix at the same time, therefore we can consider any diretory to be existing """ return True - def file_exists(self, filename: str): + def file_exists(self, filename: str) -> bool: subdir = os.path.dirname(filename) basename = os.path.basename(filename) if subdir == "": @@ -247,7 +250,7 @@ def file_exists(self, filename: str): return False - def _list(self, prefix: str = None): + def _list(self, prefix: Optional[str] = None) -> List[Dict[str, Any]]: s3 = S3CheckpointStorage.get_client() if self._base_key and prefix: @@ -274,7 +277,7 @@ def _list(self, prefix: str = None): return results - def _list_with_retry(self, prefix: str = None): + def _list_with_retry(self, prefix: Optional[str] = None) -> List[Dict[str, Any]]: max_try = 4 sleep_second = 60 for try_idx in range(max_try): @@ -296,6 +299,7 @@ def _list_with_retry(self, prefix: str = None): raise e else: raise e + assert False, "unreachable" def _find_files_impl(self, pattern: str, search_depth: int, search_root: str, max_count: int): search_dirs = [search_root] @@ -311,7 +315,7 @@ def _find_files_impl(self, pattern: str, search_depth: int, search_root: str, ma if fnmatch.fnmatch(path["name"], pattern): mdate = path.get("mdate", None) file_mdate_pairs.append((os.path.join(dirname, path["name"]), mdate)) - if type(max_count) == int and max_count > 0 and len(file_mdate_pairs) == max_count: + if isinstance(max_count, int) and max_count > 0 and len(file_mdate_pairs) == max_count: return file_mdate_pairs elif path["type"] == "dir": subdir = os.path.join(dirname, path["name"]) if dirname else path["name"] @@ -322,7 +326,7 @@ def _find_files_impl(self, pattern: str, search_depth: int, search_root: str, ma level += 1 return file_mdate_pairs - def find_files(self, pattern: str, search_depth: int, search_root: str = None, max_count=None, sort_by_mdate=True): + def find_files(self, pattern: str, search_depth: int, search_root: Optional[str] = None, max_count: Optional[int] = None, sort_by_mdate: bool = True) -> List[str]: file_mdate_pairs = self._find_files_impl(pattern, search_depth, search_root, max_count) if len(file_mdate_pairs) > 1 and sort_by_mdate: file_mdate_pairs.sort(key=lambda x: x[1]) @@ -330,61 +334,61 @@ def find_files(self, pattern: str, search_depth: int, search_root: str = None, m files = [x[0] for x in file_mdate_pairs] return files - def save_text(self, text: str, filename: str): + def save_text(self, text: str, filename: str) -> None: class TextStreamCreator: - def __init__(self, text): + def __init__(self, text: str): self._text = text - def create_stream(self): + def create_stream(self) -> BytesIO: stream = BytesIO() stream.write(bytes(self._text, "utf-8")) return stream self.upload_stream_to_file(TextStreamCreator(text), filename) - def save_object(self, obj: object, filename: str): + def save_object(self, obj: object, filename: str) -> None: class ObjectStreamCreator: - def __init__(self, obj): + def __init__(self, obj: object): self._obj = obj - def create_stream(self): + def create_stream(self) -> BytesIO: stream = BytesIO() torch.save(obj, stream) return stream self.upload_stream_to_file(ObjectStreamCreator(obj), filename) - def load_object(self, filename, map_location=None): + def load_object(self, filename: str, map_location: Optional[torch.serialization.MAP_LOCATION] = None) -> Any: stream: BytesIO = self.download_file_to_stream(filename) return torch.load(stream, map_location=map_location) - def create_dir(self, dirname: str, exist_ok: bool = True): + def create_dir(self, dirname: str, exist_ok: bool = True) -> None: """ s3 allow create files with common prefix at the same time, therefore nothing need to be done here. """ def create_shared_dir( - self, dirname: str, exist_ok: bool = True, process_group: torch.distributed.ProcessGroup = None - ): + self, dirname: str, exist_ok: bool = True, process_group: Optional[torch.distributed.ProcessGroup] = None + ) -> None: """ s3 allow create files with common prefix at the same time, therefore nothing need to be done here. """ - def remove_dir(self, dirname: str): + def remove_dir(self, dirname: str) -> None: key = self.convert_path_to_key(dirname) client = S3CheckpointStorage.get_client() S3CheckpointStorage.s3_action_with_retry( - S3CheckpointStorage.REMOVE_DIR, client, self._bucket, key, None - ) + S3CheckpointStorage.REMOVE_DIR, client, self._bucket, key, None + ) - def remove_file(self, filename: str): + def remove_file(self, filename: str) -> None: key = self.convert_path_to_key(filename) client = S3CheckpointStorage.get_client() S3CheckpointStorage.s3_action_with_retry( - S3CheckpointStorage.REMOVE_FILE, client, self._bucket, key, None - ) + S3CheckpointStorage.REMOVE_FILE, client, self._bucket, key, None + ) def upload_stream_to_file( self, stream_creator, filename: str, chunk_size_MB: int = 64, max_concurrency: int = 10 @@ -397,7 +401,7 @@ def upload_stream_to_file( S3CheckpointStorage.UPLOAD, client, self._bucket, key, config, upload_stream_creator=stream_creator ) - def convert_path_to_key(self, path: str): + def convert_path_to_key(self, path: str) -> str: return path if self._base_key is None else self._base_key + path def download_file_to_stream(self, filename: str, chunk_size_MB: int = 64, max_concurrency: int = 15) -> BytesIO: @@ -405,9 +409,7 @@ def download_file_to_stream(self, filename: str, chunk_size_MB: int = 64, max_co key = self.convert_path_to_key(filename) chunk_size = chunk_size_MB * 1048576 config = boto3.s3.transfer.TransferConfig(multipart_chunksize=chunk_size, max_concurrency=max_concurrency) - return S3CheckpointStorage.s3_action_with_retry( - S3CheckpointStorage.DOWNLOAD, client, self._bucket, key, config - ) + return S3CheckpointStorage.s3_action_with_retry(S3CheckpointStorage.DOWNLOAD, client, self._bucket, key, config) @staticmethod def parse_path(s3_path: str) -> Tuple[str, str]: @@ -415,7 +417,7 @@ def parse_path(s3_path: str) -> Tuple[str, str]: if not s3_path.startswith(head): raise RuntimeError(f"Error: invalid s3 path: {s3_path} because it does not start with {head}") - s3_path = s3_path[len(head) :] + s3_path = s3_path[len(head):] if len(s3_path) == 0: raise RuntimeError("Error: invalid s3 path: {s3_path} that is empty") @@ -426,40 +428,41 @@ def parse_path(s3_path: str) -> Tuple[str, str]: if first_slash == len(s3_path) - 1: return s3_path[0:-1], None - return s3_path[0:first_slash], s3_path[first_slash + 1 :] + return s3_path[0:first_slash], s3_path[first_slash + 1:] @staticmethod - def get_resource(profile: str = None, creds: botocore.credentials.Credentials = None, session=None, config={}): - config = botocore.config.Config(max_pool_connections=30, **config) + def get_resource( + profile: Optional[str] = None, + creds: Optional[botocore.credentials.Credentials] = None, + session: Optional[boto3.Session] = None, + config: Optional[Dict[str, Any]] = None, + ) -> "S3ServiceResource": + s3_config = botocore.config.Config(max_pool_connections=30, **(config or {})) if profile is not None and creds is not None: raise ValueError("Please provide profile or creds or neither, not both.") if profile is not None: - s3 = boto3.Session(profile_name=profile).resource("s3", config=config) + s3 = boto3.Session(profile_name=profile).resource("s3", config=s3_config) elif creds is not None: - s3 = boto3.Session().resource( + s3 = (session or boto3._get_default_session()).resource( "s3", - aws_access_key_id=creds["AccessKeyId"], - aws_secret_access_key=creds["SecretAccessKey"], - aws_session_token=creds["SessionToken"], - config=config, + aws_access_key_id=creds.access_key, + aws_secret_access_key=creds.secret_key, + aws_session_token=creds.token, + config=s3_config, ) else: - s3 = boto3.Session().resource("s3", config=config) if not session else session.resource("s3", config=config) + s3 = (session or boto3._get_default_session()).resource("s3", config=s3_config) return s3 @staticmethod - def get_client(profile: str = None, creds: botocore.credentials.Credentials = None, session=None, config={}): + def get_client(profile: Optional[str] = None, creds: Optional[botocore.credentials.Credentials] = None, session: Optional[boto3.Session] = None, config: Optional[Dict[str, Any]] = None) -> "S3Client": return S3CheckpointStorage.get_resource(profile, creds, session, config).meta.client @staticmethod - def is_slow_down_error(exception): - class_name = exception.__class__.__name__ - module_name = exception.__class__.__module__ - full_class_name = f"{module_name}.{class_name}" - + def is_slow_down_error(exception: Exception) -> bool: # Example Invalid response status that is slow down # AWS_ERROR_S3_INVALID_RESPONSE_STATUS: Invalid response status from request. # Body from error request is: @@ -485,8 +488,7 @@ def is_slow_down_error(exception): if isinstance(exception, botocore.exceptions.ConnectionClosedError): if ( - "Connection was closed before we received a valid response from endpoint" in message - and ".s3." in message + "Connection was closed before we received a valid response from endpoint" in message and ".s3." in message ): return True @@ -524,19 +526,17 @@ def s3_action_with_retry(action, client, bucket, key, config, upload_stream_crea elif action == S3CheckpointStorage.REMOVE_DIR: prefix = key if key.endswith("/") else key + "/" response = client.list_objects(Bucket=bucket, Prefix=prefix) - while 'Contents' in response: - objects = response['Contents'] + while "Contents" in response: + objects = response["Contents"] assert len(objects) > 0 - delete = {'Objects' : []} + delete = {"Objects": []} for obj in objects: - delete['Objects'].append( - {'Key' : obj['Key']} - ) + delete["Objects"].append({"Key": obj["Key"]}) client.delete_objects(Bucket=bucket, Delete=delete) response = client.list_objects(Bucket=bucket, Prefix=prefix) return else: - raise RuntimError(f"Error: unknow action {action}") + raise RuntimeError(f"Error: unknow action {action}") except Exception as e: if S3CheckpointStorage.is_slow_down_error(e): if try_idx < max_try - 1: diff --git a/src/neuronx_distributed/trainer/model.py b/src/neuronx_distributed/trainer/model.py index cf664b4..168d5c4 100644 --- a/src/neuronx_distributed/trainer/model.py +++ b/src/neuronx_distributed/trainer/model.py @@ -2,6 +2,8 @@ import torch +from neuronx_distributed.utils.model_utils import get_delay_tracing + class NxDModel(torch.nn.Module): def __init__(self, module, nxd_config): @@ -12,7 +14,9 @@ def __init__(self, module, nxd_config): self.pp_enabled = nxd_config["pipeline_parallel_size"] > 1 - self.train() + if not self.pp_enabled: + # When pp is enabled run_train() will handle this, so self.train() can be skipped during init + self.train() def __repr__(self): return "NxDModel({})".format(self.module.__repr__()) @@ -46,7 +50,7 @@ def forward(self, *args, **kwargs): return self.module(*args, **kwargs) def named_parameters(self, *args, **kwargs): - if self.pp_enabled: + if self.pp_enabled and not get_delay_tracing(self.nxd_config): for n, p in self.module.local_named_parameters(*args, **kwargs): yield n, p return diff --git a/src/neuronx_distributed/trainer/optimizer.py b/src/neuronx_distributed/trainer/optimizer.py index eaff32a..ac5ec70 100644 --- a/src/neuronx_distributed/trainer/optimizer.py +++ b/src/neuronx_distributed/trainer/optimizer.py @@ -55,7 +55,9 @@ def add_param_group(self, param_group): self.optimizer.add_param_group(param_group) def state_dict(self): - return self.optimizer.state_dict() + state_dict = self.optimizer.state_dict() + state_dict = self._mark_expert_parallel_states(state_dict) + return state_dict def load_state_dict(self, state_dict): self.optimizer.load_state_dict(state_dict) @@ -69,6 +71,48 @@ def __getstate__(self): def __setstate__(self, state): self.optimizer.__setstate__(state) + + def _mark_expert_parallel_states(self, state_dict): + if state_dict is None: + return None + + ep_ids = set() + idx = 0 + param_set = set() + for param_group in self.__getstate__()["param_groups"]: + for group, params in param_group.items(): + if group == "params": + for p in params: + if isinstance(p, torch.Tensor) and hasattr(p, "expert_model_parallel") and p.expert_model_parallel: + if id(p) not in param_set: + ep_ids.add(idx) + idx += 1 + param_set.add(id(p)) + + + for id_p, param_state_dict in state_dict["state"].items(): + if id_p in ep_ids: + for state_key in param_state_dict: + param_state_dict[state_key].expert_model_parallel = True + + return state_dict + + + def _fetch_gradients(self): + gradients = [] + ep_gradients = [] + for param_group in self.optimizer.__getstate__()["param_groups"]: + for group, params in param_group.items(): + if group == "params": + for p in params: + if isinstance(p, torch.Tensor) and p.grad is not None: + if hasattr(p, "expert_model_parallel") and p.expert_model_parallel: + ep_gradients.append(p.grad.data) + else: + gradients.append(p.grad.data) + + return gradients, ep_gradients + def step(self, closure=None): # sequence parallel all-reduce if self.nxd_config["sequence_parallel"]: @@ -76,7 +120,16 @@ def step(self, closure=None): optimizer_config = self.nxd_config["optimizer_config"] if not optimizer_config["zero_one_enabled"]: - grads.bucket_allreduce_gradients(xm._fetch_gradients(self)) + non_ep_gradients, ep_gradients = self._fetch_gradients() + grads.bucket_allreduce_gradients(non_ep_gradients + ep_gradients) + if len(ep_gradients) > 0: + # initial allreduce takes place over the expert data parallel group + # which coincides with data parallel group when ep is disabled. when ep + # is enabled, non-ep gradients would additionally need to be reduced + # over the expert model parallel groups (since ep happens over dp ranks). + # non-ep gradient reduction needs to take place separately over emp/edp + # groups (in two separate steps) to side-step the MPMD limitation in runtime. + grads.bucket_allreduce_gradients(non_ep_gradients, reduce_over_ep_group=True) if optimizer_config["grad_clipping"]: self._grad_norm = grads.clip_grad_norm(self.params, optimizer_config["max_grad_norm"]) ret = self.optimizer.step(closure=closure) diff --git a/src/neuronx_distributed/trainer/post_partition_hooks.py b/src/neuronx_distributed/trainer/post_partition_hooks.py new file mode 100644 index 0000000..e44eec9 --- /dev/null +++ b/src/neuronx_distributed/trainer/post_partition_hooks.py @@ -0,0 +1,35 @@ +from functools import partial +from typing import Any, Callable + + +class PostPartitionHooks: + def __init__( + self, + ): + self.hooks = [] + + def register_post_partition_hook(self, callable_function: Callable[..., Any], func_args=(), func_kwargs={}): + if not callable(callable_function): + raise ValueError("callable_function must be a callable object") + + self.hooks.append( + { + "function": partial(callable_function, *func_args, **func_kwargs), + "name": callable_function.__name__, + } + ) + + def execute_all_hooks(self, model=None): + hook_outputs = [] + for hook in self.hooks: + func = hook["function"] + name = hook["name"] + if name == "filter_to_local_parameter_group": + assert model is not None, "When executing filter_to_local_parameter_group hook, model object cannot be None" + hook_outputs.append(func(model=model)) + else: + hook_outputs.append(func()) + + # Finally, clear all hooks + self.hooks.clear() + return hook_outputs diff --git a/src/neuronx_distributed/trainer/trainer.py b/src/neuronx_distributed/trainer/trainer.py index 1b0bf42..abb8016 100644 --- a/src/neuronx_distributed/trainer/trainer.py +++ b/src/neuronx_distributed/trainer/trainer.py @@ -5,8 +5,8 @@ import torch_xla.core.xla_model as xm from packaging import version +from neuronx_distributed.optimizer import NeuronZero1Optimizer, NeuronEPZero1Optimizer from neuronx_distributed.modules.lora import LoraConfig, LoraModel -from neuronx_distributed.optimizer import NeuronZero1Optimizer from neuronx_distributed.parallel_layers import parallel_state from neuronx_distributed.parallel_layers.pad import pad_model from neuronx_distributed.pipeline import NxDPPModel @@ -20,15 +20,19 @@ get_model_sequential, init_on_device, is_hf_pretrained_model, - is_nxtt_pretrained_model, + is_nxdt_pretrained_model, + check_delay_tracing, + get_delay_tracing ) + logger = get_logger() def neuronx_distributed_config( tensor_parallel_size=1, pipeline_parallel_size=1, + expert_parallel_size=1, pipeline_config=None, optimizer_config=None, activation_checkpoint_config=None, @@ -110,6 +114,7 @@ def neuronx_distributed_config( config = { "tensor_parallel_size": tensor_parallel_size, "pipeline_parallel_size": pipeline_parallel_size, + "expert_parallel_size": expert_parallel_size, "pipeline_config": pipeline_config, "optimizer_config": optimizer_config, "activation_checkpoint_config": activation_checkpoint_config, @@ -124,6 +129,7 @@ def neuronx_distributed_config( parallel_state.initialize_model_parallel( tensor_model_parallel_size=config["tensor_parallel_size"], pipeline_model_parallel_size=config["pipeline_parallel_size"], + expert_model_parallel_size=config["expert_parallel_size"], ) if torch.distributed.is_initialized() and parallel_state.is_global_rank_zero(): @@ -136,6 +142,7 @@ def initialize_parallel_model(nxd_config, model_fn, *model_args, **model_kwargs) parallel_state.initialize_model_parallel( tensor_model_parallel_size=nxd_config["tensor_parallel_size"], pipeline_model_parallel_size=nxd_config["pipeline_parallel_size"], + expert_model_parallel_size=nxd_config["expert_parallel_size"], ) # Phase 1: get the base model @@ -157,8 +164,10 @@ def initialize_parallel_model(nxd_config, model_fn, *model_args, **model_kwargs) if pp_enabled: nxd_config["pipeline_config"].update({"param_init_fn": param_init_fn}) nxd_config["pipeline_config"].update({"use_model_wrapper": True}) + nxd_config["pipeline_config"]["_delay_tracing"] = check_delay_tracing(nxd_config) model = NxDPPModel(model, **nxd_config["pipeline_config"]) + # Phase 3: materialize model and move to device sequential_move_factor = nxd_config["model_init_config"]["sequential_move_factor"] @@ -192,16 +201,16 @@ def initialize_parallel_model(nxd_config, model_fn, *model_args, **model_kwargs) for name in base_model._no_split_modules: activation_checkpoint_classes.append(get_module_class_from_name(nxd_model, name)) - elif is_nxtt_pretrained_model(base_model): - # Toolkit transformer layer will always be this type - from neuronx_training_toolkit.models.megatron.transformer import ( + elif is_nxdt_pretrained_model(base_model): + # NxDT transformer layer will always be this type + from neuronx_distributed_training.models.megatron.transformer import ( ParallelTransformerLayer, ) activation_checkpoint_classes = [ParallelTransformerLayer] else: raise RuntimeError( - '`activation_checkpoint_config` "full" is only supported for huggingface transformers or nxtt models.' + '`activation_checkpoint_config` "full" is only supported for huggingface transformers or nxdt models.' ) else: @@ -220,6 +229,12 @@ def initialize_parallel_model(nxd_config, model_fn, *model_args, **model_kwargs) def initialize_parallel_optimizer(nxd_config, optimizer_class, parameters, **defaults): + optimizer = initialize_optimizer_from_class(nxd_config, optimizer_class, parameters, **defaults) + nxd_optim = NxDOptimizer(optimizer, nxd_config) + return nxd_optim + + +def initialize_optimizer_from_class(nxd_config, optimizer_class, parameters, model=None, **defaults): optimizer_config = nxd_config["optimizer_config"] mixed_precision_config = nxd_config["mixed_precision_config"] if optimizer_config["zero_one_enabled"]: @@ -233,13 +248,17 @@ def initialize_parallel_optimizer(nxd_config, optimizer_class, parameters, **def zero1_configs.update({"coalesce_cc": True}) if mixed_precision_config["use_master_weights"]: if "XLA_DOWNCAST_BF16" in os.environ and os.environ["XLA_DOWNCAST_BF16"] == "1": - zero1_configs.update({"optimizer_dtype": torch.double}) + defaults.update({"optimizer_dtype": torch.double}) if mixed_precision_config["use_fp32_grad_acc"]: zero1_configs.update({"use_grad_acc_hook": True, "higher_cc_precision": True}) zero1_configs.update( {"save_master_weights": True if mixed_precision_config["use_master_weights_in_ckpt"] else False} ) - optimizer = NeuronZero1Optimizer( + if get_delay_tracing(nxd_config): + defaults["lazy_init"] = True + logger.info("printing defaults here %s", defaults) + logger.info("printing zero1 config here %s", zero1_configs) + optimizer = zero1_optimizer_cls( parameters, optimizer_class, **zero1_configs, @@ -254,5 +273,28 @@ def initialize_parallel_optimizer(nxd_config, optimizer_class, parameters, **def "Non Zero-1 optimizer does not support `use_fp32_grad_acc` of `use_master_weights_in_ckpt`." ) optimizer = optimizer_class(parameters, **defaults) + if get_delay_tracing(nxd_config): + from neuronx_distributed.trainer import hooks + + hooks.register_post_partition_hook(filter_to_local_parameter_group, [optimizer]) + return optimizer + + +""" +During the delayed tracing and partition flow, the optimizer gets initialized with all parameters, as the model has not +yet been moved to the device. When the model is moved to device on xla new tensors are created on device. +We filter the parameters to those in model.local_parameters() once the model is moved onto device. +""" + - return NxDOptimizer(optimizer, nxd_config) +def filter_to_local_parameter_group(optimizer, model): + parameters = optimizer.param_groups + for param_group in parameters: + filtered_param_list = [] + for meta_param in param_group["params"]: + if meta_param in model.meta_device_parameter_map: + meta_or_xla_param = model.meta_device_parameter_map[meta_param] + if meta_or_xla_param.device.type == "xla": + filtered_param_list.append(meta_or_xla_param) + param_group["params"] = filtered_param_list + return diff --git a/src/neuronx_distributed/utils/activation_checkpoint.py b/src/neuronx_distributed/utils/activation_checkpoint.py index fa8daed..b72830b 100644 --- a/src/neuronx_distributed/utils/activation_checkpoint.py +++ b/src/neuronx_distributed/utils/activation_checkpoint.py @@ -64,7 +64,9 @@ def apply_activation_checkpointing( rmsg("Using NxDPPModel as input, `check_fn` will be ignored. Only transformer layer will be wrapped.") ) activation_checkpoint_module = model.transformer_layer_cls - check_fn = lambda m: isinstance(m, activation_checkpoint_module) + + def check_fn(m): + return isinstance(m, activation_checkpoint_module) models = model.local_stage_modules elif isinstance(model, NxDModel): models = model.local_modules() diff --git a/src/neuronx_distributed/utils/logger.py b/src/neuronx_distributed/utils/logger.py index 9694f74..8f78ba0 100644 --- a/src/neuronx_distributed/utils/logger.py +++ b/src/neuronx_distributed/utils/logger.py @@ -2,42 +2,41 @@ import logging import os import sys +from functools import lru_cache, wraps +from typing import Any, Optional, Callable, TypeVar +from typing_extensions import ParamSpec -_logger_initialized = False -_log_level = None +# Third Party +import torch +T = TypeVar("T") +P = ParamSpec("P") -def get_log_level(): - global _log_level - if _log_level is None: - default = "info" - log_level = os.environ.get("NXD_LOG_LEVEL", default=default) - log_level = log_level.lower() - # allowed_levels = ["info", "trace", "debug", "warning", "error", "fatal", "off"] - if log_level == "off": - level = logging.FATAL + 1 - elif log_level == "fatal": - # fatal added so that log level can take same values for cpp and py - # fatal in cpp exceptions kills the process - # so use fatal for that only - level = logging.FATAL - elif log_level == "error": - level = logging.ERROR - elif log_level == "warning": - level = logging.WARNING - elif log_level == "info": - level = logging.INFO - elif log_level in ["debug", "trace"]: - level = logging.DEBUG - else: - level = logging.INFO - _log_level = level - return _log_level +@lru_cache +def get_log_level() -> int: + """Get the log level as configured or the default""" + log_level = os.environ.get("NXD_LOG_LEVEL", default="info").lower() + if log_level == "off": + return logging.FATAL + 1 + if log_level == "fatal": + # fatal added so that log level can take same values for cpp and py + # fatal in cpp exceptions kills the process + # so use fatal for that only + return logging.FATAL + if log_level == "error": + return logging.ERROR + if log_level == "warning": + return logging.WARNING + if log_level == "info": + return logging.INFO + if log_level in ["debug", "trace"]: + return logging.DEBUG + raise ValueError(f"Allowed NXD_LOG_LEVELS are: info, trace, debug, warning, error, fatal, off. Got: {log_level}") class PackagePathFilter(logging.Filter): - def filter(self, record): + def filter(self, record: Any) -> bool: pathname = record.pathname record.relativepath = None abs_sys_paths = map(os.path.abspath, sys.path) @@ -50,33 +49,64 @@ def filter(self, record): return True -def get_logger(name="neuronx_distributed"): - global _logger_initialized - if not _logger_initialized: - level = get_log_level() - logger = logging.getLogger(name) - hide_time = os.getenv("NXD_LOG_HIDE_TIME", "False") +def get_logger(name: str = "neuronx_distributed", rank0_only: bool = True) -> logging.Logger: + logger = logging.getLogger(name + ("[0]" if rank0_only else "[]")) + if getattr(logger, "initialized", False): + return logger # already configured - fmt = "[" - if hide_time.lower() in ["true", "1"]: - hide_time = True - else: - hide_time = False - fmt += "%(asctime)s.%(msecs)03d: " - logger.handlers = [] - log_formatter = logging.Formatter( - fmt=fmt + "%(levelname).1s %(relativepath)s:%(lineno)d] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - stdout_handler = logging.StreamHandler(sys.stdout) - stdout_handler.setFormatter(log_formatter) - stdout_handler.addFilter(PackagePathFilter()) - logger.addHandler(stdout_handler) + hide_time = os.getenv("NXD_LOG_HIDE_TIME", "false").lower() + time = "" if hide_time in ["true", "1"] else "%(asctime)s.%(msecs)03d: " + log_formatter = logging.Formatter( + fmt=f"[{time}%(levelname).1s %(relativepath)s:%(lineno)d] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setFormatter(log_formatter) + stdout_handler.addFilter(PackagePathFilter()) + logger.handlers = [stdout_handler] # overwrite - if level: - logger.setLevel(level) - else: - logger.disabled = True - _logger_initialized = True - logger.propagate = False - return logging.getLogger(name) + level = get_log_level() + if level <= logging.FATAL: + logger.setLevel(level) + if rank0_only: + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + for level in ( + "debug", + "info", + "warning", + "error", + "exception", + "fatal", + "critical", + ): + setattr(logger, level, _rank0_only(getattr(logger, level))) + else: + logger.disabled = True + logger.propagate = False + logger.initialized = True + return logger + + +def _rank0_only(fn: Callable[P, T], default: Optional[T] = None, **extra_kwargs: P.kwargs) -> Callable[P, Optional[T]]: + """Wrap a logging.Logger function to call internal function only in rank zero. + Function that can be used as a decorator to enable a function/method being called only on global rank 0. + + Arguments: + fn: function to decorate + default: value to return when the global rank is not 0 + Returns: + Decorated function + """ + + @wraps(fn) + def wrapped_fn(*args: P.args, **kwargs: P.kwargs) -> Optional[T]: + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else int(os.environ.get("RANK", "0")) + if rank == 0: + # excluding the wrapper from calling stack for logger to use + # see https://docs.python.org/3/library/logging.html#logging.Logger.findCaller + kwargs["stacklevel"] = kwargs.get("stacklevel", 1) + 1 + return fn(*args, **kwargs) + return default + + return wrapped_fn diff --git a/src/neuronx_distributed/utils/medusa_utils.py b/src/neuronx_distributed/utils/medusa_utils.py new file mode 100644 index 0000000..bc75b3f --- /dev/null +++ b/src/neuronx_distributed/utils/medusa_utils.py @@ -0,0 +1,213 @@ +import torch +import torch.nn.functional as F + +TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient) + +def pad_path(path, length, pad_value=-2): + """ + Pad the given path list with a specific value up to a specified length. + + Parameters: + - path (list): The original list that needs padding. + - length (int): The desired length of the padded list. + - pad_value (optional, default=-2): The value to use for padding. + + Returns: + - list: A new list based on the original path but padded to the desired length. + + Example: + >>> pad_path([1,2,3], 5) + [1, 2, 3, -2, -2] + + Note: + If the given path is already longer than the specified length, + then no padding occurs, and the original path is returned. + """ + + # Calculate the number of padding values needed by subtracting the length + # of the path from the desired length. + # Append the padding values to the original path and return the new list. + return path + [pad_value] * (length - len(path)) + +def generate_medusa_buffers(medusa_choices): + """ + Generate buffers for the Medusa structure based on the provided choices. + + Parameters: + - medusa_choices (list): A nested list representing tree in the Medusa structure. + - device (str): Device to which the tensors should be moved. Default is "cuda". + + Returns: + - dict: A dictionary containing buffers related to the Medusa structure. + """ + + # Sort the medusa_choices based on their lengths and then their values + sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x)) + medusa_len = len(sorted_medusa_choices) + 1 + + # Initialize depth_counts to keep track of how many choices have a particular depth + depth_counts = [] + prev_depth = 0 + for path in sorted_medusa_choices: + depth = len(path) + if depth != prev_depth: + depth_counts.append(0) + depth_counts[depth - 1] += 1 + prev_depth = depth + + # Create the attention mask for Medusa + medusa_attn_mask = torch.eye(medusa_len, medusa_len) + medusa_attn_mask[:, 0] = 1 + start = 0 + for i in range(len(depth_counts)): + for j in range(depth_counts[i]): + cur_medusa_choice = sorted_medusa_choices[start + j] + # retrieve ancestor position + if len(cur_medusa_choice) == 1: + continue + ancestor_idx = [] + for c in range(len(cur_medusa_choice) - 1): + ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1) + medusa_attn_mask[j + start + 1, ancestor_idx] = 1 + start += depth_counts[i] + + # Generate tree indices for the Medusa structure + medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long) + medusa_tree_indices[0] = 0 + start = 0 + for i in range(len(depth_counts)): + for j in range(depth_counts[i]): + cur_medusa_choice = sorted_medusa_choices[start + j] + medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1 + start += depth_counts[i] + + # Generate position IDs for the Medusa structure + medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long) + start = 0 + for i in range(len(depth_counts)): + medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1 + start += depth_counts[i] + + # Generate retrieval indices for Medusa structure verification + retrieve_indices_nest = [] + retrieve_paths = [] + for i in range(len(sorted_medusa_choices)): + cur_medusa_choice = sorted_medusa_choices[-i-1] + retrieve_indice = [] + if cur_medusa_choice in retrieve_paths: + continue + else: + for c in range(len(cur_medusa_choice)): + retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1])) + retrieve_paths.append(cur_medusa_choice[:c+1]) + retrieve_indices_nest.append(retrieve_indice) + max_length = max([len(x) for x in retrieve_indices_nest]) + retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest] + retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long) + retrieve_indices = retrieve_indices + 1 + retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1) + + # Aggregate the generated buffers into a dictionary + medusa_buffers = { + "medusa_attn_mask": medusa_attn_mask.unsqueeze(0).unsqueeze(0)[0][0], + "tree_indices": medusa_tree_indices, + "medusa_position_ids": medusa_position_ids, + "retrieve_indices": retrieve_indices, + } + + return medusa_buffers + +def generate_candidates(medusa_logits, logits, tree_indices, retrieve_indices): + """ + Generate candidates based on provided logits and indices. + + Returns: + - tuple (torch.Tensor, torch.Tensor): A tuple containing two sets of candidates: + 1. Cartesian candidates derived from the combined original and Medusa logits. + 2. Tree candidates mapped from the Cartesian candidates using tree indices. + """ + # Greedy decoding: Select the most probable candidate from the original logits. + + candidates_logit=logits.squeeze(0).squeeze(0)[:1] + candidates_medusa_logits=medusa_logits.squeeze(1).squeeze(1) + + # Combine the selected candidate from the original logits with the topk medusa logits. + candidates = torch.cat([candidates_logit, candidates_medusa_logits.contiguous().view(-1)], dim=-1) + + # Map the combined candidates to the tree indices to get tree candidates. + tree_candidates = candidates[tree_indices] + + # Extend the tree candidates by appending a zero. + tree_candidates_ext = torch.cat([tree_candidates, torch.zeros((1), dtype=torch.long, device=tree_candidates.device)], dim=0) + + # Retrieve the cartesian candidates using the retrieve indices. + cart_candidates = tree_candidates_ext[retrieve_indices] + + # Unsqueeze the tree candidates for dimension consistency. + tree_candidates = tree_candidates.unsqueeze(0) + + return cart_candidates, tree_candidates + +def evaluate_posterior(logits, candidates): + """ + Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate. + + Depending on the temperature value, the function either uses greedy decoding or evaluates posterior + probabilities to select the best candidate. + + Returns: + - best_candidate (torch.Tensor): Index of the chosen best candidate. + - accept_length (int): Length of the accepted candidate sequence. + """ + # Find the tokens that match the maximum logits for each position in the sequence + posterior_mask = ( + candidates[:, 1:] == logits[:,:-1,:1].squeeze(2).long()).int() + candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1) + accept_length = candidates_accept_length.max() + # Choose the best candidate + if accept_length == 0: + # Default to the first candidate if none are accepted + best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device) + else: + best_candidate = torch.argmax(candidates_accept_length).to(torch.long) + return best_candidate, accept_length + +def update_inference_inputs( + input_ids, + candidates, + best_candidate, + accept_length, + retrieve_indices, + outputs, + logits, + medusa_logits, + new_token, +): + """ + Update the input sequences and relevant tensors based on the selected best candidate from the inference results. + + Returns: + - input_ids (torch.Tensor): Updated input token sequences. + - logits (torch.Tensor): Updated logits. + - medusa_logits (torch.Tensor): Updated medusa logits. + - new_token (int): Updated counter for the new tokens added. + - selct_indices + """ + # Calculate the starting position for new tokens based on the previous input length + prev_input_len = input_ids.shape[1] + # Map the best candidate indices to the original indices in the sequence + select_indices = ( + retrieve_indices[best_candidate, : accept_length + 1] + prev_input_len + ) + # Append the tokens from the best candidate to the input sequence + input_ids = torch.cat( + [input_ids, candidates[None, best_candidate, : accept_length + 1]], dim=-1 + ) + logits = logits[None, best_candidate, accept_length : accept_length + 1] + medusa_logits = medusa_logits[ + :, None, best_candidate, accept_length : accept_length + 1 + ] + # Update the new token counter + new_token += accept_length + 1 + + return input_ids, logits, medusa_logits, new_token, select_indices \ No newline at end of file diff --git a/src/neuronx_distributed/utils/model_utils.py b/src/neuronx_distributed/utils/model_utils.py index 535bb39..0a50c16 100644 --- a/src/neuronx_distributed/utils/model_utils.py +++ b/src/neuronx_distributed/utils/model_utils.py @@ -34,11 +34,11 @@ except ImportError: _TORCHDISTX_AVAIL = False -_NXTT_AVAIL = True +_NXDT_AVAIL = True try: - from neuronx_training_toolkit.models.megatron.module import MegatronModule + from neuronx_distributed_training.models.megatron.module import MegatronModule except ImportError: - _NXTT_AVAIL = False + _NXDT_AVAIL = False def analyze_shared_parameters(module, shared_parameters=None, prefix=""): @@ -99,11 +99,39 @@ def is_hf_transformers_available(): def is_hf_accelerate_available(): return _Accelerate_AVAIL -def is_nxtt_pretrained_model(model): - return _NXTT_AVAIL and isinstance(model, MegatronModule) +def is_nxdt_pretrained_model(model): + return _NXDT_AVAIL and isinstance(model, MegatronModule) + +def is_nxdt_available(): + return _NXDT_AVAIL + +def recursive_filter(item, predicate): + """ Filter a structure containing tensors based on the given predicate """ + + def _is_tensor_or_parameter(obj): + return isinstance(obj, (torch.Tensor, nn.Parameter)) + + def _augmented_predicate(obj): + return predicate(obj) if _is_tensor_or_parameter(obj) else True + + if isinstance(item, dict): + out = {} + for k, v in item.items(): + if _augmented_predicate(v): + out[k] = recursive_filter(v, predicate) + elif isinstance(item, (list, tuple, set)): + out = [] + for x in item: + if _augmented_predicate(x): + out.append(recursive_filter(x, predicate)) + out = type(item)(out) + else: + # under normal circumstances this should not return None, unless + # there is an unexpected data structure involved + out = item if _augmented_predicate(item) else None + + return out -def is_nxtt_available(): - return _NXTT_AVAIL @contextmanager @@ -128,10 +156,12 @@ def preserve_parallel_attributes(model: torch.nn.Module) -> None: """ Preserve the following 3 attributes for the model parameters - tensor_model_parallel + - expert_model_parallel - sequence_parallel_enabled - shared """ tp_params = {} + ep_params = {} seq_parallel_params = {} shared_parameters = {} for name, param in model.named_parameters(): @@ -141,6 +171,8 @@ def preserve_parallel_attributes(model: torch.nn.Module) -> None: "partition_dim": param.partition_dim, "stride": param.partition_stride, } + if hasattr(param, "expert_model_parallel"): + ep_params[name] = param.expert_model_parallel if hasattr(param, "sequence_parallel_enabled"): seq_parallel_params[name] = param.sequence_parallel_enabled if hasattr(param, "shared"): @@ -151,6 +183,8 @@ def preserve_parallel_attributes(model: torch.nn.Module) -> None: for name, param in model.named_parameters(): if name in tp_params and not hasattr(param, "tensor_model_parallel"): set_tensor_model_parallel_attributes(param, *tp_params[name].values()) + if name in ep_params and not hasattr(param, "expert_model_parallel"): + setattr(param, "expert_model_parallel", ep_params[name]) if name in seq_parallel_params and not hasattr(param, "sequence_parallel_enabled"): setattr(param, "sequence_parallel_enabled", seq_parallel_params[name]) if name in shared_parameters and not hasattr(param, "shared"): @@ -300,3 +334,15 @@ def get_model_sequential(model, device, sequential_move_factor=11, param_init_fn move_model_to_device(model, device) xm.rendezvous("get_model_sequential" + str(worker)) return model + + +def check_delay_tracing(nxd_config): + # Temporarily disabling delayed tracing while we investigate some issues + # TODO re-enable once the issues with delayed tracing are resolved + return False + + +def get_delay_tracing(arg): + # Temporarily disabling delayed tracing while we investigate some issues + # TODO re-enable once the issues with delayed tracing are resolved + return False diff --git a/src/neuronx_distributed/utils/sampling.py b/src/neuronx_distributed/utils/sampling.py index 6954fd9..c5a01a5 100644 --- a/src/neuronx_distributed/utils/sampling.py +++ b/src/neuronx_distributed/utils/sampling.py @@ -11,7 +11,11 @@ class Sampler: def __init__(self, config: PretrainedConfig): self.on_device_sampling = config.on_device_sampling - if config.do_sample == True and config.num_beams == 1: + if hasattr(config, "is_medusa"): + self.is_medusa = config.is_medusa + else: + self.is_medusa = False + if config.do_sample and config.num_beams == 1: self.top_k = config.top_k self.sampling_method = self.multinomial else: @@ -68,4 +72,6 @@ def multinomial(self, token_logits): # count negative values to find index of sampled value counts = count_nonzero((diffs < 0), dim=dim) # return token indeces + if self.is_medusa: + return top_k_logits_indices return gather(input=top_k_logits_indices, dim=dim, index=counts.unsqueeze(1)).flatten() diff --git a/src/neuronx_distributed/utils/serialization.py b/src/neuronx_distributed/utils/serialization.py index 042f135..334357b 100644 --- a/src/neuronx_distributed/utils/serialization.py +++ b/src/neuronx_distributed/utils/serialization.py @@ -30,7 +30,7 @@ def uncompress_from_string(serialized_data_string): def is_instance_namedtuple(iterable): - return isinstance(iterable, tuple) and iterable.__class__.__base__ == tuple and hasattr(iterable, "_fields") + return isinstance(iterable, tuple) and isinstance(iterable.__class__.__base__, tuple) and hasattr(iterable, "_fields") def find_loss_from_output_and_spec(output_val, spec_val): diff --git a/src/neuronx_distributed/utils/speculative_decoding.py b/src/neuronx_distributed/utils/speculative_decoding.py index 808966f..83140b6 100644 --- a/src/neuronx_distributed/utils/speculative_decoding.py +++ b/src/neuronx_distributed/utils/speculative_decoding.py @@ -4,12 +4,19 @@ import torch from transformers.generation.stopping_criteria import StoppingCriteriaList +from neuronx_distributed.utils.medusa_utils import ( + evaluate_posterior, + generate_candidates, + generate_medusa_buffers, + update_inference_inputs, +) + class NeuronSpeculation: - def assisted_decoding( + def _assisted_decoding( self, input_ids: torch.LongTensor, - assistant_model: "PreTrainedModel", + candidate_generator: "CandidateGenerator", #noqa do_sample: bool = False, stopping_criteria: Optional[StoppingCriteriaList] = None, pad_token_id: Optional[int] = None, @@ -19,6 +26,22 @@ def assisted_decoding( if do_sample: raise ValueError("Sampling is unsupported as part of speculation. Only greedy speculation is supported.") + assistant_model = candidate_generator.assistant_model + if self.config.is_medusa: + # TODO: move this to sampling + return self._medusa_assisted_decoding( + input_ids, assistant_model, stopping_criteria, pad_token_id, eos_token_id, **model_kwargs + ) + else: + return self._standard_assisted_decoding( + input_ids, assistant_model, stopping_criteria, pad_token_id, eos_token_id, **model_kwargs + ) + + def _standard_assisted_decoding( + self, input_ids, assistant_model, stopping_criteria, pad_token_id, eos_token_id, **model_kwargs + ): + # Implementation of standard assisted decoding + # Initialize the num_assistant_tokens used for speculation. if hasattr(assistant_model, "num_assistant_tokens"): num_assistant_tokens = assistant_model.num_assistant_tokens @@ -34,11 +57,10 @@ def assisted_decoding( if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + assistant_model = candidate_generator.assistant_model # Prepare assistant model's keys of inputs assistant_kwargs = copy.deepcopy(model_kwargs) - input_ids_key = "input_ids" - attention_key = "attention_mask" # Other auxiliary variables max_len = stopping_criteria[0].max_length @@ -54,8 +76,6 @@ def assisted_decoding( # Prepare the input ids and attention mask for the draft model candidate_input_ids = input_ids - candidate_input_ids[0, curr_pos + 1] = new_token - assistant_kwargs["attention_mask"][0, curr_pos + 1] = 1 # This is the finally return outputs; append the first generated token returned_ids = torch.cat((input_ids[:, : curr_pos + 1], new_token), dim=1) @@ -63,7 +83,7 @@ def assisted_decoding( # Speculation loop while True: # 1 Token generation using draft model - for _ in range(int(spec_len - 1)): + for _ in range(int(num_assistant_tokens)): # 1.1 Prepare assistant model inputs assistant_inputs = assistant_model.prepare_inputs_for_generation( candidate_input_ids, @@ -146,9 +166,13 @@ def assisted_decoding( break input_ids = valid_tokens[:, -1:] candidate_input_ids = valid_tokens[:, -1:] - model_inputs["attention_mask"] = model_inputs["attention_mask"].index_fill( + model_inputs_attn_mask = model_inputs["attention_mask"] + n_matches_concat_tensor = torch.zeros(1, n_matches + 1, dtype=model_inputs_attn_mask.dtype) + model_inputs_attn_mask = torch.cat([model_inputs_attn_mask, n_matches_concat_tensor], dim=-1) + model_inputs["attention_mask"] = model_inputs_attn_mask.index_fill( 1, torch.arange(curr_pos + 1, curr_pos + 1 + n_matches + 1), 1 ) + curr_pos = curr_pos + n_matches + 1 assistant_kwargs["attention_mask"] = copy.deepcopy(model_inputs["attention_mask"]) @@ -157,8 +181,125 @@ def assisted_decoding( if cur_len >= max_len: break # 8. If the rest length is smaller than speculation length, we directly run the target model to finish - if max_len - cur_len < spec_len - 1: + if max_len - cur_len < spec_len: # @yihsian: TODO: complete with using target tokengen model break return returned_ids + + def _medusa_assisted_decoding( + self, input_ids, assistant_model, stopping_criteria, pad_token_id, eos_token_id, **model_kwargs + ): + medusa_kwargs = copy.deepcopy(model_kwargs) + + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + + mc_sim_7b_63 = self.config.medusa_tree + + medusa_buffers = generate_medusa_buffers(mc_sim_7b_63) + + model_inputs = self.prepare_inputs_for_generation(input_ids, **medusa_kwargs) + + outputs = self(**model_inputs) + + non_zero_input_ids = input_ids.nonzero() + cur_len = torch.tensor([non_zero_input_ids.size(0)], dtype=torch.int64) + + logits, medusa_logits = self._extract_logits(outputs) + + medusa_logits = medusa_logits[:, :, None, :] + + accept_length = 0 + final_accept_length = 0 + new_token = 0 + accept_lengths_tree = [] + cur_length = cur_len[0].item() + 1 + accept_lengths_tree.append(1) + count = 0 + select_indices = torch.arange( + cur_len[0].item(), cur_len[0].item() + self.config.num_medusa_heads + 1, dtype=torch.int64 + ) + + for i in range(self.config.max_new_tokens): + count = count + 1 + candidates, tree_candidates = generate_candidates( + medusa_logits, + logits, + medusa_buffers["tree_indices"], + medusa_buffers["retrieve_indices"], + ) + position_ids = medusa_buffers["medusa_position_ids"] + input_ids.nonzero().shape[0] + + medusa_kwargs = self._prepare_medusa_kwargs( + position_ids, cur_len, medusa_buffers, select_indices, medusa_kwargs + ) + + tree_candidates = tree_candidates.long() + + model_inputs = self.prepare_medusa_inputs_for_generation(tree_candidates, **medusa_kwargs) + + outputs = self(**model_inputs) + + tree_logits, tree_medusa_logits = self._extract_logits(outputs) + + logits = tree_logits[0, medusa_buffers["retrieve_indices"]] + medusa_logits = tree_medusa_logits[:, 0, medusa_buffers["retrieve_indices"]] + + best_candidate, accept_length = evaluate_posterior(logits, candidates) + cur_len = torch.tensor([input_ids.nonzero().size(0) - 1], dtype=torch.int64) + + input_ids, logits, medusa_logits, new_token, select_indices = update_inference_inputs( + input_ids[:, : (int(cur_len[0] + 1))], + candidates, + best_candidate, + accept_length, + medusa_buffers["retrieve_indices"], + outputs, + logits, + medusa_logits, + new_token, + ) + + medusa_kwargs["attention_mask"] = self._update_attention_mask( + model_inputs, accept_length, cur_len, medusa_kwargs + ) + cur_len = 1 + cur_len + accept_length_tree = input_ids.shape[1] - cur_length + cur_length = accept_length_tree + cur_length + accept_lengths_tree.append(accept_length_tree) + final_accept_length += accept_length + 1 + if eos_token_id in new_token or final_accept_length > self.config.max_new_tokens: + break + return input_ids + + def _prepare_medusa_kwargs(self, position_ids, cur_len, medusa_buffers, select_indices, medusa_kwargs): + medusa_kwargs["position_ids"] = position_ids.unsqueeze(0) + medusa_kwargs["accepted_indices"] = torch.arange( + cur_len[0].item(), cur_len[0].item() + self.config.num_medusa_heads + 1, dtype=torch.int64 + ) + for index, value in enumerate(select_indices): + medusa_kwargs["accepted_indices"][index] = value + medusa_kwargs["accepted_indices"] = medusa_kwargs["accepted_indices"].unsqueeze(0) + medusa_kwargs["current_length"] = torch.arange( + cur_len[0].item(), cur_len[0].item() + self.config.num_medusa_heads + 1, dtype=torch.int64 + ).unsqueeze(0) + medusa_mask = medusa_buffers["medusa_attn_mask"].unsqueeze(0) + medusa_kwargs["medusa_mask"] = medusa_mask.type_as(torch.LongTensor()) + medusa_kwargs["scatter_index"] = torch.arange( + position_ids[0], position_ids[0] + self.config.medusa_speculation_length, dtype=torch.int64 + ).unsqueeze(0) + return medusa_kwargs + + def _update_attention_mask(self, model_inputs, accept_length, cur_len, medusa_kwargs): + accept_length_concat_tensor = torch.zeros(1, accept_length + 1, dtype=model_inputs["attention_mask"].dtype) + attn_mask = torch.cat([model_inputs["attention_mask"], accept_length_concat_tensor], dim=-1) + + medusa_kwargs["attention_mask"] = attn_mask.index_fill( + 1, torch.arange(int(cur_len[0]) + 1, int(cur_len[0]) + 1 + accept_length + 1), 1 + ) + return medusa_kwargs["attention_mask"] + + def _extract_logits(self, outputs): + logits = outputs["hidden_states"][:1, :, :] + medusa_logits = outputs["hidden_states"][1:, :, :].unsqueeze(1) + return logits, medusa_logits diff --git a/test/integration/combinatorial_tests/common/compare_gpu_trn1_metrics.py b/test/integration/combinatorial_tests/common/compare_gpu_trn1_metrics.py index aedaa9b..a64f5df 100644 --- a/test/integration/combinatorial_tests/common/compare_gpu_trn1_metrics.py +++ b/test/integration/combinatorial_tests/common/compare_gpu_trn1_metrics.py @@ -1,7 +1,5 @@ import argparse -import math import os -from collections import defaultdict import tensorboard.backend.event_processing.event_accumulator as event_acc import torch @@ -21,8 +19,8 @@ def load_events(event_file): def get_gpu_values(args): if args.gpu_event_file.endswith('pt'): if len(args.tags) > 1: - raise ValueError(f"Too many tags provided to use gpu pt benchmark. Only one tag is supported for gpu pt " - f"benchmark") + raise ValueError("Too many tags provided to use gpu pt benchmark. Only one tag is supported for gpu pt " + "benchmark") else: tag = args.tags[0] print('GPU benchmark is a pt file') @@ -48,6 +46,10 @@ def main(): "--rtol", type=float, help="Relative tolerance to use", default=0.05 ) + parser.add_argument( + "--confidence_interval", type=float, help="Relative tolerance to use", default=0.95 + ) + args = parser.parse_args() if not os.path.exists(args.gpu_event_file): @@ -59,18 +61,19 @@ def main(): trn1_events = load_events(args.trn1_event_file) for tag in args.tags: - if tag in gpu_benchmark and tag in trn1_events: + trn1_tag = 'step loss' if (tag == 'loss' and tag not in trn1_events) else tag + if tag in gpu_benchmark and trn1_tag in trn1_events: # extract values for the relevant tag from GPU to handle different benchmark types - gpu = [val.value if type(val) is not float else val for val in gpu_benchmark[tag]] - trn = [val.value if type(val) is not float else val for val in trn1_events[tag]] + gpu = [val.value if not isinstance(val, float) else val for val in gpu_benchmark[tag]] + trn = [val.value if not isinstance(val, float) else val for val in trn1_events[trn1_tag]] - are_all_close = torch.allclose(torch.Tensor(trn), torch.Tensor(gpu), rtol=args.rtol, atol=args.atol, equal_nan=False) - if not are_all_close: - raise ValueError(f"Tolerance exceeded for values for tag '{tag}'") + are_close = torch.isclose(torch.Tensor(trn), torch.Tensor(gpu), rtol=args.rtol, atol=args.atol, equal_nan=False) + + if torch.mean(are_close.float()).item() < args.confidence_interval: + raise ValueError(f"Tolerance exceeded for values for tag '{tag}', GPU value={gpu}, trn1 value={trn}") else: raise ValueError(f"Tag '{tag}' not found in one of the event files") - if __name__ == "__main__": main() diff --git a/test/integration/combinatorial_tests/config.json b/test/integration/combinatorial_tests/config.json index c1d0a7f..3def9e7 100644 --- a/test/integration/combinatorial_tests/config.json +++ b/test/integration/combinatorial_tests/config.json @@ -26,4 +26,3 @@ "selective_checkpoint_enabled": false, "move_model_to_device":true } - diff --git a/test/integration/combinatorial_tests/configs/test_TP32_SP0_SC0_PP1_Zero1Opt0_GCP0_FP32.txt b/test/integration/combinatorial_tests/configs/test_TP32_SP0_SC0_PP1_Zero1Opt0_FP32.txt similarity index 100% rename from test/integration/combinatorial_tests/configs/test_TP32_SP0_SC0_PP1_Zero1Opt0_GCP0_FP32.txt rename to test/integration/combinatorial_tests/configs/test_TP32_SP0_SC0_PP1_Zero1Opt0_FP32.txt diff --git a/test/integration/combinatorial_tests/configs/test_TP8_SP1_SC0_PP4_Zero1Opt1_FP32.txt b/test/integration/combinatorial_tests/configs/test_TP8_SP1_SC0_PP4_Zero1Opt1_FP32.txt index 464e45d..efc2629 100644 --- a/test/integration/combinatorial_tests/configs/test_TP8_SP1_SC0_PP4_Zero1Opt1_FP32.txt +++ b/test/integration/combinatorial_tests/configs/test_TP8_SP1_SC0_PP4_Zero1Opt1_FP32.txt @@ -7,4 +7,4 @@ PP_DEGREE=4 USE_ZERO_1=1 USE_MIX_PRECISION=1 QKV_LINEAR=1 -GPU_COMPATIBLE_PRECISION=0 +GPU_COMPATIBLE_PRECISION=1 diff --git a/test/integration/combinatorial_tests/configs/test_TP32_SP0_SC0_PP1_Zero1Opt0_GCP1_FP32.txt b/test/integration/combinatorial_tests/configs/test_TP8_SP1_SC0_PP4_Zero1Opt1_GCP0_FP32.txt similarity index 57% rename from test/integration/combinatorial_tests/configs/test_TP32_SP0_SC0_PP1_Zero1Opt0_GCP1_FP32.txt rename to test/integration/combinatorial_tests/configs/test_TP8_SP1_SC0_PP4_Zero1Opt1_GCP0_FP32.txt index 01e3a57..464e45d 100644 --- a/test/integration/combinatorial_tests/configs/test_TP32_SP0_SC0_PP1_Zero1Opt0_GCP1_FP32.txt +++ b/test/integration/combinatorial_tests/configs/test_TP8_SP1_SC0_PP4_Zero1Opt1_GCP0_FP32.txt @@ -1,11 +1,10 @@ GBS=4 SEQ_LEN=4096 -TP_DEGREE=32 -SEQUENCE_PARALLEL=0 +TP_DEGREE=8 +SEQUENCE_PARALLEL=1 SELECTIVE_CHECKPOINT=0 -PP_DEGREE=1 -USE_ZERO_1=0 +PP_DEGREE=4 +USE_ZERO_1=1 USE_MIX_PRECISION=1 -KV_REPLICATOR=4 QKV_LINEAR=1 GPU_COMPATIBLE_PRECISION=0 diff --git a/test/integration/combinatorial_tests/configs/test_TP8_SP1_SC0_PP4_Zero1Opt1_MetaDeviceInit0_FP32.txt b/test/integration/combinatorial_tests/configs/test_TP8_SP1_SC0_PP4_Zero1Opt1_MetaDeviceInit0_FP32.txt new file mode 100644 index 0000000..b52d479 --- /dev/null +++ b/test/integration/combinatorial_tests/configs/test_TP8_SP1_SC0_PP4_Zero1Opt1_MetaDeviceInit0_FP32.txt @@ -0,0 +1,11 @@ +GBS=4 +SEQ_LEN=4096 +TP_DEGREE=8 +SEQUENCE_PARALLEL=1 +SELECTIVE_CHECKPOINT=0 +PP_DEGREE=4 +USE_ZERO_1=1 +USE_MIX_PRECISION=1 +QKV_LINEAR=1 +GPU_COMPATIBLE_PRECISION=1 +USE_META_DEVICE_INIT=0 diff --git a/test/integration/combinatorial_tests/run.sh b/test/integration/combinatorial_tests/run.sh index fd6f60a..e84a3b6 100644 --- a/test/integration/combinatorial_tests/run.sh +++ b/test/integration/combinatorial_tests/run.sh @@ -17,8 +17,8 @@ echo "MBS: $MBS" echo "SEQUENCE_PARALLEL: $SEQUENCE_PARALLEL" echo "SEQ_LEN: $SEQ_LEN" echo "PIPELINE_PARALLEL: $PIPELINE_PARALLEL" # TODO - enable pp when running model -echo "USE_ZERO_1= $USE_ZERO_1" # 0: use pure DP; 1: use ZeRO-1 -echo "USE_MIX_PRECISION=$USE_MIX_PRECISION" # 0: bf16; 1: mixed precision +echo "USE_ZERO_1: $USE_ZERO_1" # 0: use pure DP; 1: use ZeRO-1 +echo "USE_MIX_PRECISION: $USE_MIX_PRECISION" # 0: bf16; 1: mixed precision echo "SELECTIVE_CHECKPOINT: $SELECTIVE_CHECKPOINT" ############################################# diff --git a/test/integration/common/integration_test_utils.py b/test/integration/common/integration_test_utils.py index d13c228..2046fe7 100644 --- a/test/integration/common/integration_test_utils.py +++ b/test/integration/common/integration_test_utils.py @@ -138,7 +138,7 @@ def flatten(input_tup): grads_on_device.append(child.weight.grad.data) if hasattr(child, "bias") and child.bias is not None: grads_on_device.append(child.bias.grad.data) - + torch.distributed.barrier() xm.mark_step() @@ -148,7 +148,7 @@ def flatten(input_tup): del(device) return (output_on_cpu, grads_on_cpu) - + # assert_close_on_output_tensor allows the user to choose if the output tensors should be compared with # torch.testing.assert_close (compares both relative diff and absolute diff, preferred) or using a simple absolute @@ -193,12 +193,12 @@ def test_modules(test_module: torch.nn.Module, control_module: torch.nn.Module, xm.rendezvous("start_test_modules") test_output, test_grads = exercise_single_module_fwd_bwd(test_module, input_tensors, mark_step_between_fwd_bwd) del(test_module) - xm.master_print(f"done exercising test module") + xm.master_print("done exercising test module") xm.rendezvous("start_testing_control_module") control_output, control_grads = exercise_single_module_fwd_bwd(control_module, input_tensors, mark_step_between_fwd_bwd) del(control_module) - xm.master_print(f"done exercising control module") + xm.master_print("done exercising control module") xm.rendezvous("check_outputs") @@ -210,7 +210,7 @@ def test_modules(test_module: torch.nn.Module, control_module: torch.nn.Module, # TODO: this is only toggle-able because some testcases fail assert_close on their output # e.g. V1305356298 if assert_close_on_output_tensor: - torch.testing.assert_close(test_output, control_output, atol=1e-5, rtol=0.01), f"Control and test outputs from fwd pass did not match!" + torch.testing.assert_close(test_output, control_output, atol=1e-5, rtol=0.01), "Control and test outputs from fwd pass did not match!" else: error = test_output.sub(control_output).abs().max() limit = 1e-3 @@ -230,6 +230,6 @@ def test_modules(test_module: torch.nn.Module, control_module: torch.nn.Module, except Exception as e: xm.master_print(e) grads_pass = False - + # If we got here compilation passed - return (True, output_pass, grads_pass) \ No newline at end of file + return (True, output_pass, grads_pass) diff --git a/test/integration/gpt_neox_20B/tp_zero1_gpt_neox_20b_hf_pretrain.py b/test/integration/gpt_neox_20B/tp_zero1_gpt_neox_20b_hf_pretrain.py index 0345b38..f898919 100644 --- a/test/integration/gpt_neox_20B/tp_zero1_gpt_neox_20b_hf_pretrain.py +++ b/test/integration/gpt_neox_20B/tp_zero1_gpt_neox_20b_hf_pretrain.py @@ -181,7 +181,7 @@ def get_instance_type(self): headers={"X-aws-ec2-metadata-token": token.text}, ) return data.text - except: + except Exception: return os.environ.get("HOSTNAME", "unknown") def log(self, epoch, step, step_loss, learning_rate, throughput, grad_norm=None): @@ -189,7 +189,7 @@ def log(self, epoch, step, step_loss, learning_rate, throughput, grad_norm=None) grad_norm_msg = f"grad-norm : {grad_norm}" if grad_norm else "" print( f"LOG {time_now} - ({epoch}, {step}) step_loss : {step_loss:.4f} " - f"learning_rate : {learning_rate:.2e} throughput : {throughput:.2f} " + f"learning_rate : {learning_rate:.2e} throughput : {throughput:.2f} seq/s " f"{grad_norm_msg}", flush=True, ) diff --git a/test/integration/inference/test_model_builder.py b/test/integration/inference/test_model_builder.py new file mode 100644 index 0000000..d5f865a --- /dev/null +++ b/test/integration/inference/test_model_builder.py @@ -0,0 +1,510 @@ +import os +import shutil +import torch +import multiprocessing +from functools import partial + +from neuronx_distributed.trace.model_builder import ModelBuilder, BaseModelInstance +from neuronx_distributed.parallel_layers.layers import ColumnParallelLinear, RowParallelLinear +from torch_neuronx import BucketModelConfig + +from typing import List + +ckpt_path = "/tmp/test_model_builder_ckpt.pt" + +class CPLOnlyModel(torch.nn.Module): + def __init__(self, + hidden_dim, + is_distributed): + super().__init__() + if is_distributed: + self.lay1 = ColumnParallelLinear(input_size=hidden_dim, output_size=hidden_dim, bias=False, gather_output=True, dtype=torch.float32) + self.lay2 = ColumnParallelLinear(input_size=hidden_dim, output_size=hidden_dim, bias=False, gather_output=True, dtype=torch.float32) + else: + self.lay1 = torch.nn.Linear(hidden_dim,hidden_dim, bias=False, dtype=torch.float32) + self.lay2 = torch.nn.Linear(hidden_dim,hidden_dim, bias=False, dtype=torch.float32) + + def forward(self, x): + rx = self.lay1(x) + ry = self.lay2(rx) + return ry + + +class CPLRPLModel(torch.nn.Module): + def __init__(self, + hidden_dim, + is_distributed): + super().__init__() + if is_distributed: + self.lay1 = ColumnParallelLinear(input_size=hidden_dim, output_size=hidden_dim, bias=False, gather_output=False, dtype=torch.float32) + self.lay2 = RowParallelLinear(input_size=hidden_dim, output_size=hidden_dim, bias=False, input_is_parallel=True, dtype=torch.float32) + else: + self.lay1 = torch.nn.Linear(hidden_dim,hidden_dim, bias=False, dtype=torch.float32) + self.lay2 = torch.nn.Linear(hidden_dim,hidden_dim, bias=False, dtype=torch.float32) + + def forward(self, x): + rx = self.lay1(x) + ry = self.lay2(rx) + return ry + + +class StatefulModel(torch.nn.Module): + def __init__(self, + batch_size, + hidden_dim, + is_distributed): + super().__init__() + if is_distributed: + self.lay1 = ColumnParallelLinear(input_size=hidden_dim, output_size=hidden_dim, bias=False, gather_output=False, dtype=torch.float32) + self.lay2 = RowParallelLinear(input_size=hidden_dim, output_size=hidden_dim, bias=False, input_is_parallel=True, dtype=torch.float32) + else: + self.lay1 = torch.nn.Linear(hidden_dim,hidden_dim, bias=False, dtype=torch.float32) + self.lay2 = torch.nn.Linear(hidden_dim,hidden_dim, bias=False, dtype=torch.float32) + + self.state = torch.nn.Parameter(torch.zeros(batch_size, hidden_dim), requires_grad=False) + + def forward(self, x): + rx = self.lay1(x) + ry = self.lay2(rx) + return ry + self.state, ry + + +def checkpoint_loader_fn(ckpt_path): + model_sd = torch.load(ckpt_path) + return model_sd + + +def batch_bucket_model_kernel(inputs: List[torch.Tensor]): + inp = inputs[0] + batch_size = inp.shape[0] + if (batch_size == 1 or batch_size == 4): + return inputs, torch.tensor(0) + else: + return inputs, torch.tensor(1) + + +def get_bucket_kernel(): + return torch.jit.script(batch_bucket_model_kernel) + + +def generate_simple_CPL_only_model(batch_size, hidden_dim): + model = CPLOnlyModel(hidden_dim=hidden_dim, is_distributed=False) + torch.save(model.state_dict(), ckpt_path) + + builder = ModelBuilder(router=None, + tp_degree=2, + checkpoint_loader=partial(torch.load, ckpt_path), + compiler_workdir="new_compiler_workdir/") + x = torch.randn((batch_size, hidden_dim)) + builder.add(key = "main", + model_instance = BaseModelInstance(partial(CPLOnlyModel, hidden_dim=hidden_dim, is_distributed=True), input_output_aliases={}), + example_inputs=[(x,)], + compiler_args="--auto-cast=none") + traced_model = builder.trace(initialize_model_weights=True) + + return model,traced_model + + +def test_saving_loading_model(): + _,traced_model = generate_simple_CPL_only_model(2,4) + torch.jit.save(traced_model, "test.pt") + torch.jit.load("test.pt") + os.remove("test.pt") + del traced_model + torch.classes.neuron.Runtime().unsafe_close() + + +def test_CPL_only_model(): + hidden_dim=4 + batch_size=2 + model,traced_model = generate_simple_CPL_only_model(batch_size,hidden_dim) + # Test multiple invocations + for _ in range(5): + x = torch.randn((batch_size, hidden_dim)) + cpu_result = model(x) + nxd_result = traced_model(x) + torch.testing.assert_close(cpu_result, nxd_result) + + +def test_executing_loaded_model(): + hidden_dim=4 + batch_size=2 + + model,traced_model = generate_simple_CPL_only_model(batch_size,hidden_dim) + + torch.jit.save(traced_model, "test.pt") + del traced_model + + loaded_traced_model = torch.jit.load("test.pt") + loaded_traced_model.nxd_model.initialize_with_saved_weights() + + # Test multiple invocations + for _ in range(5): + x = torch.randn((batch_size, hidden_dim)) + cpu_result = model(x) + nxd_result = loaded_traced_model(x) + torch.testing.assert_close(cpu_result, nxd_result) + + os.remove("test.pt") + + +def test_CPL_RPL_model(): + + hidden_dim=4 + batch_size=2 + + model = CPLRPLModel(hidden_dim=hidden_dim, is_distributed=False) + torch.save(model.state_dict(), ckpt_path) + + builder = ModelBuilder(router=None, + tp_degree=2, + checkpoint_loader=partial(torch.load, ckpt_path), + compiler_workdir="new_compiler_workdir/") + x = torch.randn((batch_size, hidden_dim)) + builder.add(key = "main", + model_instance = BaseModelInstance(partial(CPLRPLModel, hidden_dim=hidden_dim, is_distributed=True), input_output_aliases={}), + example_inputs=[(x,)], + compiler_args="--auto-cast=none") + traced_model = builder.trace(initialize_model_weights=True) + + # Test multiple invocations + for _ in range(5): + x = torch.randn((batch_size, hidden_dim)) + cpu_result = model(x) + nxd_result = traced_model(x) + torch.testing.assert_close(cpu_result, nxd_result) + + +def test_multiple_input_shapes(): + + hidden_dim=4 + batch_size=2 + + model = CPLRPLModel(hidden_dim=hidden_dim, is_distributed=False) + torch.save(model.state_dict(), ckpt_path) + + builder = ModelBuilder(router=None, + tp_degree=2, + checkpoint_loader=partial(torch.load, ckpt_path), + compiler_workdir="new_compiler_workdir/") + x = torch.randn((batch_size, hidden_dim)) + builder.add(key = "ctx", + model_instance = BaseModelInstance(partial(CPLRPLModel, hidden_dim=hidden_dim, is_distributed=True), input_output_aliases={}), + example_inputs=[(x,)], + compiler_args="--auto-cast=none") + y = torch.randn((batch_size+1, hidden_dim)) + builder.add(key = "tkg", + model_instance = BaseModelInstance(partial(CPLRPLModel, hidden_dim=hidden_dim, is_distributed=True), input_output_aliases={}), + example_inputs=[(y,)], + compiler_args="--auto-cast=none") + traced_model = builder.trace(initialize_model_weights=True) + + # Test multiple invocations + for _ in range(5): + x = torch.randn((batch_size, hidden_dim)) + cpu_result = model(x) + nxd_result = traced_model(x) + torch.testing.assert_close(cpu_result, nxd_result) + + x = torch.randn((batch_size+1, hidden_dim)) + cpu_result = model(x) + nxd_result = traced_model(x) + torch.testing.assert_close(cpu_result, nxd_result) + + +def test_weight_layout_optimization(): + # Currently compiler outputs hlo stub in the current working dir, and it + # needs to be clean before the execution, so moving it to a new dir + original_dir = os.getcwd() + new_dir = os.path.join(original_dir, "wlt") + print(f"current dir is {original_dir}, will move working dir to {new_dir}") + if os.path.exists(new_dir): + shutil.rmtree(new_dir) + os.makedirs(new_dir) + os.chdir(new_dir) + assert os.getcwd() == new_dir + + hidden_dim = 4 + batch_size = 2 + + model = CPLRPLModel(hidden_dim=hidden_dim, is_distributed=False) + torch.save(model.state_dict(), ckpt_path) + + builder = ModelBuilder( + router=None, + tp_degree=2, + checkpoint_loader=partial(torch.load, ckpt_path), + compiler_workdir="new_compiler_workdir/", + ) + x = torch.randn((batch_size, hidden_dim)) + builder.add( + key="ctx", + model_instance=BaseModelInstance( + partial(CPLRPLModel, hidden_dim=hidden_dim, is_distributed=True), input_output_aliases={} + ), + example_inputs=[(x,)], + ) + y = torch.randn((batch_size + 1, hidden_dim)) + builder.add( + key="tkg", + model_instance=BaseModelInstance( + partial(CPLRPLModel, hidden_dim=hidden_dim, is_distributed=True), input_output_aliases={} + ), + example_inputs=[(y,)], + priority_model_idx=0, + ) + traced_model = builder.trace(initialize_model_weights=True) + + # Test multiple invocations + for _ in range(5): + x = torch.randn((batch_size, hidden_dim)) + cpu_result = model(x) + nxd_result = traced_model(x) + torch.testing.assert_close(cpu_result, nxd_result) + + x = torch.randn((batch_size + 1, hidden_dim)) + cpu_result = model(x) + nxd_result = traced_model(x) + torch.testing.assert_close(cpu_result, nxd_result) + + os.chdir(original_dir) + + +def test_weight_layout_optimization_with_serialization(): + # Currently compiler outputs hlo stub in the current working dir, and it + # needs to be clean before the execution, so moving it to a new dir + original_dir = os.getcwd() + new_dir = os.path.join(original_dir, "wlt") + print(f"current dir is {original_dir}, will move working dir to {new_dir}") + if os.path.exists(new_dir): + shutil.rmtree(new_dir) + os.makedirs(new_dir) + os.chdir(new_dir) + assert os.getcwd() == new_dir + + hidden_dim = 4 + batch_size = 2 + tp_degree = 2 + + model = CPLRPLModel(hidden_dim=hidden_dim, is_distributed=False) + torch.save(model.state_dict(), ckpt_path) + + builder = ModelBuilder( + router=None, + tp_degree=tp_degree, + checkpoint_loader=partial(torch.load, ckpt_path), + compiler_workdir="new_compiler_workdir/", + ) + x = torch.randn((batch_size, hidden_dim)) + builder.add( + key="ctx", + model_instance=BaseModelInstance( + partial(CPLRPLModel, hidden_dim=hidden_dim, is_distributed=True), input_output_aliases={} + ), + example_inputs=[(x,)], + ) + y = torch.randn((batch_size + 1, hidden_dim)) + builder.add( + key="tkg", + model_instance=BaseModelInstance( + partial(CPLRPLModel, hidden_dim=hidden_dim, is_distributed=True), input_output_aliases={} + ), + example_inputs=[(y,)], + priority_model_idx=0, + ) + traced_model = builder.trace(initialize_model_weights=False) + + # Save the traced model + torch.jit.save(traced_model, "traced_model.pt") + del traced_model + + # Shard weights from checkpoint + shard_weights_path = "weights/" + builder.shard_checkpoint(serialize_path=shard_weights_path) + weights = [] + for rank in range(tp_degree): + ckpt = torch.load(os.path.join(shard_weights_path, f"tp{rank}_sharded_checkpoint.pt")) + weights.append(ckpt) + + # Load the traced model + traced_model = torch.jit.load("traced_model.pt") + print("Done loading serialized model") + + # Load new weights + traced_model.nxd_model.initialize(weights) + + # Test multiple invocations + for _ in range(5): + x = torch.randn((batch_size, hidden_dim)) + cpu_result = model(x) + nxd_result = traced_model(x) + torch.testing.assert_close(cpu_result, nxd_result) + + x = torch.randn((batch_size + 1, hidden_dim)) + cpu_result = model(x) + nxd_result = traced_model(x) + torch.testing.assert_close(cpu_result, nxd_result) + + os.chdir(original_dir) + + +class StatefulModelInstance(BaseModelInstance): + def __init__(self): + self.module = None + self.input_output_aliases = None + + def load_module(self): + self.module = StatefulModel(batch_size=2, hidden_dim=4, is_distributed=True) + self.input_output_aliases = {self.module.state: 1} + + def get(self, bucket_rank, **kwargs): + return self.module, self.input_output_aliases + + +def test_stateful_model(): + + hidden_dim=4 + batch_size=2 + + model = StatefulModel(batch_size=batch_size, hidden_dim=hidden_dim, is_distributed=False) + sd = model.state_dict() + torch.save(sd, ckpt_path) + + builder = ModelBuilder(router=None, + tp_degree=2, + checkpoint_loader=partial(torch.load, ckpt_path), + compiler_workdir="new_compiler_workdir/") + x = torch.randn((batch_size, hidden_dim)) + builder.add(key = "main", + model_instance = StatefulModelInstance(), + example_inputs=[(x,)], + compiler_args="--auto-cast=none") + traced_model = builder.trace(initialize_model_weights=True) + + model = StatefulModel(batch_size=batch_size, hidden_dim=hidden_dim, is_distributed=False) + model.load_state_dict(sd) + # Test multiple invocations + for _ in range(5): + x = torch.randn((batch_size, hidden_dim)) + cpu_result, new_state = model(x) + model.state.data = new_state + nxd_result = traced_model(x) + torch.testing.assert_close(cpu_result, nxd_result) + + +def test_batch_bucketed_model(): + hidden_dim=4 + batch_sizes_ctx=[1,2] + batch_sizes_tkg = [4,8] + + model = CPLRPLModel(hidden_dim=hidden_dim, is_distributed=False) + torch.save(model.state_dict(), ckpt_path) + + builder = ModelBuilder(router=None, + tp_degree=2, + checkpoint_loader=partial(torch.load, ckpt_path), + compiler_workdir="new_compiler_workdir/") + inps = [(torch.randn((batch_size, hidden_dim)),) for batch_size in batch_sizes_ctx] + bucket_config = BucketModelConfig( + get_bucket_kernel + ) + + builder.add(key = "ctx", + model_instance = BaseModelInstance(partial(CPLRPLModel, hidden_dim=hidden_dim, is_distributed=True), input_output_aliases={}), + example_inputs=inps, + bucket_config=bucket_config, + compiler_args="--auto-cast=none") + inps = [(torch.randn((batch_size, hidden_dim)),) for batch_size in batch_sizes_tkg] + builder.add(key = "tkg", + model_instance = BaseModelInstance(partial(CPLRPLModel, hidden_dim=hidden_dim, is_distributed=True), input_output_aliases={}), + example_inputs=inps, + bucket_config=bucket_config, + compiler_args="--auto-cast=none") + traced_model = builder.trace(initialize_model_weights=True) + + # Test multiple invocations + for _ in range(5): + for batch_size in batch_sizes_ctx: + x = torch.randn((batch_size, hidden_dim)) + cpu_result = model(x) + nxd_result = traced_model(x) + torch.testing.assert_close(cpu_result, nxd_result) + for batch_size in batch_sizes_tkg: + x = torch.randn((batch_size, hidden_dim)) + cpu_result = model(x) + nxd_result = traced_model(x) + torch.testing.assert_close(cpu_result, nxd_result) + + +def test_loading_checkpoint(): + + hidden_dim=4 + batch_size=2 + tp_degree=2 + + model = CPLRPLModel(hidden_dim=hidden_dim, is_distributed=False) + torch.save(model.state_dict(), ckpt_path) + + builder = ModelBuilder(router=None, + tp_degree=tp_degree, + checkpoint_loader=partial(torch.load, ckpt_path), + compiler_workdir="new_compiler_workdir/") + x = torch.randn((batch_size, hidden_dim)) + builder.add(key = "main", + model_instance = BaseModelInstance(partial(CPLRPLModel, hidden_dim=hidden_dim, is_distributed=True), input_output_aliases={}), + example_inputs=[(x,)], + compiler_args="--auto-cast=none") + + traced_model = builder.trace(initialize_model_weights=False) # stops weight sharding + + # Save the traced model + torch.jit.save(traced_model, "traced_model.pt") + del traced_model + + # Shard weights from checkpoint + shard_weights_path = "weights/" + builder.shard_checkpoint(serialize_path=shard_weights_path) + weights = [] + for rank in range(tp_degree): + ckpt = torch.load(os.path.join(shard_weights_path, f"tp{rank}_sharded_checkpoint.pt")) + weights.append(ckpt) + + # Load the traced model + traced_model = torch.jit.load("traced_model.pt") + + # Load new weights + traced_model.nxd_model.initialize(weights) + + # Test multiple invocations + for _ in range(5): + x = torch.randn((batch_size, hidden_dim)) + cpu_result = model(x) + nxd_result = traced_model(x) + torch.testing.assert_close(cpu_result, nxd_result) + + +if __name__ == "__main__": + test_list = [ + test_saving_loading_model, + test_CPL_only_model, + test_executing_loaded_model, + test_CPL_RPL_model, + test_multiple_input_shapes, + test_weight_layout_optimization, + test_weight_layout_optimization_with_serialization, + test_stateful_model, + test_batch_bucketed_model, + test_loading_checkpoint, + ] + # Run tests in a separate process so it can init and release runtime properly + for test in test_list: + print(f"Starting test: {test.__name__}") + p = multiprocessing.Process(target=test) + p.start() + p.join() + if p.exitcode == 0: + print(f"Test succeeded: {test.__name__}\n") + else: + raise Exception(f"Test failed: {test.__name__}\n") + print(f"All {len(test_list)} tests on ModelBuilder succeeded!") + diff --git a/test/integration/llama2_70B_4layers_PP/module_llama.py b/test/integration/llama2_70B_4layers_PP/module_llama.py index eff6c9e..651ab9c 100644 --- a/test/integration/llama2_70B_4layers_PP/module_llama.py +++ b/test/integration/llama2_70B_4layers_PP/module_llama.py @@ -30,7 +30,7 @@ def on_train_batch_end(self, *args, **kwargs): and self.trainer.strategy.pipeline_parallel_rank == self.trainer.strategy.pipeline_parallel_size - 1 ): print( - f"step {self.global_step} loss is {self.loss.detach().cpu().item()}, lr is {self.lr}, input_ids {torch.sum(self.input_ids.detach().cpu()).item()}" + f"step {self.global_step} loss is {self.loss.detach().cpu().item()}, lr is {self.lr}, throughput {self.tps} seq/s, input_ids {torch.sum(self.input_ids.detach().cpu()).item()}" ) step_now = self.global_step - 1 diff --git a/test/integration/llama2_7B/logger.py b/test/integration/llama2_7B/logger.py index e922805..6ae859f 100644 --- a/test/integration/llama2_7B/logger.py +++ b/test/integration/llama2_7B/logger.py @@ -66,7 +66,7 @@ def get_instance_type(self): headers={"X-aws-ec2-metadata-token": token.text}, ) return data.text - except: + except Exception: return os.environ.get("HOSTNAME", "unknown") def log(self, epoch, step, step_loss, learning_rate, throughput, grad_norm=None): @@ -74,7 +74,7 @@ def log(self, epoch, step, step_loss, learning_rate, throughput, grad_norm=None) grad_norm_msg = f"grad-norm : {grad_norm}" if grad_norm else "" print( f"LOG {time_now} - ({epoch}, {step}) step_loss : {step_loss:.4f} " - f"learning_rate : {learning_rate:.2e} throughput : {throughput:.2f} " + f"learning_rate : {learning_rate:.2e} throughput : {throughput:.2f} seq/s " f"{grad_norm_msg}", flush=True, ) diff --git a/test/integration/llama2_7B/module_llama.py b/test/integration/llama2_7B/module_llama.py index aedbf45..a73d5f5 100644 --- a/test/integration/llama2_7B/module_llama.py +++ b/test/integration/llama2_7B/module_llama.py @@ -3,16 +3,8 @@ import numpy as np import torch from module_llama_orig import NeuronLlamaLTModule as NeuronLlamaLTModuleOrigin -from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import ( - _FxValidator, -) from tensorboard.backend.event_processing.event_accumulator import EventAccumulator -from neuronx_distributed.trainer import ( - initialize_parallel_model, - initialize_parallel_optimizer, -) - def load_events(event_file): accumulator = EventAccumulator(event_file) @@ -53,7 +45,7 @@ def on_train_batch_end(self, *args, **kwargs): and self.trainer.strategy.pipeline_parallel_rank == self.trainer.strategy.pipeline_parallel_size - 1 ): print( - f"step {self.global_step} loss is {self.loss.detach().cpu().item()}, lr is {self.lr}, input_ids {torch.sum(self.input_ids.detach().cpu()).item()}" + f"step {self.global_step} loss is {self.loss.detach().cpu().item()}, lr is {self.lr}, throughput {self.tps} seq/s, input_ids {torch.sum(self.input_ids.detach().cpu()).item()}" ) step_now = self.global_step - 1 diff --git a/test/integration/llama2_7B/run_llama_7b_tp_ptl.sh b/test/integration/llama2_7B/run_llama_7b_tp_ptl.sh index 9663f29..38f8201 100644 --- a/test/integration/llama2_7B/run_llama_7b_tp_ptl.sh +++ b/test/integration/llama2_7B/run_llama_7b_tp_ptl.sh @@ -3,7 +3,7 @@ ############################################# # User defined parameters and env vars -if [ -z "$SEQ_LEN" ]; +if [ -z "$SEQ_LEN" ] || [ "$SEQ_LEN" -eq 4096 ]; then DATA_PATH="$HOME/wikicorpus_datasets/wikicorpus_llama_v2_tokenized_4k" SEQ_LEN=4096 @@ -46,7 +46,7 @@ fi SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -export NEURON_CC_FLAGS="--model-type transformer --distribution-strategy=llm-training --cache_dir=$HOME/neuron_compile_cache$SEQ_LEN/ --retry_failed_compilation" +export NEURON_CC_FLAGS="--model-type transformer --distribution-strategy=llm-training --cache_dir=$HOME/neuron_compile_cache$SEQ_LEN/ --retry_failed_compilation --enable-saturate-infinity" export NEURON_FUSE_SOFTMAX=1 # Async Runtime @@ -150,7 +150,7 @@ torchrun $DISTRIBUTED_ARGS \ --data_dir $DATA_PATH \ --tensor_parallel_size $TP_DEGREE \ --train_batch_size $MBS \ - --steps_this_run $STEPS_THIS_RUN\ + --steps_this_run $STEPS_THIS_RUN \ --max_steps $TOTAL_STEPS \ --warmup_steps $WARMUP_STEPS \ --lr $LR \ diff --git a/test/integration/llama2_7B/test_long_seqlen.py b/test/integration/llama2_7B/test_long_seqlen.py index 549bfdb..5b0c495 100644 --- a/test/integration/llama2_7B/test_long_seqlen.py +++ b/test/integration/llama2_7B/test_long_seqlen.py @@ -1,12 +1,8 @@ -import argparse -import atexit -import json import os import re import signal import subprocess import sys -import traceback SUCCEEDED = "succeeded" ERRORS = "errors" @@ -14,29 +10,11 @@ PERFORMANCE_DEGADATION = "performance degradation" -def parse_args(): - parser = argparse.ArgumentParser(add_help=False) - parser.add_argument( - "--test_json", - required=False, - help="input json listing the test spec for network to compile", - ) - parser.add_argument("--s3_dir", required=False, help="location to upload all test artifacts") - parser.add_argument("--s3_bucket", default="s3://ktf-test-runs/neuronx_distributed_parallel_layers/layers") - args, leftovers = parser.parse_known_args() - S3_BUCKET_NAME = args.s3_bucket - with open(args.test_json, "r") as f: - test_dict = json.load(f) - return test_dict, S3_BUCKET_NAME, args - - -test_config, S3_BUCKET_NAME, args = parse_args() -results = {"inference_success": 1} - - def run_job(seq_len=32768, mem_threshold=0, throughputs_threshold=99999): p1 = subprocess.run( - [f"{seq_len}", "neuron_parallel_compile", "./run_llama_7b_tp_ptl.sh", f"{seq_len}"], + [f"export SEQ_LEN={seq_len}; neuron_parallel_compile ./run_llama_7b_tp_ptl.sh"], + shell=True, + text=True, stderr=sys.stderr, stdout=sys.stdout, ) @@ -53,7 +31,9 @@ def run_job(seq_len=32768, mem_threshold=0, throughputs_threshold=99999): ) p2 = subprocess.run( - [f"{seq_len}", "./run_llama_7b_tp_ptl.sh", f"{seq_len}"], + [f"export SEQ_LEN={seq_len}; ./run_llama_7b_tp_ptl.sh"], + shell=True, + text=True, stderr=sys.stderr, stdout=sys.stdout, ) @@ -99,21 +79,14 @@ def extract_throughput(seq_len=32768): return throughputs -def on_exit(): - for k in test_config: - os.system(f"rm {args.test_json}") - with open(args.test_json, "w") as f: - json.dump({k: results}, f) - - if __name__ == "__main__": succeeded = [] failed = [] for seq_len, mem_thershold, perf_threshold in [ # Threshold with 5%-8% tolarance - [8192, 88590512128, 9.15], - [16384, 109604828160, 3.10], - [32768, 124354230272, 1.20], + [8192, 88590512128, 6.60], + [16384, 109604828160, 2.60], + [32768, 124354230272, 1.00], ]: return_status = run_job(seq_len, mem_thershold, perf_threshold) if return_status == SUCCEEDED: @@ -121,11 +94,5 @@ def on_exit(): else: failed.append(seq_len) print(f"Succeeded: seq len {succeeded}, Failed: seq len {failed}") - try: - assert not failed, "Job failed" - except: - results["inference_success"] = 0 - print(traceback.format_exc()) - raise - print(f"Tests finished successfully!") - atexit.register(on_exit) + assert not failed, "Job failed" + print("Tests finished successfully!") diff --git a/test/integration/llama3_70B_4layers_PP/run_llama_test.sh b/test/integration/llama3_70B_4layers_PP/run_llama_test.sh index de347b4..a0eeeaa 100755 --- a/test/integration/llama3_70B_4layers_PP/run_llama_test.sh +++ b/test/integration/llama3_70B_4layers_PP/run_llama_test.sh @@ -134,4 +134,3 @@ else fi fi fi - diff --git a/test/integration/llama3_8B/logger.py b/test/integration/llama3_8B/logger.py new file mode 100644 index 0000000..e922805 --- /dev/null +++ b/test/integration/llama3_8B/logger.py @@ -0,0 +1,92 @@ +import inspect +import os +import sys +import time + +import numpy as np +import requests +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator +from torch.utils.tensorboard import SummaryWriter + + +def load_events(event_file): + accumulator = EventAccumulator(event_file) + accumulator.Reload() + tags = accumulator.Tags() + + data = {} + for tag in tags["scalars"]: + data[tag] = accumulator.Scalars(tag) + return data + + +class Logger: + def __init__(self, args, world_size, model_dtype): + xla = "torch_xla" in sys.modules + self.throughputs = [] + dtype_short = model_dtype.replace("torch.", "") + self.tb = SummaryWriter( + os.path.join( + args.output_dir, + f"neuron_tblogs_{time.strftime('%m%d%y_%H%M')}" + f"_{dtype_short}" + f"_w{world_size}" + f"_lr{args.lr}" + f"_bs{args.batch_size}" + f"_acc{args.grad_accum_usteps}" + f"_warmup{args.warmup_steps}" + f"_max{args.max_steps}" + f"_xla{xla}" + f"_{self.get_instance_type()}", + ) + ) + self.tb.add_text("script", "```\n" + inspect.getsource(sys.modules[__name__]) + "\n```", 0) + + self.golden_steploss = [] + event_file = os.getenv("GOLDEN_EVENT_FILE") if not os.environ.get("NEURON_EXTRACT_GRAPHS_ONLY", None) else None + if event_file is not None: + data = load_events(event_file) + for step in data["step loss"]: + self.golden_steploss.append(step.value) + else: + golden = "golden_steploss.txt" + if os.path.exists(golden): + with open(golden, "r") as f: + self.golden_steploss = [float(i) for i in f] + print(f"Read {len(self.golden_steploss)} golden step loss values from {golden}") + + def get_instance_type(self): + try: + token = requests.put( + "http://169.254.169.254/latest/api/token", + headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"}, + ) + data = requests.get( + "http://169.254.169.254/latest/meta-data/instance-type", + headers={"X-aws-ec2-metadata-token": token.text}, + ) + return data.text + except: + return os.environ.get("HOSTNAME", "unknown") + + def log(self, epoch, step, step_loss, learning_rate, throughput, grad_norm=None): + time_now = time.asctime() + grad_norm_msg = f"grad-norm : {grad_norm}" if grad_norm else "" + print( + f"LOG {time_now} - ({epoch}, {step}) step_loss : {step_loss:.4f} " + f"learning_rate : {learning_rate:.2e} throughput : {throughput:.2f} " + f"{grad_norm_msg}", + flush=True, + ) + self.tb.add_scalar("step loss", step_loss, step) + self.tb.add_scalar("learning rate", learning_rate, step) + self.tb.add_scalar("throughput", throughput, step) + if grad_norm: + self.tb.add_scalar("grad-norm", grad_norm, step) + self.throughputs.append(throughput) + + # Comparing Loss to Golden + if not os.environ.get("NEURON_EXTRACT_GRAPHS_ONLY", None): + step_0start = step - 1 + if step_0start < len(self.golden_steploss) and step_0start >= 0: + np.testing.assert_allclose(step_loss, self.golden_steploss[step_0start], rtol=1.5e-1) diff --git a/test/integration/llama3_8B/tp_zero1_llama3_8B_hf_pretrain.sh b/test/integration/llama3_8B/tp_zero1_llama3_8B_hf_pretrain.sh new file mode 100644 index 0000000..006d5e6 --- /dev/null +++ b/test/integration/llama3_8B/tp_zero1_llama3_8B_hf_pretrain.sh @@ -0,0 +1,158 @@ +#!/bin/bash + +############################################# +# User defined parameters and env vars + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +export NEURON_CC_FLAGS="--model-type transformer --distribution-strategy=llm-training" +export NEURON_FUSE_SOFTMAX=1 + +# Async Runtime +export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3 + +# HOST OOM +export MALLOC_ARENA_MAX=64 + +MODEL_SIZE="8B" +LLAMA_VERSION='3' + +# TP degree +TP_DEGREE=8 +# 0: bf16; 1: mixed precision +USE_MIX_PRECISION=1 +# 0: use pure DP; 1: use ZeRO-1 +USE_ZERO_1=1 +# global batch size +: ${GBS:=256} +# micro batch size +MBS=1 +# number of steps to run +TOTAL_STEPS=10000 +# warmup steps +WARMUP_STEPS=100 +# learning rate +LR=1.5e-4 +# model path +MODEL_PATH=$SCRIPT_DIR/${MODEL_SIZE}_config_llama${LLAMA_VERSION} +# data path +DATA_PATH="$HOME/examples_datasets/wikicorpus_llama3_tokenized_4k" +# sequence length +SEQ_LEN=4096 + +############################################# + +export NUM_NEURONCORES=32 +NODE_ID=0 +WORLD_SIZE=1 +DISTRIBUTED_ARGS="--nproc_per_node $NUM_NEURONCORES" +if [ ! -z "$SLURM_NTASKS" ]; then + WORLD_SIZE=$SLURM_NTASKS + NODE_ID=$SLURM_NODEID + MASTER_ADDRESS=(`scontrol show hostnames $SLURM_JOB_NODELIST`) + DISTRIBUTED_ARGS="--nproc_per_node $NUM_NEURONCORES --nnodes $WORLD_SIZE --node_rank $NODE_ID --master_addr $MASTER_ADDRESS --master_port 44000" + if [ $NODE_ID -eq 0 ]; then + echo "WORLD_SIZE=$WORLD_SIZE" + echo "NODE_ID=$NODE_ID" + echo "MASTER_ADDRESS=$MASTER_ADDRESS" + echo "DISTRIBUTED_ARGS=$DISTRIBUTED_ARGS" + fi + export FI_EFA_USE_DEVICE_RDMA=1 + export FI_PROVIDER=efa +fi + +echo "WORLD_SIZE=$WORLD_SIZE" +echo "NODE_ID=$NODE_ID" +echo "MASTER_ADDRESS=$MASTER_ADDRESS" + +sudo sysctl -w net.ipv4.ip_local_reserved_ports=44000,48620 + +export NEURON_RT_NUM_CORES=32 +export NUM_NEURONCORES=$NEURON_RT_NUM_CORES +export TPU_NUM_DEVICES=$NEURON_RT_NUM_CORES +export TPU_CHIPS_PER_HOST_BOUNDS=$NEURON_RT_NUM_CORES +export NEURON_RT_ROOT_COMM_ID=localhost:48620 + +############################################# + +EXTRA_ARGS=" " +if [ $USE_MIX_PRECISION -gt 0 ]; then + EXTRA_ARGS+=" --use_mix_precision" +fi +if [ $USE_ZERO_1 -gt 0 ]; then + EXTRA_ARGS+=" --use_zero_1" +fi + +DP=$(($NEURON_RT_NUM_CORES * $WORLD_SIZE / $TP_DEGREE)) +ACC_STEPS=$(($GBS / $MBS / $DP)) + + +if [ $NEURON_EXTRACT_GRAPHS_ONLY -gt 0 ]; then + STEPS_THIS_RUN=2 + OUTPUT_LOG=log_compile-$NODE_ID.log +else + STEPS_THIS_RUN=150 + OUTPUT_LOG=log_exe-$NODE_ID.log +fi + +echo TP_DEGREE=$TP_DEGREE +echo USE_MIX_PRECISION=$USE_MIX_PRECISION +echo USE_ZERO_1=$USE_ZERO_1 +echo GBS=$GBS +echo MBS=$MBS +echo TOTAL_STEPS=$TOTAL_STEPS +echo WARMUP_STEPS=$WARMUP_STEPS +echo LR=$LR +echo MODEL_PATH=$MODEL_PATH +echo DATA_PATH=$DATA_PATH +echo SEQ_LEN=$SEQ_LEN + +echo EXTRA_ARGS=$EXTRA_ARGS +echo DP=$DP +echo ACC_STEPS=$ACC_STEPS +echo STEPS_THIS_RUN=$STEPS_THIS_RUN +echo OUTPUT_LOG=$OUTPUT_LOG + +torchrun $DISTRIBUTED_ARGS \ + tp_zero1_llama_hf_pretrain.py \ + --model_path $MODEL_PATH \ + --data_dir $DATA_PATH \ + --tensor_parallel_size $TP_DEGREE \ + --batch_size $MBS \ + --steps_this_run $STEPS_THIS_RUN\ + --max_steps $TOTAL_STEPS \ + --warmup_steps $WARMUP_STEPS \ + --lr $LR \ + --grad_accum_usteps $ACC_STEPS \ + --seq_len $SEQ_LEN \ + --sequence_parallel_enabled \ + --selective_checkpoint_enabled \ + --logging_interval 10 \ + --qkv_linear \ + $EXTRA_ARGS |& tee $OUTPUT_LOG +exit ${PIPESTATUS[0]} + +ret_val=${PIPESTATUS[0]} +echo ret_val=$ret_val + +if [ -v PERF_TEST ]; +then + echo "Performance test complete" +else + if [ $ret_val -eq 0 ]; then + success=1 + else + success=0 + fi + + if [ -z "$NEURON_EXTRACT_GRAPHS_ONLY" ]; then + echo "success=$success" + echo "update json with $HOME/ktest/dump_to_s3_update_test_json.sh" + dump_to_s3_update_json_scr=$HOME/ktest/dump_to_s3_update_test_json.sh + if [ -e $dump_to_s3_update_json_scr ]; then + $dump_to_s3_update_json_scr $@ --key=inference_success --value=$success || echo "Unable to update test result JSON." + else + echo "WARNING: Script $dump_to_s3_update_json_scr not found. Not updating test result JSON." + fi + fi +fi \ No newline at end of file diff --git a/test/integration/modules/lora/test_llama2_7b_lora_finetune.sh b/test/integration/modules/lora/test_llama2_7b_lora_finetune.sh new file mode 100644 index 0000000..09746e3 --- /dev/null +++ b/test/integration/modules/lora/test_llama2_7b_lora_finetune.sh @@ -0,0 +1,164 @@ +#!/bin/bash + +############################################# +# Override transformers and Optimum-Neuron packages, can be removed once ON released changes in https://github.com/huggingface/optimum-neuron/pull/370 +pip install git+https://github.com/huggingface/optimum-neuron.git +pip install --no-warn-conflicts transformers==4.32.1 nltk + +############################################# +# User defined parameters and env vars + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +export NEURON_CC_FLAGS="--model-type=transformer --distribution-strategy=llm-training " +export NEURON_FUSE_SOFTMAX=1 + +# Async Runtime +export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3 + +# HOST OOM +export MALLOC_ARENA_MAX=64 + +# TP degree +TP_DEGREE=8 +# 0: bf16; 1: mixed precision +USE_MIX_PRECISION=1 +# 0: use pure DP; 1: use ZeRO-1 +USE_ZERO_1=0 +# global batch size +GBS=8 +# micro batch size +MBS=1 +# number of steps to run +TOTAL_STEPS=1000 +# number of epochs to run +TOTAL_EPOCHS=2 +# warmup steps +WARMUP_STEPS=5 +# learning rate +LR=5.0e-4 +# model path +MODEL_PATH=$SCRIPT_DIR/finetune_config +# pretrained weight path +PRETRAINED_PATH="$HOME/llama-2-7b-sharded" +# base model name +BASE_MODEL="NousResearch/Llama-2-7b-hf" +# sequence length +SEQ_LEN=4096 +# golden rouge score path +GOLDEN_ROUGE_SCORE_PATH="llama2_7b_rouge_score_goldens_lora.json" + +############################################# + +export NUM_NEURONCORES=${TP_DEGREE} +NODE_ID=0 +WORLD_SIZE=1 +DISTRIBUTED_ARGS="--nproc_per_node $NUM_NEURONCORES" +if [ ! -z "$SLURM_NTASKS" ]; then + WORLD_SIZE=$SLURM_NTASKS + NODE_ID=$SLURM_NODEID + MASTER_ADDRESS=(`scontrol show hostnames $SLURM_JOB_NODELIST`) + DISTRIBUTED_ARGS="--nproc_per_node $NUM_NEURONCORES --nnodes $WORLD_SIZE --node_rank $NODE_ID --master_addr $MASTER_ADDRESS --master_port 44000" + if [ $NODE_ID -eq 0 ]; then + echo "WORLD_SIZE=$WORLD_SIZE" + echo "NODE_ID=$NODE_ID" + echo "MASTER_ADDRESS=$MASTER_ADDRESS" + echo "DISTRIBUTED_ARGS=$DISTRIBUTED_ARGS" + fi + export FI_EFA_USE_DEVICE_RDMA=1 + export FI_PROVIDER=efa +fi + +echo "WORLD_SIZE=$WORLD_SIZE" +echo "NODE_ID=$NODE_ID" +echo "MASTER_ADDRESS=$MASTER_ADDRESS" + +sudo sysctl -w net.ipv4.ip_local_reserved_ports=44000,48620 + +############################################# + +EXTRA_ARGS=" " +if [ $USE_MIX_PRECISION -gt 0 ]; then + EXTRA_ARGS+=" --use_mix_precision" +fi +if [ $USE_ZERO_1 -gt 0 ]; then + EXTRA_ARGS+=" --use_zero_1" +fi + +ACC_STEPS=$(($GBS / $MBS)) + + +if [ $NEURON_EXTRACT_GRAPHS_ONLY -gt 0 ]; then + STEPS_THIS_RUN=10 + OUTPUT_LOG=log_compile-$NODE_ID.log +else + STEPS_THIS_RUN=1000 + OUTPUT_LOG=log_exe-$NODE_ID.log +fi + +echo TP_DEGREE=$TP_DEGREE +echo USE_MIX_PRECISION=$USE_MIX_PRECISION +echo USE_ZERO_1=$USE_ZERO_1 +echo GBS=$GBS +echo MBS=$MBS +echo TOTAL_STEPS=$TOTAL_STEPS +echo TOTAL_EPOCHS=$TOTAL_EPOCHS +echo WARMUP_STEPS=$WARMUP_STEPS +echo LR=$LR +echo MODEL_PATH=$MODEL_PATH +echo SEQ_LEN=$SEQ_LEN + +echo EXTRA_ARGS=$EXTRA_ARGS +echo DP=$DP +echo ACC_STEPS=$ACC_STEPS +echo STEPS_THIS_RUN=$STEPS_THIS_RUN +echo OUTPUT_LOG=$OUTPUT_LOG + +torchrun $DISTRIBUTED_ARGS \ + tp_llama_hf_finetune_ptl.py \ + --model_path $MODEL_PATH \ + --model_name $BASE_MODEL \ + --data_dir "databricks/databricks-dolly-15k" \ + --tensor_parallel_size $TP_DEGREE \ + --batch_size $MBS \ + --steps_this_run $STEPS_THIS_RUN \ + --max_steps $TOTAL_STEPS \ + --num_train_epochs $TOTAL_EPOCHS \ + --warmup_steps $WARMUP_STEPS \ + --lr $LR \ + --grad_accum_usteps $ACC_STEPS \ + --seq_len $SEQ_LEN \ + --selective_checkpoint_enabled \ + --separate_qkv \ + --golden_rouge_score_path $GOLDEN_ROUGE_SCORE_PATH \ + --pretrained_ckpt $PRETRAINED_PATH \ + --task "open_qa" \ + $EXTRA_ARGS \ + --enable_lora \ + --use_gpu_compatible_precision 0 \ + + +ret_val=${PIPESTATUS[0]} +echo ret_val=$ret_val + +if [ -v PERF_TEST ]; +then + echo "Performance test complete" +else + if [ $ret_val -eq 0 ]; then + success=1 + else + success=0 + fi + + if [ -z "$NEURON_EXTRACT_GRAPHS_ONLY" ]; then + echo "success=$success" + echo "update json with $HOME/ktest/dump_to_s3_update_test_json.sh" + dump_to_s3_update_json_scr=$HOME/ktest/dump_to_s3_update_test_json.sh + if [ -e $dump_to_s3_update_json_scr ]; then + $dump_to_s3_update_json_scr $@ --key=inference_success --value=$success || echo "Unable to update test result JSON." + else + echo "WARNING: Script $dump_to_s3_update_json_scr not found. Not updating test result JSON." + fi + fi +fi diff --git a/test/integration/modules/lora/test_llama_lora_finetune.sh b/test/integration/modules/lora/test_llama_lora_finetune.sh new file mode 100644 index 0000000..9705e31 --- /dev/null +++ b/test/integration/modules/lora/test_llama_lora_finetune.sh @@ -0,0 +1,154 @@ +#!/bin/bash + +############################################# +# Override transformers and Optimum-Neuron packages, can be removed once ON released changes in https://github.com/huggingface/optimum-neuron/pull/370 +pip install git+https://github.com/huggingface/optimum-neuron.git +pip install --no-warn-conflicts transformers==4.32.1 nltk + +############################################# +# User defined parameters and env vars + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +export NEURON_CC_FLAGS="--model-type=transformer --distribution-strategy=llm-training -O1 " +export NEURON_FUSE_SOFTMAX=1 + +# Async Runtime +export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3 + +# HOST OOM +export MALLOC_ARENA_MAX=64 + +# TP degree +TP_DEGREE=32 +# PP degree +PP_DEGREE=1 +# 0: bf16; 1: mixed precision +USE_MIX_PRECISION=0 +# 0: use pure DP; 1: use ZeRO-1 +USE_ZERO_1=0 +# global batch size +GBS=1 +# micro batch size +MBS=1 +# number of steps to run +TOTAL_STEPS=20 +# number of epochs to run +TOTAL_EPOCHS=1 +# warmup steps +WARMUP_STEPS=5 +# learning rate +LR=5.0e-4 +# model path +MODEL_PATH=$SCRIPT_DIR +# pretrained weight path +PRETRAINED_PATH=/dev/shm/llama3_model +# base model name +BASE_MODEL='meta-llama/Meta-Llama-3-8B' +# HF Token +HF_TOKEN='' +# sequence length +SEQ_LEN=4096 + +############################################# +PROCESSES_PER_NODE=32 +export NUM_NEURONCORES=${PROCESSES_PER_NODE} +NODE_ID=0 +WORLD_SIZE=$TP_DEGREE +DISTRIBUTED_ARGS="--nproc_per_node $NUM_NEURONCORES" +if [ ! -z "$SLURM_NTASKS" ]; then + WORLD_SIZE=$SLURM_NTASKS + NODE_ID=$SLURM_NODEID + MASTER_ADDRESS=(`scontrol show hostnames $SLURM_JOB_NODELIST`) + DISTRIBUTED_ARGS="--nproc_per_node $NUM_NEURONCORES --nnodes $WORLD_SIZE --node_rank $NODE_ID --master_addr $MASTER_ADDRESS --master_port 44000" + if [ $NODE_ID -eq 0 ]; then + echo "WORLD_SIZE=$WORLD_SIZE" + echo "NODE_ID=$NODE_ID" + echo "MASTER_ADDRESS=$MASTER_ADDRESS" + echo "DISTRIBUTED_ARGS=$DISTRIBUTED_ARGS" + fi + export FI_EFA_USE_DEVICE_RDMA=1 + export FI_PROVIDER=efa +fi + +echo "WORLD_SIZE=$WORLD_SIZE" +echo "NODE_ID=$NODE_ID" +echo "MASTER_ADDRESS=$MASTER_ADDRESS" + +sudo sysctl -w net.ipv4.ip_local_reserved_ports=44000,48620 + +export NEURON_RT_NUM_CORES=${PROCESSES_PER_NODE} +export NUM_NEURONCORES=$NEURON_RT_NUM_CORES +export TPU_NUM_DEVICES=$NEURON_RT_NUM_CORES +export TPU_CHIPS_PER_HOST_BOUNDS=$NEURON_RT_NUM_CORES +export NEURON_RT_ROOT_COMM_ID=localhost:48620 + +############################################# + +EXTRA_ARGS=" " +if [ $USE_MIX_PRECISION -gt 0 ]; then + EXTRA_ARGS+=" --use_mix_precision" +fi +if [ $USE_ZERO_1 -gt 0 ]; then + EXTRA_ARGS+=" --use_zero_1" +fi + +DP=$(($NEURON_RT_NUM_CORES / $TP_DEGREE / $PP_DEGREE)) +ACC_STEPS=$(($GBS / $MBS / $DP)) + +if [ $NEURON_EXTRACT_GRAPHS_ONLY -gt 0 ]; then + STEPS_THIS_RUN=10 + OUTPUT_LOG=log_compile-$NODE_ID.log +else + STEPS_THIS_RUN=$TOTAL_STEPS + OUTPUT_LOG=log_exe-$NODE_ID.log +fi + +echo TP_DEGREE=$TP_DEGREE +echo USE_MIX_PRECISION=$USE_MIX_PRECISION +echo USE_ZERO_1=$USE_ZERO_1 +echo GBS=$GBS +echo MBS=$MBS +echo TOTAL_STEPS=$TOTAL_STEPS +echo TOTAL_EPOCHS=$TOTAL_EPOCHS +echo WARMUP_STEPS=$WARMUP_STEPS +echo LR=$LR +echo MODEL_PATH=$MODEL_PATH +echo SEQ_LEN=$SEQ_LEN + +echo EXTRA_ARGS=$EXTRA_ARGS +echo DP=$DP +echo ACC_STEPS=$ACC_STEPS +echo STEPS_THIS_RUN=$STEPS_THIS_RUN +echo OUTPUT_LOG=$OUTPUT_LOG + +export XLA_USE_BF16=1 + +torchrun $DISTRIBUTED_ARGS \ + tp_llama_hf_finetune_ptl.py \ + --model_path $MODEL_PATH \ + --model_name $BASE_MODEL \ + --hf_token $HF_TOKEN \ + --data_dir "databricks/databricks-dolly-15k" \ + --tensor_parallel_size $TP_DEGREE \ + --batch_size $MBS \ + --steps_this_run $STEPS_THIS_RUN \ + --max_steps $TOTAL_STEPS \ + --num_train_epochs $TOTAL_EPOCHS \ + --warmup_steps $WARMUP_STEPS \ + --lr $LR \ + --selective_checkpoint_enabled \ + --grad_accum_usteps $ACC_STEPS \ + --seq_len $SEQ_LEN \ + --pretrained_ckpt $PRETRAINED_PATH \ + --sequence_parallel_enabled \ + --separate_qkv \ + --task "open_qa" \ + $EXTRA_ARGS \ + --use_gpu_compatible_precision 0 \ + --enable_lora \ + --qkv_linear 1 \ + --kv_replicator 4 \ + +ret_val=${PIPESTATUS[0]} +echo ret_val=$ret_val diff --git a/test/integration/modules/moe/checkpoint_test_runner.py b/test/integration/modules/moe/checkpoint_test_runner.py new file mode 100644 index 0000000..f8674a7 --- /dev/null +++ b/test/integration/modules/moe/checkpoint_test_runner.py @@ -0,0 +1,162 @@ +import dataclasses + +import copy +import json +import os +import torch +import torch.nn.functional as F +from torch.optim import Adam +import torch_xla.core.xla_model as xm # TRN enablement + +# Imports from MoE unit tests (for this import to succeed, test/unit_test/modules/moe must be added to PYTHONPATH) +import utils_testing as ut + +from neuronx_distributed.parallel_layers import mappings, parallel_state, random +from neuronx_distributed.trainer.checkpoint import save_checkpoint, load_checkpoint + +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from examples.training.mixtral.modeling_mixtral_moe_nxd import MixtralForCausalLM + +def override_state(stateful_object): + state_dict = stateful_object.state_dict() + + stack = [state_dict] + while len(stack) > 0: + item = stack.pop() + if isinstance(item, torch.Tensor): + item.data.normal_(std=0.02) + elif isinstance(item, dict): + stack.extend(item.values()) + elif isinstance(item, list): + stack.extend(item) + + stateful_object.load_state_dict(state_dict) + +def get_converter_args(tp_degree, ep_degree, mixtral_config, cur_dir): + class Arguments: + pass + args = Arguments() + args.input_dir = cur_dir + args.output_dir = cur_dir + args.config = os.path.join(cur_dir, "config.json") + args.model_key = "model" + args.tp_size = tp_degree + args.ep_size = ep_degree + args.pp_size = 1 + args.virtual_pp_size = 1 + args.n_layers = mixtral_config.num_hidden_layers + args.coalesce_qkv = True + args.kv_size_multiplier = 1 + args.load_xser = True + args.save_xser = True + + +def assert_same_tensors(obj1, obj2): + + #assert type(obj1) == type(obj2), f"Type mismatch {type(obj1)} vs {type(obj2)}" + + if isinstance(obj1, (list, tuple)): + for item1, item2 in zip(obj1, obj2): + assert_same_tensors(item1, item2) + elif isinstance(obj1, dict): + for k1, k2 in zip(obj1, obj2): + assert k1 == k2, f"Key mismatch {k1} vs {k2}" + assert_same_tensors(obj1[k1], obj2[k2]) + elif isinstance(obj1, torch.Tensor): + ut.check_tensors(obj1.cpu(), obj2.cpu(), atol=0.0, rtol=0.0) + + +def display_object(obj, indent=0): + s = "" + for _ in range(indent): + s += " " + if isinstance(obj, dict): + print(s + "{") + for k, v in obj.items(): + print(s + str(k)) + display_object(v, indent+2) + print(s + "}") + elif isinstance(obj, list): + print(s + "[") + for item in obj: + display_object(item, indent+2) + print(s + "]") + elif isinstance(obj, torch.Tensor): + print(s + str(list(obj.shape)) + " " + str(obj.device)) + else: + print(s + str(obj)) + + +def _create_optimizer_states(model, optimizer): + for p in model.parameters(): + p.grad = torch.zeros_like(p) + + if optimizer.nxd_config["optimizer_config"]["zero_one_enabled"]: + optimizer.optimizer._reduce_gradients() + optimizer.optimizer.ep_zero_optimizer.base_optimizer.step() + optimizer.optimizer.non_ep_zero_optimizer.base_optimizer.step() + else: + optimizer.step() + + +def run_checkpoint_test(cfg): + device = "xla" + tp_degree = getattr(cfg, "tp_degree", 1) + ep_degree = getattr(cfg, "ep_degree", 1) + cur_dir = os.path.dirname(os.path.abspath(__file__)) + + config_path = os.path.join(cur_dir, "config.json") + with open(config_path, "r") as f: + json_config = json.load(f) + + # 1. Initialize original model + mixtral_config = MixtralConfig(**json_config) + mixtral_config.pretraining_tp = tp_degree + mixtral_config.sequence_parallel_enabled = True + mixtral_config.move_model_to_device = True + mixtral_config.moe_frequency = 1 + mixtral_config.capacity_factor = 2.0 + ut.nxd_init(tp_degree=tp_degree, ep_degree=ep_degree, seed=0) + model = MixtralForCausalLM(mixtral_config).to(device) + + optimizer = ut.initialize_neuron_optimizer(model, zero1=cfg.zero1, optimizer="adam") + + _create_optimizer_states(model, optimizer) + + xm.mark_step() + torch.distributed.barrier() + + # 2. Keep a copy of the original state_dicts + original_model_state_dict = copy.deepcopy(model.state_dict()) + original_optimizer_state_dict = copy.deepcopy(optimizer.state_dict()) + + # 3. Save a partial checkpoint + save_checkpoint( + checkpoint_dir_str=cur_dir, + tag="model", + model=model, + optimizer=optimizer, + use_xser=True + ) + + # 4. Override states + override_state(model) + override_state(optimizer) + + # 5. Load the partial checkpoint + load_checkpoint( + path=cur_dir, + tag="model", + model=model, + optimizer=optimizer, + strict=False, + ) + + xm.mark_step() + + # 6. Verify correctness wrt original state_dict + assert_same_tensors(original_model_state_dict, model.state_dict()) + if torch.distributed.get_rank() == 0: + display_object(original_optimizer_state_dict) + display_object(optimizer.state_dict()) + assert_same_tensors(original_optimizer_state_dict, optimizer.state_dict()) diff --git a/test/integration/modules/moe/config.json b/test/integration/modules/moe/config.json new file mode 100644 index 0000000..30d8ecf --- /dev/null +++ b/test/integration/modules/moe/config.json @@ -0,0 +1,6 @@ +{ + "num_attention_heads": 32, + "num_key_value_heads": 32, + "hidden_size": 1536, + "num_hidden_layers": 4 +} diff --git a/test/integration/modules/moe/device_correctness_test_configs.py b/test/integration/modules/moe/device_correctness_test_configs.py index 106870f..0cc72c1 100644 --- a/test/integration/modules/moe/device_correctness_test_configs.py +++ b/test/integration/modules/moe/device_correctness_test_configs.py @@ -1,25 +1,65 @@ import dataclasses -import itertools +from typing import List +import torch # Imports from MoE unit tests (for this import to succeed, test/unit_test/modules/moe must be added to PYTHONPATH) -from utils_testing import ( - ExptCfgCorrectness, - filter_valid_expt_configs, - get_random_activations, -) +from utils_testing import ExptCfg, get_random_activations -from neuronx_distributed.modules.moe import MoESequenceParallelMode +GLU_MLP_ARGS = [True, False] -@dataclasses.dataclass -class ExptCfgDeviceCorrectness(ExptCfgCorrectness): - test_mode: str = "training" +TEST_MODEL_CONFIGS = { + "sbase-small": { + "hidden_size": 4096, + "intermediate_size": 10944, + "num_experts": 4, + "top_k": 1, + }, + "sbase-large": { + "hidden_size": 8192, + "intermediate_size": 20480, + "num_experts": 24, + "top_k": 1, + }, + "mixtral": { + "hidden_size": 4096, + "intermediate_size": 14336, + "num_experts": 8, + "top_k": 2, + }, + "dbrx": { + "hidden_size": 6144, + "intermediate_size": 10752, + "num_experts": 16, + "top_k": 4, + } +} + + +def get_model_config(model_name, scale_down_factor=1): + + assert model_name in TEST_MODEL_CONFIGS + config_dict = TEST_MODEL_CONFIGS[model_name].copy() + config_dict.update({ + "hidden_size": int(config_dict["hidden_size"] / scale_down_factor), + "intermediate_size": int(config_dict["intermediate_size"] / scale_down_factor), + }) + return config_dict + +def get_neuron_cc_flags(test_dtype): + cc_flags = [ + "--model-type=transformer", + "--enable-saturate-infinity", #clip matmul transpose input to [-MAX, MAX] to avoid nans (0*INF) + "--retry_failed_compilation", + ] + if test_dtype == torch.float32: + # Disable auto-casting + cc_flags.append("--auto-cast=none") + return " ".join(cc_flags) -def get_device_correctness_test_configs(dtype): - GLU_MLP_ARGS = [True, False] - PERMUTE_STRATEGY_ARGS = ["matmul", "index"] +def get_device_correctness_test_configs(dtype) -> List[ExptCfg]: test_configs = [] # S-BASE test cases @@ -28,65 +68,63 @@ def get_device_correctness_test_configs(dtype): test_cfg = { "dtype": dtype, "glu_mlp": glu_mlp, + "num_iters": 25, "implementation": "sbase", - "expert_mlps_permute_strategy": "index", } - # Test forward_full_capacity and token-gen + # Training tests + test_cfg["test_mode"] = "training" sbase_test_configs.extend( [ - # Training / Context-encoding - ExptCfgDeviceCorrectness( + # Test forward_all_experts (full capacity) + ExptCfg( seq_len=256, batch_size=1, - hidden_size=1024, - intermediate_size=2560, - num_experts=4, - capacity_factor=4.0, + capacity_factor=None, + **get_model_config("sbase-small", scale_down_factor=4), + **test_cfg + ), + # Test forward_capacity_factor + ExptCfg( + seq_len=256, + batch_size=1, + capacity_factor=2.0, + **get_model_config("sbase-large", scale_down_factor=8), + **test_cfg + ), + ] + ) + + # Inference tests + test_cfg["test_mode"] = "inference" + sbase_test_configs.extend( + [ + # Context encoding + ExptCfg( + seq_len=256, + batch_size=1, + capacity_factor=None, + **get_model_config("sbase-small", scale_down_factor=4), **test_cfg ), # Token-generation - ExptCfgDeviceCorrectness( + ExptCfg( seq_len=1, batch_size=1, - hidden_size=1024, - intermediate_size=2560, - num_experts=24, - capacity_factor=1.0, - test_mode="inference", + capacity_factor=None, + **get_model_config("sbase-large", scale_down_factor=8), **test_cfg ), - ExptCfgDeviceCorrectness( + ExptCfg( seq_len=1, batch_size=4, - hidden_size=1024, - intermediate_size=2560, - num_experts=24, - capacity_factor=1.0, - test_mode="inference", + capacity_factor=None, + **get_model_config("sbase-large", scale_down_factor=8), **test_cfg ), ] ) - for permute_strategy in PERMUTE_STRATEGY_ARGS: - test_cfg["expert_mlps_permute_strategy"] = permute_strategy - sbase_test_configs.extend( - [ - # Training / Context-encoding - # capacity_factor such that some tokens may be dropped - ExptCfgDeviceCorrectness( - seq_len=256, - batch_size=1, - hidden_size=1024, - intermediate_size=2560, - num_experts=24, - capacity_factor=2.0, - **test_cfg - ), - ] - ) - # Test each S-BASE configuration on 2 random activation functions for test_no, cfg in enumerate(sbase_test_configs): for hidden_act in get_random_activations(num=2, seed=test_no): @@ -98,207 +136,242 @@ def get_device_correctness_test_configs(dtype): "glu_mlp": True, "hidden_act": "silu", "implementation": "topk", - "expert_mlps_permute_strategy": "index", } - # Test forward_full_capacity and token-gen + + # Training tests + test_cfg["test_mode"] = "training" + test_configs.extend( + [ + # Test forward_all_experts (full capacity) + ExptCfg( + seq_len=256, + batch_size=1, + capacity_factor=None, + **get_model_config("mixtral", scale_down_factor=4), + **test_cfg + ), + ExptCfg( + seq_len=256, + batch_size=1, + capacity_factor=None, + **get_model_config("dbrx", scale_down_factor=4), + **test_cfg + ), + # Test forward_capacity_factor + ExptCfg( + seq_len=256, + batch_size=1, + capacity_factor=2.0, + **get_model_config("mixtral", scale_down_factor=4), + **test_cfg + ), + ExptCfg( + seq_len=256, + batch_size=1, + capacity_factor=2.0, + **get_model_config("dbrx", scale_down_factor=4), + **test_cfg + ), + ] + ) + + # Inference tests + test_cfg["test_mode"] = "inference" test_configs.extend( [ - # Training / Context-encoding - # capacity_factor = num_experts/top_k to ensure no dropped tokens - ExptCfgDeviceCorrectness( + # Context-encoding + ExptCfg( seq_len=256, batch_size=1, - hidden_size=1024, - intermediate_size=3584, - num_experts=8, - top_k=2, - capacity_factor=4.0, + capacity_factor=None, + **get_model_config("mixtral", scale_down_factor=4), + **test_cfg + ), + ExptCfg( + seq_len=256, + batch_size=1, + capacity_factor=None, + **get_model_config("dbrx", scale_down_factor=4), **test_cfg ), # Token-generation - ExptCfgDeviceCorrectness( + ExptCfg( seq_len=1, batch_size=1, - hidden_size=1024, - intermediate_size=3584, - num_experts=8, - capacity_factor=1.0, - top_k=2, - test_mode="inference", + capacity_factor=None, + **get_model_config("mixtral", scale_down_factor=4), **test_cfg ), - ExptCfgDeviceCorrectness( + ExptCfg( seq_len=1, batch_size=4, - hidden_size=768, - intermediate_size=2688, - num_experts=16, - capacity_factor=1.0, - top_k=4, - test_mode="inference", + capacity_factor=None, + **get_model_config("dbrx", scale_down_factor=4), **test_cfg ), ] ) - for permute_strategy in PERMUTE_STRATEGY_ARGS: - test_cfg["expert_mlps_permute_strategy"] = permute_strategy - test_configs.extend( - [ - # Training / Context-encoding - # capacity_factor such that some tokens may be dropped - ExptCfgDeviceCorrectness( - seq_len=256, - batch_size=1, - hidden_size=1024, - intermediate_size=3584, - num_experts=8, - top_k=2, - capacity_factor=2.0, - **test_cfg - ), - ExptCfgDeviceCorrectness( - seq_len=256, - batch_size=1, - hidden_size=1536, - intermediate_size=2688, - num_experts=16, - top_k=4, - capacity_factor=2.0, - **test_cfg - ), - ] - ) - - return filter_valid_expt_configs(test_configs) + return test_configs @dataclasses.dataclass -class ExptCfgParallel(ExptCfgDeviceCorrectness): +class ExptCfgParallel(ExptCfg): # Default values must be over-ridden tp_degree: int = 0 ep_degree: int = 0 - sequence_parallel_mode: MoESequenceParallelMode = -1 + sequence_parallel_enabled: bool = False -def get_device_correctness_parallel_test_configs(dtype, tp_degree, sp_mode): - GLU_MLP_ARGS = [True, False] - PERMUTE_STRATEGY_ARGS = ["matmul", "index"] +def get_device_correctness_parallel_test_configs(dtype, test_mode, tp_degree, ep_degree, zero1): + assert test_mode in {"training", "inference"} test_configs = [] # S-BASE test cases - # All test cases use "silu" since other activations are tested in the single-core test - for test_no, (glu_mlp, permute_strategy) in enumerate(itertools.product(GLU_MLP_ARGS, PERMUTE_STRATEGY_ARGS)): - test_cfg = { - "dtype": dtype, - "glu_mlp": glu_mlp, - "hidden_act": "silu", - "implementation": "sbase", - "expert_mlps_permute_strategy": permute_strategy, - "num_iters": 5, - } - test_configs.extend( - [ - # Training / Context-encoding - ExptCfgParallel( - seq_len=256, - batch_size=1, - hidden_size=1024, - intermediate_size=2560, - num_experts=24, - capacity_factor=2.0, - **test_cfg - ), - # Token-generation - ExptCfgParallel( - seq_len=1, - batch_size=1, - hidden_size=1024, - intermediate_size=2560, - num_experts=24, - capacity_factor=1.0, - test_mode="inference", - **test_cfg - ), - ExptCfgParallel( - seq_len=1, - batch_size=4, - hidden_size=1024, - intermediate_size=2560, - num_experts=24, - capacity_factor=1.0, - test_mode="inference", - **test_cfg - ), - ] - ) + + # All test cases use glu_mlp = True (glu_mlp = False tested in single-core test) + # All test cases use "silu" (other activations tested in the single-core test) + test_cfg = { + "dtype": dtype, + "glu_mlp": True, + "hidden_act": "silu", + "implementation": "sbase", + "num_iters": 1, + "zero1": zero1, + } + + # Training tests + test_cfg["test_mode"] = "training" + test_configs.extend( + [ + ExptCfgParallel( + seq_len=256, + batch_size=1, + capacity_factor=2.0, + **get_model_config("sbase-large", scale_down_factor=8), + **test_cfg + ), + ] + ) + + # Inference tests + test_cfg["test_mode"] = "inference" + test_configs.extend( + [ + # Context-encoding + ExptCfgParallel( + seq_len=256, + batch_size=1, + capacity_factor=None, + **get_model_config("sbase-large", scale_down_factor=4), + **test_cfg + ), + # Token-generation + ExptCfgParallel( + seq_len=1, + batch_size=1, + capacity_factor=None, + **get_model_config("sbase-large", scale_down_factor=4), + **test_cfg + ), + ExptCfgParallel( + seq_len=1, + batch_size=4, + capacity_factor=None, + **get_model_config("sbase-large", scale_down_factor=4), + **test_cfg + ), + ] + ) # TopK test cases - for permute_strategy in PERMUTE_STRATEGY_ARGS: - test_cfg = { - "dtype": dtype, - "glu_mlp": True, - "hidden_act": "silu", - "implementation": "topk", - "expert_mlps_permute_strategy": permute_strategy, - "num_iters": 5, - } - test_configs.extend( - [ - # Training / Context-encoding - ExptCfgParallel( - seq_len=256, - batch_size=1, - hidden_size=1024, - intermediate_size=3584, - num_experts=8, - top_k=2, - capacity_factor=2.0, - **test_cfg - ), - ExptCfgParallel( - seq_len=256, - batch_size=1, - hidden_size=1536, - intermediate_size=2688, - num_experts=16, - top_k=4, - capacity_factor=2.0, - **test_cfg - ), - # Token-generation - ExptCfgParallel( - seq_len=1, - batch_size=1, - hidden_size=1024, - intermediate_size=3584, - num_experts=8, - capacity_factor=1.0, - top_k=2, - test_mode="inference", - **test_cfg - ), - ExptCfgParallel( - seq_len=1, - batch_size=4, - hidden_size=768, - intermediate_size=2688, - num_experts=16, - capacity_factor=1.0, - top_k=4, - test_mode="inference", - **test_cfg - ), - ] - ) + test_cfg = { + "dtype": dtype, + "glu_mlp": True, + "hidden_act": "silu", + "implementation": "topk", + "num_iters": 1, + "zero1": zero1, + } - # Add TP degree, SP mode to config + # Training tests + test_cfg["test_mode"] = "training" + test_configs.extend( + [ + ExptCfgParallel( + seq_len=256, + batch_size=1, + capacity_factor=2.0, + **get_model_config("mixtral", scale_down_factor=8), + **test_cfg + ), + ExptCfgParallel( + seq_len=256, + batch_size=1, + capacity_factor=2.0, + **get_model_config("dbrx", scale_down_factor=8), + **test_cfg + ), + ] + ) + + # Inference tests + test_cfg["test_mode"] = "inference" + test_configs.extend( + [ + # Context-encoding + ExptCfgParallel( + seq_len=256, + batch_size=1, + capacity_factor=None, + **get_model_config("mixtral", scale_down_factor=1), + **test_cfg + ), + ExptCfgParallel( + seq_len=256, + batch_size=1, + capacity_factor=None, + **get_model_config("dbrx", scale_down_factor=2), + **test_cfg + ), + # Token-generation + ExptCfgParallel( + seq_len=1, + batch_size=1, + capacity_factor=None, + **get_model_config("mixtral", scale_down_factor=1), + **test_cfg + ), + ExptCfgParallel( + seq_len=1, + batch_size=4, + capacity_factor=None, + **get_model_config("dbrx", scale_down_factor=2), + **test_cfg + ), + ] + ) + + # Add tp_degree, sequence_parallel_enabled to config test_configs_parallel = [] for cfg in test_configs: + # Filter to required test_mode + if cfg.test_mode != test_mode: + continue + + # EP + token-gen not supported + if ep_degree > 1 and cfg.seq_len == 1 and cfg.test_mode == "inference": + continue + + # Enable SP in training, disable SP in inference + sequence_parallel_enabled = True if test_mode == "training" else False cfg_parallel = dataclasses.replace( - cfg, tp_degree=tp_degree, ep_degree=1, sequence_parallel_mode=sp_mode + cfg, + tp_degree=tp_degree, + ep_degree=ep_degree, + sequence_parallel_enabled=sequence_parallel_enabled, ) test_configs_parallel.append(cfg_parallel) - return filter_valid_expt_configs(test_configs_parallel) + return test_configs_parallel diff --git a/test/integration/modules/moe/device_correctness_test_runner.py b/test/integration/modules/moe/device_correctness_test_runner.py index bf63c38..8e0d849 100644 --- a/test/integration/modules/moe/device_correctness_test_runner.py +++ b/test/integration/modules/moe/device_correctness_test_runner.py @@ -1,40 +1,34 @@ import dataclasses +import gc +import os import torch import torch.nn.functional as F +from torch.optim import Adam import torch_xla.core.xla_model as xm # TRN enablement # Imports from MoE unit tests (for this import to succeed, test/unit_test/modules/moe must be added to PYTHONPATH) import utils_testing as ut -from neuronx_distributed.modules.moe import MoESequenceParallelMode from neuronx_distributed.parallel_layers import mappings, parallel_state, random -STATE_KEYS = { - "_TENSOR_MODEL_PARALLEL_GROUP", - "_TENSOR_MODEL_PARALLEL_GROUP_SPMD", - "_PIPELINE_MODEL_PARALLEL_GROUP", - "_PIPELINE_GLOBAL_RANKS", - "_PIPELINE_MODEL_PARALLEL_GROUP_SPMD", - "_NEXT_RANK_GROUP_SPMD", - "_PREV_RANK_GROUP_SPMD", - "_NEXT_RANK_GROUP", - "_PREV_RANK_GROUP", - "_DATA_PARALLEL_GROUP", - "_DATA_PARALLEL_GROUP_SPMD", - "_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE", - "_MPU_TENSOR_MODEL_PARALLEL_RANK", -} - -PARALLEL_STATE_MAP = {} - - -def get_model_outputs(cfg, model, ip, target, sequence_parallel_enabled): +def get_model_outputs(cfg, model, optimizer, ip, target, sequence_parallel_enabled, dp_size, reduce_gradients=False, serialize_dp=False): assert model.is_test is False - if model.return_router_logits: - op, _ = model(ip) - else: - (op,) = model(ip) + + # for cpu, sequentially run each data-parallel shard and accummulate grads + # so that the token dropping pattern is the same as trn + ip_chunks = split_ip_into_chunks(ip, dp_size, serialize_dp, cfg.test_mode) + outputs = [] + + for ip in ip_chunks: + if model.return_router_logits: + op, _ = model(ip) + else: + (op,) = model(ip) + outputs.append(op) + + batch_dim = 1 if cfg.test_mode == "training" else 0 + op = torch.cat(outputs, dim=batch_dim) if cfg.test_mode == "training": if sequence_parallel_enabled: @@ -45,41 +39,93 @@ def get_model_outputs(cfg, model, ip, target, sequence_parallel_enabled): loss = F.nll_loss(op_full, target) del op_full loss.backward() + + # prevents runtime errors when running back-to-back unit tests with cross-node ep + xm.mark_step() + + loss = reduce_gradients_and_losses_and_take_step(optimizer, loss, reduce_gradients) grad_dict = ut.get_model_grads_dict(model) + return op, loss, grad_dict else: + assert cfg.test_mode == "inference" return op, torch.Tensor([0]), {} +def reduce_gradients_and_losses_and_take_step(optimizer, loss, reduce_gradients): + optimizer.step() + if not reduce_gradients: + return loss + + # reduce loss + edp_group = parallel_state.get_expert_data_parallel_group(as_list=True) + emp_group = parallel_state.get_expert_model_parallel_group(as_list=True) + dp_size = parallel_state.get_data_parallel_size() -def nxd_init(tp_degree, ep_degree, seed): - assert ep_degree == 1 + loss /= dp_size + xm.all_reduce("sum", [loss], groups=edp_group) + xm.all_reduce("sum", [loss], groups=emp_group) + return loss - world_size = torch.distributed.get_world_size() - parallel_state_key = f"{world_size}_{tp_degree}_{ep_degree}" - def _save_parallel_state(key): - state = {} - for attr in STATE_KEYS: - state[attr] = parallel_state.__dict__[attr] - PARALLEL_STATE_MAP[key] = state +def split_ip_into_chunks(ip, dp_size, serialize_dp, test_mode): + # inference input is already sharded by dp + if test_mode == 'inference' or not serialize_dp: + return [ip] - def _load_parallel_state(key): - for k, v in PARALLEL_STATE_MAP[key].items(): - parallel_state.__dict__[k] = v + batch_dim = 1 if test_mode == "training" else 0 + split_tensor = torch.tensor_split(ip, dp_size, dim=batch_dim) + return [t.contiguous() for t in split_tensor] - if parallel_state_key in PARALLEL_STATE_MAP: - _load_parallel_state(parallel_state_key) + +def shard_batch(tensor, cfg, dp_size, dp_rank, test_mode): + assert tensor.dim() < 4 and tensor.dim() > 0 + shape = list(tensor.shape) + if test_mode == "training": + tensor = tensor.reshape(cfg.seq_len, dp_size*cfg.batch_size, -1) + tensor = tensor.narrow(1, dp_rank*cfg.batch_size, cfg.batch_size) + + if len(shape) > 2: + shape[1] //= dp_size + else: + shape[0] //= dp_size + return tensor.reshape(*shape) else: - parallel_state.destroy_model_parallel() - parallel_state.initialize_model_parallel( - tensor_model_parallel_size=tp_degree, - pipeline_model_parallel_size=1, - ) - _save_parallel_state(parallel_state_key) + return tensor + + +def _get_slice_for_rank(tensor, sharding_info, split_dims=None): + tp_rank, tp_size, ep_rank, ep_size = sharding_info + for dim in split_dims: + rank, size = (tp_rank, tp_size) if dim > 0 else (ep_rank, ep_size) + tensor = torch.tensor_split(tensor, size, dim=dim)[rank] + return tensor + +def _slice_and_compare_tensors(cpu_dict, trn_dict, sharding_info, it, **tols): + assert set(cpu_dict.keys()) == set(trn_dict.keys()) + for key in sorted(cpu_dict): + cpu_dict[key] = cpu_dict[key].detach() + if cpu_dict[key].shape == trn_dict[key].shape: + key_tensor_for_rank = cpu_dict[key] + else: + if "gate_up_proj" in key: + gate_proj_tensor, up_proj_tensor = torch.tensor_split(cpu_dict[key], 2, dim=2) + gate_proj_tensor_for_rank = _get_slice_for_rank(gate_proj_tensor, sharding_info, split_dims=(0, 2)) + up_proj_tensor_for_rank = _get_slice_for_rank(up_proj_tensor, sharding_info, split_dims=(0, 2)) + key_tensor_for_rank = torch.cat([gate_proj_tensor_for_rank, up_proj_tensor_for_rank], dim=2) + elif "up_proj" in key: + key_tensor_for_rank = _get_slice_for_rank(cpu_dict[key], sharding_info, split_dims=(0, 2)) + elif "down_proj" in key: + key_tensor_for_rank = _get_slice_for_rank(cpu_dict[key], sharding_info, split_dims=(0, 1)) + else: + raise Exception( + f"Unexpected shapes for key: {key}, {cpu_dict[key].shape}, {trn_dict[key].shape}" + ) - # Set seed - random.model_parallel_xla_manual_seed(seed) + additional_msg = f"Iteration {it} \nKey: {key}" + ut.check_tensors( + key_tensor_for_rank, trn_dict[key].detach(), **tols, additional_msg=additional_msg + ) def run_device_correctness_test(cfg, output_tols, grad_tols): device = "xla" @@ -87,42 +133,62 @@ def run_device_correctness_test(cfg, output_tols, grad_tols): tp_degree = getattr(cfg, "tp_degree", 1) ep_degree = getattr(cfg, "ep_degree", 1) assert cfg.test_mode in {"training", "inference"}, f"Unknown test_mode: {cfg.test_mode}" - sequence_parallel_mode = getattr(cfg, "sequence_parallel_mode", MoESequenceParallelMode.NO_SP) - sequence_parallel_enabled = cfg.test_mode == "training" and sequence_parallel_mode != MoESequenceParallelMode.NO_SP - - # Initialize model on cpu and trn - nxd_init(tp_degree=1, ep_degree=1, seed=0) - model_cpu = ut.initialize_neuron_model(cfg) - nxd_init(tp_degree=tp_degree, ep_degree=ep_degree, seed=0) - model_trn = ut.initialize_neuron_model(cfg_trn) + sequence_parallel_enabled = cfg.sequence_parallel_enabled + if sequence_parallel_enabled: + assert cfg.test_mode == "training" + + ut.nxd_init(tp_degree=tp_degree, ep_degree=ep_degree, seed=0) + # using non-zero learning rate for zero-1 so that we can do an end-to-end test + lr = cfg_trn.lr if cfg_trn.zero1 else 0.0 + grad_clipping = True + dp_size = parallel_state.get_data_parallel_size() + dp_rank = parallel_state.get_data_parallel_rank() + tp_size = parallel_state.get_tensor_model_parallel_size() + tp_rank = parallel_state.get_tensor_model_parallel_rank() + ep_size = parallel_state.get_expert_model_parallel_size() + ep_rank = parallel_state.get_expert_model_parallel_rank() if cfg.test_mode == "training": - model_cpu.train() - model_trn.train() - # Set sinkhorn_iterations=0, because small precision errors can cause differences in routing decisions - model_cpu.router.sinkhorn_iterations = 0 - model_trn.router.sinkhorn_iterations = 0 grad_ctx_mgr = torch.enable_grad else: - model_cpu.eval() - model_trn.eval() grad_ctx_mgr = torch.no_grad with grad_ctx_mgr(): for it in range(cfg.num_iters): + print(f"iteration {it}") + # Initialize model on cpu and trn + ut.nxd_init(tp_degree=1, ep_degree=1, seed=it) + model_cpu = ut.initialize_neuron_model(cfg) + ut.nxd_init(tp_degree=tp_degree, ep_degree=ep_degree, seed=it) + model_trn = ut.initialize_neuron_model(cfg_trn) + ut.match_expert_weights(model_trn, model_cpu, cfg.glu_mlp) + + if cfg.test_mode == "training": + model_cpu.train() + model_trn.train() + # Set sinkhorn_iterations=0, because small precision errors can cause differences in routing decisions + model_cpu.router.sinkhorn_iterations = 0 + model_trn.router.sinkhorn_iterations = 0 + else: + model_cpu.eval() + model_trn.eval() + optimizer_cpu = ut.initialize_neuron_optimizer(model_cpu, grad_clipping=grad_clipping, override_grad_reduction=True, zero1=False, lr=lr) + optimizer_trn = ut.initialize_neuron_optimizer(model_trn, grad_clipping=grad_clipping, zero1=cfg_trn.zero1, lr=lr) # Init NxD with tp_degree=1 and ep_degree=1, for running on cpu model - nxd_init(tp_degree=1, ep_degree=1, seed=it) + ut.nxd_init(tp_degree=1, ep_degree=1, seed=it) # Initialize input, target, model on cpu if cfg.test_mode == "training": # Input is SBH in training - ip_cpu = torch.randn(cfg.seq_len, cfg.batch_size, cfg.hidden_size, dtype=cfg.dtype).detach() + ip_cpu = torch.randn(cfg.seq_len, cfg.batch_size * dp_size, cfg.hidden_size, dtype=cfg.dtype).detach() + target_cpu = torch.randint( + 0, cfg.hidden_size - 1, (cfg.seq_len * cfg.batch_size * dp_size,), dtype=torch.long).detach() else: # Input is BSH in inference ip_cpu = torch.randn(cfg.batch_size, cfg.seq_len, cfg.hidden_size, dtype=cfg.dtype).detach() + target_cpu = torch.randint( + 0, cfg.hidden_size - 1, (cfg.seq_len * cfg.batch_size,), dtype=torch.long).detach() ip_trn_full = ip_cpu.detach().to(device) - target_cpu = torch.randint( - 0, cfg.hidden_size - 1, (cfg.seq_len * cfg.batch_size,), dtype=torch.long - ).detach() + ip_trn_full = shard_batch(ip_trn_full, cfg, dp_size, dp_rank, cfg.test_mode) # torch.topk behavior is different on cpu and device in the case of ties. # This causes mismatches in expert assignment for the TopK tests in bf16. @@ -134,33 +200,47 @@ def run_device_correctness_test(cfg, output_tols, grad_tols): # Simulate dropping of tokens in input where the expert assignments are not matching on cpu and device with torch.no_grad(): router_logits_cpu, expert_index_cpu = model_cpu(ip_cpu)[-2:] - expert_index_trn = model_trn(ip_trn_full)[-1] - expert_mismatch_indices = set(torch.where(expert_index_cpu != expert_index_trn.cpu())[0].tolist()) + ut.nxd_init(tp_degree=tp_degree, ep_degree=ep_degree, seed=it) + if sequence_parallel_enabled: + ip_trn = mappings.scatter_to_sequence_parallel_region(ip_trn_full) + else: + ip_trn = ip_trn_full + expert_index_trn = model_trn(ip_trn)[-1] + ut.nxd_init(tp_degree=1, ep_degree=1, seed=it) + local_ip_cpu = shard_batch(ip_cpu, cfg, dp_size, dp_rank, cfg.test_mode) + local_expert_index_cpu = shard_batch(expert_index_cpu, cfg, dp_size, dp_rank, cfg.test_mode) + local_router_logits_cpu = shard_batch(router_logits_cpu, cfg, dp_size, dp_rank, cfg.test_mode) + expert_mismatch_indices = set(torch.where(local_expert_index_cpu != expert_index_trn.cpu())[0].tolist()) if len(expert_mismatch_indices) > 0: # Check that mismatches only happen when the (top_k+1) router logits are non-unique for mismatch_idx in expert_mismatch_indices: - router_logits_idx = router_logits_cpu[mismatch_idx] + router_logits_idx = local_router_logits_cpu[mismatch_idx] topk_logits, _ = torch.topk(router_logits_idx, min(cfg.top_k + 1, cfg.num_experts)) assert len(topk_logits) != len(torch.unique(topk_logits)), str(topk_logits) # Update the input tensor to mask tokens where there is an expert assignment mismatch - ip_cpu = ut.drop_tokens_in_tensor(ip_cpu, expert_mismatch_indices) + # Modifying local_ip_cpu also modifies ip_cpu since they share underlying memory + local_ip_cpu = ut.drop_tokens_in_tensor(local_ip_cpu, expert_mismatch_indices) ip_trn_full = ip_cpu.detach().to(device) - + ip_trn_full = shard_batch(ip_trn_full, cfg, dp_size, dp_rank, cfg.test_mode) # Reset is_test model_cpu.is_test = False model_trn.is_test = False + sharding_info = (tp_rank, tp_size, ep_rank, ep_size) + # Get outputs and gradients from cpu op_cpu, loss_cpu, grad_dict_cpu = get_model_outputs( - cfg, model_cpu, ip_cpu, target_cpu, sequence_parallel_enabled + cfg, model_cpu, optimizer_cpu, ip_cpu, target_cpu, sequence_parallel_enabled, dp_size, serialize_dp=True ) # Re-init NxD with actual TP degree - nxd_init(tp_degree=tp_degree, ep_degree=ep_degree, seed=it) + ut.nxd_init(tp_degree=tp_degree, ep_degree=ep_degree, seed=it) # Get sharded input for rank (for sequence parallel) tp_size = parallel_state.get_tensor_model_parallel_size() tp_rank = parallel_state.get_tensor_model_parallel_rank() + + ip_trn_full = ip_cpu.detach().to(device) if sequence_parallel_enabled: ip_trn = mappings.scatter_to_sequence_parallel_region(ip_trn_full) else: @@ -168,8 +248,12 @@ def run_device_correctness_test(cfg, output_tols, grad_tols): # Get outputs and gradients from trn, using the same input and target target_trn = target_cpu.clone().detach().to(device) + + # Data-parallel sharding + target_trn = shard_batch(target_trn, cfg, dp_size, dp_rank, cfg.test_mode) + op_trn, loss_trn, grad_dict_trn = get_model_outputs( - cfg, model_trn, ip_trn, target_trn, sequence_parallel_enabled + cfg, model_trn, optimizer_trn, ip_trn, target_trn, sequence_parallel_enabled, dp_size, reduce_gradients=True ) xm.mark_step() # TRN enablement @@ -179,6 +263,10 @@ def run_device_correctness_test(cfg, output_tols, grad_tols): if sequence_parallel_enabled: # Compare with only output shard belonging to the TP rank op_cpu = torch.tensor_split(op_cpu, tp_degree, dim=0)[tp_rank] + + if cfg.test_mode == "training": + batch_dim = 1 + op_cpu = op_cpu.narrow(batch_dim, dp_rank*cfg.batch_size, cfg.batch_size) ut.check_tensors(op_cpu.detach(), op_trn.detach(), **output_tols, additional_msg=f"Iteration {it}") del op_cpu, op_trn @@ -186,33 +274,21 @@ def run_device_correctness_test(cfg, output_tols, grad_tols): ut.check_tensors(loss_cpu.detach(), loss_trn.detach(), **output_tols) del loss_cpu, loss_trn - # Check gradients on each tp_rank - assert set(grad_dict_cpu.keys()) == set(grad_dict_trn.keys()) - for key in sorted(grad_dict_cpu): - grad_dict_cpu[key] = grad_dict_cpu[key].detach() - if grad_dict_cpu[key].shape == grad_dict_trn[key].shape: - key_grad_for_rank = grad_dict_cpu[key] - else: - if "gate_up_proj" in key: - gate_proj_grad, up_proj_grad = torch.tensor_split(grad_dict_cpu[key], 2, dim=2) - gate_proj_grad_for_rank = torch.tensor_split(gate_proj_grad, tp_size, dim=2)[tp_rank] - up_proj_grad_for_rank = torch.tensor_split(up_proj_grad, tp_size, dim=2)[tp_rank] - key_grad_for_rank = torch.cat([gate_proj_grad_for_rank, up_proj_grad_for_rank], dim=2) - elif "up_proj" in key: - key_grad_for_rank = torch.tensor_split(grad_dict_cpu[key], tp_size, dim=2)[tp_rank] - elif "down_proj" in key: - key_grad_for_rank = torch.tensor_split(grad_dict_cpu[key], tp_size, dim=1)[tp_rank] - else: - raise Exception( - f"Unexpected shapes for key: {key}, {grad_dict_cpu[key].shape}, {grad_dict_trn[key].shape}" - ) - - additional_msg = f"Iteration {it} \nKey: {key}" - ut.check_tensors( - key_grad_for_rank, grad_dict_trn[key].detach(), **grad_tols, additional_msg=additional_msg - ) - - del grad_dict_cpu, grad_dict_trn - - model_cpu.zero_grad(set_to_none=True) - model_trn.zero_grad(set_to_none=True) + if not cfg_trn.zero1: + # Check gradients on each rank + _slice_and_compare_tensors(grad_dict_cpu, grad_dict_trn, sharding_info, it, **grad_tols) + del grad_dict_cpu, grad_dict_trn + else: + # if zero1 is enabled then directly compare updated parameters, not the gradients may not match because the true gradients used is private in zero1 optimizer + trn_parameters = {n: p for n, p in model_trn.named_parameters()} + cpu_parameters = {n: p for n, p in model_cpu.named_parameters()} + param_tols = {k: cfg_trn.lr * v for k, v in grad_tols.items()} + _slice_and_compare_tensors(cpu_parameters, trn_parameters, sharding_info, it, **param_tols) + del cpu_parameters, trn_parameters, grad_dict_cpu, grad_dict_trn + + optimizer_cpu.zero_grad(set_to_none=True) + optimizer_trn.zero_grad(set_to_none=True) + xm.mark_step() + + del model_cpu, model_trn + gc.collect() diff --git a/test/integration/modules/moe/launch_device_correctness_parallel.sh b/test/integration/modules/moe/launch_device_correctness_parallel.sh new file mode 100755 index 0000000..df7d725 --- /dev/null +++ b/test/integration/modules/moe/launch_device_correctness_parallel.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +set -e + +export SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +export NXD_DIR=/home/ubuntu/ktest/NeuronxDistributed +export UNIT_TEST_DIR=$NXD_DIR/test/unit_test/modules/moe + +export PYTHONPATH=$NXD_DIR/src:$NXD_DIR:$UNIT_TEST_DIR:$PYTHONPATH +echo $PYTHONPATH + +MASTER_ADDR_JOB=(`scontrol show hostnames $SLURM_JOB_NODELIST`) + +# prevents hanging during NCCL init +#sudo sysctl -w net.ipv4.ip_local_reserved_ports=48620 + +export OMP_NUM_THREADS=1 +export FI_LOG_LEVEL=warn +export FI_EFA_USE_DEVICE_RDMA=1 +export FI_PROVIDER=efa +export FI_EFA_FORK_SAFE=1 + +torchrun --nproc_per_node=32 \ + --nnodes=${SLURM_NTASKS} \ + --node_rank=${SLURM_NODEID} \ + --master_addr=${MASTER_ADDR_JOB} \ + --master_port=2020 \ + $SCRIPT_DIR/test_device_correctness_parallel.py \ + --test_json test.json \ + --test_tp_degree=$1 \ + --test_ep_degree=$2 \ + --test_mode=training \ + --test_dtype=$4 \ + --zero1=$3 diff --git a/test/integration/modules/moe/test_checkpoint_parallel.py b/test/integration/modules/moe/test_checkpoint_parallel.py new file mode 100644 index 0000000..0781097 --- /dev/null +++ b/test/integration/modules/moe/test_checkpoint_parallel.py @@ -0,0 +1,128 @@ +import argparse +import atexit +import json +import os +import time +import traceback + +import torch + +# Imports from MoE unit tests (for this import to succeed, test/unit_test/modules/moe must be added to PYTHONPATH) +from device_correctness_test_configs import ( + get_device_correctness_parallel_test_configs, + get_neuron_cc_flags, +) +from checkpoint_test_runner import run_checkpoint_test + +from neuronx_distributed.parallel_layers.utils import is_pjrt_device +import torch_xla.core.xla_model as xm # TRN enablement + +SEPARATOR = "-" * 70 + +def parse_args(): + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument( + "--test_json", + required=False, + help="input json listing the test spec for network to compile", + ) + parser.add_argument("--s3_dir", required=False, help="location to upload all test artifacts") + parser.add_argument("--s3_bucket", default="s3://ktf-test-runs/neuronx_distributed_modules/moe") + parser.add_argument("--test_dtype", required=True, choices=["fp32", "bf16"], help="Either fp32 or bf16") + parser.add_argument("--test_mode", required=True, type=str, help="Either training or inference") + parser.add_argument( + "--test_tp_degree", required=True, type=int, choices=[1, 2, 8, 16, 32], help="One of 1, 2, 8, 16 or 32" + ) + parser.add_argument( + "--test_ep_degree", required=True, type=int, choices=[1, 2, 4, 8, 16, 32], help="One of 1, 2, 4, 8, 16 or 32" + ) + args, leftovers = parser.parse_known_args() + S3_BUCKET_NAME = args.s3_bucket + with open(args.test_json, "r") as f: + test_dict = json.load(f) + test_dtype = torch.float32 if args.test_dtype == "fp32" else torch.bfloat16 + return test_dict, S3_BUCKET_NAME, args, test_dtype, args.test_mode, args.test_tp_degree, args.test_ep_degree + + +test_config, S3_BUCKET_NAME, args, TEST_DTYPE, TEST_MODE, TEST_TP_DEGREE, TEST_EP_DEGREE = parse_args() +results = {"inference_success": 1} + +# Set compiler flags before TRN enablement +os.environ["NEURON_CC_FLAGS"] = get_neuron_cc_flags(test_dtype=TEST_DTYPE) + +def print_rank0(s): + if xm.get_ordinal() == 0: + print(s) + + +def summarize_test(start_time, num_tests, failed): + print_rank0(f"{SEPARATOR}\nRan {num_tests} tests in {round(time.time()-start_time, 1)}s\n\n") + if failed == 0: + print_rank0("OK\n\n") + else: + raise Exception(f"Failed {failed}/{num_tests} tests") + + +def test_moe_layer_checkpoint_parallel(): + def _test_moe_layer_checkpoint_parallel(): + test_configs = get_device_correctness_parallel_test_configs( + dtype=TEST_DTYPE, tp_degree=TEST_TP_DEGREE, ep_degree=TEST_EP_DEGREE, test_mode=TEST_MODE + ) + + start_time = time.time() + failed = 0 + print_rank0(f"Running {len(test_configs)} tests") + for i, cfg in enumerate(test_configs): + print_rank0(f"Running test {i+1}/{len(test_configs)}: {str(cfg)}") + try: + run_checkpoint_test(cfg) + clean_dir() + print_rank0("ok\n") + + except Exception: + print_rank0("Failed test") + print_rank0(traceback.format_exc()) + failed += 1 + + # running test only once + break + summarize_test(start_time, len(test_configs), failed) + + global results + try: + _test_moe_layer_checkpoint_parallel() + except: + results["inference_success"] = 0 + print_rank0(traceback.format_exc()) + raise + + +def clean_dir(): + xm.rendezvous("Cleaning directory") + if xm.get_ordinal() == 0: + cur_dir = os.path.dirname(os.path.abspath(__file__)) + path = os.path.join(cur_dir, "model") + os.system(f"rm -rf {path}") + xm.rendezvous("Cleaned directory") + +def on_exit(): + if xm.get_ordinal() == 0: + for k in test_config: + os.system(f"rm {args.test_json}") + with open(args.test_json, "w") as f: + json.dump({k: results}, f) + + +if __name__ == "__main__": + if is_pjrt_device(): + import torch_xla.experimental.pjrt_backend # noqa + + torch.distributed.init_process_group("xla", init_method="pjrt://") + else: + torch.distributed.init_process_group("xla") + + print_rank0( + f"test device correctness parallel, test_dtype={str(TEST_DTYPE)}, test_mode={TEST_MODE}, test_tp_degree={TEST_TP_DEGREE}, test_ep_degree={TEST_EP_DEGREE}" + ) + test_moe_layer_checkpoint_parallel() + atexit.register(on_exit) diff --git a/test/integration/modules/moe/test_device_correctness.py b/test/integration/modules/moe/test_device_correctness.py index 472ed60..61a842d 100644 --- a/test/integration/modules/moe/test_device_correctness.py +++ b/test/integration/modules/moe/test_device_correctness.py @@ -1,16 +1,17 @@ import argparse -import atexit -import json import os import time -import torch import traceback import loss_fn_correctness_test_helper as lch +import torch # Imports from MoE unit tests (for this import to succeed, test/unit_test/modules/moe must be added to PYTHONPATH) import utils_testing as ut -from device_correctness_test_configs import get_device_correctness_test_configs +from device_correctness_test_configs import ( + get_device_correctness_test_configs, + get_neuron_cc_flags, +) from device_correctness_test_runner import run_device_correctness_test from neuronx_distributed.modules.moe import ( @@ -18,7 +19,6 @@ ) from neuronx_distributed.parallel_layers.utils import is_pjrt_device - SEPARATOR = "-" * 70 # FP32 test tolerances @@ -43,37 +43,23 @@ def parse_args(): parser = argparse.ArgumentParser(add_help=False) - parser.add_argument( - "--test_json", - required=False, - help="input json listing the test spec for network to compile", - ) parser.add_argument("--s3_dir", required=False, help="location to upload all test artifacts") parser.add_argument("--s3_bucket", default="s3://ktf-test-runs/neuronx_distributed_modules/moe") parser.add_argument("--test_dtype", required=True, choices=["fp32", "bf16"], help="Either fp32 or bf16") args, leftovers = parser.parse_known_args() S3_BUCKET_NAME = args.s3_bucket - with open(args.test_json, "r") as f: - test_dict = json.load(f) test_dtype = torch.float32 if args.test_dtype == "fp32" else torch.bfloat16 - return test_dict, S3_BUCKET_NAME, args, test_dtype + return S3_BUCKET_NAME, args, test_dtype -test_config, S3_BUCKET_NAME, args, TEST_DTYPE = parse_args() +S3_BUCKET_NAME, args, TEST_DTYPE = parse_args() results = {"inference_success": 1} -if "--model-type" not in os.environ.get("NEURON_CC_FLAGS", ""): - os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + " --model-type=transformer " -else: - assert any(s in os.environ["NEURON_CC_FLAGS"] for s in ["--model-type transformer", "--model-type=transformer"]) +# Set compiler flags before TRN enablement +os.environ["NEURON_CC_FLAGS"] = get_neuron_cc_flags(test_dtype=TEST_DTYPE) - -if TEST_DTYPE == torch.float32: - # Set compiler flag to disable auto-casting before TRN enablement - assert "--auto-cast" not in os.environ.get("NEURON_CC_FLAGS", "") - os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + " --auto-cast=none" - -import torch_xla.core.xla_model as xm # TRN enablement +# TRN enablement +import torch_xla.core.xla_model as xm # noqa: E402 def summarize_test(start_time, num_tests, failed): @@ -100,9 +86,9 @@ def _test_moe_layer_device_correctness(): print(f"Running test {i+1}/{len(test_configs)}: {str(cfg)}") try: run_device_correctness_test(cfg, output_test_tols, grad_test_tols) - print(f"ok\n") - except Exception as e: - print(f"Failed test") + print("ok\n") + except Exception: + print("Failed test") print(traceback.format_exc()) failed += 1 summarize_test(start_time, len(test_configs), failed) @@ -110,7 +96,7 @@ def _test_moe_layer_device_correctness(): global results try: _test_moe_layer_device_correctness() - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise @@ -148,9 +134,9 @@ def _test_loss_fn_device_correctness(): assert neuron_loss.dtype == cpu_loss.dtype TEST_TOLS = lch.FP32_TEST_TOLS if cfg.dtype == torch.float32 else lch.BF16_TEST_TOLS ut.check_tensors(neuron_loss.cpu(), cpu_loss, **TEST_TOLS, additional_msg=f"Iteration {it}") - print(f"ok\n") - except Exception as e: - print(f"Failed test") + print("ok\n") + except Exception: + print("Failed test") print(traceback.format_exc()) failed += 1 summarize_test(start_time, len(test_configs), failed) @@ -158,20 +144,12 @@ def _test_loss_fn_device_correctness(): global results try: _test_loss_fn_device_correctness() - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise -def on_exit(): - if xm.get_ordinal() == 0: - for k in test_config: - os.system(f"rm {args.test_json}") - with open(args.test_json, "w") as f: - json.dump({k: results}, f) - - if __name__ == "__main__": if is_pjrt_device(): import torch_xla.experimental.pjrt_backend # noqa @@ -184,4 +162,3 @@ def on_exit(): test_moe_layer_device_correctness() print(f"Running loss fn device correctness test, test_dtype={str(TEST_DTYPE)}") test_loss_fn_device_correctness() - atexit.register(on_exit) diff --git a/test/integration/modules/moe/test_device_correctness_parallel.py b/test/integration/modules/moe/test_device_correctness_parallel.py index 8a81c41..c0fb2f1 100644 --- a/test/integration/modules/moe/test_device_correctness_parallel.py +++ b/test/integration/modules/moe/test_device_correctness_parallel.py @@ -1,18 +1,21 @@ import argparse -import atexit -import json import os import time -import torch import traceback +import torch + # Imports from MoE unit tests (for this import to succeed, test/unit_test/modules/moe must be added to PYTHONPATH) -from device_correctness_test_configs import get_device_correctness_parallel_test_configs +from device_correctness_test_configs import ( + get_device_correctness_parallel_test_configs, + get_neuron_cc_flags, +) from device_correctness_test_runner import run_device_correctness_test +from checkpoint_test_runner import run_checkpoint_test +from neuronx_distributed.parallel_layers.parallel_state import rmsg from neuronx_distributed.parallel_layers.utils import is_pjrt_device - SEPARATOR = "-" * 70 # FP32 test tolerances @@ -37,43 +40,33 @@ def parse_args(): parser = argparse.ArgumentParser(add_help=False) - parser.add_argument( - "--test_json", - required=False, - help="input json listing the test spec for network to compile", - ) parser.add_argument("--s3_dir", required=False, help="location to upload all test artifacts") parser.add_argument("--s3_bucket", default="s3://ktf-test-runs/neuronx_distributed_modules/moe") parser.add_argument("--test_dtype", required=True, choices=["fp32", "bf16"], help="Either fp32 or bf16") + parser.add_argument("--test_mode", required=True, type=str, help="Either training or inference") + parser.add_argument( + "--test_tp_degree", required=True, type=int, choices=[1, 2, 8, 16, 32], help="One of 1, 2, 8, 16 or 32" + ) parser.add_argument( - "--test_tp_degree", required=True, type=int, choices=[2, 8, 16, 32], help="One of 2, 8, 16 or 32" + "--test_ep_degree", required=False, default=1, type=int, choices=[1, 2, 4, 8, 16, 32], help="One of 1, 2, 4, 8, 16 or 32" ) parser.add_argument( - "--test_sp_mode", required=True, type=str, help="One of MoESequenceParallelMode" + "--zero1", required=False, default=1, type=int, choices=[0, 1], help="Enable zero-1" ) args, leftovers = parser.parse_known_args() S3_BUCKET_NAME = args.s3_bucket - with open(args.test_json, "r") as f: - test_dict = json.load(f) test_dtype = torch.float32 if args.test_dtype == "fp32" else torch.bfloat16 - return test_dict, S3_BUCKET_NAME, args, test_dtype, args.test_tp_degree, args.test_sp_mode + return S3_BUCKET_NAME, args, test_dtype, args.test_mode, args.test_tp_degree -test_config, S3_BUCKET_NAME, args, TEST_DTYPE, TEST_TP_DEGREE, TEST_SP_MODE = parse_args() +S3_BUCKET_NAME, args, TEST_DTYPE, TEST_MODE, TEST_TP_DEGREE = parse_args() results = {"inference_success": 1} -if "--model-type" not in os.environ.get("NEURON_CC_FLAGS", ""): - os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + " --model-type=transformer " -else: - assert any(s in os.environ["NEURON_CC_FLAGS"] for s in ["--model-type transformer", "--model-type=transformer"]) - +# Set compiler flags before TRN enablement +os.environ["NEURON_CC_FLAGS"] = get_neuron_cc_flags(test_dtype=TEST_DTYPE) -if TEST_DTYPE == torch.float32: - # Set compiler flag to disable auto-casting before TRN enablement - assert "--auto-cast" not in os.environ.get("NEURON_CC_FLAGS", "") - os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + " --auto-cast=none" - -import torch_xla.core.xla_model as xm # TRN enablement +# TRN enablement +import torch_xla.core.xla_model as xm # noqa: E402 def print_rank0(s): @@ -98,7 +91,9 @@ def test_moe_layer_device_correctness_parallel(): raise ValueError(f"Unknown TEST_DTYPE: {str(TEST_DTYPE)}") def _test_moe_layer_device_correctness_parallel(): - test_configs = get_device_correctness_parallel_test_configs(dtype=TEST_DTYPE, tp_degree=TEST_TP_DEGREE, sp_mode=TEST_SP_MODE) + test_configs = get_device_correctness_parallel_test_configs( + dtype=TEST_DTYPE, tp_degree=TEST_TP_DEGREE, ep_degree=TEST_EP_DEGREE, test_mode=TEST_MODE, zero1=ZERO1, + ) start_time = time.time() failed = 0 print_rank0(f"Running {len(test_configs)} tests") @@ -106,30 +101,22 @@ def _test_moe_layer_device_correctness_parallel(): print_rank0(f"Running test {i+1}/{len(test_configs)}: {str(cfg)}") try: run_device_correctness_test(cfg, output_test_tols, grad_test_tols) - print_rank0(f"ok\n") + print_rank0("ok\n") except Exception as e: - print_rank0(f"Failed test") - print_rank0(traceback.format_exc()) + print(rmsg(f"Test failed: {e}")) + print(rmsg(traceback.format_exc())) failed += 1 summarize_test(start_time, len(test_configs), failed) global results try: _test_moe_layer_device_correctness_parallel() - except: + except Exception: results["inference_success"] = 0 print_rank0(traceback.format_exc()) raise -def on_exit(): - if xm.get_ordinal() == 0: - for k in test_config: - os.system(f"rm {args.test_json}") - with open(args.test_json, "w") as f: - json.dump({k: results}, f) - - if __name__ == "__main__": if is_pjrt_device(): import torch_xla.experimental.pjrt_backend # noqa @@ -138,6 +125,7 @@ def on_exit(): else: torch.distributed.init_process_group("xla") - print_rank0(f"test device correctness parallel, test_dtype={str(TEST_DTYPE)}, test_tp_degree={TEST_TP_DEGREE}, test_sp_mode={TEST_SP_MODE}") + print_rank0( + f"test device correctness parallel, test_dtype={str(TEST_DTYPE)}, test_mode={TEST_MODE}, test_tp_degree={TEST_TP_DEGREE}, test_ep_degree={TEST_EP_DEGREE}, zero1={ZERO1}" + ) test_moe_layer_device_correctness_parallel() - atexit.register(on_exit) diff --git a/test/integration/modules/moe/test_ep.py b/test/integration/modules/moe/test_ep.py new file mode 100644 index 0000000..b0a426e --- /dev/null +++ b/test/integration/modules/moe/test_ep.py @@ -0,0 +1,177 @@ +""" +torchrun --no_python --nproc_per_node=8 pytest -rA test/moe/test_ep.py +""" +import os + +import neuronx_distributed as nxd +import pytest +import torch +import torch.distributed +from neuronx_distributed.parallel_layers.layers import divide +from neuronx_distributed.parallel_layers.mappings import ( + enter_expert_parallel_region, + exit_expert_parallel_region, +) + +# do distributed setup. test configuration for parallelism: +# - TP within node +# - EP across node +if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend="xla") +n_proc_per_node = int(os.environ["LOCAL_WORLD_SIZE"]) +n_nodes = divide(torch.distributed.get_world_size(), n_proc_per_node) +nxd.parallel_layers.initialize_model_parallel( + tensor_model_parallel_size=n_proc_per_node, + pipeline_model_parallel_size=1, + expert_model_parallel_size=n_nodes, +) +nxd.parallel_layers.random.model_parallel_xla_manual_seed(0) + + +@pytest.mark.parametrize("n_experts", [4], ids=lambda n: f"e={n}") +@pytest.mark.parametrize("expert_capacity", [128], ids=lambda n: f"ec={n}") +@pytest.mark.parametrize("hidden_sz", [4, 64], ids=lambda n: f"h={n}") +@pytest.mark.parametrize("sp_input", [False, True], ids=lambda b: f"sp={int(b)}") +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32], ids=str) +def test_ep_enter( + n_experts: int, + expert_capacity: int, + hidden_sz: int, + sp_input: bool, + dtype: torch.dtype, +) -> None: + trn = torch.device("xla") + ep_group = nxd.parallel_layers.parallel_state.get_expert_model_parallel_group() + tp_group = nxd.parallel_layers.parallel_state.get_tensor_model_parallel_group() + + n_experts_per_ep_rank = divide(n_experts, ep_group.size()) + assert n_experts_per_ep_rank == 1, "haven't implemented expert packing yet" + + # fully unpartitioned set of tokens. expert_capacity is for the full EP group. + x_global = torch.rand( + n_experts, expert_capacity, hidden_sz, dtype=dtype, device=trn + ) + + # tokens that are held by each DP_nonexp rank. (e, c/ep, h) + if not sp_input: + x = x_global.chunk(ep_group.size(), dim=1)[ep_group.rank()].contiguous() + else: + # fmt: off + x = ( + x_global + .chunk(ep_group.size(), dim=1)[ep_group.rank()] # EP + .chunk(tp_group.size(), dim=1)[tp_group.rank()] # SP + ) + # fmt: on + + x_ep = enter_expert_parallel_region(x, input_is_sequence_parallel=sp_input) + + # generate expected tensor after alltoall. + expected = x_global.chunk(ep_group.size(), dim=0)[ep_group.rank()].view( + n_experts_per_ep_rank, + ep_group.size(), + divide(expert_capacity, ep_group.size()), + hidden_sz, + ) + + torch.testing.assert_close(x_ep, expected) + + +@pytest.mark.parametrize("n_experts", [4, 8], ids=lambda n: f"n={n}") +@pytest.mark.parametrize("expert_capacity", [16], ids=lambda n: f"c={n}") +@pytest.mark.parametrize("hidden_sz", [4], ids=lambda n: f"h: {n}") +@pytest.mark.parametrize("seq_parallel", [False, True], ids=lambda b: f"sp={int(b)}") +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) +def test_ep_exit( + n_experts: int, + expert_capacity: int, + hidden_sz: int, + seq_parallel: bool, + dtype: torch.dtype, +): + trn = torch.device("xla") + ep_group = nxd.parallel_layers.parallel_state.get_expert_model_parallel_group() + tp_group = nxd.parallel_layers.parallel_state.get_tensor_model_parallel_group() + + capacity_per_ep_rank = divide(expert_capacity, ep_group.size()) + + x_global = torch.rand( + (n_experts, ep_group.size(), capacity_per_ep_rank, hidden_sz), + dtype=dtype, + device=trn, + ) + + # input: (e/ep, ep, c/sp, h) + x = ( + x_global.detach() + .chunk(ep_group.size(), dim=0)[ep_group.rank()] # EP + .chunk(tp_group.size(), dim=2)[tp_group.rank()] # SP + .contiguous() + ) + + out = exit_expert_parallel_region(x, output_in_sequence_parallel=seq_parallel) + + if seq_parallel: + # expected output: (e, c/sp, h) + expected = ( + x_global.detach() + .chunk(ep_group.size(), dim=1)[ep_group.rank()] # EP + .chunk(tp_group.size(), dim=2)[tp_group.rank()] # SP + .view(n_experts, divide(capacity_per_ep_rank, tp_group.size()), hidden_sz) + ) + else: + # expected output: (e, c, h) + expected = ( + x_global.detach() + .chunk(ep_group.size(), dim=1)[ep_group.rank()] + .view(n_experts, capacity_per_ep_rank, hidden_sz) + ) + torch.testing.assert_close(out, expected) + + +@pytest.mark.parametrize("expert_capacity", [256], ids=lambda n: f"ec={n}") +@pytest.mark.parametrize("hidden_sz", [128], ids=lambda n: f"h={n}") +@pytest.mark.parametrize("n_experts", [4, 8], ids=lambda n: f"e={n}") +@pytest.mark.parametrize("output_sp", [False, True], ids=lambda b: f"sp_out={int(b)}") +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) +def test_ep_enter_then_exit_inversion( + n_experts: int, + expert_capacity: int, + hidden_sz: int, + output_sp: bool, + dtype: torch.dtype, +) -> None: + """we expect that exiting EP should be an inversion of entering EP, or in other + words that entering and then immediate exiting should get us back to the + original input. + """ + trn = torch.device("xla") + dp_group = nxd.parallel_layers.parallel_state.get_data_parallel_group() + tp_group = nxd.parallel_layers.parallel_state.get_tensor_model_parallel_group() + + n_tokens_per_dp_rank = expert_capacity * n_experts + + # need to create these values based off DP rank. (e, c, h) + x = ( + (dp_group.rank() * n_tokens_per_dp_rank + torch.arange(0, n_tokens_per_dp_rank)) + .view(n_experts, expert_capacity, 1) + .repeat([1, 1, hidden_sz]) + .to(dtype=dtype, device=trn) + ) + # (e, c, h) -> (e/ep, ep, c, h) + x_ep = enter_expert_parallel_region(x, input_is_sequence_parallel=False) + # here we mimic a dropping operation, normally there would be an MLP in + # between the two operations which would do this via a reduce-scatter. + # (e/ep, ep, c/sp, h) + x_dropped = x_ep.chunk(dim=2, chunks=tp_group.size())[tp_group.rank()] + # (e/ep, ep, c/sp, h) -> (e, c, h) + x_out = exit_expert_parallel_region( + x_dropped, output_in_sequence_parallel=output_sp + ) + + if output_sp: + torch.testing.assert_close( + x_out, x.chunk(tp_group.size(), dim=1)[tp_group.rank()] + ) + else: + torch.testing.assert_close(x_out, x) diff --git a/test/integration/modules/moe/test_experts.py b/test/integration/modules/moe/test_experts.py new file mode 100644 index 0000000..1d86b31 --- /dev/null +++ b/test/integration/modules/moe/test_experts.py @@ -0,0 +1,298 @@ +import os +from types import SimpleNamespace +from typing import Callable, Dict + +import neuronx_distributed as nxd +import pytest +import torch +import torch.distributed +import torch_xla.core.xla_model as xm +from neuronx_distributed.modules.moe.experts import Experts +from neuronx_distributed.parallel_layers.layers import divide +from neuronx_distributed.parallel_layers.parallel_state import ( + get_data_parallel_rank, + get_expert_model_parallel_rank, + get_tensor_model_parallel_rank, +) +from torch import Tensor +from torch.nn import Module, ModuleList +from transformers.models.llama.modeling_llama import LlamaMLP + +torch.set_printoptions(precision=2, linewidth=320, sci_mode=False) + + +if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend="xla") + +# test configuration for parallelism: +# - TP within node +# - EP across node +# TODO. also need to test cases where there is DP_exp > 1 +n_proc_per_node = int(os.environ["LOCAL_WORLD_SIZE"]) +n_nodes = divide(torch.distributed.get_world_size(), n_proc_per_node) +nxd.parallel_layers.initialize_model_parallel( + tensor_model_parallel_size=n_proc_per_node, + pipeline_model_parallel_size=1, + expert_model_parallel_size=n_nodes, +) +nxd.parallel_layers.random.model_parallel_xla_manual_seed(0) + +neuron_cc_flags = [ + "--auto-cast none", + # "--internal-compiler-debug-mode=all" +] +os.environ["NEURON_CC_FLAGS"] = " ".join(neuron_cc_flags) + + +def assert_close(actual, expected, name): + tp_rank = get_tensor_model_parallel_rank() + ep_rank = get_expert_model_parallel_rank() + dp_rank = get_data_parallel_rank() + + def rank_msg(m): + return f"TP={tp_rank}, EP={ep_rank}, DP={dp_rank}, {m}" + + try: + torch.testing.assert_close(actual, expected, atol=1e-1, rtol=5e-3, msg=rank_msg) + except AssertionError as e: + print(f"EP={ep_rank}, TP={tp_rank}, DP={dp_rank}") + print(f"actual {name}") + print(actual) + print(f"expected {name}") + print(expected) + + raise e + + +class ReferenceLlamaExperts(Module): + """unpartitioned, naive implementation of expert MLPs""" + + def __init__( + self, n_experts: int, hidden_size: int, intermediate_size: int + ) -> None: + super().__init__() + self.__n_experts = n_experts + self.__hidden_size = hidden_size + cfg = SimpleNamespace( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + pretraining_tp=1, + hidden_act="silu", + ) + self.experts = ModuleList([LlamaMLP(cfg) for _ in range(n_experts)]) + + def forward(self, x_routed: Tensor) -> Tensor: + """works by iterating through experts and applying MLP to each + + Args: + x_routed (n_experts, expert_capacity, hidden_sz) + + Returns: + output (n_experts, expert_capacity, hidden_sz) + """ + input_n_experts, _, input_hidden_sz = x_routed.shape + assert input_n_experts == self.__n_experts + assert input_hidden_sz == self.__hidden_size + + expert_outputs = [None] * self.__n_experts + for expert_idx, expert in enumerate(self.experts): + expert_outputs[expert_idx] = expert.forward(x_routed[expert_idx, :, :]) + output = torch.stack(expert_outputs) + + return output + + +def _tp_partition(x: Tensor, dim: int) -> Tensor: + tp_group = nxd.parallel_layers.parallel_state.get_tensor_model_parallel_group() + return x.chunk(tp_group.size(), dim=dim)[tp_group.rank()] + + +def _convert_state_dict_expert_fused( + state_dict: Dict[str, Tensor], n_experts: int +) -> Dict[str, Tensor]: + ep_group = nxd.parallel_layers.parallel_state.get_expert_model_parallel_group() + + experts_per_ep_group = divide(n_experts, ep_group.size()) + local_expert_indices = list( + range( + (ep_group.rank() + 0) * experts_per_ep_group, + (ep_group.rank() + 1) * experts_per_ep_group, + ) + ) + + # copy the MLP parameters, accounting for TP. + up_gate_proj = torch.stack( + [ + torch.cat( + [ + _tp_partition(state_dict[f"experts.{e}.up_proj.weight"], dim=0), + _tp_partition(state_dict[f"experts.{e}.gate_proj.weight"], dim=0), + ], + dim=0, + ) + for e in local_expert_indices + ] + ) + down_proj = torch.stack( + [ + _tp_partition(state_dict[f"experts.{e}.down_proj.weight"], dim=1) + for e in local_expert_indices + ] + ) + + return { + "up_gate_proj.weight": up_gate_proj, + "down_proj.weight": down_proj, + } + + +@pytest.mark.parametrize("expert_capacity", [32, 64], ids=lambda n: f"c={n}") +@pytest.mark.parametrize("hidden_sz", [4], ids=lambda n: f"h={n}") +@pytest.mark.parametrize("inter_sz", [32], ids=lambda n: f"i={n}") +@pytest.mark.parametrize("n_experts", [4, 8], ids=lambda n: f"e={n}") +@pytest.mark.parametrize("glu", [True], ids=lambda b: f"glu={int(b)}") +@pytest.mark.parametrize( + "activation_fn", [torch.nn.functional.silu], ids=lambda f: f"act={f}" +) +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) +def test_forward_backward( + expert_capacity: int, + hidden_sz: int, + inter_sz: int, + n_experts: int, + glu: bool, + activation_fn: Callable[[Tensor], Tensor], + dtype: torch.dtype, +): + ###################################################################################### + # ARRANGE + ###################################################################################### + tp_group = nxd.parallel_layers.parallel_state.get_tensor_model_parallel_group() + ep_group = nxd.parallel_layers.parallel_state.get_expert_model_parallel_group() + + trn = torch.device("xla") + cpu = torch.device("cpu") + # removed support but leaving test code in case we need it later + output_sp = False + + n_experts_per_ep_rank = divide(n_experts, ep_group.size()) + inter_sz_local = divide(inter_sz, tp_group.size()) + + ref_experts = ( + ReferenceLlamaExperts( + n_experts=n_experts, hidden_size=hidden_sz, intermediate_size=inter_sz + ) + .to(dtype=dtype, device=cpu) + .requires_grad_(True) + ) + + experts = Experts( + n_experts=n_experts, + hidden_size=hidden_sz, + intermediate_size=inter_sz, + # only testing EP path here, which isn't compatible with this being false + reduce_output=True, + glu=glu, + activation_fn=activation_fn, + dtype=dtype, + device=trn, + ) + experts.load_state_dict( + _convert_state_dict_expert_fused(ref_experts.state_dict(), n_experts=n_experts), + strict=True, + ) + + xm.mark_step() + + x_global_ref = torch.randn( + # in this case the "expert_capacity" is for the full EP group. + # so it will contain tokens from multiple DP_nonexp groups. + (n_experts, expert_capacity, hidden_sz), + requires_grad=True, + dtype=dtype, + device=cpu, + ) + # input: (e, c/ep, h) + x = ( + x_global_ref.detach() + .chunk(ep_group.size(), dim=1)[ep_group.rank()] # EP + .to(dtype=dtype, device=trn) + .requires_grad_(True) + ) + xm.mark_step() + + ###################################################################################### + # ACT + ###################################################################################### + # impl fwd + output = experts.forward(x) + output.sum().backward(retain_graph=True) + + xm.mark_step() + + # reference fwd + output_global_ref = ref_experts.forward(x_global_ref) + output_global_ref.sum().backward(retain_graph=True) + + ###################################################################################### + # ASSERT + ###################################################################################### + # check output + if not output_sp: + expected_output = output_global_ref.chunk(ep_group.size(), dim=1)[ + ep_group.rank() + ] + else: + # fmt: off + expected_output = ( + output_global_ref + .chunk(ep_group.size(), dim=1)[ep_group.rank()] + .chunk(tp_group.size(), dim=1)[tp_group.rank()] + ) + # fmt: on + assert_close(output.cpu(), expected_output, "output") + + # check weight grads + for local_expert_idx in range(n_experts_per_ep_rank): + global_expert_idx = ep_group.rank() * n_experts_per_ep_rank + local_expert_idx + ref_expert: LlamaMLP = ref_experts.experts[global_expert_idx] + + # down (row parallel projection) + assert_close( + experts.down_proj.weight.grad[local_expert_idx, :, :].cpu(), + _tp_partition(ref_expert.down_proj.weight.grad, dim=1), + "down_weight.grad", + ) + + # up (col parallel projection) + assert_close( + experts.up_gate_proj.weight.grad[ + local_expert_idx, :inter_sz_local, : + ].cpu(), + _tp_partition(ref_expert.up_proj.weight.grad, dim=0), + "up_weight.grad", + ) + + # gate (col parallel projection) + assert_close( + experts.up_gate_proj.weight.grad[ + local_expert_idx, inter_sz_local:, : + ].cpu(), + _tp_partition(ref_expert.gate_proj.weight.grad, dim=0), + "gate_weight.grad", + ) + + # check input grad + assert_close( + x.grad.cpu(), + x_global_ref.grad.chunk(ep_group.size(), dim=1)[ep_group.rank()], + "input.grad", + ) + + +if __name__ == "__main__": + """pytest swallows some of the neuron errors so run as python script if you need + to debug things like that.""" + test_forward_backward( + expert_capacity=32, hidden_sz=4, inter_sz=32, n_experts=4, dtype=torch.bfloat16 + ) diff --git a/test/integration/modules/test_qkv_linear.py b/test/integration/modules/test_qkv_linear.py index 44b5e36..16a64f8 100644 --- a/test/integration/modules/test_qkv_linear.py +++ b/test/integration/modules/test_qkv_linear.py @@ -20,21 +20,13 @@ def parse_args(): parser = argparse.ArgumentParser(add_help=False) - parser.add_argument( - "--test_json", - required=False, - help="input json listing the test spec for network to compile", - ) parser.add_argument("--s3_dir", required=False, help="location to upload all test artifacts") parser.add_argument("--s3_bucket", default="s3://ktf-test-runs/neuronx_distributed_parallel_layers/layers") args, leftovers = parser.parse_known_args() S3_BUCKET_NAME = args.s3_bucket - with open(args.test_json, "r") as f: - test_dict = json.load(f) - return test_dict, S3_BUCKET_NAME, args + return S3_BUCKET_NAME, args - -test_config, S3_BUCKET_NAME, args = parse_args() +S3_BUCKET_NAME, args = parse_args() results = {"inference_success": 1} @@ -45,7 +37,7 @@ def set_random_seed(seed): torch.manual_seed(seed) -def test_qkv_linear_with_kv_multipler_1(tensor_model_parallel_size): +def test_qkv_linear_with_kv_multipler_1(tensor_model_parallel_size, fuse_qkv=False): def _test_qkv_linear_with_kv_multipler_1(): batch_size = 8 seq_length = 128 @@ -69,6 +61,7 @@ def _test_qkv_linear_with_kv_multipler_1(): sequence_parallel_enabled=True, keep_master_weight=True, kv_size_multiplier=1, + fuse_qkv=fuse_qkv ).to(device) row_linear = layers.RowParallelLinear( @@ -127,10 +120,18 @@ def _test_qkv_linear_with_kv_multipler_1(): with torch.no_grad(): dldy = orig_loss_weight.clone() x = orig_input_tensor.clone() - ref_q_linear.weight.copy_(col_linear.master_weight_q) - ref_k_linear.weight.copy_(col_linear.master_weight_k) - ref_v_linear.weight.copy_(col_linear.master_weight_v) - ref_mlp_linear.weight.copy_(row_linear.master_weight) + if fuse_qkv: + sizes = [tensor_model_parallel_size * hidden_size, tensor_model_parallel_size * hidden_size, tensor_model_parallel_size * hidden_size] + master_weight_q, master_weight_k, master_weight_v = torch.split(col_linear.master_weight_qkv, sizes, dim=0) + ref_q_linear.weight.copy_(master_weight_q) + ref_k_linear.weight.copy_(master_weight_k) + ref_v_linear.weight.copy_(master_weight_v) + ref_mlp_linear.weight.copy_(row_linear.master_weight) + else: + ref_q_linear.weight.copy_(col_linear.master_weight_q) + ref_k_linear.weight.copy_(col_linear.master_weight_k) + ref_v_linear.weight.copy_(col_linear.master_weight_v) + ref_mlp_linear.weight.copy_(row_linear.master_weight) x.requires_grad_() expected_q, expected_k, expected_v = ref_q_linear(x), ref_k_linear(x), ref_v_linear(x) e_q, e_k, e_v = ( @@ -181,42 +182,48 @@ def _test_qkv_linear_with_kv_multipler_1(): atol=1e-2, ), "output_v doesn't match rank{}".format(xm.get_ordinal()) - # if tensor_model_parallel_size_ == 1: expected_q_grad_chunk = ref_q_linear.weight.grad.chunk( chunks=tensor_model_parallel_size_, dim=0, )[parallel_state.get_tensor_model_parallel_rank()] - - assert np.allclose( - col_linear.weight_q.grad.detach().cpu().numpy(), - expected_q_grad_chunk.detach().cpu().numpy(), - rtol=1e-2, - atol=1e-2, - ), "grad_q doesn't match rank{}".format(xm.get_ordinal()) - expected_k_grad_chunk = ref_k_linear.weight.grad.chunk( chunks=tensor_model_parallel_size_, dim=0, )[parallel_state.get_tensor_model_parallel_rank()] - - assert np.allclose( - col_linear.weight_k.grad.detach().cpu().numpy(), - expected_k_grad_chunk.detach().cpu().numpy(), - rtol=1e-2, - atol=1e-2, - ), "grad_k doesn't match rank{}".format(xm.get_ordinal()) - expected_v_grad_chunk = ref_v_linear.weight.grad.chunk( chunks=tensor_model_parallel_size_, dim=0, )[parallel_state.get_tensor_model_parallel_rank()] - assert np.allclose( - col_linear.weight_v.grad.detach().cpu().numpy(), - expected_v_grad_chunk.detach().cpu().numpy(), - rtol=1e-2, - atol=1e-2, - ), "grad_v doesn't match rank{}".format(xm.get_ordinal()) + if fuse_qkv: + expected_qkv_grad_chunk = torch.cat([expected_q_grad_chunk, expected_k_grad_chunk, expected_v_grad_chunk], dim=0) + assert np.allclose( + col_linear.weight_qkv.grad.detach().cpu().numpy(), + expected_qkv_grad_chunk.detach().cpu().numpy(), + rtol=5e-2, + atol=1e-2, + ), "grad_qkv doesn't match rank{}".format(xm.get_ordinal()) + else: + assert np.allclose( + col_linear.weight_q.grad.detach().cpu().numpy(), + expected_q_grad_chunk.detach().cpu().numpy(), + rtol=1e-2, + atol=1e-2, + ), "grad_q doesn't match rank{}".format(xm.get_ordinal()) + + assert np.allclose( + col_linear.weight_k.grad.detach().cpu().numpy(), + expected_k_grad_chunk.detach().cpu().numpy(), + rtol=1e-2, + atol=1e-2, + ), "grad_k doesn't match rank{}".format(xm.get_ordinal()) + + assert np.allclose( + col_linear.weight_v.grad.detach().cpu().numpy(), + expected_v_grad_chunk.detach().cpu().numpy(), + rtol=1e-2, + atol=1e-2, + ), "grad_v doesn't match rank{}".format(xm.get_ordinal()) # Reset groups parallel_state.destroy_model_parallel() @@ -231,13 +238,13 @@ def _test_qkv_linear_with_kv_multipler_1(): global results try: _test_qkv_linear_with_kv_multipler_1() - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise -def test_qkv_linear_with_kv_multipler_4(tensor_model_parallel_size): +def test_qkv_linear_with_kv_multipler_4(tensor_model_parallel_size, fuse_qkv=False): def _test_qkv_linear_with_kv_multipler_4(): batch_size = 8 seq_length = 128 @@ -266,6 +273,7 @@ def _test_qkv_linear_with_kv_multipler_4(): sequence_parallel_enabled=True, keep_master_weight=True, kv_size_multiplier=kv_shared_group_size, + fuse_qkv=fuse_qkv ).to(device) row_linear = layers.RowParallelLinear( @@ -324,10 +332,18 @@ def _test_qkv_linear_with_kv_multipler_4(): with torch.no_grad(): dldy = orig_loss_weight.clone() x = orig_input_tensor.clone() - ref_q_linear.weight.copy_(col_linear.master_weight_q) - ref_k_linear.weight.copy_(col_linear.master_weight_k) - ref_v_linear.weight.copy_(col_linear.master_weight_v) - ref_mlp_linear.weight.copy_(row_linear.master_weight) + if fuse_qkv: + sizes = [tensor_model_parallel_size * hidden_size, tensor_model_parallel_size * hidden_size // kv_shared_group_size , tensor_model_parallel_size * hidden_size // kv_shared_group_size] + master_weight_q, master_weight_k, master_weight_v = torch.split(col_linear.master_weight_qkv, sizes, dim=0) + ref_q_linear.weight.copy_(master_weight_q) + ref_k_linear.weight.copy_(master_weight_k) + ref_v_linear.weight.copy_(master_weight_v) + ref_mlp_linear.weight.copy_(row_linear.master_weight) + else: + ref_q_linear.weight.copy_(col_linear.master_weight_q) + ref_k_linear.weight.copy_(col_linear.master_weight_k) + ref_v_linear.weight.copy_(col_linear.master_weight_v) + ref_mlp_linear.weight.copy_(row_linear.master_weight) x.requires_grad_() expected_q, expected_k, expected_v = ref_q_linear(x), ref_k_linear(x), ref_v_linear(x) e_q, e_k, e_v = ( @@ -355,32 +371,53 @@ def _test_qkv_linear_with_kv_multipler_4(): chunks=tensor_model_parallel_size_, dim=0, )[parallel_state.get_tensor_model_parallel_rank()] - assert np.allclose( - col_linear.weight_q.grad.detach().cpu().numpy(), - expected_q_grad_chunk.detach().cpu().numpy(), - rtol=1e-2, - atol=1e-2, - ), "grad_q doesn't match rank{}".format(xm.get_ordinal()) - expected_k_grad_chunk = ref_k_linear.weight.grad.chunk( chunks=tensor_model_parallel_size_ // kv_shared_group_size, dim=0, )[parallel_state.get_tensor_model_parallel_rank() % 8] - assert np.allclose( - col_linear.weight_k.grad.detach().cpu().numpy(), expected_k_grad_chunk.cpu().numpy(), rtol=1e-2, atol=1 - ), "grad_k doesn't match rank{}".format(xm.get_ordinal()) expected_v_grad_chunk = ref_v_linear.weight.grad.chunk( chunks=tensor_model_parallel_size_ // kv_shared_group_size, dim=0, )[parallel_state.get_tensor_model_parallel_rank() % 8] - - assert np.allclose( - col_linear.weight_v.grad.detach().cpu().numpy(), - expected_v_grad_chunk.detach().cpu().numpy(), - rtol=5e-2, - atol=1e-1, - ), "grad_v doesn't match rank{}".format(xm.get_ordinal()) + if fuse_qkv: + grad_q_chunk, grad_k_chunk, grad_v_chunk = torch.split(col_linear.weight_qkv.grad.detach().cpu(), [hidden_size,hidden_size,hidden_size], dim=0) + assert np.allclose( + grad_q_chunk.numpy(), + expected_q_grad_chunk.detach().cpu().numpy(), + rtol=5e-2, + atol=1e-2, + ), "grad_q doesn't match rank{}".format(xm.get_ordinal()) + assert np.allclose( + grad_k_chunk.numpy(), + expected_k_grad_chunk.detach().cpu().numpy(), + rtol=1e-2, + atol=1, + ), "grad_k doesn't match rank{}".format(xm.get_ordinal()) + assert np.allclose( + grad_v_chunk.numpy(), + expected_v_grad_chunk.detach().cpu().numpy(), + rtol=5e-2, + atol=1e-1, + ), "grad_v doesn't match rank{}".format(xm.get_ordinal()) + else: + assert np.allclose( + col_linear.weight_q.grad.detach().cpu().numpy(), + expected_q_grad_chunk.detach().cpu().numpy(), + rtol=1e-2, + atol=1e-2, + ), "grad_q doesn't match rank{}".format(xm.get_ordinal()) + + assert np.allclose( + col_linear.weight_k.grad.detach().cpu().numpy(), expected_k_grad_chunk.cpu().numpy(), rtol=1e-2, atol=1 + ), "grad_k doesn't match rank{}".format(xm.get_ordinal()) + + assert np.allclose( + col_linear.weight_v.grad.detach().cpu().numpy(), + expected_v_grad_chunk.detach().cpu().numpy(), + rtol=5e-2, + atol=1e-1, + ), "grad_v doesn't match rank{}".format(xm.get_ordinal()) # Reset groups parallel_state.destroy_model_parallel() @@ -395,20 +432,12 @@ def _test_qkv_linear_with_kv_multipler_4(): global results try: _test_qkv_linear_with_kv_multipler_4() - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise -def on_exit(): - if xm.get_ordinal() == 0: - for k in test_config: - os.system(f"rm {args.test_json}") - with open(args.test_json, "w") as f: - json.dump({k: results}, f) - - if __name__ == "__main__": if requires_init_pg_override(): import torch_xla.experimental.pjrt_backend # noqa @@ -418,6 +447,9 @@ def on_exit(): torch.distributed.init_process_group("xla") world_size = xm.xrt_world_size() tensor_model_parallel_size = 32 + # Set the XLA_DISABLE_FUNCTIONALIZATION flag to avoid accuracy issues with PT2.1 and fused_qkv + os.environ['XLA_DISABLE_FUNCTIONALIZATION'] = '0' test_qkv_linear_with_kv_multipler_1(tensor_model_parallel_size) + test_qkv_linear_with_kv_multipler_1(tensor_model_parallel_size, fuse_qkv=True) test_qkv_linear_with_kv_multipler_4(tensor_model_parallel_size) - atexit.register(on_exit) + test_qkv_linear_with_kv_multipler_4(tensor_model_parallel_size, fuse_qkv=True) diff --git a/test/integration/modules/test_sharded_conv_blocks.py b/test/integration/modules/test_sharded_conv_blocks.py index f0fdb0e..3df8e45 100644 --- a/test/integration/modules/test_sharded_conv_blocks.py +++ b/test/integration/modules/test_sharded_conv_blocks.py @@ -10,13 +10,13 @@ from typing import Tuple import torch -import torch.nn.init as init +import torch.nn.init as init # noqa: F401 import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met from neuronx_distributed.parallel_layers import layers, parallel_state -from neuronx_distributed.parallel_layers.pad import pad_model -from neuronx_distributed.parallel_layers.random import model_parallel_xla_manual_seed +from neuronx_distributed.parallel_layers.pad import pad_model # noqa: F401 +from neuronx_distributed.parallel_layers.random import model_parallel_xla_manual_seed # noqa: F401 from neuronx_distributed.parallel_layers.utils import requires_init_pg_override import diffusers @@ -33,11 +33,12 @@ sys.path.append(parentdir) # Import the module from the parent directory -from common.integration_test_utils import test_init, test_cleanup, test_modules, print_separator, set_random_seed +from common.integration_test_utils import test_init, test_cleanup, test_modules, print_separator, set_random_seed # noqa: E402, F401 UP_BLOCKS = (CrossAttnUpBlock2D, UpBlock2D) CROSS_ATTN_BLOCKS = (CrossAttnDownBlock2D, CrossAttnUpBlock2D, UNetMidBlock2DCrossAttn) + def parse_args(): parser = argparse.ArgumentParser(add_help=False) parser.add_argument( @@ -61,29 +62,31 @@ def parse_args(): MID_BLOCK = "mid_block" UP_BLOCK = "up_blocks" + def get_sharded_data(data: torch.Tensor, dim: int) -> torch.Tensor: tp_rank = parallel_state.get_tensor_model_parallel_rank() per_partition_size = data.shape[dim] // parallel_state.get_tensor_model_parallel_size() if dim == 0: return data[ - per_partition_size * tp_rank : per_partition_size * (tp_rank + 1) + per_partition_size * tp_rank: per_partition_size * (tp_rank + 1) ].clone() elif dim == 1: return data[ - :, per_partition_size * tp_rank : per_partition_size * (tp_rank + 1) + :, per_partition_size * tp_rank: per_partition_size * (tp_rank + 1) ].clone() else: raise Exception( f"Partiton value of 0,1 are supported, found {dim}." ) + # Shard a given Conv2d to be an OutputChannelParallelConv2d or InputChannelParallelConv2d, # including copying the weight/bias data to the sharded conv from the original def shard_conv2d(conv: torch.nn.Module, layer_type: type, gather_output: bool = None, input_is_parallel: bool = None) -> torch.nn.Module: allowed_layer_types = [layers.InputChannelParallelConv2d, layers.OutputChannelParallelConv2d] assert layer_type in allowed_layer_types, f"Requested layer must be one of {allowed_layer_types} but got {layer_type}" - assert (layer_type == layers.InputChannelParallelConv2d and input_is_parallel is not None) or (layer_type == layers.OutputChannelParallelConv2d and gather_output is not None), f"Must specify gather_output for OutputChannelParallelConv2d or input_is_parallel for InputChannelParallelConv2d" - + assert (layer_type == layers.InputChannelParallelConv2d and input_is_parallel is not None) or (layer_type == layers.OutputChannelParallelConv2d and gather_output is not None), "Must specify gather_output for OutputChannelParallelConv2d or input_is_parallel for InputChannelParallelConv2d" + orig_conv = conv partition_dim = 0 if layer_type == layers.OutputChannelParallelConv2d else 1 kw = {'gather_output': gather_output} if layer_type == layers.OutputChannelParallelConv2d else {'input_is_parallel': input_is_parallel} @@ -96,15 +99,16 @@ def shard_conv2d(conv: torch.nn.Module, layer_type: type, gather_output: bool = # InputChannelParallel bias not sharded conv.bias.data.copy_(orig_conv.bias.data) - del(orig_conv) + del orig_conv return conv + def shard_groupnorm(norm: torch.nn.Module) -> torch.nn.Module: tp_degree = parallel_state.get_tensor_model_parallel_size() if norm.num_channels % tp_degree != 0 or (norm.num_channels // tp_degree) % norm.num_groups != 0: raise NotImplementedError(f"Have not implemented padding for norms yet. Cannot shard {norm} to TP degree {tp_degree}") - + orig_norm = norm norm = torch.nn.GroupNorm(orig_norm.num_groups, orig_norm.num_channels // tp_degree, orig_norm.eps, orig_norm.affine) norm.weight.data = get_sharded_data(orig_norm.weight.data, 0) @@ -112,6 +116,7 @@ def shard_groupnorm(norm: torch.nn.Module) -> torch.nn.Module: return norm + def shard_sd_resnet_block(block: torch.nn.Module) -> torch.nn.Module: assert hasattr(block, 'conv1') and hasattr(block, 'conv2'), f"Expected the module being tested has a conv1 and conv2 to shard but found it doesn't! Selected module: {block}" @@ -138,9 +143,10 @@ def shard_sd_resnet_block(block: torch.nn.Module) -> torch.nn.Module: if hasattr(block, 'conv_shortcut') and block.conv_shortcut is not None: block.conv_shortcut = shard_conv2d(block.conv_shortcut, layers.OutputChannelParallelConv2d, gather_output=True) - + return block + # model: HuggingFace model ID, e.g. stabilityai/stable-diffusion-2-1-base # block_type: down block, mid block, or up block # block_idx: index of the block in the list of selected block type, e.g. which downblock @@ -170,7 +176,7 @@ def _test_stable_diffusion_resnet_block(): control_module = copy.deepcopy(blocks[block_idx].resnets[resnet_idx]) test_module = copy.deepcopy(blocks[block_idx].resnets[resnet_idx]) - del(model) + del model # Build the input tuple input_channels = test_module.conv1.in_channels @@ -190,17 +196,17 @@ def _test_stable_diffusion_resnet_block(): # See V1310769999 pass_fail = test_modules(test_module, control_module, input_tuple, check_output_tensor=False) test_cleanup() - + return pass_fail global results try: ret = None ret = _test_stable_diffusion_resnet_block() - assert all(ret), f"Test failed!" + assert all(ret), "Test failed!" # If we reach this point, test has passed xm.master_print("test passed") - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) print(f"test_stable_diffusion_resnet_block FAILED for {model_id}.{block_type}[{block_idx}].resnets[{resnet_idx}], input size {input_size}") @@ -209,7 +215,7 @@ def _test_stable_diffusion_resnet_block(): if ret is None: # Compilation failed ret = (False, False, False) - + gc.collect() return ret @@ -241,7 +247,7 @@ def _test_stable_diffusion_unet_block(): control_module = copy.deepcopy(blocks[block_idx]) test_module = copy.deepcopy(blocks[block_idx]) - del(model) + del model assert hasattr(control_module, 'resnets'), f"{control_module}\nExpected the module being tested (see above) has an attribute 'resnets' (a list of resnet blocks) but found it doesn't!" @@ -256,7 +262,7 @@ def _test_stable_diffusion_unet_block(): hidden_states_shape = (batch_size, input_channels, input_spatial_dim, input_spatial_dim) input_hidden_states = torch.randn(hidden_states_shape, requires_grad=True) - + input_list.append(input_hidden_states) if isinstance(test_module, UP_BLOCKS): @@ -266,10 +272,10 @@ def _test_stable_diffusion_unet_block(): for i, resnet in enumerate(test_module.resnets): # Cross attn upblocks vary their resnet sizes, so need to choose the right number of channels if isinstance(test_module, CROSS_ATTN_BLOCKS) and i != 0: - input_channels = resnet.conv1.in_channels - test_module.attentions[i-1].proj_out.out_features + input_channels = resnet.conv1.in_channels - test_module.attentions[i - 1].proj_out.out_features else: input_channels = resnet.conv1.in_channels // 2 - + shape = (batch_size, input_channels, input_spatial_dim, input_spatial_dim) xm.master_print(f"computed shape of {shape} for resnet {i}") res_hidden_states.append(torch.randn(shape, requires_grad=True)) @@ -279,10 +285,9 @@ def _test_stable_diffusion_unet_block(): res_hidden_states = tuple(res_hidden_states) input_list.append(res_hidden_states) - temb_shape = (batch_size, temb_in_features) input_temb = torch.randn(temb_shape, requires_grad=True) - + input_list.append(input_temb) if isinstance(test_module, CROSS_ATTN_BLOCKS): @@ -306,15 +311,13 @@ def _test_stable_diffusion_unet_block(): return pass_fail - - global results try: ret = None ret = _test_stable_diffusion_unet_block() - assert all(ret), f"Test failed!" + assert all(ret), "Test failed!" xm.master_print("test passed") - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) print(f"test_stable_diffusion_unet_block FAILED for {model_id}.{block_type}[{block_idx}], input size {input_size}") @@ -324,11 +327,11 @@ def _test_stable_diffusion_unet_block(): # Compilation failed ret = (False, False, False) - gc.collect() return ret + def upload_to_s3(): os.system(f'aws s3 cp --no-progress "{datetime_str}" {S3_BUCKET_NAME}') print(met.metrics_report()) @@ -364,7 +367,7 @@ def on_exit(): test_results_csv_file = open("sharded_conv_functional_block_test_results.csv", "w+") test_results_csv_file.write("module_id,batch_size,input_shape,compile_pass,output_pass,grads_pass,test_pass\n") test_results_csv_file.flush() - + # TODO: test more models - at least SD 1.5 and SD 2.1 non-base model_id = "stabilityai/stable-diffusion-2-1-base" @@ -377,7 +380,7 @@ def on_exit(): num_resnets_per_block[MID_BLOCK].append(len(unet.mid_block.resnets)) for block in unet.up_blocks: num_resnets_per_block[UP_BLOCK].append(len(block.resnets)) - del(unet) + del unet for block_type in [DOWN_BLOCK, MID_BLOCK, UP_BLOCK]: for block_idx, num_resnets in enumerate(num_resnets_per_block[block_type]): for resnet_idx in range(0, num_resnets): @@ -390,7 +393,7 @@ def on_exit(): test_results_csv_file.write(f"{block_identifier},{1},{input_size},{pass_fail[0]},{pass_fail[1]},{pass_fail[2]},{all(pass_fail)}\n") test_results_csv_file.flush() xm.mark_step() - + # TODO: mid and up blocks fail for various issues, tracked in the following tickets # V1317582055 # V1317572506 @@ -406,7 +409,7 @@ def on_exit(): test_results_csv_file.write(f"{block_identifier},{1},{input_size},{pass_fail[0]},{pass_fail[1]},{pass_fail[2]},{all(pass_fail)}\n") test_results_csv_file.flush() xm.mark_step() - + if xm.is_master_ordinal(): test_results_csv_file.close() atexit.register(on_exit) diff --git a/test/integration/parallel_layers/pipeline/test_comm.py b/test/integration/parallel_layers/pipeline/test_comm.py index aa04fda..6502b0d 100644 --- a/test/integration/parallel_layers/pipeline/test_comm.py +++ b/test/integration/parallel_layers/pipeline/test_comm.py @@ -29,11 +29,6 @@ def parse_args(): parser = argparse.ArgumentParser(add_help=False) - parser.add_argument( - "--test_json", - required=False, - help="input json listing the test spec for network to compile", - ) parser.add_argument("--s3_dir", required=False, help="location to upload all test artifacts") parser.add_argument( "--s3_bucket", @@ -41,12 +36,10 @@ def parse_args(): ) args, leftovers = parser.parse_known_args() S3_BUCKET_NAME = args.s3_bucket - with open(args.test_json, "r") as f: - test_dict = json.load(f) - return test_dict, S3_BUCKET_NAME, args + return S3_BUCKET_NAME, args -test_config, S3_BUCKET_NAME, args = parse_args() +S3_BUCKET_NAME, args = parse_args() results = {"inference_success": 1} @@ -74,13 +67,13 @@ def _test_send_and_recv(): ) if get_pipeline_model_parallel_rank() == 0: a = torch.rand(2, 3, device=xm.xla_device()) - t = send(a) + t = send(a) #noqa xm.mark_step() torch.save(a.cpu(), "tensor.pt") elif get_pipeline_model_parallel_rank() < 7: recv_a = recv_from(tensor_meta) xm.mark_step() - t = send(recv_a) + t = send(recv_a) #noqa xm.mark_step() recv_a_cpu = recv_a.to(torch.device("cpu")) a = torch.load("tensor.pt", map_location=torch.device("cpu")) @@ -95,7 +88,7 @@ def _test_send_and_recv(): global results try: _test_send_and_recv() - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise @@ -120,7 +113,7 @@ def _test_1f_1b_comm(): # Testing 1F1B communication if get_pipeline_model_parallel_rank() == 0: forward = torch.rand(2, 3, device=xm.xla_device()) - t = send(forward) + t = send(forward) #noqa xm.mark_step() torch.save(forward.cpu(), "forward.pt") recv_backward = recv_from(backward_tensor_meta, recv_prev=False) @@ -131,11 +124,11 @@ def _test_1f_1b_comm(): elif get_pipeline_model_parallel_rank() < 7: recv_forward = recv_from(forward_tensor_meta) xm.mark_step() - t = send(recv_forward) + t = send(recv_forward) #noqa xm.mark_step() recv_backward = recv_from(backward_tensor_meta, recv_prev=False) xm.mark_step() - t = send(recv_backward, send_next=False) + t = send(recv_backward, send_next=False) #noqa xm.mark_step() recv_forward_cpu = recv_forward.to(torch.device("cpu")) forward = torch.load("forward.pt", map_location=torch.device("cpu")) @@ -147,7 +140,7 @@ def _test_1f_1b_comm(): recv_forward = recv_from(forward_tensor_meta) xm.mark_step() backward = torch.rand(1, 2, device=xm.xla_device()) - t = send(backward, send_next=False) + t = send(backward, send_next=False) #noqa xm.mark_step() torch.save(backward.cpu(), "backward.pt") recv_forward_cpu = recv_forward.to(torch.device("cpu")) @@ -157,7 +150,7 @@ def _test_1f_1b_comm(): global results try: _test_1f_1b_comm() - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise @@ -191,23 +184,14 @@ def _test_send_and_recv_python_object(): global results try: _test_send_and_recv_python_object() - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise -def upload_to_s3(): - os.system(f'aws s3 cp --no-progress "{datetime_str}" {S3_BUCKET_NAME}') - print(met.metrics_report()) - - def on_exit(): - upload_to_s3() - for k in test_config: - os.system(f"rm {args.test_json}") - with open(args.test_json, "w") as f: - json.dump({k: results}, f) + print(met.metrics_report()) if __name__ == "__main__": diff --git a/test/integration/parallel_layers/test_bert_pretraining.py b/test/integration/parallel_layers/test_bert_pretraining.py index 2f513d3..80bbadc 100644 --- a/test/integration/parallel_layers/test_bert_pretraining.py +++ b/test/integration/parallel_layers/test_bert_pretraining.py @@ -64,7 +64,6 @@ import neuronx_distributed as nxd from neuronx_distributed.parallel_layers import ( checkpointing, - grads, layers, parallel_state, ) @@ -74,16 +73,16 @@ ) from neuronx_distributed.utils.model_utils import move_model_to_device +# Workaround for NaNs seen with transformers version >= 4.21.0 +# https://github.com/aws-neuron/aws-neuron-sdk/issues/593 +import transformers.modeling_utils as modeling_utils + os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + " --model-type=transformer" # For PT autocast. torch.cuda.is_bf16_supported = lambda: True -# Workaround for NaNs seen with transformers version >= 4.21.0 -# https://github.com/aws-neuron/aws-neuron-sdk/issues/593 -import transformers.modeling_utils as modeling_utils - if os.environ.get("XLA_USE_BF16") or os.environ.get("XLA_DOWNCAST_BF16"): modeling_utils.get_parameter_dtype = lambda x: torch.bfloat16 @@ -222,7 +221,7 @@ def get_instance_type(self): headers={"X-aws-ec2-metadata-token": token.text}, ) return data.text - except: + except Exception: return os.environ.get("HOSTNAME", "unknown") def log(self, epoch, step, step_loss, learning_rate, throughput, grad_norm=None): @@ -230,7 +229,7 @@ def log(self, epoch, step, step_loss, learning_rate, throughput, grad_norm=None) grad_norm_msg = f"grad-norm : {grad_norm}" if grad_norm else "" print( f"LOG {time_now} - ({epoch}, {step}) step_loss : {step_loss:.4f} " - f"learning_rate : {learning_rate:.2e} throughput : {throughput:.2f} " + f"learning_rate : {learning_rate:.2e} throughput : {throughput:.2f} seq/s " f"{grad_norm_msg}", flush=True, ) @@ -507,8 +506,6 @@ def train_bert_hdf5(flags): ) def train_loop_fn(model, optimizer, train_loader, epoch, global_step, training_ustep, running_loss): - max_grad_norm = 1.0 - for i, data in enumerate(train_loader): training_ustep += 1 ( @@ -577,7 +574,6 @@ def _print_logs(running_loss_reduced_detached, total_norm=None): scheduler_state_dict = None if flags.resume_ckpt: - step = flags.resume_step state_dict = checkpointing.load(flags.output_dir, model) optimizer.load_state_dict(state_dict["optimizer"]) global_step = state_dict["global_step"] diff --git a/test/integration/parallel_layers/test_checkpoint.py b/test/integration/parallel_layers/test_checkpoint.py index 7f990ee..1ba093f 100644 --- a/test/integration/parallel_layers/test_checkpoint.py +++ b/test/integration/parallel_layers/test_checkpoint.py @@ -19,21 +19,15 @@ def parse_args(): parser = argparse.ArgumentParser(add_help=False) - parser.add_argument( - "--test_json", - required=False, - help="input json listing the test spec for network to compile", - ) parser.add_argument("--s3_dir", required=False, help="location to upload all test artifacts") parser.add_argument("--s3_bucket", default="s3://ktf-test-runs/neuronx_distributed_parallel_layers/layers") args, leftovers = parser.parse_known_args() S3_BUCKET_NAME = args.s3_bucket - with open(args.test_json, "r") as f: - test_dict = json.load(f) - return test_dict, S3_BUCKET_NAME, args + return S3_BUCKET_NAME, args -test_config, S3_BUCKET_NAME, args = parse_args() +# test_config, S3_BUCKET_NAME, args = parse_args() +S3_BUCKET_NAME, args = parse_args() results = {"inference_success": 1} @@ -69,23 +63,14 @@ def _test_zero1_checkpoint(): global results try: _test_zero1_checkpoint() - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise -def upload_to_s3(): - os.system(f'aws s3 cp --no-progress "{datetime_str}" {S3_BUCKET_NAME}') - print(met.metrics_report()) - - def on_exit(): - upload_to_s3() - for k in test_config: - os.system(f"rm {args.test_json}") - with open(args.test_json, "w") as f: - json.dump({k: results}, f) + print(met.metrics_report()) if __name__ == "__main__": diff --git a/test/integration/parallel_layers/test_grads.py b/test/integration/parallel_layers/test_grads.py index af96f27..6ec5ebf 100644 --- a/test/integration/parallel_layers/test_grads.py +++ b/test/integration/parallel_layers/test_grads.py @@ -22,21 +22,13 @@ def parse_args(): parser = argparse.ArgumentParser(add_help=False) - parser.add_argument( - "--test_json", - required=False, - help="input json listing the test spec for network to compile", - ) parser.add_argument("--s3_dir", required=False, help="location to upload all test artifacts") parser.add_argument("--s3_bucket", default="s3://ktf-test-runs/neuronx_distributed_parallel_layers/layers") args, leftovers = parser.parse_known_args() S3_BUCKET_NAME = args.s3_bucket - with open(args.test_json, "r") as f: - test_dict = json.load(f) - return test_dict, S3_BUCKET_NAME, args + return S3_BUCKET_NAME, args - -test_config, S3_BUCKET_NAME, args = parse_args() +S3_BUCKET_NAME, args = parse_args() results = {"inference_success": 1} @@ -112,7 +104,7 @@ def _test_single_layer_output_module(): global results try: _test_single_layer_output_module() - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise @@ -218,24 +210,12 @@ def _test_tp_zero1_pp_gradient_clipping(tensor_parallel_size, pipeline_parallel_ global results try: _test_tp_zero1_pp_gradient_clipping(tensor_parallel_size, pipeline_parallel_size) - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise -def upload_to_s3(): - os.system(f'aws s3 cp --no-progress "{datetime_str}" {S3_BUCKET_NAME}') - - -def on_exit(): - # upload_to_s3() - for k in test_config: - os.system(f"rm {args.test_json}") - with open(args.test_json, "w") as f: - json.dump({k: results}, f) - - if __name__ == "__main__": if requires_init_pg_override(): import torch_xla.experimental.pjrt_backend # noqa @@ -249,4 +229,3 @@ def on_exit(): test_tp_zero1_pp_gradient_clipping(tensor_parallel_size=8, pipeline_parallel_size=1) test_tp_zero1_pp_gradient_clipping(tensor_parallel_size=2, pipeline_parallel_size=4) test_single_layer_output_module() - atexit.register(on_exit) diff --git a/test/integration/parallel_layers/test_layers.py b/test/integration/parallel_layers/test_layers.py index 62e2870..791237e 100644 --- a/test/integration/parallel_layers/test_layers.py +++ b/test/integration/parallel_layers/test_layers.py @@ -1,6 +1,5 @@ import argparse import atexit -import json import os import traceback from datetime import datetime @@ -26,26 +25,19 @@ sys.path.append(parentdir) # Import the module from the parent directory -from common.integration_test_utils import test_init, test_cleanup, test_modules +from common.integration_test_utils import test_init, test_cleanup, test_modules # noqa: E402 def parse_args(): parser = argparse.ArgumentParser(add_help=False) - parser.add_argument( - "--test_json", - required=False, - help="input json listing the test spec for network to compile", - ) parser.add_argument("--s3_dir", required=False, help="location to upload all test artifacts") parser.add_argument("--s3_bucket", default="s3://ktf-test-runs/neuronx_distributed_parallel_layers/layers") args, leftovers = parser.parse_known_args() S3_BUCKET_NAME = args.s3_bucket - with open(args.test_json, "r") as f: - test_dict = json.load(f) - return test_dict, S3_BUCKET_NAME, args + return S3_BUCKET_NAME, args -test_config, S3_BUCKET_NAME, args = parse_args() +S3_BUCKET_NAME, args = parse_args() results = {"inference_success": 1} @@ -103,7 +95,7 @@ def _test_parallel_embedding(): global results try: _test_parallel_embedding() - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise @@ -173,7 +165,7 @@ def _test_parallel_embedding(): global results try: _test_parallel_embedding() - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise @@ -247,7 +239,7 @@ def _test_initialize_parameter_cpu(): global results try: _test_initialize_parameter_cpu() - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise @@ -342,7 +334,7 @@ def _test_row_parallel_linear_seq_parallel(): global results try: _test_row_parallel_linear_seq_parallel() - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise @@ -436,7 +428,7 @@ def _test_column_parallel_linear_seq_parallel(): global results try: _test_column_parallel_linear_seq_parallel() - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise @@ -553,13 +545,12 @@ def _test_padding_attention_heads(): global results try: _test_padding_attention_heads() - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise - def test_output_channel_parallel_conv(tensor_model_parallel_size): def _test_output_channel_parallel_conv(): test_init(tensor_model_parallel_size, 1234) @@ -612,11 +603,12 @@ def _test_output_channel_parallel_conv(): try: _test_output_channel_parallel_conv() xm.master_print("test passed") - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise + def test_input_channel_parallel_conv(tensor_model_parallel_size): def _test_input_channel_parallel_conv(): test_init(tensor_model_parallel_size, 1234) @@ -668,11 +660,12 @@ def _test_input_channel_parallel_conv(): try: _test_input_channel_parallel_conv() xm.master_print("test passed") - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise + class BackToBackConvs(torch.nn.Module): def __init__(self, conv1_args, conv2_args, parallel: bool = False): super().__init__() @@ -734,22 +727,14 @@ def _test_back_to_back_parallel_convs(): try: _test_back_to_back_parallel_convs() xm.master_print("test passed") - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise -def upload_to_s3(): - os.system(f'aws s3 cp --no-progress "{datetime_str}" {S3_BUCKET_NAME}') - print(met.metrics_report()) - def on_exit(): - upload_to_s3() - for k in test_config: - os.system(f"rm {args.test_json}") - with open(args.test_json, "w") as f: - json.dump({k: results}, f) + print(met.metrics_report()) if __name__ == "__main__": diff --git a/test/integration/parallel_layers/test_loss_functions.py b/test/integration/parallel_layers/test_loss_functions.py index ce8173e..7a1e716 100644 --- a/test/integration/parallel_layers/test_loss_functions.py +++ b/test/integration/parallel_layers/test_loss_functions.py @@ -24,11 +24,6 @@ def parse_args(): parser = argparse.ArgumentParser(add_help=False) - parser.add_argument( - "--test_json", - required=False, - help="input json listing the test spec for network to compile", - ) parser.add_argument("--s3_dir", required=False, help="location to upload all test artifacts") parser.add_argument( "--s3_bucket", @@ -36,12 +31,9 @@ def parse_args(): ) args, leftovers = parser.parse_known_args() S3_BUCKET_NAME = args.s3_bucket - with open(args.test_json, "r") as f: - test_dict = json.load(f) - return test_dict, S3_BUCKET_NAME, args + return S3_BUCKET_NAME, args - -test_config, S3_BUCKET_NAME, args = parse_args() +S3_BUCKET_NAME, args = parse_args() results = {"inference_success": 1} @@ -122,23 +114,14 @@ def _test_parallel_cross_entropy(): global results try: _test_parallel_cross_entropy() - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise -def upload_to_s3(): - os.system(f'aws s3 cp --no-progress "{datetime_str}" {S3_BUCKET_NAME}') - print(met.metrics_report()) - - def on_exit(): - upload_to_s3() - for k in test_config: - os.system(f"rm {args.test_json}") - with open(args.test_json, "w") as f: - json.dump({k: results}, f) + print(met.metrics_report()) if __name__ == "__main__": diff --git a/test/integration/parallel_layers/test_parallel_state.py b/test/integration/parallel_layers/test_parallel_state.py index bf60d65..790c42c 100644 --- a/test/integration/parallel_layers/test_parallel_state.py +++ b/test/integration/parallel_layers/test_parallel_state.py @@ -18,11 +18,6 @@ def parse_args(): parser = argparse.ArgumentParser(add_help=False) - parser.add_argument( - "--test_json", - required=False, - help="input json listing the test spec for network to compile", - ) parser.add_argument("--s3_dir", required=False, help="location to upload all test artifacts") parser.add_argument( "--s3_bucket", @@ -30,12 +25,10 @@ def parse_args(): ) args, leftovers = parser.parse_known_args() S3_BUCKET_NAME = args.s3_bucket - with open(args.test_json, "r") as f: - test_dict = json.load(f) - return test_dict, S3_BUCKET_NAME, args + return S3_BUCKET_NAME, args -test_config, S3_BUCKET_NAME, args = parse_args() +S3_BUCKET_NAME, args = parse_args() results = {"inference_success": 1} @@ -77,7 +70,7 @@ def check(group, world_size, rank): global results try: _test_initialize_model_parallel() - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise @@ -106,23 +99,14 @@ def _test_get_tensor_model_parallel_src_rank(): global results try: _test_get_tensor_model_parallel_src_rank() - except: + except Exception: results["inference_success"] = 0 print(traceback.format_exc()) raise -def upload_to_s3(): - os.system(f'aws s3 cp --no-progress "{datetime_str}" {S3_BUCKET_NAME}') - print(met.metrics_report()) - - def on_exit(): - upload_to_s3() - for k in test_config: - os.system(f"rm {args.test_json}") - with open(args.test_json, "w") as f: - json.dump({k: results}, f) + print(met.metrics_report()) if __name__ == "__main__": diff --git a/test/integration/quantization/test_quantized_expert_fused_mlp.py b/test/integration/quantization/test_quantized_expert_fused_mlp.py new file mode 100644 index 0000000..5cee6fd --- /dev/null +++ b/test/integration/quantization/test_quantized_expert_fused_mlp.py @@ -0,0 +1,189 @@ +import os +import shutil +from concurrent.futures import ProcessPoolExecutor + +import torch +from torch.ao.nn.quantized.dynamic.modules.linear import _quantize_weight +from torch.ao.quantization.qconfig import default_dynamic_qconfig + +import neuronx_distributed.parallel_layers.parallel_state as p_state +from neuronx_distributed.modules.moe.moe_parallel_layers import ( + ExpertFusedColumnParallelLinear, + ExpertFusedRowParallelLinear, +) +from neuronx_distributed.quantization.quantize import convert + +num_experts = 3 +intermediate_size = 4 +hidden_size = 5 +capacity = 2 +torch.manual_seed(0) + +TEMP_COMPILER_WORK_DIR = "compiler_workdir" +TEMP_STATE_DICT_NAME = "quantized_state_dict.pt" + +SCALE = 0.123 + + +class Model(torch.nn.Module): + def __init__(self): + torch.manual_seed(0) # to ensure the weight is the same on every initialization + super().__init__() + self.lay1 = ExpertFusedColumnParallelLinear( + num_experts=num_experts, + input_size=hidden_size, + output_size=intermediate_size, + dtype=torch.float32, + ) + self.lay2 = ExpertFusedRowParallelLinear( + num_experts=num_experts, + input_size=intermediate_size, + output_size=hidden_size, + reduce_output=True, + dtype=torch.float32, + ) + self.lay3 = ExpertFusedColumnParallelLinear( + num_experts=num_experts, + input_size=hidden_size, + output_size=intermediate_size, + dtype=torch.float32, + ) + self.lay4 = ExpertFusedRowParallelLinear( + num_experts=num_experts, + input_size=intermediate_size, + output_size=hidden_size, + reduce_output=True, + dtype=torch.float32, + ) + + def forward(self, x): + x = self.lay1(x) + x = self.lay2(x) + x = self.lay3(x) + x = self.lay4(x) + return x + + +def quantize_weight(float_state_dict): + int8_state_dict = {} + for name, weight in float_state_dict.items(): + weight_observer = default_dynamic_qconfig.weight() + weight_observer(weight) + qint8_weight = _quantize_weight(weight, weight_observer) + int8_state_dict[name] = qint8_weight.int_repr() + int8_state_dict[name.replace("weight", "scale")] = torch.tensor([qint8_weight.q_scale()]) + return int8_state_dict + + +def get_input(): + return torch.randn((num_experts, capacity, hidden_size)) + + +def init_ditributed_env(): + os.environ["RANK"] = str(0) + os.environ["WORLD_SIZE"] = str(1) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "2024" + + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend="xla") + + p_state.destroy_model_parallel() + p_state.initialize_model_parallel(tensor_model_parallel_size=1) + + +def _prepare_state_dict(): + init_ditributed_env() + with torch.no_grad(): + model = Model() + float_sd = model.state_dict() + q_sd = quantize_weight(float_sd) + torch.save(q_sd, TEMP_STATE_DICT_NAME) + p_state.destroy_model_parallel() + + +def prepare_state_dict(): + with ProcessPoolExecutor(max_workers=1) as executor: + future = executor.submit(_prepare_state_dict) + future.result() + + +def load_model(): + model = Model() + model_quant = convert(model, q_config=None, inplace=True, mapping=None) + print(model_quant) + all_parameters_name = [] + for name, _ in model_quant.named_parameters(): + all_parameters_name.append(name) + print(all_parameters_name) + + alias = {} + + return model_quant, alias + + +def checkpoint_loader_fn(): + return torch.load(TEMP_STATE_DICT_NAME) + + +def load_traced_model(input_fp32): + from neuronx_distributed.trace import parallel_model_trace + + sample_inputs = input_fp32 + traced_model = parallel_model_trace( + load_model, # This loads the parallel model + sample_inputs, + tp_degree=2, + compiler_workdir=TEMP_COMPILER_WORK_DIR, # This is where you will find the hlo & neff + compiler_args="--auto-cast=none", # Pass your compiler flags here, + inline_weights_to_neff=False, + spmd_mode=True, + checkpoint_loader_callable=checkpoint_loader_fn, + force_custom_init_on_device=True, + ) + return traced_model + + +def get_output_from_traced_quantized_model(input_fp32): + prepare_state_dict() + traced_quantized_model = load_traced_model(input_fp32) + return traced_quantized_model(input_fp32) + + +def _get_output_from_cpu_model(input_fp32): + init_ditributed_env() + with torch.no_grad(): + model = Model() + output = model(input_fp32) + p_state.destroy_model_parallel() + return output + + +def get_output_from_cpu_model(input_fp32): + """Put execution in another process to avoid neuron device not available error""" + with ProcessPoolExecutor(max_workers=1) as executor: + future = executor.submit(_get_output_from_cpu_model, input_fp32) + output = future.result() + return output + + +def main(): + input_fp32 = get_input() + cpu_float_result = get_output_from_cpu_model(input_fp32) + traced_quantized_result = get_output_from_traced_quantized_model(input_fp32) + + print(f"cpu result: {cpu_float_result}") + print(f"traced quantized result: {traced_quantized_result}") + assert torch.allclose(cpu_float_result, traced_quantized_result, atol=1e-2) + + print("Test succeeded for Quantized Expert-fused Parallel Linear Layers!") + + if os.path.exists(TEMP_STATE_DICT_NAME): + os.remove(TEMP_STATE_DICT_NAME) + + if os.path.exists(TEMP_COMPILER_WORK_DIR) and os.path.isdir(TEMP_COMPILER_WORK_DIR): + shutil.rmtree(TEMP_COMPILER_WORK_DIR) + + +if __name__ == "__main__": + main() diff --git a/test/integration/quantization/test_quantized_mlp.py b/test/integration/quantization/test_quantized_mlp.py index b0cf932..2a876aa 100644 --- a/test/integration/quantization/test_quantized_mlp.py +++ b/test/integration/quantization/test_quantized_mlp.py @@ -1,5 +1,8 @@ import os import shutil +import traceback +from concurrent.futures import ProcessPoolExecutor +from functools import partial import torch import torch.nn.functional as F @@ -8,10 +11,17 @@ ColumnParallelLinear, RowParallelLinear, ) -from neuronx_distributed.quantization.dequantize import dequantize +from neuronx_distributed.quantization.dequantize import scale_dequantize +from neuronx_distributed.quantization.quantization_config import ( + BASE_QCONFIG_DICT_TYPE, + QuantizationType, + get_default_custom_qconfig_dict, + get_default_per_channel_custom_qconfig_dict, +) from neuronx_distributed.quantization.quantization_utils import ( - convert_float_model_to_pytorch_int8_model, convert_qint8_to_int8_state_dict, + quantize_pytorch_model_per_channel_symmetric, + quantize_pytorch_model_per_tensor_symmetric, ) from neuronx_distributed.quantization.quantize import convert @@ -24,29 +34,43 @@ class QuantizedCpuLinear(torch.nn.Module): CPU version of our Dequant logic . Used for testing. """ - def __init__(self, in_features: int, out_features: int, bias: bool = False, device=None, dtype=None) -> None: + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + device=None, + dtype=None, + quantization_type=QuantizationType.PER_TENSOR_SYMMETRIC, + per_channel_axis=0, + ) -> None: super(QuantizedCpuLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = torch.nn.Parameter( torch.empty(out_features, in_features, dtype=torch.int8, device=device), requires_grad=False ) - self.scale = torch.nn.Parameter(torch.tensor([1.0], dtype=dtype)) + if quantization_type == QuantizationType.PER_TENSOR_SYMMETRIC: + self.scale = torch.nn.Parameter(torch.tensor([1.0], dtype=dtype)) + else: + weight_shape = self.weight.shape + scale_shape = [1] * len(weight_shape) + scale_shape[per_channel_axis] = weight_shape[per_channel_axis] + self.scale = torch.nn.Parameter(torch.ones(scale_shape, device=self.weight.device), requires_grad=False) if bias: raise NotImplementedError() else: self.register_parameter("bias", None) def forward(self, input: torch.Tensor): - weight = dequantize(self.weight, scale=self.scale, upcast_dtype=input.dtype) + weight = scale_dequantize(self.weight, scale=self.scale, upcast_dtype=input.dtype) return F.linear(input, weight, self.bias) @classmethod def from_float( cls, mod, - quantization_type, - quantized_dtype, + q_config, ): """Create a QuantizedRowParallel from a float module @@ -60,6 +84,7 @@ def from_float( bias=mod.bias, device=mod.weight.device, dtype=mod.weight.dtype, + quantization_type=q_config["quantization_type"], ) @@ -103,28 +128,32 @@ def forward(self, x): return x -def load_qunatize_model(): +def load_qunatize_model(q_config: BASE_QCONFIG_DICT_TYPE): model_fp32 = Model() input_fp32 = torch.randn((2, dim)) # Get Pytorch Quantized Model model_fp32.eval() - model_fp32_int8 = convert_float_model_to_pytorch_int8_model(model_fp32) + if q_config["quantization_type"] == QuantizationType.PER_CHANNEL_SYMMETRIC: + model_fp32_int8 = quantize_pytorch_model_per_channel_symmetric(model_fp32) + else: + model_fp32_int8 = quantize_pytorch_model_per_tensor_symmetric(model_fp32) + state_dict = model_fp32_int8.state_dict() torch.save(state_dict, "fp32_qint8_model.pt") # Get NxD version of Quatization Model on CPU mapping = {torch.nn.Linear: QuantizedCpuLinear} - nxd_quantized_cpu_model = convert(model_fp32, q_config=None, mapping=mapping) + nxd_quantized_cpu_model = convert(model_fp32, q_config=q_config, mapping=mapping) convert_qint8_to_int8_state_dict(state_dict=state_dict) nxd_quantized_cpu_model.load_state_dict(state_dict, strict=False) return model_fp32_int8, input_fp32, model_fp32, nxd_quantized_cpu_model -def load_model(): +def load_model(q_config: BASE_QCONFIG_DICT_TYPE): model_parallel = Model_Parallel() - model_quant = convert(model_parallel, q_config=None, inplace=True, mapping=None) + model_quant = convert(model_parallel, q_config=q_config, inplace=True, mapping=None) print(model_quant) all_parameters_name = [] for name, _ in model_quant.named_parameters(): @@ -140,12 +169,13 @@ def checkpoint_loader_fn(): return torch.load("fp32_qint8_model.pt") -def load_traced_model(input_fp32): +def load_traced_model(input_fp32, qconfig): from neuronx_distributed.trace import parallel_model_trace sample_inputs = input_fp32 + load_model_partial = partial(load_model, qconfig) traced_model = parallel_model_trace( - load_model, # This loads the parallel model + load_model_partial, # This loads the parallel model sample_inputs, tp_degree=2, compiler_workdir="compiler_workdir", # This is where you will find the hlo & neff @@ -172,7 +202,7 @@ def validate_against_pytorch_quantization(pytorch_quantized_cpu_model, nxd_quant prefix = key.split("_packed_params._packed_params")[0] assert torch.allclose( pytorch_quantized_cpu_model_sd[key][0].dequantize(), - dequantize( + scale_dequantize( nxd_quantized_cpu_model_sd[prefix + "weight"], nxd_quantized_cpu_model_sd[prefix + "scale"], torch.float32, @@ -183,20 +213,48 @@ def validate_against_pytorch_quantization(pytorch_quantized_cpu_model, nxd_quant print("Test successful for validate_against_pytorch_quantization") +def recreate_sharded_scales(traced_model_sd, scale_name, partition_dim): + tensors_to_gather = [] + for i in range(2): + tensors_to_gather.append(traced_model_sd[f"models.{i}.weights.{scale_name}"]) + recreated_scale = torch.cat(tensors_to_gather, axis=partition_dim) + return recreated_scale + + +def is_scalar_partitioned(scalar_tensor): + if scalar_tensor.shape == (1,) or max(scalar_tensor.shape) == dim: + return False + return True + + +def extract_partition_dim(scale_tensor): + scale_shape = scale_tensor.shape + for i, shape_dim in enumerate(scale_shape): + if shape_dim > 1: + return i + raise RuntimeError("scale is not really sharded") + + def validate_scales_in_nxd_model(nxd_quantized_cpu_model, traced_model): - traced_model_sd = traced_model.models[0].weights.state_dict() + traced_model_sd = traced_model.state_dict() + traced_model_sd_rank0 = traced_model.models[0].weights.state_dict() nxd_quantized_cpu_model_sd = nxd_quantized_cpu_model.state_dict() - for key, _ in traced_model_sd.items(): + for key, _ in traced_model_sd_rank0.items(): if "scale" in key: cpu_scale = nxd_quantized_cpu_model_sd[key.replace("->", ".")] - nxd_scale = traced_model_sd[key] - assert cpu_scale == nxd_scale + if not is_scalar_partitioned(traced_model_sd_rank0[key]): + nxd_scale = traced_model_sd_rank0[key] + else: + nxd_scale = recreate_sharded_scales( + traced_model_sd, key, extract_partition_dim(traced_model_sd_rank0[key]) + ) + assert torch.allclose(cpu_scale, nxd_scale) print("scale verification successful") -def main(): - model_fp32_int8, input_fp32, model_fp32, nxd_quantized_cpu_model = load_qunatize_model() - traced_model = load_traced_model(input_fp32=input_fp32) +def main(q_config): + model_fp32_int8, input_fp32, model_fp32, nxd_quantized_cpu_model = load_qunatize_model(q_config=q_config) + traced_model = load_traced_model(input_fp32=input_fp32, qconfig=q_config) # Validate the CPU version of our de-quant logic matches the pytorch dequant validate_against_pytorch_quantization( @@ -217,9 +275,10 @@ def main(): assert torch.allclose(cpu_result, nxd_result, atol=1e-2) # FP32 model result and NxD result - torch.allclose(fp_32_result, nxd_result, atol=1e-2) + atol = 1e-3 if q_config["quantization_type"] == QuantizationType.PER_CHANNEL_SYMMETRIC else 1e-2 + torch.allclose(fp_32_result, nxd_result, atol=atol) - print("Test successful for Quantized Layers") + print(f"Test successful for Quantized Layers with qconfig {q_config}") if os.path.exists("fp32_qint8_model.pt"): os.remove("fp32_qint8_model.pt") @@ -229,4 +288,18 @@ def main(): if __name__ == "__main__": - main() + try: + q_config = get_default_custom_qconfig_dict() + with ProcessPoolExecutor(max_workers=1) as executor: + future = executor.submit(main, q_config) + results = future.result() + except Exception: + print(traceback.format_exc()) + + try: + q_config = get_default_per_channel_custom_qconfig_dict() + with ProcessPoolExecutor(max_workers=1) as executor: + future = executor.submit(main, q_config) + results = future.result() + except Exception: + print(traceback.format_exc()) diff --git a/test/integration/zero1/test_zero1.py b/test/integration/zero1/test_zero1.py index 6f886c2..ed0bb16 100644 --- a/test/integration/zero1/test_zero1.py +++ b/test/integration/zero1/test_zero1.py @@ -28,11 +28,6 @@ def parse_args(): parser = argparse.ArgumentParser(add_help=False) - parser.add_argument( - "--test_json", - required=False, - help="input json listing the test spec for network to compile", - ) parser.add_argument("--s3_dir", required=False, help="location to upload all test artifacts") parser.add_argument( "--s3_bucket", @@ -40,27 +35,14 @@ def parse_args(): ) args, leftovers = parser.parse_known_args() S3_BUCKET_NAME = args.s3_bucket - with open(args.test_json, "r") as f: - test_dict = json.load(f) - return test_dict, S3_BUCKET_NAME, args + return S3_BUCKET_NAME, args -test_config, S3_BUCKET_NAME, args = parse_args() +S3_BUCKET_NAME, args = parse_args() results = {"inference_success": 1} - -def upload_to_s3(): - os.system(f'aws s3 cp --no-progress "{datetime_str}" {S3_BUCKET_NAME}') - print(met.metrics_report()) - - def on_exit(): - upload_to_s3() - for k in test_config: - os.system(f"rm {args.test_json}") - with open(args.test_json, "w") as f: - json.dump({k: results}, f) - + print(met.metrics_report()) def get_test_result(opt, use_pp, model_dtype, optimizer_dtype, grad_clipping, max_norm, pin_layout, coalesce_cc): seed = 1234 diff --git a/test/integration/zero1_dcp/test_zero1_dcp.py b/test/integration/zero1_dcp/test_zero1_dcp.py new file mode 100644 index 0000000..fc3293b --- /dev/null +++ b/test/integration/zero1_dcp/test_zero1_dcp.py @@ -0,0 +1,142 @@ +import argparse +import atexit +import json +import os +import random +from datetime import datetime +import shutil + +import numpy as np +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.debug.metrics as met + +from neuronx_distributed.optimizer import NeuronZero1Optimizer +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.utils import is_pjrt_device +from neuronx_distributed.optimizer.zero_dcp_utils import get_dcp_aux_infos, save_optim_state_dict, load_optim_state_dict +from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu + +datetime_str = str(datetime.now()) + + +def parse_args(): + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument( + "--test_json", + required=False, + help="input json listing the test spec for network to compile", + ) + parser.add_argument("--s3_dir", required=False, help="location to upload all test artifacts") + parser.add_argument( + "--s3_bucket", + default="s3://ktf-test-runs/neuronx_distributed_parallel_layers/parallel_state", + ) + args, leftovers = parser.parse_known_args() + S3_BUCKET_NAME = args.s3_bucket + with open(args.test_json, "r") as f: + test_dict = json.load(f) + return test_dict, S3_BUCKET_NAME, args + + +test_config, S3_BUCKET_NAME, args = parse_args() +results = {"inference_success": 1} + + +def upload_to_s3(): + os.system(f'aws s3 cp --no-progress "{datetime_str}" {S3_BUCKET_NAME}') + print(met.metrics_report()) + + +def on_exit(): + upload_to_s3() + for k in test_config: + os.system(f"rm {args.test_json}") + with open(args.test_json, "w") as f: + json.dump({k: results}, f) + + +class DummyModel(torch.nn.Module): + def __init__(self): + super().__init__() + # test pad/unpad + self.a = torch.nn.Parameter(torch.randn(10, 10)) + self.b = torch.nn.Parameter(torch.randn(15, 10)) + self.c = torch.nn.Parameter(torch.randn(16, 10)) + self.d = torch.nn.Parameter(torch.randn(20, 10)) + + +def test_zero1_dcp(): + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=8, + pipeline_model_parallel_size=1, + ) + + seed = 1234 + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + model = DummyModel() + model.to(device=xm.xla_device()) + for p in model.parameters(): + p.grad = torch.clone(p.data) / 100 + + optimizer = NeuronZero1Optimizer( + model.parameters(), + torch.optim.AdamW, + lr=0.01, + pin_layout=False, + sharding_groups=parallel_state.get_data_parallel_group(as_list=True), + grad_norm_groups=parallel_state.get_tensor_model_parallel_group(as_list=True), + max_norm=1.0, + grad_clipping=True, + ) + xm.mark_step() + + # run step once to get states inited + optimizer.step() + optimizer.zero_grad() + xm.mark_step() + + dcp_aux_infos = get_dcp_aux_infos(model, optimizer) + + if os.path.exists("ckpts"): + load_optim_state_dict("ckpts", optimizer, dcp_aux_infos, False) + xm.mark_step() + xm.rendezvous("sync load 1") + # test able to run step normally + optimizer.step() + optimizer.zero_grad() + xm.mark_step() + + s0 = optimizer.state_dict() + s0 = move_all_tensor_to_cpu(s0) + save_optim_state_dict("ckpts", s0, dcp_aux_infos, False) + xm.mark_step() + xm.rendezvous("sync save 1") + save_optim_state_dict("ckpts2", s0, dcp_aux_infos, False) + xm.mark_step() + xm.rendezvous("sync save 2") + s1 = optimizer.state_dict() + s1 = move_all_tensor_to_cpu(s1) + load_optim_state_dict("ckpts", optimizer, dcp_aux_infos, False) + xm.mark_step() + xm.rendezvous("sync load 2") + s2 = optimizer.state_dict() + s2 = move_all_tensor_to_cpu(s2) + torch.testing.assert_close(s1, s2, rtol=0, atol=0) + shutil.rmtree("ckpts", ignore_errors=True) + + +if __name__ == "__main__": + if is_pjrt_device(): + import torch_xla.experimental.pjrt_backend # noqa + + torch.distributed.init_process_group("xla", init_method="pjrt://") + else: + torch.distributed.init_process_group("xla") + + test_zero1_dcp() + atexit.register(on_exit) diff --git a/test/unit_test/__init__.py b/test/unit_test/__init__.py index 91082f6..41f6ad0 100644 --- a/test/unit_test/__init__.py +++ b/test/unit_test/__init__.py @@ -11,9 +11,6 @@ def parse_common_options(logdir=None, num_cores=None, num_workers=0, opts=None): parser.add_argument("--num_workers", type=int, default=num_workers) parser.add_argument("--metrics_debug", action="store_true") parser.add_argument("--async_closures", action="store_true") - parser.add_argument("--test_json", required=False, help="input json listing the test spec for network to compile") - parser.add_argument("--s3_dir", required=False, help="location to upload all test artifacts") - parser.add_argument("--s3_bucket", default="neuron-canary-nn-models") if opts: for name, aopts in opts: parser.add_argument(name, **aopts) @@ -25,15 +22,4 @@ def parse_common_options(logdir=None, num_cores=None, num_workers=0, opts=None): return args -def update_result(results): - data[test_name].update(results) - os.system(f"rm {FLAGS.test_json}") - with open(FLAGS.test_json, "w+") as file: - dump(data, file) - - FLAGS = parse_common_options() -with open(FLAGS.test_json) as file: - data = loads(file.read()) -test_name = next(iter(data)) -update_result({"inference_success": 1}) diff --git a/test/unit_test/checkpoint/test_checkpoint.py b/test/unit_test/checkpoint/test_checkpoint.py index ac38867..13259a2 100644 --- a/test/unit_test/checkpoint/test_checkpoint.py +++ b/test/unit_test/checkpoint/test_checkpoint.py @@ -1,7 +1,10 @@ # Standard Library import os +import time +import pytest import unittest from copy import deepcopy +from packaging import version from unittest.mock import MagicMock, patch # Third Party @@ -13,8 +16,6 @@ import neuronx_distributed as nxd -from .. import update_result - def get_model(): seq_len = 512 @@ -90,94 +91,245 @@ class TestCheckpoint(unittest.TestCase): @patch("neuronx_distributed.trainer.checkpoint.get_tensor_model_parallel_rank", MagicMock(return_value=1)) @patch("neuronx_distributed.trainer.checkpoint.get_pipeline_model_parallel_rank", MagicMock(return_value=1)) @patch("neuronx_distributed.trainer.checkpoint.get_local_world_size", MagicMock(return_value=32)) + @patch("neuronx_distributed.trainer.trainer.get_expert_model_parallel_size", MagicMock(return_value=1)) + @patch("neuronx_distributed.trainer.checkpoint.get_expert_model_parallel_size", MagicMock(return_value=1)) + @patch("neuronx_distributed.trainer.checkpoint.get_expert_model_parallel_rank", MagicMock(return_value=0)) + @patch("neuronx_distributed.trainer.checkpoint.get_expert_data_parallel_size", MagicMock(return_value=1)) + @patch("neuronx_distributed.trainer.checkpoint.get_expert_data_parallel_rank", MagicMock(return_value=0)) + @patch("neuronx_distributed.trainer.checkpoint.get_expert_model_parallel_group", MagicMock(return_value=None)) + @patch( + "neuronx_distributed.trainer.checkpoint.get_expert_data_parallel_group", + MagicMock(return_value=[[i] for i in range(64)]), + ) + @patch( + "neuronx_distributed.trainer.checkpoint.model_parallel_is_initialized", + MagicMock(return_value=True), + ) @patch("torch.distributed.get_rank", MagicMock(return_value=0)) def test_checkpoint(self): - try: - pipeline_cuts = [ - "transformer.h.1", - "transformer.h.2", - "transformer.h.3", - "transformer.h.4", - "transformer.h.5", - "transformer.h.6", - "transformer.h.7", - ] - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - pipeline_parallel_size=8, - pipeline_config={ - "transformer_layer_cls": GPT2Block, - "tracer_cls": "hf", - "num_microbatches": 1, - "output_loss_value_spec": True, - "input_names": ["input_ids", "attention_mask", "labels"], - "pipeline_cuts": pipeline_cuts, - "param_init_fn": None, - "leaf_module_cls": ["GPT2Block"], - "use_zero1_optimizer": True, - "use_optimizer_wrapper": True, - }, - optimizer_config={ - "zero_one_enabled": True, - "grad_clipping": True, - "max_grad_norm": 1.0, - }, - sequence_parallel=True, - activation_checkpoint_config="full", - ) - model = nxd.initialize_parallel_model(nxd_config, get_model) - optimizer = nxd.initialize_parallel_optimizer(nxd_config, torch.optim.AdamW, model.parameters(), lr=1e-3) - - model_state = deepcopy(model.state_dict()) - opt_state = deepcopy(optimizer.state_dict()) - - nxd.save_checkpoint( - "ckpts", - "unittest", - model=model, - optimizer=optimizer, - num_workers=8, - use_xser=True, - ) - - nxd.load_checkpoint( - "ckpts", - "unittest", - model=model, - optimizer=optimizer, - num_workers=8, - ) - - # test save load functionality - torch.testing.assert_close(model.state_dict(), model_state, rtol=0, atol=0) - torch.testing.assert_close(optimizer.state_dict(), opt_state, rtol=0, atol=0) - - # test able to be loaded by xla - xmodel_state = xser.load("ckpts/unittest/model/dp_rank_00_tp_rank_01_pp_rank_01.pt") - xmodel_state = xm.send_cpu_data_to_device(xmodel_state, xm.xla_device()) - torch.testing.assert_close(xmodel_state, model_state, rtol=0, atol=0) - - # check format - assert os.path.exists("ckpts/unittest/done") and os.path.isfile("ckpts/unittest/done") - assert os.path.isfile("ckpts/unittest/model/dp_rank_00_tp_rank_01_pp_rank_01.pt") - assert os.path.isfile("ckpts/unittest/optim/dp_rank_00_tp_rank_01_pp_rank_01.pt") - assert os.path.isdir("ckpts/unittest/model/dp_rank_00_tp_rank_01_pp_rank_01.pt.tensors") - assert os.path.isdir("ckpts/unittest/optim/dp_rank_00_tp_rank_01_pp_rank_01.pt.tensors") - - # test auto resume - nxd.load_checkpoint( - "ckpts", - tag=None, - model=model, - optimizer=optimizer, - num_workers=8, - ) - torch.testing.assert_close(model.state_dict(), model_state, rtol=0, atol=0) - torch.testing.assert_close(optimizer.state_dict(), opt_state, rtol=0, atol=0) - - except: - update_result({"inference_success": 0}) - raise + pipeline_cuts = [ + "transformer.h.1", + "transformer.h.2", + "transformer.h.3", + "transformer.h.4", + "transformer.h.5", + "transformer.h.6", + "transformer.h.7", + ] + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=8, + pipeline_parallel_size=8, + pipeline_config={ + "transformer_layer_cls": GPT2Block, + "tracer_cls": "hf", + "num_microbatches": 1, + "output_loss_value_spec": True, + "input_names": ["input_ids", "attention_mask", "labels"], + "pipeline_cuts": pipeline_cuts, + "param_init_fn": None, + "leaf_module_cls": ["GPT2Block"], + "use_zero1_optimizer": True, + "use_optimizer_wrapper": True, + }, + optimizer_config={ + "zero_one_enabled": True, + "grad_clipping": True, + "max_grad_norm": 1.0, + }, + sequence_parallel=True, + activation_checkpoint_config="full", + ) + model = nxd.initialize_parallel_model(nxd_config, get_model) + optimizer = nxd.initialize_parallel_optimizer(nxd_config, torch.optim.AdamW, model.parameters(), lr=1e-3) + + model_state = deepcopy(model.state_dict()) + opt_state = deepcopy(optimizer.state_dict()) + + nxd.save_checkpoint( + "ckpts", + "unittest", + model=model, + optimizer=optimizer, + num_workers=8, + use_xser=True, + ) + + nxd.load_checkpoint( + "ckpts", + "unittest", + model=model, + optimizer=optimizer, + num_workers=8, + ) + + # test save load functionality + torch.testing.assert_close(model.state_dict(), model_state, rtol=0, atol=0) + torch.testing.assert_close(optimizer.state_dict(), opt_state, rtol=0, atol=0) + + # test able to be loaded by xla + xmodel_state = xser.load("ckpts/unittest/model/dp_rank_00_tp_rank_01_pp_rank_01.pt") + xmodel_state = xm.send_cpu_data_to_device(xmodel_state, xm.xla_device()) + torch.testing.assert_close(xmodel_state, model_state, rtol=0, atol=0) + xopt_state = xser.load("ckpts/unittest/optim/dp_rank_00_tp_rank_01_pp_rank_01.pt") + xopt_state = xm.send_cpu_data_to_device(xopt_state, xm.xla_device()) + torch.testing.assert_close(xopt_state, opt_state, rtol=0, atol=0) + + # check format + assert os.path.exists("ckpts/unittest/done") and os.path.isfile("ckpts/unittest/done") + assert os.path.exists("ckpts/unittest/checkpoint") and os.path.isfile("ckpts/unittest/checkpoint") + assert os.path.isfile("ckpts/unittest/model/dp_rank_00_tp_rank_01_pp_rank_01.pt") + assert os.path.isfile("ckpts/unittest/optim/dp_rank_00_tp_rank_01_pp_rank_01.pt") + assert os.path.isdir("ckpts/unittest/model/dp_rank_00_tp_rank_01_pp_rank_01.pt.tensors") + assert os.path.isdir("ckpts/unittest/optim/dp_rank_00_tp_rank_01_pp_rank_01.pt.tensors") + + # test auto resume + nxd.load_checkpoint( + "ckpts", + tag=None, + model=model, + optimizer=optimizer, + num_workers=8, + ) + torch.testing.assert_close(model.state_dict(), model_state, rtol=0, atol=0) + torch.testing.assert_close(optimizer.state_dict(), opt_state, rtol=0, atol=0) + + @pytest.mark.skipif(not version.parse(torch.__version__) >= version.parse("2.1"), reason="skip this test if no DCP support") + @patch("torch.distributed.is_initialized", MagicMock(return_value=True)) + @patch("neuronx_distributed.pipeline.model.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) + @patch( + "neuronx_distributed.pipeline.model.parallel_state.model_parallel_is_initialized", MagicMock(return_value=True) + ) + @patch( + "neuronx_distributed.pipeline.model.parallel_state.get_pipeline_model_parallel_size", MagicMock(return_value=8) + ) + @patch( + "neuronx_distributed.pipeline.model.parallel_state.get_pipeline_model_parallel_rank", MagicMock(return_value=1) + ) + @patch( + "neuronx_distributed.pipeline.model.parallel_state.get_tensor_model_parallel_size", MagicMock(return_value=8) + ) + @patch( + "neuronx_distributed.pipeline.model.parallel_state.get_tensor_model_parallel_rank", MagicMock(return_value=1) + ) + @patch("neuronx_distributed.pipeline.partition.get_pipeline_model_parallel_rank", MagicMock(return_value=1)) + @patch("neuronx_distributed.pipeline.partition.get_pipeline_model_parallel_size", MagicMock(return_value=8)) + @patch("neuronx_distributed.pipeline.model.NxDPPModel._create_pg_with_ranks", MagicMock(return_value=None)) + @patch( + "neuronx_distributed.parallel_layers.parallel_state.get_data_parallel_group", + MagicMock(return_value=[[i] for i in range(64)]), + ) + @patch( + "neuronx_distributed.trainer.checkpoint.get_data_parallel_group", + MagicMock(return_value=[[i] for i in range(64)]), + ) + @patch( + "neuronx_distributed.parallel_layers.parallel_state.get_tensor_model_parallel_group", + MagicMock(return_value=None), + ) + @patch( + "neuronx_distributed.optimizer.zero_redundancy_optimizer.model_parallel_is_initialized", + MagicMock(return_value=True), + ) + @patch( + "neuronx_distributed.optimizer.zero_redundancy_optimizer.get_data_parallel_group", + MagicMock(return_value=[[i] for i in range(64)]), + ) + @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) + @patch("neuronx_distributed.trainer.checkpoint.get_tensor_model_parallel_rank", MagicMock(return_value=1)) + @patch("neuronx_distributed.trainer.checkpoint.get_pipeline_model_parallel_rank", MagicMock(return_value=1)) + @patch("neuronx_distributed.trainer.checkpoint.get_local_world_size", MagicMock(return_value=32)) + @patch("neuronx_distributed.trainer.trainer.get_expert_model_parallel_size", MagicMock(return_value=1)) + @patch("neuronx_distributed.trainer.checkpoint.get_expert_model_parallel_size", MagicMock(return_value=1)) + @patch("neuronx_distributed.trainer.checkpoint.get_expert_model_parallel_rank", MagicMock(return_value=0)) + @patch("neuronx_distributed.trainer.checkpoint.get_expert_data_parallel_size", MagicMock(return_value=1)) + @patch("neuronx_distributed.trainer.checkpoint.get_expert_data_parallel_rank", MagicMock(return_value=0)) + @patch("neuronx_distributed.trainer.checkpoint.get_expert_model_parallel_group", MagicMock(return_value=None)) + @patch( + "neuronx_distributed.trainer.checkpoint.get_expert_data_parallel_group", + MagicMock(return_value=[[i] for i in range(64)]), + ) + @patch( + "neuronx_distributed.trainer.checkpoint.model_parallel_is_initialized", + MagicMock(return_value=True), + ) + @patch("neuronx_distributed.optimizer.zero_dcp_utils.get_pipeline_model_parallel_rank", MagicMock(return_value=1)) + @patch("neuronx_distributed.optimizer.zero_dcp_utils.get_tensor_model_parallel_group", MagicMock(return_value=[[i] for i in range(8)])) + @patch("neuronx_distributed.optimizer.zero_dcp_utils.get_pipeline_model_parallel_group", MagicMock(return_value=[[i] for i in range(8)])) + @patch("neuronx_distributed.optimizer.zero_dcp_utils.get_data_parallel_group", MagicMock(return_value=[[i] for i in range(64)])) + @patch("torch.distributed.get_rank", MagicMock(return_value=0)) + @patch("torch.distributed.get_world_size", MagicMock(return_value=1)) + def test_checkpoint_dcp(self): + pipeline_cuts = [ + "transformer.h.1", + "transformer.h.2", + "transformer.h.3", + "transformer.h.4", + "transformer.h.5", + "transformer.h.6", + "transformer.h.7", + ] + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=8, + pipeline_parallel_size=8, + pipeline_config={ + "transformer_layer_cls": GPT2Block, + "tracer_cls": "hf", + "num_microbatches": 1, + "output_loss_value_spec": True, + "input_names": ["input_ids", "attention_mask", "labels"], + "pipeline_cuts": pipeline_cuts, + "param_init_fn": None, + "leaf_module_cls": ["GPT2Block"], + "use_zero1_optimizer": True, + "use_optimizer_wrapper": True, + }, + optimizer_config={ + "zero_one_enabled": True, + "grad_clipping": True, + "max_grad_norm": 1.0, + }, + sequence_parallel=True, + activation_checkpoint_config="full", + ) + model = nxd.initialize_parallel_model(nxd_config, get_model) + optimizer = nxd.initialize_parallel_optimizer(nxd_config, torch.optim.AdamW, model.parameters(), lr=1e-3) + + model_state = deepcopy(model.state_dict()) + opt_state = deepcopy(optimizer.state_dict()) + + nxd.save_checkpoint( + "ckpts", + "unittest_dcp", + model=model, + optimizer=optimizer, + num_workers=8, + use_xser=True, + async_save=True, + use_zero1_dcp=True, + ) + + time.sleep(5) + + nxd.load_checkpoint( + "ckpts", + "unittest_dcp", + model=model, + optimizer=optimizer, + num_workers=8, + use_zero1_dcp=True, + ) + + # test save load functionality + torch.testing.assert_close(model.state_dict(), model_state, rtol=0, atol=0) + torch.testing.assert_close(optimizer.state_dict(), opt_state, rtol=0, atol=0) + + # check format + assert os.path.exists("ckpts/unittest_dcp/done") and os.path.isfile("ckpts/unittest_dcp/done") + assert os.path.exists("ckpts/unittest_dcp/checkpoint") and os.path.isfile("ckpts/unittest_dcp/checkpoint") + assert os.path.isfile("ckpts/unittest_dcp/model/dp_rank_00_tp_rank_01_pp_rank_01.pt") + assert os.path.isdir("ckpts/unittest_dcp/model/dp_rank_00_tp_rank_01_pp_rank_01.pt.tensors") + assert os.path.isfile("ckpts/unittest_dcp/optim/.metadata") + assert os.path.isfile("ckpts/unittest_dcp/optim/__0_0.distcp") if __name__ == "__main__": diff --git a/test/unit_test/checkpoint/test_checkpoint_storage.py b/test/unit_test/checkpoint/test_checkpoint_storage.py new file mode 100644 index 0000000..cfa7780 --- /dev/null +++ b/test/unit_test/checkpoint/test_checkpoint_storage.py @@ -0,0 +1,21 @@ +import unittest +from unittest.mock import patch + +import neuronx_distributed as nxd + +MODULE = "neuronx_distributed.trainer.checkpoint_storage" + + +class S3CheckpointStorageTest(unittest.TestCase): + @patch(f"{MODULE}.boto3") + def test_uses_default_session(self, mock_boto3): + # Arrange + storage = nxd.trainer.checkpoint_storage.S3CheckpointStorage("s3://some_bucket/some_dir") + # Act + resource = storage.get_resource() + # Assert + self.assertEqual(resource, mock_boto3._get_default_session.return_value.resource.return_value) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/checkpoint/test_dcp_functionality.py b/test/unit_test/checkpoint/test_dcp_functionality.py new file mode 100644 index 0000000..cd8d4b5 --- /dev/null +++ b/test/unit_test/checkpoint/test_dcp_functionality.py @@ -0,0 +1,177 @@ +import torch +import pytest +from packaging import version +if not version.parse(torch.__version__) >= version.parse("2.1"): + pytest.skip("skip this test if no DCP support", allow_module_level=True) + +# Standard Library +import os +import unittest +from copy import deepcopy +from unittest.mock import MagicMock, patch + +# Third Party +import torch_xla.core.xla_model as xm +import torch_xla.utils.serialization as xser +from transformers import AutoModelForCausalLM, GPT2Config +from transformers.models.gpt2.modeling_gpt2 import GPT2Block + +import neuronx_distributed as nxd +from neuronx_distributed.optimizer.zero_dcp_utils import ( + _get_optim_pid_to_params, + _get_optim_pid_to_param_names, + _get_param_to_param_names, + _wrap_optim_state_dict, + _unwrap_optim_state_dict, + get_dcp_aux_infos, +) + + +def get_model(): + seq_len = 512 + model_config = GPT2Config( + vocab_size=50257, + n_positions=seq_len, + n_embd=768, + n_layer=8, + n_head=12, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0, + layer_norm_epsilon=1e-05, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.0, + use_cache=False, + bos_token_id=50256, + eos_token_id=50256, + return_dict=False, + ) + model = AutoModelForCausalLM.from_config(model_config) + return model + + +class DCPFunctionalityTest(unittest.TestCase): + @patch("torch.distributed.is_initialized", MagicMock(return_value=True)) + @patch("neuronx_distributed.pipeline.model.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) + @patch( + "neuronx_distributed.pipeline.model.parallel_state.model_parallel_is_initialized", MagicMock(return_value=True) + ) + @patch( + "neuronx_distributed.pipeline.model.parallel_state.get_pipeline_model_parallel_size", MagicMock(return_value=8) + ) + @patch( + "neuronx_distributed.pipeline.model.parallel_state.get_pipeline_model_parallel_rank", MagicMock(return_value=1) + ) + @patch( + "neuronx_distributed.pipeline.model.parallel_state.get_tensor_model_parallel_size", MagicMock(return_value=8) + ) + @patch( + "neuronx_distributed.pipeline.model.parallel_state.get_tensor_model_parallel_rank", MagicMock(return_value=1) + ) + @patch("neuronx_distributed.pipeline.partition.get_pipeline_model_parallel_rank", MagicMock(return_value=1)) + @patch("neuronx_distributed.pipeline.partition.get_pipeline_model_parallel_size", MagicMock(return_value=8)) + @patch("neuronx_distributed.pipeline.model.NxDPPModel._create_pg_with_ranks", MagicMock(return_value=None)) + @patch( + "neuronx_distributed.parallel_layers.parallel_state.get_data_parallel_group", + MagicMock(return_value=[[i] for i in range(64)]), + ) + @patch( + "neuronx_distributed.trainer.checkpoint.get_data_parallel_group", + MagicMock(return_value=[[i] for i in range(64)]), + ) + @patch( + "neuronx_distributed.parallel_layers.parallel_state.get_tensor_model_parallel_group", + MagicMock(return_value=None), + ) + @patch( + "neuronx_distributed.optimizer.zero_redundancy_optimizer.model_parallel_is_initialized", + MagicMock(return_value=True), + ) + @patch( + "neuronx_distributed.optimizer.zero_redundancy_optimizer.get_data_parallel_group", + MagicMock(return_value=[[i] for i in range(64)]), + ) + @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) + @patch("neuronx_distributed.trainer.checkpoint.get_tensor_model_parallel_rank", MagicMock(return_value=1)) + @patch("neuronx_distributed.trainer.checkpoint.get_pipeline_model_parallel_rank", MagicMock(return_value=1)) + @patch("neuronx_distributed.trainer.checkpoint.get_local_world_size", MagicMock(return_value=32)) + @patch("neuronx_distributed.trainer.trainer.get_expert_model_parallel_size", MagicMock(return_value=1)) + @patch("neuronx_distributed.trainer.checkpoint.get_expert_model_parallel_size", MagicMock(return_value=1)) + @patch("neuronx_distributed.trainer.checkpoint.get_expert_model_parallel_rank", MagicMock(return_value=0)) + @patch("neuronx_distributed.trainer.checkpoint.get_expert_data_parallel_size", MagicMock(return_value=1)) + @patch("neuronx_distributed.trainer.checkpoint.get_expert_data_parallel_rank", MagicMock(return_value=0)) + @patch("neuronx_distributed.trainer.checkpoint.get_expert_model_parallel_group", MagicMock(return_value=None)) + @patch( + "neuronx_distributed.trainer.checkpoint.get_expert_data_parallel_group", + MagicMock(return_value=[[i] for i in range(64)]), + ) + @patch( + "neuronx_distributed.trainer.checkpoint.model_parallel_is_initialized", + MagicMock(return_value=True), + ) + @patch("neuronx_distributed.optimizer.zero_dcp_utils.get_pipeline_model_parallel_rank", MagicMock(return_value=1)) + @patch("neuronx_distributed.optimizer.zero_dcp_utils.get_tensor_model_parallel_group", MagicMock(return_value=[[i] for i in range(8)])) + @patch("neuronx_distributed.optimizer.zero_dcp_utils.get_pipeline_model_parallel_group", MagicMock(return_value=[[i] for i in range(8)])) + @patch("neuronx_distributed.optimizer.zero_dcp_utils.get_data_parallel_group", MagicMock(return_value=[[i] for i in range(64)])) + @patch("torch.distributed.get_rank", MagicMock(return_value=0)) + @patch("torch.distributed.get_world_size", MagicMock(return_value=1)) + def test_checkpoint_dcp(self): + pipeline_cuts = [ + "transformer.h.1", + "transformer.h.2", + "transformer.h.3", + "transformer.h.4", + "transformer.h.5", + "transformer.h.6", + "transformer.h.7", + ] + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=8, + pipeline_parallel_size=8, + pipeline_config={ + "transformer_layer_cls": GPT2Block, + "tracer_cls": "hf", + "num_microbatches": 1, + "output_loss_value_spec": True, + "input_names": ["input_ids", "attention_mask", "labels"], + "pipeline_cuts": pipeline_cuts, + "param_init_fn": None, + "leaf_module_cls": ["GPT2Block"], + "use_zero1_optimizer": True, + "use_optimizer_wrapper": True, + }, + optimizer_config={ + "zero_one_enabled": True, + "grad_clipping": True, + "max_grad_norm": 1.0, + }, + sequence_parallel=True, + activation_checkpoint_config="full", + ) + model = nxd.initialize_parallel_model(nxd_config, get_model) + optimizer = nxd.initialize_parallel_optimizer(nxd_config, torch.optim.AdamW, model.parameters(), lr=1e-3) + + for k, v in _get_optim_pid_to_params(optimizer).items(): + assert isinstance(k, int) + assert isinstance(v, torch.nn.Parameter) + for k, v in _get_param_to_param_names(model).items(): + assert isinstance(k, torch.nn.Parameter) + assert isinstance(v, str) and v.startswith("transformer") + for k, v in _get_optim_pid_to_param_names(model, optimizer).items(): + assert isinstance(k, int) + assert isinstance(v, str) and v.startswith("transformer") + + aux_infos = get_dcp_aux_infos(model, optimizer) + wrapped_state_dict = _wrap_optim_state_dict(optimizer.state_dict(), aux_infos) + state_dict = _unwrap_optim_state_dict(wrapped_state_dict, aux_infos) + torch.testing.assert_close(state_dict, optimizer.state_dict(), rtol=0, atol=0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/lightning/__init__.py b/test/unit_test/lightning/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/unit_test/lightning/test_checkpoint_io.py b/test/unit_test/lightning/test_checkpoint_io.py new file mode 100644 index 0000000..ce0dccb --- /dev/null +++ b/test/unit_test/lightning/test_checkpoint_io.py @@ -0,0 +1,34 @@ +import unittest +from unittest.mock import patch + +from neuronx_distributed.lightning.checkpoint_io import NeuronCheckpointIO + +MODULE = "neuronx_distributed.lightning.checkpoint_io" + + +class NeuronCheckpointIOTest(unittest.TestCase): + @patch(f"{MODULE}.load") + def test_load(self, mock_load): + # Arrange + io = NeuronCheckpointIO() + # Act + io.load_checkpoint(path := "some_path", master_dp_only := True) + # Assert + mock_load.assert_called_once_with( + chkpt_path=path, load_xser=True, master_dp_only=master_dp_only, weights_only=False + ) + + @patch(f"{MODULE}.load") + def test_load_weights_only(self, mock_load): + # Arrange + io = NeuronCheckpointIO(weights_only=True) + # Act + io.load_checkpoint(path := "some_path", master_dp_only := True) + # Assert + mock_load.assert_called_once_with( + chkpt_path=path, load_xser=True, master_dp_only=master_dp_only, weights_only=True + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/modules/lora/test_lora_layers.py b/test/unit_test/modules/lora/test_lora_layers.py new file mode 100644 index 0000000..93548e2 --- /dev/null +++ b/test/unit_test/modules/lora/test_lora_layers.py @@ -0,0 +1,92 @@ +# Standard Library +import unittest +from unittest.mock import MagicMock, patch + +import torch +from neuronx_distributed.parallel_layers import ( + ColumnParallelLinear, + RowParallelLinear, +) + +from neuronx_distributed.modules.lora import LoraConfig +from neuronx_distributed.modules.lora.layer import ( + LoraLinear, + LoraConv2d, + LoraEmbedding, +) + +from neuronx_distributed.modules.lora.tp_layer import ( + LoraParallelLinear, +) + + +def get_lora_config(use_rslora=False, init_lora_weights="default"): + return LoraConfig( + enable_lora=True, + lora_rank=8, + lora_alpha=16, + lora_dropout=0.05, + use_rslora=use_rslora, + lora_verbose=False, + init_lora_weights=init_lora_weights, + ) + + +class TestLoraLayers(unittest.TestCase): + def test_torch_linear_layer(self): + layer = torch.nn.Linear(32, 32) + rslora_modes = [False, True] + init_weights_modes = ["default", "gaussian"] + + for rslora in rslora_modes: + for init_mode in init_weights_modes: + lora_config = get_lora_config(use_rslora=rslora, init_lora_weights=init_mode) + lora_layer = LoraLinear(layer, lora_config) + layer_str = str(lora_layer) + assert "lora" in layer_str + + def test_torch_conv2d_layer(self): + layer = torch.nn.Conv2d(32, 32, 2) + rslora_modes = [False, True] + init_weights_modes = ["default", "gaussian"] + + for rslora in rslora_modes: + for init_mode in init_weights_modes: + lora_config = get_lora_config(use_rslora=rslora, init_lora_weights=init_mode) + lora_layer = LoraConv2d(layer, lora_config) + layer_str = str(lora_layer) + assert "lora" in layer_str + + def test_torch_embedding_layer(self): + layer = torch.nn.Embedding(32, 32) + rslora_modes = [False, True] + init_weights_modes = ["default", "gaussian"] + + for rslora in rslora_modes: + for init_mode in init_weights_modes: + lora_config = get_lora_config(use_rslora=rslora, init_lora_weights=init_mode) + lora_layer = LoraEmbedding(layer, lora_config) + layer_str = str(lora_layer) + assert "lora" in layer_str + + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_size", MagicMock(return_value=8)) + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_rank", MagicMock(return_value=1)) + @patch("neuronx_distributed.parallel_layers.parallel_state.initialize_model_parallel", MagicMock(return_value=True)) + @patch("neuronx_distributed.parallel_layers.parallel_state.model_parallel_is_initialized", MagicMock(return_value=True)) + @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=8)) + def test_tp_linear_layers(self): + layers = [ColumnParallelLinear(32, 32), RowParallelLinear(32, 32)] + rslora_modes = [False, True] + init_weights_modes = ["default", "gaussian"] + + for layer in layers: + for rslora in rslora_modes: + for init_mode in init_weights_modes: + lora_config = get_lora_config(use_rslora=rslora, init_lora_weights=init_mode) + lora_layer = LoraParallelLinear(layer, lora_config) + layer_str = str(lora_layer) + assert "lora" in layer_str + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/modules/lora/test_lora_models.py b/test/unit_test/modules/lora/test_lora_models.py new file mode 100644 index 0000000..ff8116f --- /dev/null +++ b/test/unit_test/modules/lora/test_lora_models.py @@ -0,0 +1,83 @@ +# Standard Library +import unittest +from unittest.mock import MagicMock, patch + +# Third Party +import torch +from neuronx_distributed.parallel_layers import ( + ColumnParallelLinear, + RowParallelLinear, +) + +from neuronx_distributed.modules.lora import LoraConfig, LoraModel + + +class NxDModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.rpl = ColumnParallelLinear(32, 32) + self.cpl = RowParallelLinear(32, 32) + self.linear = torch.nn.Linear(32, 32) + + +class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(32, 32) + self.conv2d = torch.nn.Conv2d(32, 32, 4) + self.embedding = torch.nn.Embedding(32, 32) + + +def get_nxd_lora_config(bias="none"): + return LoraConfig( + enable_lora=True, + lora_rank=8, + lora_alpha=16, + lora_dropout=0.05, + lora_verbose=False, + bias=bias, + target_modules=["rpl", "cpl"], + ) + + +def get_lora_config(bias="none"): + return LoraConfig( + enable_lora=True, + lora_rank=8, + lora_alpha=16, + lora_dropout=0.05, + bias=bias, + lora_verbose=True, + target_modules=["linear", "conv2d", "embedding"], + ) + + +class TestLoraModels(unittest.TestCase): + def test_torch_model(self): + bias_modes = ["none", "all", "lora_only"] + for mode in bias_modes: + model = Module() + lora_config=get_lora_config(bias=mode) + model = LoraModel(model, lora_config) + assert isinstance(model, LoraModel) + model_str = str(model) + assert "LoraModel" in model_str + + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_size", MagicMock(return_value=8)) + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_rank", MagicMock(return_value=1)) + @patch("neuronx_distributed.parallel_layers.parallel_state.initialize_model_parallel", MagicMock(return_value=True)) + @patch("neuronx_distributed.parallel_layers.parallel_state.model_parallel_is_initialized", MagicMock(return_value=True)) + @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=8)) + def test_nxd_model(self): + bias_modes = ["none", "all", "lora_only"] + for mode in bias_modes: + model = NxDModule() + lora_config=get_nxd_lora_config(bias=mode) + model = LoraModel(model, lora_config) + assert isinstance(model, LoraModel) + model_str = str(model) + assert "LoraModel" in model_str + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/modules/lora/test_model_config.py b/test/unit_test/modules/lora/test_model_config.py deleted file mode 100644 index c69480b..0000000 --- a/test/unit_test/modules/lora/test_model_config.py +++ /dev/null @@ -1,213 +0,0 @@ -# Standard Library -import unittest -from unittest.mock import MagicMock, patch - -# Third Party -import torch -from transformers import AutoModelForCausalLM, BertConfig, GPT2Config, LlamaConfig - -import neuronx_distributed as nxd -from neuronx_distributed.modules.lora import LoraConfig, LoraModel - -from ... import update_result - - -def get_gpt2_model(): - seq_len = 512 - model_config = GPT2Config( - vocab_size=50257, - n_positions=seq_len, - n_embd=768, - n_layer=4, - n_head=12, - n_inner=None, - activation_function="gelu_new", - resid_pdrop=0.0, - embd_pdrop=0.0, - attn_pdrop=0.0, - layer_norm_epsilon=1e-05, - initializer_range=0.02, - summary_type="cls_index", - summary_use_proj=True, - summary_activation=None, - summary_proj_to_labels=True, - summary_first_dropout=0.0, - use_cache=False, - bos_token_id=50256, - eos_token_id=50256, - return_dict=False, - ) - model = AutoModelForCausalLM.from_config(model_config) - return model - - -def get_bert_model(): - model_config = BertConfig( - vocab_size=30522, - hidden_size=768, - num_hidden_layers=4, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, - layer_norm_eps=1e-12, - pad_token_id=0, - position_embedding_type="absolute", - use_cache=True, - classifier_dropout=None, - ) - model = AutoModelForCausalLM.from_config(model_config) - return model - - -def get_llama_model(): - model_config = LlamaConfig( - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=4, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=None, - bos_token_id=1, - eos_token_id=2, - pretraining_tp=1, - tie_word_embeddings=False, - rope_theta=10000.0, - rope_scaling=None, - attention_bias=False, - attention_dropout=0.0, - ) - model = AutoModelForCausalLM.from_config(model_config) - return model - - -def get_lora_config(target_modules): - lora_config = LoraConfig( - enable_lora=True, - lora_rank=16, - lora_alpha=32, - lora_dropout=0.05, - bias="none", - lora_verbose=True, - target_modules=target_modules, - ) - return lora_config - - -class TestModelWrapper(unittest.TestCase): - @patch("neuronx_distributed.parallel_layers.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) - @patch( - "neuronx_distributed.parallel_layers.parallel_state.model_parallel_is_initialized", MagicMock(return_value=True) - ) - @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) - @patch("torch.distributed.get_rank") - def test_gpt2_model_wrapper(self, rank_mock): - try: - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - optimizer_config={ - "zero_one_enabled": True, - "grad_clipping": True, - "max_grad_norm": 1.0, - }, - sequence_parallel=True, - activation_checkpoint_config="full", - lora_config=get_lora_config(target_modules=["c_attn"]), - ) - model = nxd.initialize_parallel_model(nxd_config, get_gpt2_model) - - assert isinstance(model, nxd.trainer.model.NxDModel) - assert model.nxd_config == nxd_config - assert not model.pp_enabled - assert model.dtype == torch.float32 - assert isinstance(model.module, LoraModel) - model_str = str(model) - assert "NxDPPModel" not in model_str - assert "NxDCheckpointWrapper" in model_str - assert "LoraModel" in model_str - - except: - update_result({"inference_success": 0}) - raise - - @patch("neuronx_distributed.parallel_layers.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) - @patch( - "neuronx_distributed.parallel_layers.parallel_state.model_parallel_is_initialized", MagicMock(return_value=True) - ) - @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) - @patch("torch.distributed.get_rank") - def test_bert_model_wrapper(self, rank_mock): - try: - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - optimizer_config={ - "zero_one_enabled": True, - "grad_clipping": True, - "max_grad_norm": 1.0, - }, - sequence_parallel=True, - lora_config=get_lora_config(target_modules=["query", "value"]), - ) - model = nxd.initialize_parallel_model(nxd_config, get_bert_model) - - assert isinstance(model, nxd.trainer.model.NxDModel) - assert model.nxd_config == nxd_config - assert not model.pp_enabled - assert model.dtype == torch.float32 - assert isinstance(model.module, LoraModel) - model_str = str(model) - assert "NxDPPModel" not in model_str - assert "LoraModel" in model_str - - except: - update_result({"inference_success": 0}) - raise - - @patch("neuronx_distributed.parallel_layers.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) - @patch( - "neuronx_distributed.parallel_layers.parallel_state.model_parallel_is_initialized", MagicMock(return_value=True) - ) - @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) - @patch("torch.distributed.get_rank") - def test_llama_model_wrapper(self, rank_mock): - try: - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - optimizer_config={ - "zero_one_enabled": True, - "grad_clipping": True, - "max_grad_norm": 1.0, - }, - sequence_parallel=True, - activation_checkpoint_config="full", - lora_config=get_lora_config(target_modules=["q_proj", "v_proj"]), - ) - model = nxd.initialize_parallel_model(nxd_config, get_llama_model) - - assert isinstance(model, nxd.trainer.model.NxDModel) - assert model.nxd_config == nxd_config - assert not model.pp_enabled - assert model.dtype == torch.float32 - assert isinstance(model.module, LoraModel) - model_str = str(model) - assert "NxDPPModel" not in model_str - assert "NxDCheckpointWrapper" in model_str - assert "LoraModel" in model_str - - except: - update_result({"inference_success": 0}) - raise - - -if __name__ == "__main__": - unittest.main() diff --git a/test/unit_test/modules/lora/test_model_wrapper.py b/test/unit_test/modules/lora/test_model_wrapper.py index c51cecb..30fe80a 100644 --- a/test/unit_test/modules/lora/test_model_wrapper.py +++ b/test/unit_test/modules/lora/test_model_wrapper.py @@ -4,302 +4,224 @@ # Third Party import torch -from transformers import AutoModelForCausalLM, GPT2Config -from transformers.models.gpt2.modeling_gpt2 import GPT2Block - +from neuronx_distributed.parallel_layers import ( + ColumnParallelLinear, + RowParallelLinear, +) import neuronx_distributed as nxd from neuronx_distributed.modules.lora import LoraConfig, LoraModel, get_lora_model +from neuronx_distributed.pipeline.model import NxDPPModel -from ... import update_result +class NxDModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.rpl = ColumnParallelLinear(32, 32) + self.cpl = RowParallelLinear(32, 32) + self.linear = torch.nn.Linear(32, 32) -def get_model(): - seq_len = 512 - model_config = GPT2Config( - vocab_size=50257, - n_positions=seq_len, - n_embd=768, - n_layer=8, - n_head=12, - n_inner=None, - activation_function="gelu_new", - resid_pdrop=0.0, - embd_pdrop=0.0, - attn_pdrop=0.0, - layer_norm_epsilon=1e-05, - initializer_range=0.02, - summary_type="cls_index", - summary_use_proj=True, - summary_activation=None, - summary_proj_to_labels=True, - summary_first_dropout=0.0, - use_cache=False, - bos_token_id=50256, - eos_token_id=50256, - return_dict=False, - ) - model = AutoModelForCausalLM.from_config(model_config) - return model +def get_nxd_model(): + return NxDModule() -def get_lora_config(): - lora_config = LoraConfig( + +class NxDPPModule(torch.nn.Module): + def __init__(self, num_layers): + super().__init__() + self.rpl = RowParallelLinear(10, 10) + self.cpl = ColumnParallelLinear(10, 10) + self.layers = torch.nn.ModuleList([torch.nn.Linear(2, 2) for _ in range(num_layers)]) + + +def get_pp_model(num_layers=4): + return NxDPPModule(num_layers) + + +class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(32, 32) + self.conv2d = torch.nn.Conv2d(32, 32, 4) + self.embedding = torch.nn.Embedding(32, 32) + + +def get_nxd_lora_config(): + return LoraConfig( enable_lora=True, - lora_rank=16, - lora_alpha=32, + lora_rank=8, + lora_alpha=16, lora_dropout=0.05, - bias="none", - lora_verbose=False, - target_modules=["c_attn"], + target_modules=["rpl", "cpl"], ) - return lora_config +def get_lora_config(): + return LoraConfig( + enable_lora=True, + lora_rank=8, + lora_alpha=16, + lora_dropout=0.05, + target_modules=["linear", "conv2d", "embedding"], + ) class TestModelWrapper(unittest.TestCase): def test_model_wrapper_single_device(self): - try: - model = get_model() - lora_config = get_lora_config() - lora_model = LoraModel(model, lora_config) - assert isinstance(lora_model, LoraModel) - assert lora_model.lora_config == lora_config - assert type(lora_model.get_base_model()) == type(model) - except: - update_result({"inference_success": 0}) - raise + model = Module() + lora_config = get_lora_config() + lora_model = LoraModel(model, lora_config) + assert isinstance(lora_model, LoraModel) + assert lora_model.lora_config == lora_config + assert type(lora_model.get_base_model()) is type(model) def test_unified_model_wrapper_single_device(self): - try: - model = get_model() - lora_config = get_lora_config() - lora_model = get_lora_model(model, lora_config) - assert isinstance(lora_model, LoraModel) - - assert lora_model.lora_config == lora_config - assert type(lora_model.get_base_model()) == type(model) - except: - update_result({"inference_success": 0}) - raise - + model = Module() + lora_config = get_lora_config() + lora_model = get_lora_model(model, lora_config) + assert isinstance(lora_model, LoraModel) + + assert lora_model.lora_config == lora_config + assert type(lora_model.get_base_model()) is type(model) + + @patch("neuronx_distributed.parallel_layers.layers._initialize_parameter_cpu", MagicMock(return_value=None)) + @patch("neuronx_distributed.parallel_layers.layers._initialize_affine_weight_neuron", MagicMock(return_value=None)) + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_size", MagicMock(return_value=8)) + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_rank", MagicMock(return_value=1)) + @patch("neuronx_distributed.parallel_layers.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) + @patch( + "neuronx_distributed.parallel_layers.parallel_state.model_parallel_is_initialized", MagicMock(return_value=True) + ) + @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) + @patch("neuronx_distributed.utils.model_utils.move_model_to_device", MagicMock(return_value=None)) + @patch("neuronx_distributed.parallel_layers.move_model_to_device", MagicMock(return_value=None)) + def test_model_wrapper(self): + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=8, + lora_config=get_lora_config(), + ) + model = nxd.initialize_parallel_model(nxd_config, get_nxd_model) + + assert isinstance(model, nxd.trainer.model.NxDModel) + assert model.nxd_config == nxd_config + assert isinstance(model.module, LoraModel) + model_str = str(model) + assert "LoraModel" in model_str + + @patch("neuronx_distributed.parallel_layers.layers._initialize_parameter_cpu", MagicMock(return_value=None)) + @patch("neuronx_distributed.parallel_layers.layers._initialize_affine_weight_neuron", MagicMock(return_value=None)) + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_size", MagicMock(return_value=8)) + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_rank", MagicMock(return_value=1)) + @patch("neuronx_distributed.parallel_layers.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) + @patch( + "neuronx_distributed.parallel_layers.parallel_state.model_parallel_is_initialized", MagicMock(return_value=True) + ) + @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) + @patch("neuronx_distributed.utils.model_utils.move_model_to_device", MagicMock(return_value=None)) + @patch("neuronx_distributed.parallel_layers.move_model_to_device", MagicMock(return_value=None)) + def test_unified_model_wrapper(self): + lora_config = get_nxd_lora_config() + model = NxDModule() + model = get_lora_model(model, lora_config) + + assert isinstance(model, LoraModel) + model_str = str(model) + assert "LoraModel" in model_str + + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_size", MagicMock(return_value=1)) + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_rank", MagicMock(return_value=1)) + @patch("neuronx_distributed.parallel_layers.layers._initialize_parameter_cpu", MagicMock(return_value=None)) + @patch("neuronx_distributed.parallel_layers.layers._initialize_affine_weight_neuron", MagicMock(return_value=None)) @patch("neuronx_distributed.pipeline.model.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) @patch( "neuronx_distributed.pipeline.model.parallel_state.model_parallel_is_initialized", MagicMock(return_value=True) ) @patch( - "neuronx_distributed.pipeline.model.parallel_state.get_pipeline_model_parallel_size", MagicMock(return_value=8) + "neuronx_distributed.pipeline.model.parallel_state.get_pipeline_model_parallel_size", MagicMock(return_value=2) ) @patch( "neuronx_distributed.pipeline.model.parallel_state.get_pipeline_model_parallel_rank", MagicMock(return_value=1) ) @patch( - "neuronx_distributed.pipeline.model.parallel_state.get_tensor_model_parallel_size", MagicMock(return_value=8) + "neuronx_distributed.pipeline.model.parallel_state.get_tensor_model_parallel_size", MagicMock(return_value=2) ) @patch( - "neuronx_distributed.pipeline.model.parallel_state.get_tensor_model_parallel_rank", MagicMock(return_value=1) + "neuronx_distributed.pipeline.model.parallel_state.get_tensor_model_parallel_rank", MagicMock(return_value=0) ) @patch("neuronx_distributed.pipeline.partition.get_pipeline_model_parallel_rank", MagicMock(return_value=1)) - @patch("neuronx_distributed.pipeline.partition.get_pipeline_model_parallel_size", MagicMock(return_value=8)) + @patch("neuronx_distributed.pipeline.partition.get_pipeline_model_parallel_size", MagicMock(return_value=2)) @patch("neuronx_distributed.pipeline.model.NxDPPModel._create_pg_with_ranks", MagicMock(return_value=None)) @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) - @patch("torch.distributed.get_rank") - def test_model_wrapper(self, rank_mock): - try: - pipeline_cuts = [ - "transformer.h.1", - "transformer.h.2", - "transformer.h.3", - "transformer.h.4", - "transformer.h.5", - "transformer.h.6", - "transformer.h.7", - ] - - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - pipeline_parallel_size=8, - pipeline_config={ - "transformer_layer_cls": GPT2Block, - "tracer_cls": "hf", - "num_microbatches": 1, - "output_loss_value_spec": True, - "input_names": ["input_ids", "attention_mask", "labels"], - "pipeline_cuts": pipeline_cuts, - "param_init_fn": None, - "leaf_module_cls": ["GPT2Block"], - "use_zero1_optimizer": True, - "use_optimizer_wrapper": True, - }, - optimizer_config={ - "zero_one_enabled": True, - "grad_clipping": True, - "max_grad_norm": 1.0, - }, - sequence_parallel=True, - activation_checkpoint_config="full", - lora_config=get_lora_config(), - ) - model = nxd.initialize_parallel_model(nxd_config, get_model) - - assert isinstance(model, nxd.trainer.model.NxDModel) - assert model.nxd_config == nxd_config - assert model.pp_enabled - assert model.dtype == torch.float32 - assert isinstance(model.module, LoraModel) - assert isinstance(model.module.get_base_model(), nxd.pipeline.NxDPPModel) - # assert isinstance(model.module, nxd.pipeline.NxDPPModel) - # assert isinstance(model.module.original_torch_module, LoraModel) - model_str = str(model) - assert "NxDPPModel" in model_str - assert "NxDCheckpointWrapper" in model_str - assert "LoraModel" in model_str - - except: - update_result({"inference_success": 0}) - raise - + @patch("neuronx_distributed.utils.model_utils.move_model_to_device", MagicMock(return_value=None)) + @patch("neuronx_distributed.parallel_layers.move_model_to_device", MagicMock(return_value=None)) + def test_pp_model_wrapper(self): + pipeline_cuts = [ + "layers.1", + ] + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=2, + pipeline_parallel_size=2, + pipeline_config={ + "transformer_layer_cls": torch.nn.Linear, + "tracer_cls": "torch", + "pipeline_cuts": pipeline_cuts, + "param_init_fn": None, + "use_zero1_optimizer": True, + "use_optimizer_wrapper": True, + "input_names": ["input_ids", "attention_mask", "labels"], + }, + optimizer_config={ + "zero_one_enabled": True, + "grad_clipping": True, + "max_grad_norm": 1.0, + }, + sequence_parallel=True, + activation_checkpoint_config="full", + lora_config=get_nxd_lora_config(), + ) + model = nxd.initialize_parallel_model(nxd_config, get_pp_model) + + assert isinstance(model, nxd.trainer.model.NxDModel) + assert model.nxd_config == nxd_config + assert model.pp_enabled + assert isinstance(model.module, LoraModel) + assert isinstance(model.module.get_base_model(), nxd.pipeline.NxDPPModel) + model_str = str(model) + assert "NxDModel" in model_str + assert "NxDPPModel" in model_str + assert "LoraModel" in model_str + + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_size", MagicMock(return_value=1)) + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_rank", MagicMock(return_value=1)) + @patch("neuronx_distributed.parallel_layers.layers._initialize_parameter_cpu", MagicMock(return_value=None)) + @patch("neuronx_distributed.parallel_layers.layers._initialize_affine_weight_neuron", MagicMock(return_value=None)) @patch("neuronx_distributed.pipeline.model.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) @patch( "neuronx_distributed.pipeline.model.parallel_state.model_parallel_is_initialized", MagicMock(return_value=True) ) @patch( - "neuronx_distributed.pipeline.model.parallel_state.get_pipeline_model_parallel_size", MagicMock(return_value=8) + "neuronx_distributed.pipeline.model.parallel_state.get_pipeline_model_parallel_size", MagicMock(return_value=2) ) @patch( "neuronx_distributed.pipeline.model.parallel_state.get_pipeline_model_parallel_rank", MagicMock(return_value=1) ) @patch( - "neuronx_distributed.pipeline.model.parallel_state.get_tensor_model_parallel_size", MagicMock(return_value=8) + "neuronx_distributed.pipeline.model.parallel_state.get_tensor_model_parallel_size", MagicMock(return_value=2) ) @patch( "neuronx_distributed.pipeline.model.parallel_state.get_tensor_model_parallel_rank", MagicMock(return_value=1) ) @patch("neuronx_distributed.pipeline.partition.get_pipeline_model_parallel_rank", MagicMock(return_value=1)) - @patch("neuronx_distributed.pipeline.partition.get_pipeline_model_parallel_size", MagicMock(return_value=8)) + @patch("neuronx_distributed.pipeline.partition.get_pipeline_model_parallel_size", MagicMock(return_value=2)) @patch("neuronx_distributed.pipeline.model.NxDPPModel._create_pg_with_ranks", MagicMock(return_value=None)) @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) - @patch("torch.distributed.get_rank") - def test_unified_model_wrapper(self, rank_mock): - try: - pipeline_cuts = [ - "transformer.h.1", - "transformer.h.2", - "transformer.h.3", - "transformer.h.4", - "transformer.h.5", - "transformer.h.6", - "transformer.h.7", - ] - - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - pipeline_parallel_size=8, - pipeline_config={ - "transformer_layer_cls": GPT2Block, - "tracer_cls": "hf", - "num_microbatches": 1, - "output_loss_value_spec": True, - "input_names": ["input_ids", "attention_mask", "labels"], - "pipeline_cuts": pipeline_cuts, - "param_init_fn": None, - "leaf_module_cls": ["GPT2Block"], - "use_zero1_optimizer": True, - "use_optimizer_wrapper": True, - }, - optimizer_config={ - "zero_one_enabled": True, - "grad_clipping": True, - "max_grad_norm": 1.0, - }, - sequence_parallel=True, - activation_checkpoint_config="full", - ) - model = nxd.initialize_parallel_model(nxd_config, get_model) - model = get_lora_model(model, get_lora_config()) - - assert isinstance(model, nxd.trainer.model.NxDModel) - assert model.nxd_config == nxd_config - assert model.pp_enabled - assert model.dtype == torch.float32 - assert isinstance(model.module, LoraModel) - assert isinstance(model.module.get_base_model(), nxd.pipeline.NxDPPModel) - model_str = str(model) - assert "NxDPPModel" in model_str - assert "NxDCheckpointWrapper" in model_str - assert "LoraModel" in model_str - - except: - update_result({"inference_success": 0}) - raise - - @patch("neuronx_distributed.parallel_layers.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) - @patch( - "neuronx_distributed.parallel_layers.parallel_state.model_parallel_is_initialized", MagicMock(return_value=True) - ) - @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) - @patch("torch.distributed.get_rank") - def test_model_wrapper_no_pp(self, rank_mock): - try: - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - optimizer_config={ - "zero_one_enabled": True, - "grad_clipping": True, - "max_grad_norm": 1.0, - }, - sequence_parallel=True, - activation_checkpoint_config="full", - lora_config=get_lora_config(), - ) - model = nxd.initialize_parallel_model(nxd_config, get_model) - - assert isinstance(model, nxd.trainer.model.NxDModel) - assert model.nxd_config == nxd_config - assert not model.pp_enabled - assert model.dtype == torch.float32 - assert isinstance(model.module, LoraModel) - model_str = str(model) - assert "NxDPPModel" not in model_str - assert "NxDCheckpointWrapper" in model_str - assert "LoraModel" in model_str - - except: - update_result({"inference_success": 0}) - raise - - @patch("neuronx_distributed.parallel_layers.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) - @patch( - "neuronx_distributed.parallel_layers.parallel_state.model_parallel_is_initialized", MagicMock(return_value=True) - ) - @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) - @patch("torch.distributed.get_rank") - def test_unified_model_wrapper_no_pp(self, rank_mock): - try: - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - optimizer_config={ - "zero_one_enabled": True, - "grad_clipping": True, - "max_grad_norm": 1.0, - }, - sequence_parallel=True, - activation_checkpoint_config="full", - ) - model = nxd.initialize_parallel_model(nxd_config, get_model) - model = get_lora_model(model, get_lora_config()) - - assert isinstance(model, nxd.trainer.model.NxDModel) - assert model.nxd_config == nxd_config - assert not model.pp_enabled - assert model.dtype == torch.float32 - assert isinstance(model.module, LoraModel) - model_str = str(model) - assert "NxDPPModel" not in model_str - assert "NxDCheckpointWrapper" in model_str - assert "LoraModel" in model_str - - except: - update_result({"inference_success": 0}) - raise + @patch("neuronx_distributed.utils.model_utils.move_model_to_device", MagicMock(return_value=None)) + @patch("neuronx_distributed.parallel_layers.move_model_to_device", MagicMock(return_value=None)) + def test_unified_pp_model_wrapper(self): + model = NxDPPModel(module=NxDPPModule(4), transformer_layer_cls=torch.nn.Linear, tracer_cls="torch") + model = get_lora_model(model, get_nxd_lora_config()) + assert isinstance(model, LoraModel) + assert isinstance(model.module, NxDPPModel) + model_str = str(model) + assert "NxDPPModel" in model_str + assert "LoraModel" in model_str if __name__ == "__main__": diff --git a/test/unit_test/modules/lora/test_save_load.py b/test/unit_test/modules/lora/test_save_load.py index 6c92a3a..654e787 100644 --- a/test/unit_test/modules/lora/test_save_load.py +++ b/test/unit_test/modules/lora/test_save_load.py @@ -2,235 +2,182 @@ import unittest from unittest.mock import MagicMock, patch -# Third Party import torch -from transformers import AutoModelForCausalLM, GPT2Config - -import neuronx_distributed as nxd +from neuronx_distributed.parallel_layers import ( + ColumnParallelLinear, + RowParallelLinear, +) from neuronx_distributed.modules.lora import LoraConfig, LoraModel, get_lora_model -from ... import update_result - - -def get_model(): - seq_len = 512 - model_config = GPT2Config( - vocab_size=50257, - n_positions=seq_len, - n_embd=768, - n_layer=8, - n_head=12, - n_inner=None, - activation_function="gelu_new", - resid_pdrop=0.0, - embd_pdrop=0.0, - attn_pdrop=0.0, - layer_norm_epsilon=1e-05, - initializer_range=0.02, - summary_type="cls_index", - summary_use_proj=True, - summary_activation=None, - summary_proj_to_labels=True, - summary_first_dropout=0.0, - use_cache=False, - bos_token_id=50256, - eos_token_id=50256, - return_dict=False, + +class NxDModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.rpl = ColumnParallelLinear(32, 32) + self.cpl = RowParallelLinear(32, 32) + self.linear = torch.nn.Linear(32, 32) + + + +class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(32, 32) + self.conv2d = torch.nn.Conv2d(32, 32, 4) + self.embedding = torch.nn.Embedding(32, 32) + + + +def get_nxd_lora_config(save_lora_base, merge_lora, load_lora_from_ckpt=False): + return LoraConfig( + enable_lora=True, + lora_rank=8, + lora_alpha=16, + lora_dropout=0.05, + lora_verbose=True, + target_modules=["rpl", "cpl"], + save_lora_config_adapter=False, + save_lora_base=save_lora_base, + merge_lora=merge_lora, + load_lora_from_ckpt=load_lora_from_ckpt, ) - model = AutoModelForCausalLM.from_config(model_config) - return model - def get_lora_config(save_lora_base, merge_lora, load_lora_from_ckpt=False): - lora_config = LoraConfig( + return LoraConfig( enable_lora=True, - lora_rank=16, - lora_alpha=32, + lora_rank=8, + lora_alpha=16, lora_dropout=0.05, - bias="none", - lora_verbose=False, - target_modules=["c_attn"], + lora_verbose=True, + target_modules=["linear", "conv2d", "embedding"], save_lora_config_adapter=False, save_lora_base=save_lora_base, merge_lora=merge_lora, load_lora_from_ckpt=load_lora_from_ckpt, ) - return lora_config -class TestModelWrapper(unittest.TestCase): + +class TestLoraSaveLoad(unittest.TestCase): def test_save_load_no_base_single_device(self): - try: - model = get_model() - base_model_state_dict = model.state_dict() - lora_config = get_lora_config(save_lora_base=False, merge_lora=False) - model = LoraModel(model, lora_config) - state_dict = model.state_dict() - for key in state_dict: - assert "lora_" in key - - model.lora_config = get_lora_config(save_lora_base=False, merge_lora=False, load_lora_from_ckpt=True) - model.lora_ckpt = state_dict - model.is_checkpoint_loaded = True - - load_result = model.load_state_dict(base_model_state_dict) - assert load_result is not None - assert len(load_result.unexpected_keys) == 0 - except: - update_result({"inference_success": 0}) - raise - - - + model = Module() + base_model_state_dict = model.state_dict() + lora_config = get_lora_config(save_lora_base=False, merge_lora=False) + model = LoraModel(model, lora_config) + state_dict = model.state_dict() + for key in state_dict: + assert "lora_" in key + + model.lora_config = get_lora_config(save_lora_base=False, merge_lora=False, load_lora_from_ckpt=True) + model.lora_ckpt = state_dict + model.is_checkpoint_loaded = True + + load_result = model.load_state_dict(base_model_state_dict) + assert load_result is not None + assert len(load_result.unexpected_keys) == 0 + + def test_save_load_with_base_single_device(self): - try: - model = get_model() - lora_config = get_lora_config(save_lora_base=True, merge_lora=False) - model = LoraModel(model, lora_config) - state_dict = model.state_dict() - - model.lora_config = get_lora_config(save_lora_base=True, merge_lora=False, load_lora_from_ckpt=True) - model.lora_ckpt = state_dict - model.is_checkpoint_loaded = True - - load_result = model.load_state_dict() - assert load_result is not None - assert len(load_result.missing_keys) == 0 - assert len(load_result.unexpected_keys) == 0 - except: - update_result({"inference_success": 0}) - raise - - + model = Module() + lora_config = get_lora_config(save_lora_base=True, merge_lora=False) + model = LoraModel(model, lora_config) + state_dict = model.state_dict() + + model.lora_config = get_lora_config(save_lora_base=True, merge_lora=False, load_lora_from_ckpt=True) + model.lora_ckpt = state_dict + model.is_checkpoint_loaded = True + + load_result = model.load_state_dict() + assert load_result is not None + assert len(load_result.missing_keys) == 0 + assert len(load_result.unexpected_keys) == 0 + def test_save_load_with_base_merged_single_device(self): - try: - model = get_model() - base_model_state_dict = model.state_dict() - base_model_keys = base_model_state_dict.keys() - lora_config = get_lora_config(save_lora_base=True, merge_lora=True) - model = LoraModel(model, lora_config) - state_dict = model.state_dict() - keys = state_dict.keys() - - for key in keys: - assert key in base_model_keys - - for key in base_model_keys: - assert key in keys - except: - update_result({"inference_success": 0}) - raise - - + model = Module() + base_model_state_dict = model.state_dict() + base_model_keys = base_model_state_dict.keys() + lora_config = get_lora_config(save_lora_base=True, merge_lora=True) + model = LoraModel(model, lora_config) + state_dict = model.state_dict() + keys = state_dict.keys() + + for key in keys: + assert key in base_model_keys + + for key in base_model_keys: + assert key in keys + + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_size", MagicMock(return_value=8)) + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_rank", MagicMock(return_value=1)) @patch("neuronx_distributed.parallel_layers.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) @patch( "neuronx_distributed.parallel_layers.parallel_state.model_parallel_is_initialized", MagicMock(return_value=True) ) @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) - @patch("torch.distributed.get_rank") - def test_save_load_no_base(self, rank_mock): - try: - lora_config = get_lora_config(save_lora_base=False, merge_lora=False) - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - optimizer_config={ - "zero_one_enabled": True, - "grad_clipping": True, - "max_grad_norm": 1.0, - }, - sequence_parallel=True, - activation_checkpoint_config="full", - ) - model = nxd.initialize_parallel_model(nxd_config, get_model) - base_model_state_dict = model.state_dict() - model = get_lora_model(model, lora_config) - state_dict = model.state_dict() - for key in state_dict: - assert "lora_" in key - - model.module.lora_config = get_lora_config(save_lora_base=False, merge_lora=False, load_lora_from_ckpt=True) - model.module.lora_ckpt = state_dict - model.module.is_checkpoint_loaded = True - - load_result = model.load_state_dict(base_model_state_dict) - assert load_result is not None - assert len(load_result.unexpected_keys) == 0 - except: - update_result({"inference_success": 0}) - raise - - + def test_save_load_no_base(self): + model = NxDModule() + base_model_state_dict = model.state_dict() + + lora_config = get_nxd_lora_config(save_lora_base=False, merge_lora=False) + model = get_lora_model(model, lora_config) + + state_dict = model.state_dict() + for key in state_dict: + assert "lora_" in key + + model.lora_config = get_nxd_lora_config(save_lora_base=False, merge_lora=False, load_lora_from_ckpt=True) + model.lora_ckpt = state_dict + model.is_checkpoint_loaded = True + + load_result = model.load_state_dict(base_model_state_dict) + assert load_result is not None + assert len(load_result.unexpected_keys) == 0 + + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_size", MagicMock(return_value=8)) + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_rank", MagicMock(return_value=1)) @patch("neuronx_distributed.parallel_layers.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) @patch( "neuronx_distributed.parallel_layers.parallel_state.model_parallel_is_initialized", MagicMock(return_value=True) ) @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) - @patch("torch.distributed.get_rank") - def test_save_load_with_base(self, rank_mock): - try: - lora_config = get_lora_config(save_lora_base=True, merge_lora=False) - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - optimizer_config={ - "zero_one_enabled": True, - "grad_clipping": True, - "max_grad_norm": 1.0, - }, - sequence_parallel=True, - activation_checkpoint_config="full", - lora_config=lora_config, - ) - model = nxd.initialize_parallel_model(nxd_config, get_model) - state_dict = model.state_dict() - - model.module.lora_config = get_lora_config(save_lora_base=True, merge_lora=False, load_lora_from_ckpt=True) - model.module.lora_ckpt = state_dict - model.module.is_checkpoint_loaded = True - - load_result = model.load_state_dict(None) - assert load_result is not None - assert len(load_result.missing_keys) == 0 - assert len(load_result.unexpected_keys) == 0 - except: - update_result({"inference_success": 0}) - raise - - + def test_save_load_with_base(self): + model = NxDModule() + lora_config = get_nxd_lora_config(save_lora_base=True, merge_lora=False) + model = get_lora_model(model, lora_config) + state_dict = model.state_dict() + + model.lora_config = get_nxd_lora_config(save_lora_base=True, merge_lora=False, load_lora_from_ckpt=True) + model.lora_ckpt = state_dict + model.is_checkpoint_loaded = True + + load_result = model.load_state_dict(None) + assert load_result is not None + assert len(load_result.missing_keys) == 0 + assert len(load_result.unexpected_keys) == 0 + + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_size", MagicMock(return_value=8)) + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_rank", MagicMock(return_value=1)) @patch("neuronx_distributed.parallel_layers.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) @patch( "neuronx_distributed.parallel_layers.parallel_state.model_parallel_is_initialized", MagicMock(return_value=True) ) @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) - @patch("torch.distributed.get_rank") - def test_save_load_with_base_merged(self, rank_mock): - try: - lora_config = get_lora_config(save_lora_base=True, merge_lora=True) - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - optimizer_config={ - "zero_one_enabled": True, - "grad_clipping": True, - "max_grad_norm": 1.0, - }, - sequence_parallel=True, - activation_checkpoint_config="full", - ) - model = nxd.initialize_parallel_model(nxd_config, get_model) - base_model_state_dict = model.state_dict() - base_model_keys = base_model_state_dict.keys() - model = get_lora_model(model, lora_config) - state_dict = model.state_dict() - keys = state_dict.keys() - - for key in keys: - assert key in base_model_keys - - for key in base_model_keys: - assert key in keys - except: - update_result({"inference_success": 0}) - raise - + def test_save_load_with_base_merged(self): + model = NxDModule() + base_model_state_dict = model.state_dict() + base_model_keys = base_model_state_dict.keys() + + lora_config = get_nxd_lora_config(save_lora_base=True, merge_lora=True) + model = get_lora_model(model, lora_config) + state_dict = model.state_dict() + keys = state_dict.keys() + + for key in keys: + assert key in base_model_keys + + for key in base_model_keys: + assert key in keys + if __name__ == "__main__": unittest.main() diff --git a/test/unit_test/modules/moe/mixtral_model.py b/test/unit_test/modules/moe/mixtral_model.py index 5b69e6c..667727c 100644 --- a/test/unit_test/modules/moe/mixtral_model.py +++ b/test/unit_test/modules/moe/mixtral_model.py @@ -8,17 +8,15 @@ if version.parse(transformers_ver) < version.parse("4.36.0"): assert False, f"transformers library version is {transformers_ver}. Minimum required is 4.36.0" -from transformers.models.mixtral.configuration_mixtral import MixtralConfig -from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock +from transformers.models.mixtral.configuration_mixtral import MixtralConfig # noqa: E402 +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock # noqa: E402 -from . import utils_testing as ut +from . import utils_testing as ut # noqa: E402 def initialize_mixtral_model(cfg, seed=5): assert cfg.implementation == "topk" - assert cfg.glu_mlp == True, f"Mixtral implementation available only for GLU MLP" - - intermediate_size = int(8 / 3 * cfg.hidden_size) + assert cfg.glu_mlp is True, "Mixtral implementation available only for GLU MLP" mixtral_config = MixtralConfig() mixtral_config.hidden_size = cfg.hidden_size @@ -41,7 +39,7 @@ def convert_mixtral_to_neuron_state_dict(mixtral_state_dict, cfg): This function implements workarounds for this. """ - assert cfg.glu_mlp == True, f"Only GPU MLP is supported for Mixtral Top-K model" + assert cfg.glu_mlp is True, "Only GPU MLP is supported for Mixtral Top-K model" neuron_state_dict = {} # Copy router weights @@ -62,7 +60,7 @@ def convert_mixtral_to_neuron_state_dict(mixtral_state_dict, cfg): gate_up_proj_slice = torch.narrow(gate_up_proj, 0, e, 1) gate_up_proj_weights = torch.cat([gate_proj_weights, up_proj_weights], dim=1) gate_up_proj_slice.copy_(gate_up_proj_weights) - neuron_state_dict["expert_mlps.gate_up_proj.weight"] = gate_up_proj + neuron_state_dict["expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up_proj down_proj = torch.empty(cfg.num_experts, intermediate_size, hidden_size, device=device) for e in range(cfg.num_experts): @@ -71,6 +69,6 @@ def convert_mixtral_to_neuron_state_dict(mixtral_state_dict, cfg): if down_proj_weights is not None: down_proj_slice = torch.narrow(down_proj, 0, e, 1) down_proj_slice.copy_(down_proj_weights) - neuron_state_dict["expert_mlps.down_proj.weight"] = down_proj + neuron_state_dict["expert_mlps.mlp_op.down_proj.weight"] = down_proj return neuron_state_dict diff --git a/test/unit_test/modules/moe/sbase_model.py b/test/unit_test/modules/moe/sbase_model.py index 787e7f4..5ea39ff 100644 --- a/test/unit_test/modules/moe/sbase_model.py +++ b/test/unit_test/modules/moe/sbase_model.py @@ -106,9 +106,7 @@ def convert_sbase_to_neuron_state_dict(sbase_state_dict, cfg): # copy the MLP parameters if cfg.glu_mlp: - gate_up_proj = torch.empty( - cfg.num_experts, hidden_size, 2 * intermediate_size, device=device, dtype=cfg.dtype - ) + gate_up_proj = torch.empty(cfg.num_experts, hidden_size, 2 * intermediate_size, device=device, dtype=cfg.dtype) for e in range(cfg.num_experts): # Copy gate_proj and up_proj after concatenation gate_proj_weights = sbase_state_dict[f"experts.{e}.gate_proj.weight"].T @@ -116,7 +114,7 @@ def convert_sbase_to_neuron_state_dict(sbase_state_dict, cfg): gate_up_proj_slice = torch.narrow(gate_up_proj, 0, e, 1) gate_up_proj_weights = torch.cat([gate_proj_weights, up_proj_weights], dim=1) gate_up_proj_slice.copy_(gate_up_proj_weights) - neuron_state_dict["expert_mlps.gate_up_proj.weight"] = gate_up_proj + neuron_state_dict["expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up_proj else: up_proj = torch.empty(cfg.num_experts, hidden_size, intermediate_size, device=device, dtype=cfg.dtype) for e in range(cfg.num_experts): @@ -124,7 +122,7 @@ def convert_sbase_to_neuron_state_dict(sbase_state_dict, cfg): up_proj_weights = sbase_state_dict[f"experts.{e}.up_proj.weight"].T up_proj_slice = torch.narrow(up_proj, 0, e, 1) up_proj_slice.copy_(up_proj_weights) - neuron_state_dict["expert_mlps.up_proj.weight"] = up_proj + neuron_state_dict["expert_mlps.mlp_op.up_proj.weight"] = up_proj down_proj = torch.empty(cfg.num_experts, intermediate_size, hidden_size, device=device, dtype=cfg.dtype) for e in range(cfg.num_experts): @@ -132,6 +130,6 @@ def convert_sbase_to_neuron_state_dict(sbase_state_dict, cfg): down_proj_weights = sbase_state_dict[f"experts.{e}.down_proj.weight"].T down_proj_slice = torch.narrow(down_proj, 0, e, 1) down_proj_slice.copy_(down_proj_weights) - neuron_state_dict["expert_mlps.down_proj.weight"] = down_proj + neuron_state_dict["expert_mlps.mlp_op.down_proj.weight"] = down_proj return neuron_state_dict diff --git a/test/unit_test/modules/moe/test_impl_correctness.py b/test/unit_test/modules/moe/test_impl_correctness.py index 80b71ba..423a50b 100644 --- a/test/unit_test/modules/moe/test_impl_correctness.py +++ b/test/unit_test/modules/moe/test_impl_correctness.py @@ -9,17 +9,15 @@ from parameterized import parameterized from neuronx_distributed import parallel_layers -from neuronx_distributed.modules.moe import ACT2FN from neuronx_distributed.modules.moe import ( load_balancing_loss_func as neuron_load_balancing_loss_func, ) -from ... import update_result from . import loss_fn_correctness_test_helper as lch from . import mixtral_model as m_mixtral from . import sbase_model as m_sbase from . import utils_testing as ut -from .utils_testing import ExptCfgCorrectness +from .utils_testing import ExptCfg if not torch.distributed.is_initialized(): # Simulate torchrun (required because MoE uses parallel layers for TP) @@ -47,12 +45,14 @@ } -def get_impl_correctness_test_configs(test_type): - assert test_type in {"fwd", "bwd"} +def get_impl_correctness_test_configs(test_modes): + test_modes = set(test_modes) + assert ( + len(test_modes) > 0 and len(test_modes - {"training", "inference"}) == 0 + ), f"Unknown test modes: {str(test_modes)}" GLU_MLP_ARGS = [True, False] DTYPE_ARGS = [torch.float32, torch.bfloat16] - PERMUTE_STRATEGY_ARGS = ["matmul", "index"] test_configs = [] @@ -65,53 +65,40 @@ def get_impl_correctness_test_configs(test_type): "implementation": "sbase", } - # Test forward_full_capacity - test_cfg["expert_mlps_permute_strategy"] = "index" + # Training tests + test_cfg["test_mode"] = "training" sbase_test_configs.extend( [ - ExptCfgCorrectness( - seq_len=128, batch_size=1, hidden_size=384, num_experts=2, capacity_factor=2.0, **test_cfg - ), - ExptCfgCorrectness( - seq_len=128, batch_size=4, hidden_size=384, num_experts=8, capacity_factor=8.0, **test_cfg - ), + # Test forward_all_experts (full capacity) + ExptCfg(seq_len=128, batch_size=1, hidden_size=384, num_experts=2, capacity_factor=None, **test_cfg), + ExptCfg(seq_len=128, batch_size=4, hidden_size=384, num_experts=8, capacity_factor=None, **test_cfg), + # Test forward_capacity_factor + ExptCfg(seq_len=128, batch_size=1, hidden_size=384, num_experts=4, capacity_factor=2.0, **test_cfg), + ExptCfg(seq_len=128, batch_size=4, hidden_size=384, num_experts=8, capacity_factor=1.0, **test_cfg), ] ) - # Test cases for inference (therefore we return them only for the fwd correctness) - if test_type == "fwd": - # Tests for token-gen - sbase_test_configs.extend( - [ - ExptCfgCorrectness( - seq_len=1, batch_size=1, hidden_size=384, num_experts=4, capacity_factor=1.0, **test_cfg - ), - ExptCfgCorrectness( - seq_len=1, batch_size=4, hidden_size=960, num_experts=4, capacity_factor=1.0, **test_cfg - ), - ] - ) - - # capacity_factor such that some tokens may be dropped - for permute_strategy in PERMUTE_STRATEGY_ARGS: - test_cfg["expert_mlps_permute_strategy"] = permute_strategy - sbase_test_configs.extend( - [ - # capacity_factor such that some tokens may be dropped - ExptCfgCorrectness( - seq_len=128, batch_size=1, hidden_size=384, num_experts=4, capacity_factor=2.0, **test_cfg - ), - ExptCfgCorrectness( - seq_len=128, batch_size=4, hidden_size=384, num_experts=8, capacity_factor=1.0, **test_cfg - ), - ] - ) + # Inference tests + test_cfg["test_mode"] = "inference" + sbase_test_configs.extend( + [ + # Test context encoding + ExptCfg(seq_len=128, batch_size=1, hidden_size=384, num_experts=2, capacity_factor=None, **test_cfg), + ExptCfg(seq_len=128, batch_size=4, hidden_size=384, num_experts=8, capacity_factor=None, **test_cfg), + ExptCfg(seq_len=128, batch_size=1, hidden_size=384, num_experts=4, capacity_factor=2.0, **test_cfg), + ExptCfg(seq_len=128, batch_size=4, hidden_size=384, num_experts=8, capacity_factor=1.0, **test_cfg), + # Test token generation + ExptCfg(seq_len=1, batch_size=1, hidden_size=384, num_experts=4, capacity_factor=None, **test_cfg), + ExptCfg(seq_len=1, batch_size=2, hidden_size=384, num_experts=4, capacity_factor=None, **test_cfg), + ExptCfg(seq_len=1, batch_size=8, hidden_size=960, num_experts=4, capacity_factor=None, **test_cfg), + ] + ) # Test each configuration on 2 random activation functions for test_no, cfg in enumerate(sbase_test_configs): for hidden_act in ut.get_random_activations(num=2, seed=test_no): test_configs.append(dataclasses.replace(cfg, hidden_act=hidden_act)) - + # Mixtral test cases # Only fp32 testing with full capacity is supported for mixtral because we havent hacked the golden implementation test_cfg = { @@ -119,63 +106,119 @@ def get_impl_correctness_test_configs(test_type): "glu_mlp": True, "hidden_act": "silu", "implementation": "topk", - "expert_mlps_permute_strategy": "index", + "capacity_factor": None, } - # Test forward_full_capacity + # Training tests + test_cfg["test_mode"] = "training" test_configs.extend( [ - ExptCfgCorrectness( - seq_len=128, batch_size=1, hidden_size=384, num_experts=2, capacity_factor=2.0, top_k=1, **test_cfg, + ExptCfg( + seq_len=128, + batch_size=1, + hidden_size=384, + num_experts=2, + top_k=1, + **test_cfg, ), - ExptCfgCorrectness( - seq_len=128, batch_size=2, hidden_size=384, num_experts=4, capacity_factor=2.0, top_k=2, **test_cfg, + ExptCfg( + seq_len=128, + batch_size=2, + hidden_size=384, + num_experts=4, + top_k=2, + **test_cfg, ), - ExptCfgCorrectness( - seq_len=128, batch_size=4, hidden_size=384, num_experts=4, capacity_factor=1.0, top_k=4, **test_cfg, + ExptCfg( + seq_len=128, + batch_size=4, + hidden_size=384, + num_experts=4, + top_k=4, + **test_cfg, ), ] ) - # Test cases for inference (therefore we return them only for the fwd correctness) - if test_type == "fwd": - # Tests for token-gen - test_configs.extend( - [ - ExptCfgCorrectness( - seq_len=1, batch_size=1, hidden_size=384, num_experts=4, capacity_factor=1.0, top_k=2, **test_cfg, - ), - ExptCfgCorrectness( - seq_len=1, batch_size=2, hidden_size=384, num_experts=4, capacity_factor=1.0, top_k=4, **test_cfg, - ), - ExptCfgCorrectness( - seq_len=1, batch_size=4, hidden_size=384, num_experts=8, capacity_factor=1.0, top_k=4, **test_cfg, - ), - ] - ) + # Inference tests + test_cfg["test_mode"] = "inference" + test_configs.extend( + [ + # Test context encoding + ExptCfg( + seq_len=128, + batch_size=1, + hidden_size=384, + num_experts=2, + top_k=1, + **test_cfg, + ), + ExptCfg( + seq_len=128, + batch_size=2, + hidden_size=384, + num_experts=4, + top_k=2, + **test_cfg, + ), + ExptCfg( + seq_len=128, + batch_size=4, + hidden_size=384, + num_experts=4, + top_k=4, + **test_cfg, + ), + # Test token generation + ExptCfg( + seq_len=1, + batch_size=1, + hidden_size=384, + num_experts=4, + top_k=2, + **test_cfg, + ), + ExptCfg( + seq_len=1, + batch_size=2, + hidden_size=384, + num_experts=4, + top_k=4, + **test_cfg, + ), + ExptCfg( + seq_len=1, + batch_size=4, + hidden_size=384, + num_experts=8, + top_k=4, + **test_cfg, + ), + ] + ) - # Add full capacity tests for matmul and index - full_capacity_permute_test_configs = [] + # Add full capacity tests for forward_capacity_factor + forward_capacity_factor_full_capacity_test_configs = [] eps = 10**-6 for cfg in test_configs: if cfg.seq_len == 1: - # Skip for token-gen + # Skip for token generation configs continue - - if cfg.capacity_factor >= cfg.num_experts / cfg.top_k: + + if cfg.capacity_factor is None: # Set capacity_factor = full_capacity_factor - eps full_cf_eps = float(cfg.num_experts / cfg.top_k) - eps - for permute_strategy in PERMUTE_STRATEGY_ARGS: - full_cf_eps_cfg = dataclasses.replace( - cfg, - capacity_factor=full_cf_eps, - expert_mlps_permute_strategy=permute_strategy - ) - full_capacity_permute_test_configs.append(full_cf_eps_cfg) + full_cf_eps_cfg = dataclasses.replace( + cfg, + capacity_factor=full_cf_eps, + ) + forward_capacity_factor_full_capacity_test_configs.append(full_cf_eps_cfg) + + test_configs.extend(forward_capacity_factor_full_capacity_test_configs) - test_configs.extend(full_capacity_permute_test_configs) - - return test_configs + test_mode_configs = [cfg for cfg in test_configs if cfg.test_mode in test_modes] + + return test_mode_configs def initialize_neuron_and_golden_models(cfg): @@ -228,188 +271,181 @@ def get_expected_dropped_token_indices(expert_ind, cfg): class TestImplCorrectness(unittest.TestCase): - @parameterized.expand(get_impl_correctness_test_configs("fwd"), name_func=ut.custom_name_func) + @parameterized.expand( + get_impl_correctness_test_configs(test_modes=["training", "inference"]), name_func=ut.custom_name_func + ) def test_fwd_correctness(self, cfg): - try: - is_token_gen = cfg.seq_len == 1 - model_neuron, model_golden = initialize_neuron_and_golden_models(cfg) + model_neuron, model_golden = initialize_neuron_and_golden_models(cfg) + + # Set is_test=True + model_neuron.is_test = True + if cfg.implementation == "sbase": + model_golden.is_test = True + + is_token_gen = cfg.seq_len == 1 + if cfg.test_mode == "inference": + model_neuron.eval() + model_golden.eval() + else: + model_neuron.train() + model_golden.train() - # Set is_test=True - model_neuron.is_test = True - if cfg.implementation == "sbase": - model_golden.is_test = True + with torch.no_grad(): + for it in range(cfg.num_iters): + if cfg.test_mode == "inference": + # Inference: input is BSH + ip = torch.randn( + cfg.batch_size, cfg.seq_len, cfg.hidden_size, dtype=cfg.dtype, device=cfg.device + ) + else: + # Training: input is SBH + ip = torch.randn( + cfg.seq_len, cfg.batch_size, cfg.hidden_size, dtype=cfg.dtype, device=cfg.device + ) - if is_token_gen: - model_neuron.eval() - model_golden.eval() - else: - model_neuron.train() - model_golden.train() - - with torch.no_grad(): - for it in range(cfg.num_iters): - if is_token_gen: - # Token gen: input should be BSH - ip = torch.randn(cfg.batch_size, cfg.seq_len, cfg.hidden_size, dtype=cfg.dtype, device=cfg.device) - else: - # For training, input is SBH - # For context encoding, the ordering of S and B does not matter for cpu testing - ip = torch.randn(cfg.seq_len, cfg.batch_size, cfg.hidden_size, dtype=cfg.dtype, device=cfg.device) + if cfg.implementation == "topk": + # Run fwd on both the Neuron and Mixtral HF model + op_neuron, router_logits_neuron, exp_ind_neuron = model_neuron(ip) + op_mixtral, router_logits_mixtral = model_golden(ip) - if cfg.implementation == "topk": - # Run fwd on both the Neuron and Mixtral HF model - op_neuron, router_logits_neuron, exp_ind_neuron = model_neuron(ip) - op_mixtral, router_logits_mixtral = model_golden(ip) + # Check that router logits and outputs match + ut.check_tensors( + router_logits_neuron, router_logits_mixtral, **TEST_TOLS, additional_msg=f"Iteration {it}" + ) + ut.check_tensors(op_neuron, op_mixtral, **TEST_TOLS, additional_msg=f"Iteration {it}") - # check that router logits and outputs match + elif cfg.implementation == "sbase": + # Run fwd on both the Neuron and S-BASE model + op_neuron, _, exp_ind_neuron = model_neuron(ip) + op_sbase, _, exp_ind_sbase = model_golden(ip) + + if cfg.dtype == torch.bfloat16: + # Skip this check for token-gen (because perc_discrepancy may be large since S*B is small) + if not is_token_gen: + # Permit minor discrepancies for bf16 + perc_discrepancy = 1 - torch.mean( + torch.isclose(exp_ind_neuron, exp_ind_sbase, **TEST_TOLS).to(torch.float32) + ) + assert ( + perc_discrepancy.item() < BF16_EXPERT_ASSIGNMENT_DIFF_TOL + ), f" diff is {perc_discrepancy}" + else: + # Check that the initial expert assignments were identical for fp32 ut.check_tensors( - router_logits_neuron, router_logits_mixtral, **TEST_TOLS, additional_msg=f"Iteration {it}" + exp_ind_neuron, exp_ind_sbase, **TEST_TOLS, additional_msg=f"Iteration {it}" ) - ut.check_tensors(op_neuron, op_mixtral, **TEST_TOLS, additional_msg=f"Iteration {it}") - - elif cfg.implementation == "sbase": - # Run fwd on both the Neuron and S-BASE model - op_neuron, _, exp_ind_neuron = model_neuron(ip) - op_sbase, _, exp_ind_sbase = model_golden(ip) - - if cfg.dtype == torch.bfloat16: - # Skip this check for token-gen (because perc_discrepancy may be large since S*B is small) - if not is_token_gen: - # Permit minor discrepancies for bf16 - perc_discrepancy = 1 - torch.mean( - torch.isclose(exp_ind_neuron, exp_ind_sbase, **TEST_TOLS).to(torch.float32) - ) - assert ( - perc_discrepancy.item() < BF16_EXPERT_ASSIGNMENT_DIFF_TOL - ), f" diff is {perc_discrepancy}" - else: - # Check that the initial expert assignments were identical for fp32 - ut.check_tensors( - exp_ind_neuron, exp_ind_sbase, **TEST_TOLS, additional_msg=f"Iteration {it}" - ) - # Token-gen is dropless - if not is_token_gen: - # Get the indices of the tokens which should have been dropped by the model_neuron - expected_dropped_token_indices = get_expected_dropped_token_indices(exp_ind_neuron, cfg) + # Token-gen is dropless + if not is_token_gen: + # Get the indices of the tokens which should have been dropped by the model_neuron + expected_dropped_token_indices = get_expected_dropped_token_indices(exp_ind_neuron, cfg) - # Manually simulate the dropping of tokens in op_sbase - op_sbase = ut.drop_tokens_in_tensor(op_sbase, expected_dropped_token_indices) + # Manually simulate the dropping of tokens in op_sbase + op_sbase = ut.drop_tokens_in_tensor(op_sbase, expected_dropped_token_indices) - if cfg.dtype == torch.bfloat16: - # Simulate dropping of tokens in op_neuron and op_sbase where the expert assignments are not matching with neuron - expert_mismatch_indices = torch.where(exp_ind_neuron != exp_ind_sbase)[0].tolist() - op_sbase = ut.drop_tokens_in_tensor(op_sbase, expert_mismatch_indices) - op_neuron = ut.drop_tokens_in_tensor(op_neuron, expert_mismatch_indices) + if cfg.dtype == torch.bfloat16: + # Simulate dropping of tokens in op_neuron and op_sbase where the expert assignments are not matching with neuron + expert_mismatch_indices = torch.where(exp_ind_neuron != exp_ind_sbase)[0].tolist() + op_sbase = ut.drop_tokens_in_tensor(op_sbase, expert_mismatch_indices) + op_neuron = ut.drop_tokens_in_tensor(op_neuron, expert_mismatch_indices) - # Check that op_neuron matches the op_sbase with the dropped tokens - ut.check_tensors(op_neuron, op_sbase, **TEST_TOLS, additional_msg=f"Iteration {it}") + # Check that op_neuron matches the op_sbase with the dropped tokens + ut.check_tensors(op_neuron, op_sbase, **TEST_TOLS, additional_msg=f"Iteration {it}") - else: - raise ValueError(f"Unknown implementation: {cfg.implementation}") - except: - update_result({"inference_success": 0}) - raise + else: + raise ValueError(f"Unknown implementation: {cfg.implementation}") - @parameterized.expand(get_impl_correctness_test_configs("bwd"), name_func=ut.custom_name_func) + @parameterized.expand(get_impl_correctness_test_configs(test_modes=["training"]), name_func=ut.custom_name_func) def test_bwd_correctness(self, cfg): - try: - model_neuron, model_golden = initialize_neuron_and_golden_models(cfg) + model_neuron, model_golden = initialize_neuron_and_golden_models(cfg) - # Set is_test=True - model_neuron.is_test = True - if cfg.implementation == "sbase": - model_golden.is_test = True + # Set is_test=True + model_neuron.is_test = True + if cfg.implementation == "sbase": + model_golden.is_test = True - # Set models to train mode - model_neuron.train() - model_golden.train() + # Set models to train mode + model_neuron.train() + model_golden.train() - optimizer_neuron = torch.optim.Adadelta(model_neuron.parameters()) - optimizer_golden = torch.optim.Adadelta(model_golden.parameters()) - mse_loss = torch.nn.MSELoss() + optimizer_neuron = torch.optim.Adadelta(model_neuron.parameters()) + optimizer_golden = torch.optim.Adadelta(model_golden.parameters()) + mse_loss = torch.nn.MSELoss() - for it in range(cfg.num_iters): - # Generate random input tensor - ip = torch.randn(cfg.seq_len, cfg.batch_size, cfg.hidden_size, dtype=cfg.dtype, device=cfg.device) - - if cfg.dtype == torch.bfloat16: - # Simulate dropping of tokens in input where the expert assignments are not matching with neuron - assert cfg.implementation == "sbase" - with torch.no_grad(): - op_neuron, _, exp_ind_neuron = model_neuron(ip) - op_sbase, _, exp_ind_sbase = model_golden(ip) - expert_mismatch_indices = torch.where(exp_ind_neuron != exp_ind_sbase)[0].tolist() - ip = ut.drop_tokens_in_tensor(ip, expert_mismatch_indices) - - # Run forward pass on model_neuron - op_neuron, _, exp_ind_neuron = model_neuron(ip) - - if cfg.implementation == "sbase": - # Get the indices of the tokens which should have been dropped by the model_neuron - expected_dropped_token_indices = get_expected_dropped_token_indices(exp_ind_neuron, cfg) - # Manually simulate the dropping of tokens in the input passed - ip = ut.drop_tokens_in_tensor(ip.clone().detach(), expected_dropped_token_indices) - - # Run forward pass on model_golden - if cfg.implementation == "sbase": - op_golden, _, _ = model_golden(ip) - elif cfg.implementation == "topk": - op_golden, _ = model_golden(ip) - else: - raise ValueError(f"Unknown implementation: {cfg.implementation}") + for it in range(cfg.num_iters): + # Generate random input tensor + ip = torch.randn(cfg.seq_len, cfg.batch_size, cfg.hidden_size, dtype=cfg.dtype, device=cfg.device) - # Compute MSE loss wrt which we get the gradients - targets = torch.zeros_like(ip, device=cfg.device, dtype=torch.float32) - loss_neuron = mse_loss(op_neuron.to(torch.float32), targets) - loss_golden = mse_loss(op_golden.to(torch.float32), targets) - ut.check_tensors(loss_neuron, loss_golden, **TEST_TOLS, additional_msg=f"Iteration {it}") - - # Run backward pass to compute gradients - loss_neuron.backward() - loss_golden.backward() - - # Compare gradients - grads_neuron = ut.get_model_grads_dict(model_neuron) - grads_golden = convert_golden_to_neuron_state_dict(ut.get_model_grads_dict(model_golden), cfg=cfg) - assert set(grads_neuron.keys()) == set(grads_golden.keys()) - for key in grads_neuron: - ut.check_tensors( - grads_neuron[key], grads_golden[key], **TEST_TOLS, additional_msg=f"Iteration: {it}, key: {key}" - ) + if cfg.dtype == torch.bfloat16: + # Simulate dropping of tokens in input where the expert assignments are not matching with neuron + assert cfg.implementation == "sbase" + with torch.no_grad(): + op_neuron, _, exp_ind_neuron = model_neuron(ip) + op_sbase, _, exp_ind_sbase = model_golden(ip) + expert_mismatch_indices = torch.where(exp_ind_neuron != exp_ind_sbase)[0].tolist() + ip = ut.drop_tokens_in_tensor(ip, expert_mismatch_indices) + + # Run forward pass on model_neuron + op_neuron, _, exp_ind_neuron = model_neuron(ip) - # Zero out gradients before next iteration - optimizer_neuron.zero_grad() - optimizer_golden.zero_grad() - except: - update_result({"inference_success": 0}) - raise + if cfg.implementation == "sbase": + # Get the indices of the tokens which should have been dropped by the model_neuron + expected_dropped_token_indices = get_expected_dropped_token_indices(exp_ind_neuron, cfg) + # Manually simulate the dropping of tokens in the input passed + ip = ut.drop_tokens_in_tensor(ip.clone().detach(), expected_dropped_token_indices) + + # Run forward pass on model_golden + if cfg.implementation == "sbase": + op_golden, _, _ = model_golden(ip) + elif cfg.implementation == "topk": + op_golden, _ = model_golden(ip) + else: + raise ValueError(f"Unknown implementation: {cfg.implementation}") + + # Compute MSE loss wrt which we get the gradients + targets = torch.zeros_like(ip, device=cfg.device, dtype=torch.float32) + loss_neuron = mse_loss(op_neuron.to(torch.float32), targets) + loss_golden = mse_loss(op_golden.to(torch.float32), targets) + ut.check_tensors(loss_neuron, loss_golden, **TEST_TOLS, additional_msg=f"Iteration {it}") + + # Run backward pass to compute gradients + loss_neuron.backward() + loss_golden.backward() + + # Compare gradients + grads_neuron = ut.get_model_grads_dict(model_neuron) + grads_golden = convert_golden_to_neuron_state_dict(ut.get_model_grads_dict(model_golden), cfg=cfg) + assert set(grads_neuron.keys()) == set(grads_golden.keys()) + for key in grads_neuron: + ut.check_tensors( + grads_neuron[key], grads_golden[key], **TEST_TOLS, additional_msg=f"Iteration: {it}, key: {key}" + ) + + # Zero out gradients before next iteration + optimizer_neuron.zero_grad() + optimizer_golden.zero_grad() @parameterized.expand( lch.get_loss_fn_correctness_test_configs(dtypes=[torch.bfloat16, torch.float32]), name_func=ut.custom_name_func ) def test_loss_fn_correctness(self, cfg): - try: - # Set random seed for reproducibility - torch.manual_seed(cfg.num_experts) - with torch.no_grad(): - for it in range(cfg.num_iters): - test_gate_logits = [ - torch.randn(cfg.batch_size * cfg.seq_len, cfg.num_experts, device=cfg.device, dtype=cfg.dtype) - for _ in range(cfg.num_layers) - ] - test_gate_logits = tuple(test_gate_logits) - hf_loss = lch.hf_load_balancing_loss_func(test_gate_logits, cfg.num_experts, cfg.top_k) - concatenated_test_gate_logits = torch.cat([layer_gate for layer_gate in test_gate_logits], dim=0) - neuron_loss = neuron_load_balancing_loss_func( - concatenated_test_gate_logits, cfg.num_experts, cfg.top_k - ) - assert neuron_loss.dtype == hf_loss.dtype - test_tols = lch.FP32_TEST_TOLS if cfg.dtype == torch.float32 else lch.BF16_TEST_TOLS - ut.check_tensors(neuron_loss, hf_loss, **test_tols, additional_msg=f"Iteration {it}") - except: - update_result({"inference_success": 0}) - raise + # Set random seed for reproducibility + torch.manual_seed(cfg.num_experts) + with torch.no_grad(): + for it in range(cfg.num_iters): + test_gate_logits = [ + torch.randn(cfg.batch_size * cfg.seq_len, cfg.num_experts, device=cfg.device, dtype=cfg.dtype) + for _ in range(cfg.num_layers) + ] + test_gate_logits = tuple(test_gate_logits) + hf_loss = lch.hf_load_balancing_loss_func(test_gate_logits, cfg.num_experts, cfg.top_k) + concatenated_test_gate_logits = torch.cat([layer_gate for layer_gate in test_gate_logits], dim=0) + neuron_loss = neuron_load_balancing_loss_func( + concatenated_test_gate_logits, cfg.num_experts, cfg.top_k + ) + assert neuron_loss.dtype == hf_loss.dtype + test_tols = lch.FP32_TEST_TOLS if cfg.dtype == torch.float32 else lch.BF16_TEST_TOLS + ut.check_tensors(neuron_loss, hf_loss, **test_tols, additional_msg=f"Iteration {it}") if __name__ == "__main__": diff --git a/test/unit_test/modules/moe/utils_testing.py b/test/unit_test/modules/moe/utils_testing.py index 66a7587..f59dbf8 100644 --- a/test/unit_test/modules/moe/utils_testing.py +++ b/test/unit_test/modules/moe/utils_testing.py @@ -1,20 +1,78 @@ import math import random from dataclasses import dataclass +import functools import torch +from torch.optim import Adam, SGD +from neuronx_distributed.optimizer import NeuronZero1Optimizer, NeuronEPZero1Optimizer from neuronx_distributed.modules.moe import ( ACT2FN, - ExpertMLPsCapacityFactor, + ExpertMLPs, MoE, - MoESequenceParallelMode, RouterSinkhorn, RouterTopK, ) +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers import random as nxd_random +from neuronx_distributed.utils.model_utils import move_model_to_device +from neuronx_distributed.trainer.optimizer import NxDOptimizer +import torch_xla.core.xla_model as xm ALL_ACTIVATIONS = sorted(list(ACT2FN.keys())) +STATE_KEYS = { + "_TENSOR_MODEL_PARALLEL_GROUP", + "_TENSOR_MODEL_PARALLEL_GROUP_SPMD", + "_PIPELINE_MODEL_PARALLEL_GROUP", + "_PIPELINE_GLOBAL_RANKS", + "_PIPELINE_MODEL_PARALLEL_GROUP_SPMD", + "_NEXT_RANK_GROUP_SPMD", + "_PREV_RANK_GROUP_SPMD", + "_NEXT_RANK_GROUP", + "_PREV_RANK_GROUP", + "_EXPERT_MODEL_PARALLEL_GROUP", + "_EXPERT_MODEL_PARALLEL_GROUP_SPMD", + "_EXP_DATA_PARALLEL_GROUP", + "_EXP_DATA_PARALLEL_GROUP_SPMD", + "_DATA_PARALLEL_GROUP", + "_DATA_PARALLEL_GROUP_SPMD", + "_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE", + "_MPU_TENSOR_MODEL_PARALLEL_RANK", +} + +PARALLEL_STATE_MAP = {} + +def nxd_init(tp_degree, ep_degree, seed): + + world_size = torch.distributed.get_world_size() + parallel_state_key = f"{world_size}_{tp_degree}_{ep_degree}" + + def _save_parallel_state(key): + state = {} + for attr in STATE_KEYS: + state[attr] = parallel_state.__dict__[attr] + PARALLEL_STATE_MAP[key] = state + + def _load_parallel_state(key): + for k, v in PARALLEL_STATE_MAP[key].items(): + parallel_state.__dict__[k] = v + + if parallel_state_key in PARALLEL_STATE_MAP: + _load_parallel_state(parallel_state_key) + else: + parallel_state.destroy_model_parallel() + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=tp_degree, + expert_model_parallel_size=ep_degree, + pipeline_model_parallel_size=1, + ) + _save_parallel_state(parallel_state_key) + + # Set seed + nxd_random.model_parallel_xla_manual_seed(seed) + @dataclass class ExptCfg: @@ -26,17 +84,16 @@ class ExptCfg: capacity_factor: float dtype: torch.dtype glu_mlp: bool - expert_mlps_permute_strategy: str # Either 'matmul' or 'index' + test_mode: str # Either 'training' or 'inference' implementation: str # "sbase" or "topk" intermediate_size: int = None hidden_act: str = "silu" # One of ACT2FN device: str = "cpu" top_k: int = 1 - - -@dataclass -class ExptCfgCorrectness(ExptCfg): + zero1: bool = False + sequence_parallel_enabled: bool = False num_iters: int = 10 + lr: float = 0.1 def get_random_activations(num, seed=None): @@ -45,28 +102,6 @@ def get_random_activations(num, seed=None): return random.sample(ALL_ACTIVATIONS, num) -def filter_valid_expt_configs(expt_configs): - valid_expt_configs = [] - - for cfg in expt_configs: - # OPTIMIZED_SP_MATMUL mode does not apply to the the 'index' permute strategy - sequence_parallel_mode = MoESequenceParallelMode[getattr(cfg, "sequence_parallel_mode", MoESequenceParallelMode.NO_SP)] - if ( - cfg.expert_mlps_permute_strategy == "index" - and sequence_parallel_mode == MoESequenceParallelMode.OPTIMIZED_SP_MATMUL - ): - continue - - # OPTIMIZED_SP_MATMUL is not supported for inference due to SPMD restriction - test_mode = getattr(cfg, "test_mode", "training") - if test_mode == "inference" and sequence_parallel_mode == MoESequenceParallelMode.OPTIMIZED_SP_MATMUL: - continue - - valid_expt_configs.append(cfg) - - return valid_expt_configs - - class StackedModel(torch.nn.Module): def __init__(self, stack_size, cfg, return_router_logits): super().__init__() @@ -99,7 +134,8 @@ def custom_name_func(testcase_func, param_num, param): def check_tensors(t1, t2, atol, rtol, additional_msg=None): - msg = lambda s: "%s\n%s" % (s, additional_msg) if additional_msg is not None else s + def msg(s): + return "%s\n%s" % (s, additional_msg) if additional_msg is not None else s torch.testing.assert_close(t1, t2, atol=atol, rtol=rtol, msg=msg, check_device=False, check_dtype=True) @@ -119,9 +155,13 @@ def get_model_grads_dict(model): def get_expert_capacity(cfg): - expert_capacity = math.ceil(cfg.seq_len * cfg.batch_size * cfg.capacity_factor / cfg.num_experts) - expert_capacity = min(expert_capacity, cfg.seq_len * cfg.batch_size) - return expert_capacity + total_tokens = cfg.seq_len * cfg.batch_size + if cfg.capacity_factor is None: + return total_tokens + else: + expert_capacity = math.ceil(total_tokens * cfg.capacity_factor / cfg.num_experts) + expert_capacity = min(expert_capacity, total_tokens) + return expert_capacity def get_intermediate_size(cfg): @@ -134,6 +174,89 @@ def get_intermediate_size(cfg): return intermediate_size +def match_expert_weights(model_trn, model_cpu, glu_mlp): + """ + Copy expert weights from the CPU model to the TRN model. This is necessary + under expert parallelism because NxD weight initialization currently does not + take expert parallelism into account. + """ + + module = model_cpu.expert_mlps.mlp_op.gate_up_proj if glu_mlp else model_cpu.expert_mlps.mlp_op.up_proj + num_experts = module.weight.shape[0] + ep_degree = parallel_state.get_expert_model_parallel_size() + ep_rank = parallel_state.get_expert_model_parallel_rank() + tp_degree = parallel_state.get_tensor_model_parallel_size() + tp_rank = parallel_state.get_tensor_model_parallel_rank() + num_local_experts = num_experts // ep_degree + expert_dim = 0 + input_dim = 1 + output_dim = 2 + + with torch.no_grad(): + for (cpu_name, cpu_param), (trn_name, trn_param) in zip(model_cpu.named_parameters(), model_trn.named_parameters()): + if "gate_up_proj" in cpu_name: + _, input_size, output_size = cpu_param.shape + stride = 2 if glu_mlp else 1 + local_output_size = output_size // tp_degree // stride + single_output_size = output_size // stride + weight_slice = cpu_param.narrow(expert_dim, num_local_experts * ep_rank, num_local_experts) + gate_weight_slice = weight_slice.narrow(output_dim, local_output_size * tp_rank, local_output_size) + up_weight_slice = weight_slice.narrow(output_dim, local_output_size * tp_rank + single_output_size, local_output_size) + gate_up_weight_slice = torch.cat((gate_weight_slice, up_weight_slice), dim=output_dim) + trn_param.copy_(gate_up_weight_slice.contiguous()) + elif "up_proj" in cpu_name: + _, input_size, output_size = cpu_param.shape + stride = 1 + local_output_size = output_size // tp_degree // stride + weight_slice = cpu_param.narrow(expert_dim, num_local_experts * ep_rank, num_local_experts) + up_weight_slice = weight_slice.narrow(output_dim, local_output_size * tp_rank, local_output_size) + trn_param.copy_(up_weight_slice.contiguous()) + elif "down_proj" in cpu_name: + _, input_size, output_size = cpu_param.shape + local_input_size = input_size // tp_degree + weight_slice = cpu_param.narrow(expert_dim, num_local_experts * ep_rank, num_local_experts) + weight_slice = weight_slice.narrow(input_dim, local_input_size * tp_rank, local_input_size) + trn_param.copy_(weight_slice.contiguous()) + + xm.mark_step() + +def initialize_neuron_optimizer(model, override_grad_reduction=True, sequence_parallel=False, grad_clipping=False, zero1=False, lr=0.0, optimizer=None): + optimizer_config = {"zero_one_enabled": zero1, "grad_clipping": grad_clipping} + + if not zero1: + optimizer_config["max_grad_norm"] = 1.0 + + # MoE parameters are not in sequence parallel + nxd_config = {"optimizer_config": optimizer_config, "sequence_parallel": sequence_parallel} + + def dummy_fetch_grads(self, *args, **kwargs): + return [], [] + + base_optimizer_cls = Adam if optimizer == "adam" else SGD + + if zero1: + ep_enabled = parallel_state.get_expert_model_parallel_size() > 1 + zero1_optimizer_cls = NeuronEPZero1Optimizer if ep_enabled else NeuronZero1Optimizer + + optimizer = zero1_optimizer_cls( + [p for p in model.parameters()], + base_optimizer_cls, + grad_clipping=optimizer_config["grad_clipping"], + pin_layout=False, + grad_norm_groups=parallel_state.get_tensor_model_parallel_group(as_list=True), + max_norm=1.0, + lr=lr, + ) + + else: + optimizer = base_optimizer_cls(model.parameters(), lr=lr) + + nxd_opt = NxDOptimizer(optimizer, nxd_config) + if override_grad_reduction: + nxd_opt._fetch_gradients = functools.partial(dummy_fetch_grads, self=nxd_opt) + return nxd_opt + + def initialize_neuron_model(cfg, seed=0): """ Create a Neuron model, as specified in the config. @@ -142,14 +265,11 @@ def initialize_neuron_model(cfg, seed=0): # Set random seed for reproducibility torch.manual_seed(seed) - sequence_parallel_mode = getattr(cfg, "sequence_parallel_mode", MoESequenceParallelMode.NO_SP) - # Initialize router router_args = dict( num_experts=cfg.num_experts, top_k=cfg.top_k, hidden_size=cfg.hidden_size, - sequence_parallel_mode=sequence_parallel_mode, dtype=cfg.dtype, device=torch.device("cpu"), ) @@ -171,8 +291,6 @@ def initialize_neuron_model(cfg, seed=0): init_method=torch.nn.init.kaiming_uniform_, output_layer_init_method=torch.nn.init.kaiming_uniform_, glu_mlp=cfg.glu_mlp, - sequence_parallel_mode=sequence_parallel_mode, - permute_strategy=cfg.expert_mlps_permute_strategy, dtype=cfg.dtype, device=torch.device("cpu"), ) @@ -182,10 +300,12 @@ def initialize_neuron_model(cfg, seed=0): expert_mlps_args.update({"normalize_top_k_affinities": False}) # Initialize ExpertMLPs - expert_mlps = ExpertMLPsCapacityFactor(**expert_mlps_args) + expert_mlps = ExpertMLPs(**expert_mlps_args) # Workaround for when testing top_k=1 with topk if cfg.implementation == "topk": expert_mlps.normalize_top_k_affinities = True + # Enable selective loading for unit tests + expert_mlps.SELECTIVE_LOADING_THRESHOLD = 1.0 # Initialize model neuron_model = MoE( @@ -193,10 +313,10 @@ def initialize_neuron_model(cfg, seed=0): expert_mlps=expert_mlps, # Always return router logits in testing return_router_logits=True, - sequence_parallel_mode=sequence_parallel_mode, + sequence_parallel_enabled=cfg.sequence_parallel_enabled, ) # Move model to required device - neuron_model = neuron_model.to(device=cfg.device) + move_model_to_device(neuron_model, cfg.device) return neuron_model diff --git a/test/unit_test/parallel_layers/__init__.py b/test/unit_test/parallel_layers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/unit_test/parallel_layers/test_checkpointing.py b/test/unit_test/parallel_layers/test_checkpointing.py new file mode 100644 index 0000000..087e1d0 --- /dev/null +++ b/test/unit_test/parallel_layers/test_checkpointing.py @@ -0,0 +1,60 @@ +import unittest +from unittest.mock import patch + +from neuronx_distributed.parallel_layers.checkpointing import load + +MODULE = "neuronx_distributed.parallel_layers.checkpointing" + + +class CheckpointingTest(unittest.TestCase): + @patch(f"{MODULE}.get_tensor_model_parallel_size", return_value=1) + @patch(f"{MODULE}.get_pipeline_model_parallel_rank", return_value=0) + @patch(f"{MODULE}.get_tensor_model_parallel_rank", return_value=0) + @patch(f"{MODULE}.torch") + def test_load(self, mock_torch, mock_tp_rk, mock_pp_rk, mock_tp_sz): + # Act + res = load(path := "some_path") + # Assert + mock_torch.load.assert_called_once_with( + path + "/tp_rank_00_pp_rank_00/checkpoint.pt", map_location="cpu", weights_only=False + ) + assert res == mock_torch.load.return_value + + @patch(f"{MODULE}.get_tensor_model_parallel_size", return_value=1) + @patch(f"{MODULE}.get_pipeline_model_parallel_rank", return_value=0) + @patch(f"{MODULE}.get_tensor_model_parallel_rank", return_value=0) + @patch(f"{MODULE}.torch") + def test_load_xser(self, mock_torch, mock_tp_rk, mock_pp_rk, mock_tp_sz): + # Act + res = load(path := "some_path", load_xser=True) + # Assert + mock_torch.load.assert_called_once_with(path + "/tp_rank_00_pp_rank_00", weights_only=False) + assert res == mock_torch.load.return_value + + @patch(f"{MODULE}.get_tensor_model_parallel_size", return_value=1) + @patch(f"{MODULE}.get_pipeline_model_parallel_rank", return_value=0) + @patch(f"{MODULE}.get_tensor_model_parallel_rank", return_value=0) + @patch(f"{MODULE}.torch") + def test_load_weights_only(self, mock_torch, mock_tp_rk, mock_pp_rk, mock_tp_sz): + # Act + res = load(path := "some_path", weights_only=True) + # Assert + mock_torch.load.assert_called_once_with( + path + "/tp_rank_00_pp_rank_00/checkpoint.pt", map_location="cpu", weights_only=True + ) + assert res == mock_torch.load.return_value + + @patch(f"{MODULE}.get_tensor_model_parallel_size", return_value=1) + @patch(f"{MODULE}.get_pipeline_model_parallel_rank", return_value=0) + @patch(f"{MODULE}.get_tensor_model_parallel_rank", return_value=0) + @patch(f"{MODULE}.torch") + def test_load_xser_weights_only(self, mock_torch, mock_tp_rk, mock_pp_rk, mock_tp_sz): + # Act + res = load(path := "some_path", load_xser=True, weights_only=True) + # Assert + mock_torch.load.assert_called_once_with(path + "/tp_rank_00_pp_rank_00", weights_only=True) + assert res == mock_torch.load.return_value + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/pipeline/test_auto_partition.py b/test/unit_test/pipeline/test_auto_partition.py index c8573d5..6d285da 100644 --- a/test/unit_test/pipeline/test_auto_partition.py +++ b/test/unit_test/pipeline/test_auto_partition.py @@ -13,8 +13,6 @@ from neuronx_distributed.pipeline.model import NxDPPModel from neuronx_distributed.pipeline.partition import create_partitions -from .. import update_result - class NxDModule(torch.nn.Module): def __init__(self, num_layers): @@ -52,23 +50,19 @@ class TestAutoPartition(unittest.TestCase): @patch("neuronx_distributed.pipeline.model.parallel_state") @patch("torch.distributed.get_rank") def test_model_autopartition(self, rank_mock, state_mock): - try: - num_layers = 40 - model = get_model_nxd(num_layers) - transformer_layer_cls = torch.nn.Linear + num_layers = 40 + model = get_model_nxd(num_layers) + transformer_layer_cls = torch.nn.Linear - pipeline_parallel_size = 4 - model_layers = model.get_model_layers(model.original_torch_module, transformer_layer_cls) - partitions = create_partitions(pipeline_parallel_size, model_layers) + pipeline_parallel_size = 4 + model_layers = model.get_model_layers(model.original_torch_module, transformer_layer_cls) + partitions = create_partitions(pipeline_parallel_size, model_layers) - expected_model_layers = [f"layers.{x}" for x in range(num_layers)] - expected_partitions = [f"layers.{x}" for x in range(9, num_layers, 10)] - expected_partitions = expected_partitions[:-1] - assert model_layers == expected_model_layers - assert partitions == expected_partitions - except: - update_result({"inference_success": 0}) - raise + expected_model_layers = [f"layers.{x}" for x in range(num_layers)] + expected_partitions = [f"layers.{x}" for x in range(9, num_layers, 10)] + expected_partitions = expected_partitions[:-1] + assert model_layers == expected_model_layers + assert partitions == expected_partitions @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_size", MagicMock(return_value=1)) @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_rank", MagicMock(return_value=1)) @@ -86,21 +80,17 @@ def test_model_autopartition(self, rank_mock, state_mock): @patch("neuronx_distributed.pipeline.model.parallel_state") @patch("torch.distributed.get_rank") def test_model_autopartition_unevenly_divisible_layers(self, rank_mock, state_mock): - try: - num_layers = 19 - model = get_model_nxd(num_layers) - transformer_layer_cls = torch.nn.Linear - - pipeline_parallel_size = 4 - model_layers = model.get_model_layers(model.original_torch_module, transformer_layer_cls) - partitions = create_partitions(pipeline_parallel_size, model_layers) - expected_model_layers = [f"layers.{x}" for x in range(num_layers)] - expected_partitions = ["layers.3", "layers.8", "layers.13"] - assert model_layers == expected_model_layers - assert partitions == expected_partitions - except: - update_result({"inference_success": 0}) - raise + num_layers = 19 + model = get_model_nxd(num_layers) + transformer_layer_cls = torch.nn.Linear + + pipeline_parallel_size = 4 + model_layers = model.get_model_layers(model.original_torch_module, transformer_layer_cls) + partitions = create_partitions(pipeline_parallel_size, model_layers) + expected_model_layers = [f"layers.{x}" for x in range(num_layers)] + expected_partitions = ["layers.3", "layers.8", "layers.13"] + assert model_layers == expected_model_layers + assert partitions == expected_partitions if __name__ == "__main__": diff --git a/test/unit_test/pipeline/test_get_delayed_tracing.py b/test/unit_test/pipeline/test_get_delayed_tracing.py new file mode 100644 index 0000000..9d0ca4f --- /dev/null +++ b/test/unit_test/pipeline/test_get_delayed_tracing.py @@ -0,0 +1,112 @@ +import unittest +from unittest.mock import MagicMock + +from neuronx_distributed.utils.model_utils import ( + get_delay_tracing, + check_delay_tracing +) +from neuronx_distributed.pipeline.model import NxDPPModel + +class TestGetDelayTracing(unittest.TestCase): + @unittest.skip("disabled delayed tracing") + def test_nxdppmodel_with_delay_tracing(self): + + mock_model = MagicMock(spec=NxDPPModel) + mock_model._delay_tracing = True + + result = get_delay_tracing(mock_model) + self.assertTrue(result) + + @unittest.skip("disabled delayed tracing") + def test_nxdppmodel_without_delay_tracing(self): + + mock_model = MagicMock(spec=NxDPPModel) + mock_model._delay_tracing = False + + result = get_delay_tracing(mock_model) + self.assertEqual(result, False) + + @unittest.skip("disabled delayed tracing") + def test_dict_with_pipeline_config(self): + arg = { + "pipeline_config": { + "_delay_tracing": True + } + } + + result = get_delay_tracing(arg) + self.assertTrue(result) + + @unittest.skip("disabled delayed tracing") + def test_dict_without_pipeline_config(self): + arg = { + "pipeline_config": { + "other_config": True + } + } + + result = get_delay_tracing(arg) + self.assertEqual(result, None) + + @unittest.skip("disabled delayed tracing") + def test_dict_without_delay_tracing(self): + arg = { + "pipeline_config": {} + } + + result = get_delay_tracing(arg) + self.assertEqual(result, None) + + @unittest.skip("disabled delayed tracing") + def test_non_nxdppmodel_and_non_dict(self): + result = get_delay_tracing("some string") + self.assertEqual(result, None) + + +class TestCheckDelayTracing(unittest.TestCase): + @unittest.skip("disabled delayed tracing") + def test_pipeline_config_with_use_model_wrapper_and_no_input_names(self): + nxd_config = { + "pipeline_config": { + "use_model_wrapper": True + } + } + + result = check_delay_tracing(nxd_config) + self.assertTrue(result) + + @unittest.skip("disabled delayed tracing") + def test_pipeline_config_with_use_model_wrapper_and_input_names(self): + nxd_config = { + "pipeline_config": { + "use_model_wrapper": True, + "input_names": ["x"] + } + } + + result = check_delay_tracing(nxd_config) + self.assertFalse(result) + + @unittest.skip("disabled delayed tracing") + def test_pipeline_config_without_use_model_wrapper(self): + nxd_config = { + "pipeline_config": { + "use_model_wrapper": False + } + } + + result = check_delay_tracing(nxd_config) + self.assertFalse(result) + + @unittest.skip("disabled delayed tracing") + def test_pipeline_config_missing_use_model_wrapper(self): + nxd_config = { + "pipeline_config": {} + } + + result = check_delay_tracing(nxd_config) + self.assertFalse(result) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/pipeline/test_input_signature.py b/test/unit_test/pipeline/test_input_signature.py new file mode 100644 index 0000000..97996e9 --- /dev/null +++ b/test/unit_test/pipeline/test_input_signature.py @@ -0,0 +1,59 @@ +import unittest +from typing import List, Optional +from unittest.mock import MagicMock, patch + +import torch + +from neuronx_distributed.pipeline.trace import get_concrete_args +from neuronx_distributed.pipeline.model import NxDPPModel + + +class NxDModule(torch.nn.Module): + def __init__(self, num_layers): + super().__init__() + self.layers = torch.nn.ModuleList([torch.nn.Linear(2, 2) for _ in range(num_layers)]) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ): + return None + + +def get_model_nxd(num_layers): + model = NxDPPModel(module=NxDModule(num_layers), transformer_layer_cls=torch.nn.Linear, tracer_cls="torch") + return model + + +def run_concrete_args(*args, **kwargs): + model = get_model_nxd(4) + return get_concrete_args(model.original_torch_module, None, args, kwargs) + + +class TestSignatureAnalysis(unittest.TestCase): + @patch("neuronx_distributed.pipeline.model.parallel_state") + @patch("torch.distributed.get_rank") + def test_signature_analyse(self, rank_mock, state_mock): + concrete_args = run_concrete_args(1, 2, "a", use_cache=True, output_hidden_states=True, return_dict=True) + expected_concrete_args = [ + "past_key_values", + "inputs_embeds", + "labels", + "output_attentions", + "cache_position", + ] + assert list(concrete_args.keys()) == expected_concrete_args + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/pipeline/test_partition.py b/test/unit_test/pipeline/test_partition.py index 8b1bdf3..932e3ee 100644 --- a/test/unit_test/pipeline/test_partition.py +++ b/test/unit_test/pipeline/test_partition.py @@ -17,7 +17,6 @@ from neuronx_distributed.parallel_layers.loss_functions import parallel_cross_entropy from neuronx_distributed.pipeline.model import NxDPPModel -from .. import update_result from .test_base import get_traced_model_gpt @@ -63,29 +62,25 @@ class TestPartition(unittest.TestCase): @patch("neuronx_distributed.pipeline.model.parallel_state") @patch("torch.distributed.get_rank") def test_partition_traced_model(self, rank_mock, state_mock): - try: - traced_model = get_traced_model_nxd() - split_mod = partition.partition_traced_model(traced_model) - for name, module in split_mod.named_children(): - if name == "submod_0": - for n, child_module in module.named_children(): - if n == "rpl": - assert isinstance(child_module, RowParallelLinear) - elif n == "cpl": - assert isinstance(child_module, ColumnParallelLinear) - else: - assert False, "Unexpected node in submoule 0" - elif name == "submod_1": - for n, child_module in module.named_children(): - if n == "linear4": - assert isinstance(child_module, torch.nn.Linear) - else: - assert False, "Unexpected node in submoule 1" - else: - assert False, "Unexpected number of submodule" - except: - update_result({"inference_success": 0}) - raise + traced_model = get_traced_model_nxd() + split_mod = partition.partition_traced_model(traced_model) + for name, module in split_mod.named_children(): + if name == "submod_0": + for n, child_module in module.named_children(): + if n == "rpl": + assert isinstance(child_module, RowParallelLinear) + elif n == "cpl": + assert isinstance(child_module, ColumnParallelLinear) + else: + assert False, "Unexpected node in submoule 0" + elif name == "submod_1": + for n, child_module in module.named_children(): + if n == "linear4": + assert isinstance(child_module, torch.nn.Linear) + else: + assert False, "Unexpected node in submoule 1" + else: + assert False, "Unexpected number of submodule" @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_size", MagicMock(return_value=1)) @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_rank", MagicMock(return_value=1)) @@ -103,27 +98,23 @@ def test_partition_traced_model(self, rank_mock, state_mock): @patch("neuronx_distributed.pipeline.model.parallel_state") @patch("torch.distributed.get_rank") def test_partition_traced_model_gpt2(self, rank_mock, state_mock): - try: - traced_model = get_traced_model_gpt() - split_mod = partition.partition_traced_model(traced_model) - partition_count = 0 - for name, module in split_mod.named_children(): - partition_count += 1 - if name != "submod_7": - for n, child_module in module.named_children(): - if "transformer_h_" in n: - assert isinstance(child_module, GPT2Block) - num_params = sum([np.prod(p.size()) for p in module.parameters()]) - if partition_count == 1: - assert num_params == 53166336 - elif partition_count == 8: - assert num_params == 38598912 - else: - assert num_params == 7087872 - assert partition_count == 8 - except: - update_result({"inference_success": 0}) - raise + traced_model = get_traced_model_gpt() + split_mod = partition.partition_traced_model(traced_model) + partition_count = 0 + for name, module in split_mod.named_children(): + partition_count += 1 + if name != "submod_7": + for n, child_module in module.named_children(): + if "transformer_h_" in n: + assert isinstance(child_module, GPT2Block) + num_params = sum([np.prod(p.size()) for p in module.parameters()]) + if partition_count == 1: + assert num_params == 53166336 + elif partition_count == 8: + assert num_params == 38598912 + else: + assert num_params == 7087872 + assert partition_count == 8 @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_size", MagicMock(return_value=1)) @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_rank", MagicMock(return_value=1)) @@ -141,70 +132,66 @@ def test_partition_traced_model_gpt2(self, rank_mock, state_mock): @patch("neuronx_distributed.pipeline.model.parallel_state") @patch("torch.distributed.get_rank") def test_analyze_pipeline_module(self, rank_mock, state_mock): - try: - traced_model = get_traced_model_gpt() - split_mod = partition.partition_traced_model(traced_model) - ( - stage_id_to_IO_input_names, - stage_id_to_model_input_names, - stage_id_to_input_count, - stage_id_to_output_count, - ) = partition.analyze_pipeline_module(split_mod) - expected_io_names = {0: OrderedDict()} - # Note: Renamed "add_2" to "add_3" when upgrading transformers from 4.28.1 to 4.36.2 - for i in range(7): - if i == 0: - io_dict = OrderedDict( - [ - ( - "transformer_h_" + str(i + 1), - partition.PipelineIO("transformer_h_" + str(i + 1), input_idx=0, output_idx=0), - ), - ("mul", partition.PipelineIO("mul", input_idx=1, output_idx=1)), - ("add_3", partition.PipelineIO("add_3", output_idx=2)), - ] - ) - elif i < 6: - io_dict = OrderedDict( - [ - ( - "transformer_h_" + str(i + 1), - partition.PipelineIO("transformer_h_" + str(i + 1), input_idx=0, output_idx=0), - ), - ("mul", partition.PipelineIO("mul", input_idx=1)), - ("add_3", partition.PipelineIO("add_3")), - ] - ) - else: - io_dict = OrderedDict( - [ - ( - "transformer_h_" + str(i + 1), - partition.PipelineIO("transformer_h_" + str(i + 1), input_idx=0, output_idx=0), - ), - ("add_3", partition.PipelineIO("add_3", input_idx=1)), - ] - ) - expected_io_names.update({i + 1: io_dict}) - expected_stage_id_to_model_input_names = { - 0: {"input_ids": 0, "attention_mask": 1}, - 1: {}, - 2: {}, - 3: {}, - 4: {}, - 5: {}, - 6: {}, - 7: {"labels": 2}, - } - expected_stage_id_to_input_count = {0: 2, 1: 2, 2: 2, 3: 2, 4: 2, 5: 2, 6: 2, 7: 3} - expected_stage_id_to_output_count = {0: 3, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 2} - assert str(stage_id_to_IO_input_names) == str(expected_io_names) - assert stage_id_to_model_input_names == expected_stage_id_to_model_input_names - assert stage_id_to_input_count == expected_stage_id_to_input_count - assert stage_id_to_output_count == expected_stage_id_to_output_count - except: - update_result({"inference_success": 0}) - raise + traced_model = get_traced_model_gpt() + split_mod = partition.partition_traced_model(traced_model) + ( + stage_id_to_IO_input_names, + stage_id_to_model_input_names, + stage_id_to_input_count, + stage_id_to_output_count, + ) = partition.analyze_pipeline_module(split_mod) + expected_io_names = {0: OrderedDict()} + # Note: Renamed "add_2" to "add_3" when upgrading transformers from 4.28.1 to 4.36.2 + for i in range(7): + if i == 0: + io_dict = OrderedDict( + [ + ( + "transformer_h_" + str(i + 1), + partition.PipelineIO("transformer_h_" + str(i + 1), input_idx=0, output_idx=0), + ), + ("mul", partition.PipelineIO("mul", input_idx=1, output_idx=1)), + ("add_3", partition.PipelineIO("add_3", output_idx=2)), + ] + ) + elif i < 6: + io_dict = OrderedDict( + [ + ( + "transformer_h_" + str(i + 1), + partition.PipelineIO("transformer_h_" + str(i + 1), input_idx=0, output_idx=0), + ), + ("mul", partition.PipelineIO("mul", input_idx=1)), + ("add_3", partition.PipelineIO("add_3")), + ] + ) + else: + io_dict = OrderedDict( + [ + ( + "transformer_h_" + str(i + 1), + partition.PipelineIO("transformer_h_" + str(i + 1), input_idx=0, output_idx=0), + ), + ("add_3", partition.PipelineIO("add_3", input_idx=1)), + ] + ) + expected_io_names.update({i + 1: io_dict}) + expected_stage_id_to_model_input_names = { + 0: {"input_ids": 0, "attention_mask": 1}, + 1: {}, + 2: {}, + 3: {}, + 4: {}, + 5: {}, + 6: {}, + 7: {"labels": 2}, + } + expected_stage_id_to_input_count = {0: 2, 1: 2, 2: 2, 3: 2, 4: 2, 5: 2, 6: 2, 7: 3} + expected_stage_id_to_output_count = {0: 3, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 2} + assert str(stage_id_to_IO_input_names) == str(expected_io_names) + assert stage_id_to_model_input_names == expected_stage_id_to_model_input_names + assert stage_id_to_input_count == expected_stage_id_to_input_count + assert stage_id_to_output_count == expected_stage_id_to_output_count if __name__ == "__main__": diff --git a/test/unit_test/pipeline/test_post_partition_hooks.py b/test/unit_test/pipeline/test_post_partition_hooks.py new file mode 100644 index 0000000..6b725e0 --- /dev/null +++ b/test/unit_test/pipeline/test_post_partition_hooks.py @@ -0,0 +1,34 @@ +import unittest + +from neuronx_distributed.trainer import PostPartitionHooks + + +def function(a, b): + a = a * b + return a + + +class TestPostPartitionHooks(unittest.TestCase): + def test_execute_post_partition_hook(self): + hooks = PostPartitionHooks() + inputs = [5,6] + hooks.register_post_partition_hook(function, inputs) + output = hooks.execute_all_hooks() + expected_result = 30 + assert output[0] == expected_result + + def test_register_post_partition_hooks(self): + hooks = PostPartitionHooks() + inputs = [5,6] + hooks.register_post_partition_hook(function, inputs) + hooks.register_post_partition_hook(function, inputs) + + expected_len_hooks = 2 + assert len(hooks.hooks) == expected_len_hooks + # Execute hooks and validate that hooks get cleared + hooks.execute_all_hooks() + assert len(hooks.hooks) == 0 + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/pipeline/test_scheduler.py b/test/unit_test/pipeline/test_scheduler.py index bc363cd..04fdcd9 100644 --- a/test/unit_test/pipeline/test_scheduler.py +++ b/test/unit_test/pipeline/test_scheduler.py @@ -16,8 +16,6 @@ TrainSchedule, ) -from .. import update_result - class TestScheduler(unittest.TestCase): @patch("torch.distributed.get_rank") @@ -25,272 +23,232 @@ def test_train_schedule_1F1B(self, rank_mock): """ Test the new Train1F1BSchedule has the same schedule as the old TrainSchedule """ - try: - for pp_size in [2, 4, 8, 16]: - for pp_rank in range(pp_size): - for mb in [1, 4, 8, 32]: - schedule = TrainSchedule(mb, pp_size, pp_rank) - all_steps = list(schedule.steps()) - all_steps_flat = [] - for item in all_steps: - all_steps_flat.extend(item) - new_schedule = Train1F1BSchedule(mb, pp_size, pp_rank) - all_steps_new = list(new_schedule.steps()) - all_steps_new_flat = [] - for item in all_steps_new: - all_steps_new_flat.extend(item) - - def _is_same(lst1, lst2): - if len(lst1) != len(lst2): + for pp_size in [2, 4, 8, 16]: + for pp_rank in range(pp_size): + for mb in [1, 4, 8, 32]: + schedule = TrainSchedule(mb, pp_size, pp_rank) + all_steps = list(schedule.steps()) + all_steps_flat = [] + for item in all_steps: + all_steps_flat.extend(item) + new_schedule = Train1F1BSchedule(mb, pp_size, pp_rank) + all_steps_new = list(new_schedule.steps()) + all_steps_new_flat = [] + for item in all_steps_new: + all_steps_new_flat.extend(item) + + def _is_same(lst1, lst2): + if len(lst1) != len(lst2): + return False + for i in range(len(lst1)): + if lst1[i] != lst2[i]: return False - for i in range(len(lst1)): - if lst1[i] != lst2[i]: - return False - return True + return True - assert _is_same(all_steps_new_flat, all_steps_flat) - except: - update_result({"inference_success": 0}) - raise + assert _is_same(all_steps_new_flat, all_steps_flat) @patch("torch.distributed.get_rank") def test_train_schedule_mb2_stage0(self, rank_mock): - try: - train_scheduler = TrainSchedule(num_microbatches=2, stages=4, stage_id=0) - tasks = [task for task in train_scheduler.steps()] - expected_tasks = [ - [ForwardPreprocessTask(mb=0), ForwardStepTask(mb=0)], - [ForwardPostprocessTask(mb=0)], - [ForwardPreprocessTask(mb=1), ForwardStepTask(mb=1)], - [ForwardPostprocessTask(mb=1)], - [], - [], - [], - [BackwardPreprocessTask(mb=0), BackwardStepTask(mb=0)], - [], - [BackwardPreprocessTask(mb=1), BackwardStepTask(mb=1), ReduceGradsTask()], - ] - assert str(tasks) == str(expected_tasks) - except: - update_result({"inference_success": 0}) - raise + train_scheduler = TrainSchedule(num_microbatches=2, stages=4, stage_id=0) + tasks = [task for task in train_scheduler.steps()] + expected_tasks = [ + [ForwardPreprocessTask(mb=0), ForwardStepTask(mb=0)], + [ForwardPostprocessTask(mb=0)], + [ForwardPreprocessTask(mb=1), ForwardStepTask(mb=1)], + [ForwardPostprocessTask(mb=1)], + [], + [], + [], + [BackwardPreprocessTask(mb=0), BackwardStepTask(mb=0)], + [], + [BackwardPreprocessTask(mb=1), BackwardStepTask(mb=1), ReduceGradsTask()], + ] + assert str(tasks) == str(expected_tasks) @patch("torch.distributed.get_rank") def test_train_schedule_mb2_stage1(self, rank_mock): - try: - train_scheduler = TrainSchedule(num_microbatches=2, stages=4, stage_id=1) - tasks = [task for task in train_scheduler.steps()] - expected_tasks = [ - [], - [ForwardPreprocessTask(mb=0), ForwardStepTask(mb=0)], - [ForwardPostprocessTask(mb=0)], - [ForwardPreprocessTask(mb=1), ForwardStepTask(mb=1)], - [ForwardPostprocessTask(mb=1)], - [], - [BackwardPreprocessTask(mb=0), BackwardStepTask(mb=0)], - [BackwardPostprocessTask(mb=0)], - [BackwardPreprocessTask(mb=1), BackwardStepTask(mb=1)], - [BackwardPostprocessTask(mb=1), ReduceGradsTask()], - ] - assert str(tasks) == str(expected_tasks) - except: - update_result({"inference_success": 0}) - raise + train_scheduler = TrainSchedule(num_microbatches=2, stages=4, stage_id=1) + tasks = [task for task in train_scheduler.steps()] + expected_tasks = [ + [], + [ForwardPreprocessTask(mb=0), ForwardStepTask(mb=0)], + [ForwardPostprocessTask(mb=0)], + [ForwardPreprocessTask(mb=1), ForwardStepTask(mb=1)], + [ForwardPostprocessTask(mb=1)], + [], + [BackwardPreprocessTask(mb=0), BackwardStepTask(mb=0)], + [BackwardPostprocessTask(mb=0)], + [BackwardPreprocessTask(mb=1), BackwardStepTask(mb=1)], + [BackwardPostprocessTask(mb=1), ReduceGradsTask()], + ] + assert str(tasks) == str(expected_tasks) @patch("torch.distributed.get_rank") def test_train_schedule_mb2_stage2(self, rank_mock): - try: - train_scheduler = TrainSchedule(num_microbatches=2, stages=4, stage_id=2) - tasks = [task for task in train_scheduler.steps()] - expected_tasks = [ - [], - [], - [ForwardPreprocessTask(mb=0), ForwardStepTask(mb=0)], - [ForwardPostprocessTask(mb=0)], - [ForwardPreprocessTask(mb=1), ForwardStepTask(mb=1)], - [BackwardPreprocessTask(mb=0), ForwardPostprocessTask(mb=1), BackwardStepTask(mb=0)], - [BackwardPostprocessTask(mb=0)], - [BackwardPreprocessTask(mb=1), BackwardStepTask(mb=1)], - [BackwardPostprocessTask(mb=1)], - [ReduceGradsTask()], - ] - assert str(tasks) == str(expected_tasks) - except: - update_result({"inference_success": 0}) - raise + train_scheduler = TrainSchedule(num_microbatches=2, stages=4, stage_id=2) + tasks = [task for task in train_scheduler.steps()] + expected_tasks = [ + [], + [], + [ForwardPreprocessTask(mb=0), ForwardStepTask(mb=0)], + [ForwardPostprocessTask(mb=0)], + [ForwardPreprocessTask(mb=1), ForwardStepTask(mb=1)], + [BackwardPreprocessTask(mb=0), ForwardPostprocessTask(mb=1), BackwardStepTask(mb=0)], + [BackwardPostprocessTask(mb=0)], + [BackwardPreprocessTask(mb=1), BackwardStepTask(mb=1)], + [BackwardPostprocessTask(mb=1)], + [ReduceGradsTask()], + ] + assert str(tasks) == str(expected_tasks) @patch("torch.distributed.get_rank") def test_train_schedule_mb2_stage3(self, rank_mock): - try: - train_scheduler = TrainSchedule(num_microbatches=2, stages=4, stage_id=3) - tasks = [task for task in train_scheduler.steps()] - expected_tasks = [ - [], - [], - [], - [ForwardPreprocessTask(mb=0), ForwardStepTask(mb=0)], - [BackwardStepTask(mb=0)], - [BackwardPostprocessTask(mb=0), ForwardPreprocessTask(mb=1), ForwardStepTask(mb=1)], - [BackwardStepTask(mb=1)], - [BackwardPostprocessTask(mb=1)], - [], - [ReduceGradsTask()], - ] - assert str(tasks) == str(expected_tasks) - except: - update_result({"inference_success": 0}) - raise + train_scheduler = TrainSchedule(num_microbatches=2, stages=4, stage_id=3) + tasks = [task for task in train_scheduler.steps()] + expected_tasks = [ + [], + [], + [], + [ForwardPreprocessTask(mb=0), ForwardStepTask(mb=0)], + [BackwardStepTask(mb=0)], + [BackwardPostprocessTask(mb=0), ForwardPreprocessTask(mb=1), ForwardStepTask(mb=1)], + [BackwardStepTask(mb=1)], + [BackwardPostprocessTask(mb=1)], + [], + [ReduceGradsTask()], + ] + assert str(tasks) == str(expected_tasks) @patch("torch.distributed.get_rank") def test_train_schedule_mb8_stage0(self, rank_mock): - try: - train_scheduler = TrainSchedule(num_microbatches=8, stages=4, stage_id=0) - tasks = [task for task in train_scheduler.steps()] - expected_tasks = [ - [ForwardPreprocessTask(mb=0), ForwardStepTask(mb=0)], - [ForwardPostprocessTask(mb=0)], - [ForwardPreprocessTask(mb=1), ForwardStepTask(mb=1)], - [ForwardPostprocessTask(mb=1)], - [ForwardPreprocessTask(mb=2), ForwardStepTask(mb=2)], - [ForwardPostprocessTask(mb=2)], - [ForwardPreprocessTask(mb=3), ForwardStepTask(mb=3)], - [BackwardPreprocessTask(mb=0), ForwardPostprocessTask(mb=3), BackwardStepTask(mb=0)], - [ForwardPreprocessTask(mb=4), ForwardStepTask(mb=4)], - [BackwardPreprocessTask(mb=1), ForwardPostprocessTask(mb=4), BackwardStepTask(mb=1)], - [ForwardPreprocessTask(mb=5), ForwardStepTask(mb=5)], - [BackwardPreprocessTask(mb=2), ForwardPostprocessTask(mb=5), BackwardStepTask(mb=2)], - [ForwardPreprocessTask(mb=6), ForwardStepTask(mb=6)], - [BackwardPreprocessTask(mb=3), ForwardPostprocessTask(mb=6), BackwardStepTask(mb=3)], - [ForwardPreprocessTask(mb=7), ForwardStepTask(mb=7)], - [BackwardPreprocessTask(mb=4), ForwardPostprocessTask(mb=7), BackwardStepTask(mb=4)], - [], - [BackwardPreprocessTask(mb=5), BackwardStepTask(mb=5)], - [], - [BackwardPreprocessTask(mb=6), BackwardStepTask(mb=6)], - [], - [BackwardPreprocessTask(mb=7), BackwardStepTask(mb=7), ReduceGradsTask()], - ] - - assert str(tasks) == str(expected_tasks) - except: - update_result({"inference_success": 0}) - raise + train_scheduler = TrainSchedule(num_microbatches=8, stages=4, stage_id=0) + tasks = [task for task in train_scheduler.steps()] + expected_tasks = [ + [ForwardPreprocessTask(mb=0), ForwardStepTask(mb=0)], + [ForwardPostprocessTask(mb=0)], + [ForwardPreprocessTask(mb=1), ForwardStepTask(mb=1)], + [ForwardPostprocessTask(mb=1)], + [ForwardPreprocessTask(mb=2), ForwardStepTask(mb=2)], + [ForwardPostprocessTask(mb=2)], + [ForwardPreprocessTask(mb=3), ForwardStepTask(mb=3)], + [BackwardPreprocessTask(mb=0), ForwardPostprocessTask(mb=3), BackwardStepTask(mb=0)], + [ForwardPreprocessTask(mb=4), ForwardStepTask(mb=4)], + [BackwardPreprocessTask(mb=1), ForwardPostprocessTask(mb=4), BackwardStepTask(mb=1)], + [ForwardPreprocessTask(mb=5), ForwardStepTask(mb=5)], + [BackwardPreprocessTask(mb=2), ForwardPostprocessTask(mb=5), BackwardStepTask(mb=2)], + [ForwardPreprocessTask(mb=6), ForwardStepTask(mb=6)], + [BackwardPreprocessTask(mb=3), ForwardPostprocessTask(mb=6), BackwardStepTask(mb=3)], + [ForwardPreprocessTask(mb=7), ForwardStepTask(mb=7)], + [BackwardPreprocessTask(mb=4), ForwardPostprocessTask(mb=7), BackwardStepTask(mb=4)], + [], + [BackwardPreprocessTask(mb=5), BackwardStepTask(mb=5)], + [], + [BackwardPreprocessTask(mb=6), BackwardStepTask(mb=6)], + [], + [BackwardPreprocessTask(mb=7), BackwardStepTask(mb=7), ReduceGradsTask()], + ] + + assert str(tasks) == str(expected_tasks) @patch("torch.distributed.get_rank") def test_train_schedule_mb8_stage1(self, rank_mock): - try: - train_scheduler = TrainSchedule(num_microbatches=8, stages=4, stage_id=1) - tasks = [task for task in train_scheduler.steps()] - expected_tasks = [ - [], - [ForwardPreprocessTask(mb=0), ForwardStepTask(mb=0)], - [ForwardPostprocessTask(mb=0)], - [ForwardPreprocessTask(mb=1), ForwardStepTask(mb=1)], - [ForwardPostprocessTask(mb=1)], - [ForwardPreprocessTask(mb=2), ForwardStepTask(mb=2)], - [BackwardPreprocessTask(mb=0), ForwardPostprocessTask(mb=2), BackwardStepTask(mb=0)], - [BackwardPostprocessTask(mb=0), ForwardPreprocessTask(mb=3), ForwardStepTask(mb=3)], - [BackwardPreprocessTask(mb=1), ForwardPostprocessTask(mb=3), BackwardStepTask(mb=1)], - [BackwardPostprocessTask(mb=1), ForwardPreprocessTask(mb=4), ForwardStepTask(mb=4)], - [BackwardPreprocessTask(mb=2), ForwardPostprocessTask(mb=4), BackwardStepTask(mb=2)], - [BackwardPostprocessTask(mb=2), ForwardPreprocessTask(mb=5), ForwardStepTask(mb=5)], - [BackwardPreprocessTask(mb=3), ForwardPostprocessTask(mb=5), BackwardStepTask(mb=3)], - [BackwardPostprocessTask(mb=3), ForwardPreprocessTask(mb=6), ForwardStepTask(mb=6)], - [BackwardPreprocessTask(mb=4), ForwardPostprocessTask(mb=6), BackwardStepTask(mb=4)], - [BackwardPostprocessTask(mb=4), ForwardPreprocessTask(mb=7), ForwardStepTask(mb=7)], - [BackwardPreprocessTask(mb=5), ForwardPostprocessTask(mb=7), BackwardStepTask(mb=5)], - [BackwardPostprocessTask(mb=5)], - [BackwardPreprocessTask(mb=6), BackwardStepTask(mb=6)], - [BackwardPostprocessTask(mb=6)], - [BackwardPreprocessTask(mb=7), BackwardStepTask(mb=7)], - [BackwardPostprocessTask(mb=7), ReduceGradsTask()], - ] - assert str(tasks) == str(expected_tasks) - except: - update_result({"inference_success": 0}) - raise + train_scheduler = TrainSchedule(num_microbatches=8, stages=4, stage_id=1) + tasks = [task for task in train_scheduler.steps()] + expected_tasks = [ + [], + [ForwardPreprocessTask(mb=0), ForwardStepTask(mb=0)], + [ForwardPostprocessTask(mb=0)], + [ForwardPreprocessTask(mb=1), ForwardStepTask(mb=1)], + [ForwardPostprocessTask(mb=1)], + [ForwardPreprocessTask(mb=2), ForwardStepTask(mb=2)], + [BackwardPreprocessTask(mb=0), ForwardPostprocessTask(mb=2), BackwardStepTask(mb=0)], + [BackwardPostprocessTask(mb=0), ForwardPreprocessTask(mb=3), ForwardStepTask(mb=3)], + [BackwardPreprocessTask(mb=1), ForwardPostprocessTask(mb=3), BackwardStepTask(mb=1)], + [BackwardPostprocessTask(mb=1), ForwardPreprocessTask(mb=4), ForwardStepTask(mb=4)], + [BackwardPreprocessTask(mb=2), ForwardPostprocessTask(mb=4), BackwardStepTask(mb=2)], + [BackwardPostprocessTask(mb=2), ForwardPreprocessTask(mb=5), ForwardStepTask(mb=5)], + [BackwardPreprocessTask(mb=3), ForwardPostprocessTask(mb=5), BackwardStepTask(mb=3)], + [BackwardPostprocessTask(mb=3), ForwardPreprocessTask(mb=6), ForwardStepTask(mb=6)], + [BackwardPreprocessTask(mb=4), ForwardPostprocessTask(mb=6), BackwardStepTask(mb=4)], + [BackwardPostprocessTask(mb=4), ForwardPreprocessTask(mb=7), ForwardStepTask(mb=7)], + [BackwardPreprocessTask(mb=5), ForwardPostprocessTask(mb=7), BackwardStepTask(mb=5)], + [BackwardPostprocessTask(mb=5)], + [BackwardPreprocessTask(mb=6), BackwardStepTask(mb=6)], + [BackwardPostprocessTask(mb=6)], + [BackwardPreprocessTask(mb=7), BackwardStepTask(mb=7)], + [BackwardPostprocessTask(mb=7), ReduceGradsTask()], + ] + assert str(tasks) == str(expected_tasks) @patch("torch.distributed.get_rank") def test_train_schedule_mb8_stage2(self, rank_mock): - try: - train_scheduler = TrainSchedule(num_microbatches=8, stages=4, stage_id=2) - tasks = [task for task in train_scheduler.steps()] - expected_tasks = [ - [], - [], - [ForwardPreprocessTask(mb=0), ForwardStepTask(mb=0)], - [ForwardPostprocessTask(mb=0)], - [ForwardPreprocessTask(mb=1), ForwardStepTask(mb=1)], - [BackwardPreprocessTask(mb=0), ForwardPostprocessTask(mb=1), BackwardStepTask(mb=0)], - [BackwardPostprocessTask(mb=0), ForwardPreprocessTask(mb=2), ForwardStepTask(mb=2)], - [BackwardPreprocessTask(mb=1), ForwardPostprocessTask(mb=2), BackwardStepTask(mb=1)], - [BackwardPostprocessTask(mb=1), ForwardPreprocessTask(mb=3), ForwardStepTask(mb=3)], - [BackwardPreprocessTask(mb=2), ForwardPostprocessTask(mb=3), BackwardStepTask(mb=2)], - [BackwardPostprocessTask(mb=2), ForwardPreprocessTask(mb=4), ForwardStepTask(mb=4)], - [BackwardPreprocessTask(mb=3), ForwardPostprocessTask(mb=4), BackwardStepTask(mb=3)], - [BackwardPostprocessTask(mb=3), ForwardPreprocessTask(mb=5), ForwardStepTask(mb=5)], - [BackwardPreprocessTask(mb=4), ForwardPostprocessTask(mb=5), BackwardStepTask(mb=4)], - [BackwardPostprocessTask(mb=4), ForwardPreprocessTask(mb=6), ForwardStepTask(mb=6)], - [BackwardPreprocessTask(mb=5), ForwardPostprocessTask(mb=6), BackwardStepTask(mb=5)], - [BackwardPostprocessTask(mb=5), ForwardPreprocessTask(mb=7), ForwardStepTask(mb=7)], - [BackwardPreprocessTask(mb=6), ForwardPostprocessTask(mb=7), BackwardStepTask(mb=6)], - [BackwardPostprocessTask(mb=6)], - [BackwardPreprocessTask(mb=7), BackwardStepTask(mb=7)], - [BackwardPostprocessTask(mb=7)], - [ReduceGradsTask()], - ] - assert str(tasks) == str(expected_tasks) - except: - update_result({"inference_success": 0}) - raise + train_scheduler = TrainSchedule(num_microbatches=8, stages=4, stage_id=2) + tasks = [task for task in train_scheduler.steps()] + expected_tasks = [ + [], + [], + [ForwardPreprocessTask(mb=0), ForwardStepTask(mb=0)], + [ForwardPostprocessTask(mb=0)], + [ForwardPreprocessTask(mb=1), ForwardStepTask(mb=1)], + [BackwardPreprocessTask(mb=0), ForwardPostprocessTask(mb=1), BackwardStepTask(mb=0)], + [BackwardPostprocessTask(mb=0), ForwardPreprocessTask(mb=2), ForwardStepTask(mb=2)], + [BackwardPreprocessTask(mb=1), ForwardPostprocessTask(mb=2), BackwardStepTask(mb=1)], + [BackwardPostprocessTask(mb=1), ForwardPreprocessTask(mb=3), ForwardStepTask(mb=3)], + [BackwardPreprocessTask(mb=2), ForwardPostprocessTask(mb=3), BackwardStepTask(mb=2)], + [BackwardPostprocessTask(mb=2), ForwardPreprocessTask(mb=4), ForwardStepTask(mb=4)], + [BackwardPreprocessTask(mb=3), ForwardPostprocessTask(mb=4), BackwardStepTask(mb=3)], + [BackwardPostprocessTask(mb=3), ForwardPreprocessTask(mb=5), ForwardStepTask(mb=5)], + [BackwardPreprocessTask(mb=4), ForwardPostprocessTask(mb=5), BackwardStepTask(mb=4)], + [BackwardPostprocessTask(mb=4), ForwardPreprocessTask(mb=6), ForwardStepTask(mb=6)], + [BackwardPreprocessTask(mb=5), ForwardPostprocessTask(mb=6), BackwardStepTask(mb=5)], + [BackwardPostprocessTask(mb=5), ForwardPreprocessTask(mb=7), ForwardStepTask(mb=7)], + [BackwardPreprocessTask(mb=6), ForwardPostprocessTask(mb=7), BackwardStepTask(mb=6)], + [BackwardPostprocessTask(mb=6)], + [BackwardPreprocessTask(mb=7), BackwardStepTask(mb=7)], + [BackwardPostprocessTask(mb=7)], + [ReduceGradsTask()], + ] + assert str(tasks) == str(expected_tasks) @patch("torch.distributed.get_rank") def test_train_schedule_mb8_stage3(self, rank_mock): - try: - train_scheduler = TrainSchedule(num_microbatches=8, stages=4, stage_id=3) - tasks = [task for task in train_scheduler.steps()] - expected_tasks = [ - [], - [], - [], - [ForwardPreprocessTask(mb=0), ForwardStepTask(mb=0)], - [BackwardStepTask(mb=0)], - [BackwardPostprocessTask(mb=0), ForwardPreprocessTask(mb=1), ForwardStepTask(mb=1)], - [BackwardStepTask(mb=1)], - [BackwardPostprocessTask(mb=1), ForwardPreprocessTask(mb=2), ForwardStepTask(mb=2)], - [BackwardStepTask(mb=2)], - [BackwardPostprocessTask(mb=2), ForwardPreprocessTask(mb=3), ForwardStepTask(mb=3)], - [BackwardStepTask(mb=3)], - [BackwardPostprocessTask(mb=3), ForwardPreprocessTask(mb=4), ForwardStepTask(mb=4)], - [BackwardStepTask(mb=4)], - [BackwardPostprocessTask(mb=4), ForwardPreprocessTask(mb=5), ForwardStepTask(mb=5)], - [BackwardStepTask(mb=5)], - [BackwardPostprocessTask(mb=5), ForwardPreprocessTask(mb=6), ForwardStepTask(mb=6)], - [BackwardStepTask(mb=6)], - [BackwardPostprocessTask(mb=6), ForwardPreprocessTask(mb=7), ForwardStepTask(mb=7)], - [BackwardStepTask(mb=7)], - [BackwardPostprocessTask(mb=7)], - [], - [ReduceGradsTask()], - ] - assert str(tasks) == str(expected_tasks) - except: - update_result({"inference_success": 0}) - raise + train_scheduler = TrainSchedule(num_microbatches=8, stages=4, stage_id=3) + tasks = [task for task in train_scheduler.steps()] + expected_tasks = [ + [], + [], + [], + [ForwardPreprocessTask(mb=0), ForwardStepTask(mb=0)], + [BackwardStepTask(mb=0)], + [BackwardPostprocessTask(mb=0), ForwardPreprocessTask(mb=1), ForwardStepTask(mb=1)], + [BackwardStepTask(mb=1)], + [BackwardPostprocessTask(mb=1), ForwardPreprocessTask(mb=2), ForwardStepTask(mb=2)], + [BackwardStepTask(mb=2)], + [BackwardPostprocessTask(mb=2), ForwardPreprocessTask(mb=3), ForwardStepTask(mb=3)], + [BackwardStepTask(mb=3)], + [BackwardPostprocessTask(mb=3), ForwardPreprocessTask(mb=4), ForwardStepTask(mb=4)], + [BackwardStepTask(mb=4)], + [BackwardPostprocessTask(mb=4), ForwardPreprocessTask(mb=5), ForwardStepTask(mb=5)], + [BackwardStepTask(mb=5)], + [BackwardPostprocessTask(mb=5), ForwardPreprocessTask(mb=6), ForwardStepTask(mb=6)], + [BackwardStepTask(mb=6)], + [BackwardPostprocessTask(mb=6), ForwardPreprocessTask(mb=7), ForwardStepTask(mb=7)], + [BackwardStepTask(mb=7)], + [BackwardPostprocessTask(mb=7)], + [], + [ReduceGradsTask()], + ] + assert str(tasks) == str(expected_tasks) @patch("torch.distributed.get_rank") def test_inference_schedule_mb2_stage4(self, rank_mock): - try: - inference_scheduler = InferenceSchedule(num_microbatches=2, stages=4, stage_id=2) - tasks = [task for task in inference_scheduler.steps()] - expected_tasks = [ - [ForwardPreprocessTask(mb=0), ForwardStepTask(mb=0), ForwardPostprocessTask(mb=0)], - [ForwardPreprocessTask(mb=1), ForwardStepTask(mb=1), ForwardPostprocessTask(mb=1)], - ] - assert str(tasks) == str(expected_tasks) - except: - update_result({"inference_success": 0}) - raise + inference_scheduler = InferenceSchedule(num_microbatches=2, stages=4, stage_id=2) + tasks = [task for task in inference_scheduler.steps()] + expected_tasks = [ + [ForwardPreprocessTask(mb=0), ForwardStepTask(mb=0), ForwardPostprocessTask(mb=0)], + [ForwardPreprocessTask(mb=1), ForwardStepTask(mb=1), ForwardPostprocessTask(mb=1)], + ] + assert str(tasks) == str(expected_tasks) if __name__ == "__main__": diff --git a/test/unit_test/pipeline/test_shared_weights.py b/test/unit_test/pipeline/test_shared_weights.py index d77c134..ef5b08a 100644 --- a/test/unit_test/pipeline/test_shared_weights.py +++ b/test/unit_test/pipeline/test_shared_weights.py @@ -4,7 +4,6 @@ import neuronx_distributed.pipeline.partition as partition -from .. import update_result from .test_base import get_traced_model_gpt @@ -27,17 +26,13 @@ class TestSharedWeights(unittest.TestCase): @patch("neuronx_distributed.pipeline.model.parallel_state") @patch("torch.distributed.get_rank") def test_analyze_shared_weights_across_stages(self, rank_mock, state_mock): - try: - traced_model = get_traced_model_gpt() - split_mod = partition.partition_traced_model(traced_model) - partitions = [] - for _, module in split_mod.named_children(): - partitions.append(module) - shared_weights = partition.analyze_shared_weights_across_stages(traced_model, partitions) - assert shared_weights == [[("transformer_wte.weight", 0), ("lm_head.weight", 7)]] - except: - update_result({"inference_success": 0}) - raise + traced_model = get_traced_model_gpt() + split_mod = partition.partition_traced_model(traced_model) + partitions = [] + for _, module in split_mod.named_children(): + partitions.append(module) + shared_weights = partition.analyze_shared_weights_across_stages(traced_model, partitions) + assert shared_weights == [[("transformer_wte.weight", 0), ("lm_head.weight", 7)]] if __name__ == "__main__": diff --git a/test/unit_test/pipeline/test_trace.py b/test/unit_test/pipeline/test_trace.py index 8e8a295..997da51 100644 --- a/test/unit_test/pipeline/test_trace.py +++ b/test/unit_test/pipeline/test_trace.py @@ -9,8 +9,6 @@ from neuronx_distributed.parallel_layers.layers import ColumnParallelLinear from neuronx_distributed.parallel_layers.loss_functions import parallel_cross_entropy -from .. import update_result - class MyModule(torch.nn.Module): def __init__(self): @@ -56,87 +54,59 @@ def forward(self, x): class TestTrace(unittest.TestCase): def test_get_concrete_args(self): - try: - mod = MyModule() - args = tracer.get_concrete_args(mod, []) - assert "x" in args - assert len(args) == 1 - except: - update_result({"inference_success": 0}) - raise + mod = MyModule() + args = tracer.get_concrete_args(mod, []) + assert "x" in args + assert len(args) == 1 def test_get_tracer_class_torch_model(self): - try: - mod = MyModule() - tracer_cls = tracer.get_tracer_class(mod) - assert tracer_cls == tracer.TorchTracerWrapper - except: - update_result({"inference_success": 0}) - raise + mod = MyModule() + tracer_cls = tracer.get_tracer_class(mod) + assert tracer_cls == tracer.TorchTracerWrapper def test_get_tracer_class_hf_cls_input(self): - try: - mod = MyModule() - tracer_cls = tracer.get_tracer_class(mod, "hf") - assert tracer_cls == tracer.HFTracerWrapper - except: - update_result({"inference_success": 0}) - raise + mod = MyModule() + tracer_cls = tracer.get_tracer_class(mod, "hf") + assert tracer_cls == tracer.HFTracerWrapper def test_get_tracer_class_torch_cls_input(self): - try: - mod = MyModule() - tracer_cls = tracer.get_tracer_class(mod, "torch") - assert tracer_cls == tracer.TorchTracerWrapper - except: - update_result({"inference_success": 0}) - raise + mod = MyModule() + tracer_cls = tracer.get_tracer_class(mod, "torch") + assert tracer_cls == tracer.TorchTracerWrapper def test_get_tracer_class_invalid_input(self): - try: - mod = MyModule() - with self.assertRaises(ValueError): - tracer.get_tracer_class(mod, "test") - except: - update_result({"inference_success": 0}) - raise + mod = MyModule() + with self.assertRaises(ValueError): + tracer.get_tracer_class(mod, "test") @patch("torch.distributed.get_rank") def test_trace_model(self, rank_mock): - try: - mod = MyModule() - traced_model = tracer.trace_model(model=mod, input_names=["x"]) - expected_nodes = [ - {"op": "placeholder", "name": "x"}, - {"op": "call_module", "name": "linear1"}, - {"op": "call_module", "name": "linear2"}, - {"op": "call_module", "name": "linear3"}, - {"op": "call_function", "name": "add"}, - {"op": "call_method", "name": "transpose"}, - {"op": "output", "name": "output"}, - ] - ops = [{"op": node.op, "name": node.name} for node in traced_model.graph.nodes] - assert ops == expected_nodes - except: - update_result({"inference_success": 0}) - raise + mod = MyModule() + traced_model = tracer.trace_model(model=mod, input_names=["x"]) + expected_nodes = [ + {"op": "placeholder", "name": "x"}, + {"op": "call_module", "name": "linear1"}, + {"op": "call_module", "name": "linear2"}, + {"op": "call_module", "name": "linear3"}, + {"op": "call_function", "name": "add"}, + {"op": "call_method", "name": "transpose"}, + {"op": "output", "name": "output"}, + ] + ops = [{"op": node.op, "name": node.name} for node in traced_model.graph.nodes] + assert ops == expected_nodes @patch("torch.distributed.get_rank") def test_trace_model_will_leaf_module(self, rank_mock): - try: - mod = NestedModule() - traced_model = tracer.trace_model(model=mod, input_names=["x"], leaf_modules=["MyModule"]) - expected_nodes = [ - {"op": "placeholder", "name": "x"}, - {"op": "call_module", "name": "my_mod"}, - {"op": "call_module", "name": "linear4"}, - {"op": "output", "name": "output"}, - ] - ops = [{"op": node.op, "name": node.name} for node in traced_model.graph.nodes] - assert ops == expected_nodes - except: - update_result({"inference_success": 0}) - raise + mod = NestedModule() + traced_model = tracer.trace_model(model=mod, input_names=["x"], leaf_modules=["MyModule"]) + expected_nodes = [ + {"op": "placeholder", "name": "x"}, + {"op": "call_module", "name": "my_mod"}, + {"op": "call_module", "name": "linear4"}, + {"op": "output", "name": "output"}, + ] + ops = [{"op": node.op, "name": node.name} for node in traced_model.graph.nodes] + assert ops == expected_nodes @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_size", MagicMock(return_value=1)) @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_rank", MagicMock(return_value=1)) @@ -144,27 +114,23 @@ def test_trace_model_will_leaf_module(self, rank_mock): @patch("neuronx_distributed.parallel_layers.layers._initialize_affine_weight_neuron", MagicMock(return_value=None)) @patch("torch.distributed.get_rank") def test_nxd_trace_model_will_leaf_module(self, rank_mock): - try: - mod = NestedNxDModule() - traced_model = tracer.trace_model( - model=mod, - input_names=["x"], - leaf_modules=["MyModule", "ColumnParallelLinear"], - autowrap_functions=[parallel_cross_entropy], - ) - expected_nodes = [ - {"op": "placeholder", "name": "x"}, - {"op": "call_module", "name": "my_mod"}, - {"op": "call_module", "name": "cpl"}, - {"op": "call_function", "name": "parallel_cross_entropy"}, - {"op": "call_module", "name": "linear4"}, - {"op": "output", "name": "output"}, - ] - ops = [{"op": node.op, "name": node.name} for node in traced_model.graph.nodes] - assert ops == expected_nodes - except: - update_result({"inference_success": 0}) - raise + mod = NestedNxDModule() + traced_model = tracer.trace_model( + model=mod, + input_names=["x"], + leaf_modules=["MyModule", "ColumnParallelLinear"], + autowrap_functions=[parallel_cross_entropy], + ) + expected_nodes = [ + {"op": "placeholder", "name": "x"}, + {"op": "call_module", "name": "my_mod"}, + {"op": "call_module", "name": "cpl"}, + {"op": "call_function", "name": "parallel_cross_entropy"}, + {"op": "call_module", "name": "linear4"}, + {"op": "output", "name": "output"}, + ] + ops = [{"op": node.op, "name": node.name} for node in traced_model.graph.nodes] + assert ops == expected_nodes if __name__ == "__main__": diff --git a/test/unit_test/quantization/test_dequantize.py b/test/unit_test/quantization/test_dequantize.py index c2754d0..7f9f3f8 100644 --- a/test/unit_test/quantization/test_dequantize.py +++ b/test/unit_test/quantization/test_dequantize.py @@ -2,16 +2,23 @@ import torch -from neuronx_distributed.quantization.dequantize import dequantize +from neuronx_distributed.quantization.dequantize import direct_cast_dequantize, scale_dequantize class TestDequantize(unittest.TestCase): - def test_dequantize(self): + def test_direct_cast_dequantize(self): + tensor = torch.tensor([-10, 30], dtype=torch.int8) + dequantized_tensor = direct_cast_dequantize(tensor=tensor, upcast_dtype=torch.bfloat16) + + assert dequantized_tensor.dtype == torch.bfloat16 + torch.testing.assert_close(dequantized_tensor, torch.tensor([-10.0, 30.0], dtype=torch.bfloat16)) + + def test_scale_dequantize(self): tensor = torch.tensor([-10, 30], dtype=torch.int8) scale = torch.tensor(10.0) input = torch.tensor([-89.9, 84.8], dtype=torch.bfloat16) - dequantized_tensor = dequantize(tensor=tensor, scale=scale, upcast_dtype=input.dtype) + dequantized_tensor = scale_dequantize(tensor=tensor, scale=scale, upcast_dtype=input.dtype) assert dequantized_tensor.dtype == torch.bfloat16 torch.testing.assert_close(dequantized_tensor, torch.tensor([-100.0, 300.0], dtype=torch.bfloat16)) diff --git a/test/unit_test/quantization/test_observer.py b/test/unit_test/quantization/test_observer.py new file mode 100644 index 0000000..64da95f --- /dev/null +++ b/test/unit_test/quantization/test_observer.py @@ -0,0 +1,43 @@ +import unittest + +import torch + +from neuronx_distributed.quantization.observer import PerChannelAbsMaxObserver + + +class TestPerChannelAbsMaxObserver(unittest.TestCase): + def test_forward(self): + # Channel axis = 0 + tensor_observer = PerChannelAbsMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_channel_symmetric, ch_axis=0 + )() + tensor = torch.Tensor( + [ + [1.0, -190, -950, 900, 100], + [2.0, 255, -900, 80, -80], + ] + ) + + tensor_observer(tensor) + + expected_max_val = torch.Tensor([[950.0], [900.0]]) + expected_scale = expected_max_val / 127.0 + assert torch.allclose(tensor_observer.max_val, expected_max_val) + assert torch.allclose(tensor_observer.calculate_qparams()[0], expected_scale.squeeze(1)) + + # Channel axis = 1 + tensor_observer = PerChannelAbsMaxObserver.with_args( + dtype=torch.qint8, qscheme=torch.per_channel_symmetric, ch_axis=1 + )() + tensor = torch.Tensor( + [ + [1.0, -190, -950, 900, 100], + [2.0, 255, -900, 80, -80], + ] + ) + + tensor_observer(tensor) + expected_max_val = torch.Tensor([[2.0], [255.0], [950.0], [900.0], [100.0]]) + expected_scale = expected_max_val / 127.0 + assert torch.allclose(tensor_observer.max_val, expected_max_val) + assert torch.allclose(tensor_observer.calculate_qparams()[0], expected_scale.squeeze(1)) diff --git a/test/unit_test/quantization/test_quantization_layers.py b/test/unit_test/quantization/test_quantization_layers.py index 5a49fc0..4b731f5 100644 --- a/test/unit_test/quantization/test_quantization_layers.py +++ b/test/unit_test/quantization/test_quantization_layers.py @@ -5,16 +5,24 @@ import torch import torch_xla.core.xla_model as xm +from neuronx_distributed.modules.moe.moe_parallel_layers import ( + ExpertFusedColumnParallelLinear, + ExpertFusedRowParallelLinear, +) from neuronx_distributed.parallel_layers import parallel_state from neuronx_distributed.parallel_layers.layers import ( ColumnParallelLinear, RowParallelLinear, ) +from neuronx_distributed.quantization.quantization_config import ( + get_default_custom_qconfig_dict, + get_default_per_channel_custom_qconfig_dict, +) from neuronx_distributed.quantization.quantization_layers import ( BaseQuantizeParallelLinear, - QuantizationType, QuantizedColumnParallel, - QuantizedDtype, + QuantizedExpertFusedColumnParallel, + QuantizedExpertFusedRowParallel, QuantizedParallelLinearLayerStateDictAdaptor, QuantizedRowParallel, ) @@ -61,6 +69,19 @@ def test_get_scale_from_state_dict(self): QuantizedParallelLinearLayerStateDictAdaptor.get_scale_from_state_dict("lay1.", state_dict), ) + weight = torch.quantize_per_channel( + torch.randn(4, 4), torch.tensor([0.1, 0.2, 0.3, 0.4]), torch.tensor([0, 1, 2, 3]), 0, torch.qint8 + ) + state_dict = { + "lay1._packed_params.dtype": torch.qint8, + "lay1._packed_params._packed_params": (weight, MagicMock()), + } + + assert torch.allclose( + torch.Tensor([[0.1], [0.2], [0.3], [0.4]]), + QuantizedParallelLinearLayerStateDictAdaptor.get_scale_from_state_dict("lay1.", state_dict), + ) + class TestBaseQuantizeParallelLinear(unittest.TestCase): @patch.multiple(BaseQuantizeParallelLinear, __abstractmethods__=set()) @@ -69,31 +90,72 @@ def test_init(self): BaseQuantizeParallelLinear(quantization_type="something") self.assertTrue( - "something quantization is not supported currently. Specify from [['scalar']]" in str(context.exception) + "something quantization is not supported currently. Specify from [['per_tensor_symmetric', 'per_channel_symmetric']]" + in str(context.exception) ) with self.assertRaises(AssertionError) as context: - BaseQuantizeParallelLinear(quantization_type="scalar", quantized_dtype=torch.float16) + BaseQuantizeParallelLinear(quantization_type="per_tensor_symmetric", quantized_dtype=torch.float16) self.assertTrue("torch.float16 quantization is not supported currently. Specify from [['torch.int8']]") - test_class = BaseQuantizeParallelLinear(quantization_type="scalar", quantized_dtype=torch.int8) + test_class = BaseQuantizeParallelLinear(quantization_type="per_tensor_symmetric", quantized_dtype=torch.int8) assert test_class.scale is None @patch.multiple(BaseQuantizeParallelLinear, __abstractmethods__=set()) def test_init_weight(self): - test_class = BaseQuantizeParallelLinear(quantization_type="scalar", quantized_dtype=torch.int8) + test_class = BaseQuantizeParallelLinear(quantization_type="per_tensor_symmetric", quantized_dtype=torch.int8) weight = torch.empty((5, 5), dtype=torch.int8) test_class._init_weight(weight=weight) torch.testing.assert_close(weight, torch.zeros(5, 5, dtype=torch.int8)) @patch.multiple(BaseQuantizeParallelLinear, __abstractmethods__=set()) def test_init_bias(self): - test_class = BaseQuantizeParallelLinear(quantization_type="scalar", quantized_dtype=torch.int8) + test_class = BaseQuantizeParallelLinear(quantization_type="per_tensor_symmetric", quantized_dtype=torch.int8) bias = torch.empty((5,), dtype=torch.bfloat16) test_class._init_bias(bias=bias) torch.testing.assert_close(bias, torch.zeros(5, dtype=torch.bfloat16)) + @patch.multiple(BaseQuantizeParallelLinear, __abstractmethods__=set()) + def test_setup_for_scale(self): + # for per_tensor_symmetric + test_class = BaseQuantizeParallelLinear(quantization_type="per_tensor_symmetric", quantized_dtype=torch.int8) + test_class._setup_for_scale( + weight_shape=MagicMock(), quantization_type=test_class.quantization_type, weight_partition_dim=MagicMock() + ) + assert hasattr(test_class.scale, "get_tensor_from_state_dict") + assert test_class.scale.tensor_model_parallel is False + assert torch.allclose(test_class.scale, torch.Tensor([1.0])) + + del test_class + + # for per_channel_symmetric, per channel same as partition dim + test_class = BaseQuantizeParallelLinear(quantization_type="per_channel_symmetric", quantized_dtype=torch.int8) + test_class.weight = MagicMock(device=torch.device("cpu")) + test_class._setup_for_scale( + weight_shape=(8, 10), + quantization_type=test_class.quantization_type, + weight_partition_dim=0, + per_channel_axis=0, + ) + assert hasattr(test_class.scale, "get_tensor_from_state_dict") + assert torch.allclose(test_class.scale, torch.ones((8, 1))) + assert test_class.scale.tensor_model_parallel is True + assert test_class.scale.partition_dim == 0 + + # for per_channel_symmetric, per channel not same as partition dim + test_class = BaseQuantizeParallelLinear(quantization_type="per_channel_symmetric", quantized_dtype=torch.int8) + test_class.weight = MagicMock(device=torch.device("cpu")) + test_class._setup_for_scale( + weight_shape=(8, 10), + quantization_type=test_class.quantization_type, + weight_partition_dim=0, + per_channel_axis=1, + ) + assert hasattr(test_class.scale, "get_tensor_from_state_dict") + assert torch.allclose(test_class.scale, torch.ones((1, 10))) + assert test_class.scale.tensor_model_parallel is False + class TestQuantizedColumnParallel(unittest.TestCase): def setUp(self) -> None: @@ -109,24 +171,30 @@ def tearDown(self) -> None: return @patch("neuronx_distributed.quantization.quantization_layers.get_tensor_model_parallel_size", return_value=1) - @patch("neuronx_distributed.quantization.quantization_layers.QuantizedColumnParallel._setup_for_weight") + @patch("neuronx_distributed.quantization.quantization_layers.BaseQuantizeParallelLinear._setup_for_weight") @patch("neuronx_distributed.quantization.quantization_layers.QuantizedColumnParallel._setup_for_bias") + @patch("neuronx_distributed.quantization.quantization_layers.BaseQuantizeParallelLinear._setup_for_scale") @patch("neuronx_distributed.quantization.quantization_layers.QuantizedColumnParallel._setup_for_parallelism") def test_init( self, mock_setup_for_parallelism, + mock_setup_for_scale, mock_setup_for_bias, mock_setup_for_weight, mock_get_tensor_model_parallel_size, ): - test_class = QuantizedColumnParallel( - input_size=5, output_size=5, quantization_type="scalar", quantized_dtype=torch.int8, dtype=torch.bfloat16 + _ = QuantizedColumnParallel( + input_size=5, + output_size=5, + quantization_type="per_tensor_symmetric", + quantized_dtype=torch.int8, + dtype=torch.bfloat16, ) mock_get_tensor_model_parallel_size.assert_called_once() mock_setup_for_weight.assert_called_once() mock_setup_for_bias.assert_called_once() + mock_setup_for_scale.assert_called_once() mock_setup_for_parallelism.assert_called_once() - assert hasattr(test_class.scale, "get_tensor_from_state_dict") @patch("neuronx_distributed.quantization.quantization_layers._initialize_affine_weight_neuron", return_value=1) @patch("neuronx_distributed.quantization.quantization_layers._initialize_parameter_cpu") @@ -182,8 +250,24 @@ def test_from_float(self): cpl = ColumnParallelLinear( input_size=4, output_size=6, device=torch.device("cpu"), bias=True, dtype=torch.float32 ) - qcpl = QuantizedColumnParallel.from_float(cpl, quantized_dtype=torch.int8) + q_config = get_default_custom_qconfig_dict() + qcpl = QuantizedColumnParallel.from_float(cpl, q_config=q_config) + assert qcpl.weight.dtype == torch.int8 + assert qcpl.scale.shape == (1,) + assert qcpl.bias is not None + + # Channel axis = 0 + q_config = get_default_per_channel_custom_qconfig_dict() + qcpl = QuantizedColumnParallel.from_float(cpl, q_config=q_config) + assert qcpl.weight.dtype == torch.int8 + assert qcpl.scale.shape == (6, 1) + assert qcpl.bias is not None + + # Channel axis = 1 + q_config["quantization_per_channel_axis"] = 1 + qcpl = QuantizedColumnParallel.from_float(cpl, q_config=q_config) assert qcpl.weight.dtype == torch.int8 + assert qcpl.scale.shape == (1, 4) assert qcpl.bias is not None @@ -200,17 +284,19 @@ def tearDown(self) -> None: parallel_state._MPU_TENSOR_MODEL_PARALLEL_RANK = self.initial_rank return + @patch("neuronx_distributed.quantization.quantization_layers.BaseQuantizeParallelLinear._setup_for_scale") @patch("neuronx_distributed.quantization.quantization_layers.QuantizedRowParallel._setup_for_bias") @patch("neuronx_distributed.quantization.quantization_layers.QuantizedRowParallel._setup_for_weight") def test_init( self, mock_setup_for_weight, mock_setup_for_bias, + mock_setup_for_scale, ): - test_class = QuantizedRowParallel(input_size=6, output_size=4, dtype=torch.bfloat16) + _ = QuantizedRowParallel(input_size=6, output_size=4, dtype=torch.bfloat16) mock_setup_for_weight.assert_called_once() mock_setup_for_bias.assert_called_once() - assert hasattr(test_class.scale, "get_tensor_from_state_dict") + mock_setup_for_scale.assert_called_once() @patch("neuronx_distributed.quantization.quantization_layers._initialize_affine_weight_neuron", return_value=1) @patch("neuronx_distributed.quantization.quantization_layers._initialize_parameter_cpu") @@ -256,6 +342,235 @@ def test_setup_for_bias(self): def test_from_float(self): rpl = RowParallelLinear(input_size=6, output_size=4, device=torch.device("cpu"), bias=True, dtype=torch.float32) - qrpl = QuantizedRowParallel.from_float(rpl, quantized_dtype=torch.int8) + q_config = get_default_custom_qconfig_dict() + qrpl = QuantizedRowParallel.from_float(rpl, q_config=q_config) assert qrpl.weight.dtype == torch.int8 + assert qrpl.scale.shape == (1,) assert qrpl.bias is not None + + # Channel axis = 0 + q_config = get_default_per_channel_custom_qconfig_dict() + qrpl = QuantizedRowParallel.from_float(rpl, q_config=q_config) + assert qrpl.weight.dtype == torch.int8 + assert qrpl.bias is not None + assert qrpl.scale.shape == (4, 1) + + # Channel axis = 1 + q_config["quantization_per_channel_axis"] = 1 + qrpl = QuantizedRowParallel.from_float(rpl, q_config=q_config) + assert qrpl.weight.dtype == torch.int8 + assert qrpl.bias is not None + assert qrpl.scale.shape == (1, 6) + + +class TestQuantizedExpertFusedColumnParallel(unittest.TestCase): + def setUp(self) -> None: + self.initial_world_size = parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + self.initial_rank = parallel_state._MPU_TENSOR_MODEL_PARALLEL_RANK + self.initial_expert_world_size = parallel_state._MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE + + + parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = 1 + parallel_state._MPU_TENSOR_MODEL_PARALLEL_RANK = 0 + parallel_state._MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = 1 + + def tearDown(self) -> None: + parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = self.initial_world_size + parallel_state._MPU_TENSOR_MODEL_PARALLEL_RANK = self.initial_rank + parallel_state._MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = self.initial_expert_world_size + return + + @patch("neuronx_distributed.quantization.quantization_layers.get_tensor_model_parallel_size", return_value=1) + @patch("neuronx_distributed.quantization.quantization_layers.QuantizedExpertFusedColumnParallel._setup_for_weight") + @patch("neuronx_distributed.quantization.quantization_layers.QuantizedExpertFusedColumnParallel._setup_for_bias") + @patch("neuronx_distributed.quantization.quantization_layers.BaseQuantizeParallelLinear._setup_for_scale") + @patch( + "neuronx_distributed.quantization.quantization_layers.QuantizedExpertFusedColumnParallel._setup_for_parallelism" + ) + def test_init( + self, + mock_setup_for_parallelism, + mock_setup_for_bias, + mock_setup_for_weight, + mock_setup_for_scale, + mock_get_tensor_model_parallel_size, + ): + _ = QuantizedExpertFusedColumnParallel( + num_experts=2, + input_size=3, + output_size=4, + quantization_type="per_tensor_symmetric", + quantized_dtype=torch.int8, + dtype=torch.bfloat16, + ) + mock_get_tensor_model_parallel_size.assert_called_once() + mock_setup_for_weight.assert_called_once() + mock_setup_for_bias.assert_called_once() + mock_setup_for_scale.assert_called_once() + mock_setup_for_parallelism.assert_called_once() + + @patch("neuronx_distributed.quantization.quantization_layers._initialize_affine_weight_neuron", return_value=1) + @patch("neuronx_distributed.quantization.quantization_layers._initialize_parameter_cpu") + @patch("neuronx_distributed.quantization.quantization_layers.get_tensor_model_parallel_size", return_value=2) + def test_setup_for_weight( + self, + mock_get_tensor_model_parallel_size, + mock_initialize_parameter_cpu, + mock_initialize_affine_weight_neuron, + ): + num_experts = 2 + input_size = 4 + output_size = 6 + + layer = QuantizedExpertFusedColumnParallel(num_experts, input_size, output_size, device=torch.device("cpu")) + + # Assert _initialize_parameter_cpu called with right inputs + mock_initialize_parameter_cpu.assert_called_once_with( + param=layer.weight, + partition_dim=2, + init_method=layer._init_weight, + param_dtype=torch.int8, + stride=layer.stride, + return_master_param=layer.keep_master_weight, + ) + + # Assert weight properties + self.assertEqual(layer.weight.shape, (2, 4, 3)) # Adjusted for partition_dim=2 + self.assertEqual(layer.weight.dtype, torch.int8) + self.assertFalse(layer.weight.requires_grad) + + del layer + + layer = QuantizedExpertFusedColumnParallel(num_experts, input_size, output_size, device=xm.xla_device()) + + # Assert _initialize_affine_weight_neuron called with right inputs + mock_initialize_affine_weight_neuron.assert_called_once_with( + weight=layer.weight, init_method=layer._init_weight, partition_dim=2, stride=layer.stride + ) + + # Assert weight properties + self.assertEqual(layer.weight.shape, (2, 4, 3)) # Adjusted for partition_dim=2 + self.assertEqual(layer.weight.dtype, torch.int8) + self.assertFalse(layer.weight.requires_grad) + + def test_from_float(self): + cpl = ExpertFusedColumnParallelLinear( + num_experts=2, input_size=4, output_size=6, device=torch.device("cpu"), dtype=torch.float32 + ) + q_config = get_default_custom_qconfig_dict() + qcpl = QuantizedExpertFusedColumnParallel.from_float(cpl, q_config=q_config) + assert qcpl.weight.dtype == torch.int8 + assert qcpl.bias is None + qcpl.scale.shape == (1,) + + q_config = get_default_per_channel_custom_qconfig_dict() + q_config["quantization_per_channel_axis"] = 1 # First dimension is reserved for experts + qcpl = QuantizedExpertFusedColumnParallel.from_float(cpl, q_config=q_config) + assert qcpl.scale.shape == (1, 4, 1) + assert qcpl.scale.tensor_model_parallel is False + + q_config = get_default_per_channel_custom_qconfig_dict() + q_config["quantization_per_channel_axis"] = 2 # First dimension is reserved for experts + qcpl = QuantizedExpertFusedColumnParallel.from_float(cpl, q_config=q_config) + qcpl.scale.shape == (1, 1, 6) + assert qcpl.scale.tensor_model_parallel is True and qcpl.scale.partition_dim == 2 + + +class TestQuantizedExpertFusedRowParallel(unittest.TestCase): + def setUp(self) -> None: + self.initial_world_size = parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + self.initial_rank = parallel_state._MPU_TENSOR_MODEL_PARALLEL_RANK + self.initial_expert_world_size = parallel_state._MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE + + parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = 1 + parallel_state._MPU_TENSOR_MODEL_PARALLEL_RANK = 0 + parallel_state._MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = 1 + + def tearDown(self) -> None: + parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = self.initial_world_size + parallel_state._MPU_TENSOR_MODEL_PARALLEL_RANK = self.initial_rank + parallel_state._MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = self.initial_expert_world_size + return + + @patch("neuronx_distributed.quantization.quantization_layers.BaseQuantizeParallelLinear._setup_for_scale") + @patch("neuronx_distributed.quantization.quantization_layers.QuantizedExpertFusedRowParallel._setup_for_bias") + @patch("neuronx_distributed.quantization.quantization_layers.QuantizedExpertFusedRowParallel._setup_for_weight") + def test_init( + self, + mock_setup_for_weight, + mock_setup_for_bias, + mock_setup_for_scale, + ): + _ = QuantizedExpertFusedRowParallel(num_experts=2, input_size=6, output_size=4, dtype=torch.bfloat16) + mock_setup_for_weight.assert_called_once() + mock_setup_for_bias.assert_called_once() + mock_setup_for_scale.assert_called_once() + + @patch("neuronx_distributed.quantization.quantization_layers._initialize_affine_weight_neuron", return_value=1) + @patch("neuronx_distributed.quantization.quantization_layers._initialize_parameter_cpu") + def test_setup_for_weight(self, mock_initialize_parameter_cpu, mock_initialize_affine_weight_neuron): + num_experts = 2 + input_size = 6 + output_size = 4 + parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = 2 + layer = QuantizedExpertFusedRowParallel(num_experts, input_size, output_size, device=torch.device("cpu")) + + # Assert _initialize_parameter_cpu called with right inputs + mock_initialize_parameter_cpu.assert_called_once_with( + param=layer.weight, + partition_dim=1, + init_method=layer._init_weight, + param_dtype=torch.int8, + stride=layer.stride, + return_master_param=layer.keep_master_weight, + ) + + # Assert weight properties + self.assertEqual(layer.weight.shape, (2, 3, 4)) + self.assertEqual(layer.weight.dtype, torch.int8) + self.assertFalse(layer.weight.requires_grad) + assert hasattr(layer.weight, "get_tensor_from_state_dict") + + del layer + + layer = QuantizedExpertFusedRowParallel(num_experts, input_size, output_size, device=xm.xla_device()) + + # Assert _initialize_affine_weight_neuron called with right inputs + mock_initialize_affine_weight_neuron.assert_called_once_with( + weight=layer.weight, init_method=layer._init_weight, partition_dim=1, stride=layer.stride + ) + + # Assert weight properties + self.assertEqual(layer.weight.shape, (2, 3, 4)) # Adjusted for partition_dim=1 + self.assertEqual(layer.weight.dtype, torch.int8) + self.assertFalse(layer.weight.requires_grad) + + def test_from_float(self): + rpl = ExpertFusedRowParallelLinear( + num_experts=2, + input_size=6, + output_size=4, + device=torch.device("cpu"), + dtype=torch.float32, + ) + q_config = get_default_custom_qconfig_dict() + qrpl = QuantizedExpertFusedRowParallel.from_float(rpl, q_config=q_config) + assert qrpl.weight.dtype == torch.int8 + assert qrpl.bias is None + assert qrpl.scale.shape == (1,) + + q_config = get_default_per_channel_custom_qconfig_dict() + q_config["quantization_per_channel_axis"] = 1 # First dimension is reserved for experts + qrpl = QuantizedExpertFusedRowParallel.from_float(rpl, q_config=q_config) + qrpl.scale.shape == (1, 6, 1) + assert qrpl.scale.tensor_model_parallel is True and qrpl.scale.partition_dim == 1 + + q_config = get_default_per_channel_custom_qconfig_dict() + q_config["quantization_per_channel_axis"] = 2 # First dimension is reserved for experts + qrpl = QuantizedExpertFusedRowParallel.from_float(rpl, q_config=q_config) + qrpl.scale.shape == (1, 1, 4) + assert qrpl.scale.tensor_model_parallel is False + + +if __name__ == "__main__": + unittest.main(verbosity=3, failfast=False) diff --git a/test/unit_test/quantization/test_quantization_utils.py b/test/unit_test/quantization/test_quantization_utils.py index e69de29..f93aed7 100644 --- a/test/unit_test/quantization/test_quantization_utils.py +++ b/test/unit_test/quantization/test_quantization_utils.py @@ -0,0 +1,77 @@ +import unittest + +import pytest +import torch +import torch.ao.quantization + +from neuronx_distributed.quantization.quantization_utils import ( + extract_q_scale, + quantize_per_channel_symmetric, + quantize_per_tensor_symmetric, +) + + +class TestExtractQscale(unittest.TestCase): + def test_extract_q_scale_per_tensor(self): + q_tensor = torch.quantize_per_tensor(torch.randn(4, 4), 0.1, 2, torch.quint8) + assert extract_q_scale(q_tensor) == 0.1 + + q_tensor = torch.quantize_per_channel( + torch.randn(4, 4), + torch.tensor([0.1, 0.2, 0.3, 0.4]), + torch.tensor([0, 1, 2, 3]), + 0, + torch.qint8, + ) + assert torch.allclose(extract_q_scale(q_tensor), torch.Tensor([[0.1], [0.2], [0.3], [0.4]])) + + +@pytest.mark.skip("Not testing convert_qint8_to_int8_state_dict as its a temporary solution before refactoring") +class TestConvertQint8ToInt8StateDict(unittest.TestCase): + def test_convert_qint8_to_int8_state_dict(self): + pass + + +@pytest.mark.skip("Customer facing optional code. Test later") +class TestQuantizePytorchModelPerChannelSymmetric(unittest.TestCase): + def test_quantize_pytorch_model_per_channel_symmetric(self): + pass + + +@pytest.mark.skip("Customer facing optional code. Test later") +class TestQuantizePytorchModelPerTensorSymmetric(unittest.TestCase): + def test_quantize_pytorch_model_per_tensor_symmetric(self): + pass + + +class TestQuantizePerTensorSymmetric(unittest.TestCase): + # Tests quantize_per_tensor_symmetric function + def test_quantize_per_tensor_symmetric(self): + tensor = torch.Tensor([[1.0, -190, -950, 900, 100]]) + expected_scale = 950 / (float(127 - (-128)) / 2) + expected_scale = torch.tensor(expected_scale) + + # We just verify the scale. Once that is fine, _quantize method is already pytorch tested + quantized_tensor = quantize_per_tensor_symmetric(tensor) + assert torch.allclose(torch.tensor(quantized_tensor.q_scale()), expected_scale) + + +class TestQuantizePerChannelSymmetric(unittest.TestCase): + # Tests quantize_per_channel_symmetric function + def test_quantize_per_channel_symmetric(self): + tensor = torch.Tensor( + [ + [1.0, -190, -950, 900, 100], + [2.0, 255, -900, 80, -80], + ] + ) + + # channel axis = 1 + expected_scale = torch.Tensor([2.0, 255, 950, 900, 100]) / 127.0 + quantized_tensor = quantize_per_channel_symmetric(tensor, 1) + assert torch.allclose(quantized_tensor.q_per_channel_scales().to(torch.float32), expected_scale) + + # channel axis = 0 + expected_scale = torch.Tensor([950, 900]) / 127 + quantized_tensor = quantize_per_channel_symmetric(tensor, 0) + assert torch.allclose(quantized_tensor.q_per_channel_scales().to(torch.float32), expected_scale) diff --git a/test/unit_test/quantization/test_quantize.py b/test/unit_test/quantization/test_quantize.py index d36ea37..87e7530 100644 --- a/test/unit_test/quantization/test_quantize.py +++ b/test/unit_test/quantization/test_quantize.py @@ -8,10 +8,7 @@ RowParallelLinear, ) from neuronx_distributed.quantization.quantization_layers import ( - BaseQuantizeParallelLinear, - QuantizationType, QuantizedColumnParallel, - QuantizedDtype, QuantizedRowParallel, ) from neuronx_distributed.quantization.quantize import convert @@ -66,7 +63,7 @@ def __init__(self): # With inplace model1 = Model() - model2 = convert(module=model1, q_config=None, inplace=True, mapping=None) + convert(module=model1, q_config=None, inplace=True, mapping=None) assert isinstance(model1.linear1.lay1, QuantizedColumnParallel) assert isinstance(model1.linear1.lay2, QuantizedRowParallel) diff --git a/test/unit_test/trace/test_spmd_trace.py b/test/unit_test/trace/test_spmd_trace.py index b31a95a..eedfc2f 100644 --- a/test/unit_test/trace/test_spmd_trace.py +++ b/test/unit_test/trace/test_spmd_trace.py @@ -15,113 +15,107 @@ shard_children, ) -from .. import update_result - class TestCheckpoint(unittest.TestCase): def test_validate_traceable(self): - try: - class Model(torch.nn.Module): - def __init__(self, shard_across_embedding): - super().__init__() - self.embed_tokens = ParallelEmbedding(10, 10, shard_across_embedding=shard_across_embedding) + class Model(torch.nn.Module): + def __init__(self, shard_across_embedding): + super().__init__() + self.embed_tokens = ParallelEmbedding(10, 10, shard_across_embedding=shard_across_embedding) + + def forward(self): + pass - def forward(self): - pass + def model_over_vocab(): + return (Model(shard_across_embedding=False), {}) - model_over_vocab = lambda: (Model(shard_across_embedding=False), {}) - model_over_embed = lambda: (Model(shard_across_embedding=True), {}) + def model_over_embed(): + return (Model(shard_across_embedding=True), {}) - with self.assertRaises(ValueError): - _validate_traceable(model_over_vocab, tp_degree=1) - _validate_traceable(model_over_embed, tp_degree=1) + with self.assertRaises(ValueError): + _validate_traceable(model_over_vocab, tp_degree=1) + _validate_traceable(model_over_embed, tp_degree=1) - class ModelWithChildren(torch.nn.Module): - def __init__(self, shard_across_embedding): - super().__init__() - self.model = Model(shard_across_embedding) + class ModelWithChildren(torch.nn.Module): + def __init__(self, shard_across_embedding): + super().__init__() + self.model = Model(shard_across_embedding) - def forward(self): - pass + def forward(self): + pass - model_over_vocab = lambda: (ModelWithChildren(shard_across_embedding=False), {}) - model_over_embed = lambda: (ModelWithChildren(shard_across_embedding=True), {}) + def model_over_vocab(): + return (ModelWithChildren(shard_across_embedding=False), {}) - with self.assertRaises(ValueError): - _validate_traceable(model_over_vocab, tp_degree=1) - _validate_traceable(model_over_embed, tp_degree=1) + def model_over_embed(): + return (ModelWithChildren(shard_across_embedding=True), {}) - except: - update_result({"inference_success": 0}) - raise + with self.assertRaises(ValueError): + _validate_traceable(model_over_vocab, tp_degree=1) + _validate_traceable(model_over_embed, tp_degree=1) def test_shard_children(self): - try: - - class InnerModel(torch.nn.Module): - def __init__(self): - super().__init__() - if parallel_state.model_parallel_is_initialized(): - self.embed_tokens = ParallelEmbedding(10, 32, shard_across_embedding=True) - self.cpl = ColumnParallelLinear(10, 64) - self.rpl = RowParallelLinear(64, 10) - else: - self.embed_tokens = torch.nn.Embedding(10, 32) - self.cpl = torch.nn.Linear(10, 64) - self.rpl = torch.nn.Linear(64, 10) - - def forward(self): - pass - - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.lay1 = InnerModel() - if parallel_state.model_parallel_is_initialized(): - self.embed_tokens = ParallelEmbedding(10, 32, shard_across_embedding=True) - self.cpl = ColumnParallelLinear(8, 128) - self.rpl = RowParallelLinear(128, 8) - else: - self.embed_tokens = torch.nn.Embedding(8, 32) - self.cpl = torch.nn.Linear(8, 128) - self.rpl = torch.nn.Linear(128, 8) - - def forward(self): - pass - - model = Model() - checkpoint = model.state_dict() - - for tp_degree in [2, 4, 8, 16, 32]: - for rank in range(0, tp_degree): - _mock_parallel_state(tp_degree, rank) - nxd_model = Model() - sharded_checkpoint = copy.deepcopy(checkpoint) - shard_children(nxd_model, sharded_checkpoint, "", torch.float32, rank=rank, tp_degree=tp_degree) - - def validate_shard_weight(prefix): - embed_shard = sharded_checkpoint[prefix + "embed_tokens.weight"] - embed = checkpoint[prefix + "embed_tokens.weight"] - assert embed_shard.shape == (embed.shape[0], embed.shape[1] / tp_degree) - assert torch.equal(embed_shard, torch.split(embed, embed.shape[1] // tp_degree, dim=1)[rank]) - - rpl_shard = sharded_checkpoint[prefix + "rpl.weight"] - rpl = checkpoint[prefix + "rpl.weight"] - assert rpl_shard.shape == (rpl.shape[0], rpl.shape[1] / tp_degree) - assert torch.equal(rpl_shard, torch.split(rpl, rpl.shape[1] // tp_degree, dim=1)[rank]) - - cpl_shard = sharded_checkpoint[prefix + "cpl.weight"] - cpl = checkpoint[prefix + "cpl.weight"] - assert cpl_shard.shape == (cpl.shape[0] / tp_degree, cpl.shape[1]) - assert torch.equal(cpl_shard, torch.split(cpl, cpl.shape[0] // tp_degree, dim=0)[rank]) - - validate_shard_weight("") - validate_shard_weight("lay1.") - - except: - update_result({"inference_success": 0}) - raise + + class InnerModel(torch.nn.Module): + def __init__(self): + super().__init__() + if parallel_state.model_parallel_is_initialized(): + self.embed_tokens = ParallelEmbedding(10, 32, shard_across_embedding=True) + self.cpl = ColumnParallelLinear(10, 64) + self.rpl = RowParallelLinear(64, 10) + else: + self.embed_tokens = torch.nn.Embedding(10, 32) + self.cpl = torch.nn.Linear(10, 64) + self.rpl = torch.nn.Linear(64, 10) + + def forward(self): + pass + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.lay1 = InnerModel() + if parallel_state.model_parallel_is_initialized(): + self.embed_tokens = ParallelEmbedding(10, 32, shard_across_embedding=True) + self.cpl = ColumnParallelLinear(8, 128) + self.rpl = RowParallelLinear(128, 8) + else: + self.embed_tokens = torch.nn.Embedding(8, 32) + self.cpl = torch.nn.Linear(8, 128) + self.rpl = torch.nn.Linear(128, 8) + + def forward(self): + pass + + model = Model() + checkpoint = model.state_dict() + + for tp_degree in [2, 4, 8, 16, 32]: + for rank in range(0, tp_degree): + _mock_parallel_state(tp_degree, rank) + nxd_model = Model() + sharded_checkpoint = copy.deepcopy(checkpoint) + shard_children(nxd_model, sharded_checkpoint, "", torch.float32, rank=rank, tp_degree=tp_degree) + + def validate_shard_weight(prefix): + embed_shard = sharded_checkpoint[prefix + "embed_tokens.weight"] + embed = checkpoint[prefix + "embed_tokens.weight"] + assert embed_shard.shape == (embed.shape[0], embed.shape[1] / tp_degree) + assert torch.equal(embed_shard, torch.split(embed, embed.shape[1] // tp_degree, dim=1)[rank]) + + rpl_shard = sharded_checkpoint[prefix + "rpl.weight"] + rpl = checkpoint[prefix + "rpl.weight"] + assert rpl_shard.shape == (rpl.shape[0], rpl.shape[1] / tp_degree) + assert torch.equal(rpl_shard, torch.split(rpl, rpl.shape[1] // tp_degree, dim=1)[rank]) + + cpl_shard = sharded_checkpoint[prefix + "cpl.weight"] + cpl = checkpoint[prefix + "cpl.weight"] + assert cpl_shard.shape == (cpl.shape[0] / tp_degree, cpl.shape[1]) + assert torch.equal(cpl_shard, torch.split(cpl, cpl.shape[0] // tp_degree, dim=0)[rank]) + + validate_shard_weight("") + validate_shard_weight("lay1.") if __name__ == "__main__": diff --git a/test/unit_test/utils/test_logger.py b/test/unit_test/utils/test_logger.py new file mode 100644 index 0000000..254dc8d --- /dev/null +++ b/test/unit_test/utils/test_logger.py @@ -0,0 +1,122 @@ +# Standard Library +import unittest +from unittest.mock import patch, MagicMock +import logging + +from neuronx_distributed.utils.logger import get_log_level, get_logger, _rank0_only + + +class TestLogger(unittest.TestCase): + def test_log_level(self): + levels = { + "trace": logging.DEBUG, + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "fatal": logging.FATAL, + "off": logging.FATAL + 1, + "unsupported": ValueError, + } + for value, level in levels.items(): + with patch.dict("os.environ", {"NXD_LOG_LEVEL": value}): + get_log_level.cache_clear() + if not isinstance(level, int): + with self.assertRaises(level): + get_log_level() + else: + self.assertEqual(level, get_log_level()) + + def test_logger(self): + for level in ["info", "off"]: + with patch.dict("os.environ", {"NXD_LOG_LEVEL": level}), patch("logging.getLogger") as mock_logger: + # Arrange + get_log_level.cache_clear() + mock_logger.return_value.initialized = False + lvls = ["debug", "info", "warning", "error", "exception", "fatal", "critical"] + original = {lvl: getattr(mock_logger.return_value, lvl) for lvl in lvls} + # Act + logger = get_logger("some_name") + # Assert + self.assertEqual(logger, mock_logger.return_value) + if level == "off": + logger.setLevel.assert_not_called() + self.assertTrue(logger.disabled) + else: + logger.setLevel.assert_called_once_with(get_log_level()) + for lvl, method in original.items(): + self.assertNotEqual(getattr(logger, lvl), method) + self.assertFalse(logger.propagate) + self.assertTrue(logger.initialized) + + @patch.dict("os.environ", {"NXD_LOG_LEVEL": "info"}) + @patch("logging.getLogger") + def test_logger_skips_initialized(self, mock_logger): + # Arrange + mock_logger.return_value.initialized = True + # Act + logger = get_logger() + # Assert + self.assertEqual(logger, mock_logger.return_value) + logger.addHandler.assert_not_called() + + @patch.dict("os.environ", {"NXD_LOG_LEVEL": "info"}) + @patch("logging.getLogger") + def test_logger_not_rank0_only(self, mock_logger): + # Arrange + lvls = ["debug", "info", "warning", "error", "exception", "fatal", "critical"] + original = {lvl: getattr(mock_logger.return_value, lvl) for lvl in lvls} + # Act + logger = get_logger(rank0_only=False) + # Assert + self.assertEqual(logger, mock_logger.return_value) + for lvl, method in original.items(): + self.assertEqual(getattr(logger, lvl), method) + + @patch("torch.distributed") + def test_rank0_works(self, mock_dist): + # Arrange + fn = MagicMock() + mock_dist.is_initialized.return_value = True + mock_dist.get_rank.return_value = 0 + # Act + wrapped = _rank0_only(fn) + # Assert + self.assertEqual(wrapped(), fn.return_value) + fn.assert_called_once_with() + + @patch.dict("os.environ", {"RANK": "1"}) + @patch("torch.distributed") + def test_rank0_checks_torch_dist(self, mock_dist): + # Arrange + fn = MagicMock() + mock_dist.is_initialized.return_value = True + mock_dist.get_rank.return_value = 42 + # Act + wrapped = _rank0_only(fn) + # Assert + self.assertNotEqual(wrapped(), fn.return_value) + fn.assert_not_called() + + @patch.dict("os.environ", {"RANK": "42"}) + @patch("torch.distributed") + def test_rank0_checks_environ(self, mock_dist): + # Arrange + fn = MagicMock() + mock_dist.is_initialized.return_value = False + # Act + wrapped = _rank0_only(fn) + # Assert + self.assertNotEqual(wrapped(), fn.return_value) + fn.assert_not_called() + + @patch("torch.distributed") + def test_rank0_only_assumes_rank0(self, mock_dist): + # Arrange + fn = MagicMock() + mock_dist.is_initialized.return_value = False + # Act + wrapped = _rank0_only(fn) + # Assert + self.assertEqual(wrapped(), fn.return_value) + fn.assert_called_once_with() diff --git a/test/unit_test/utils/test_sampling.py b/test/unit_test/utils/test_sampling.py index f3f89b1..e64f65b 100644 --- a/test/unit_test/utils/test_sampling.py +++ b/test/unit_test/utils/test_sampling.py @@ -31,7 +31,6 @@ def test_multinomial_sampling(self): config.top_k = 3 config.num_beams = 1 sampler = Sampler(config) - neg_inf = -float("inf") torch.random.manual_seed(5) x = torch.rand((2, 100)) sampled = sampler.sample(x) diff --git a/test/unit_test/utils/test_serialization.py b/test/unit_test/utils/test_serialization.py index ae70f26..efd0266 100644 --- a/test/unit_test/utils/test_serialization.py +++ b/test/unit_test/utils/test_serialization.py @@ -7,8 +7,6 @@ from neuronx_distributed.utils.serialization import SerializationManager, TensorMeta -from .. import update_result - class TestSerialization(unittest.TestCase): def test_with_class_type(self): @@ -19,94 +17,82 @@ def __init__(self, x): def increment(self): self.x += 1 - try: - data = A(3.2) - s = SerializationManager() - serialized, tx_list, tensor_meta = s.serialize(data) - obj = s.deserialize(serialized, tx_list) - self.assertTrue(data == obj) - self.assertTrue(len(tx_list) == 0) - self.assertTrue(len(tensor_meta) == 0) - except: - update_result({"inference_success": 0}) - raise + data = A(3.2) + s = SerializationManager() + serialized, tx_list, tensor_meta = s.serialize(data) + obj = s.deserialize(serialized, tx_list) + self.assertTrue(data == obj) + self.assertTrue(len(tx_list) == 0) + self.assertTrue(len(tensor_meta) == 0) def test_with_tensor(self): - try: - data = torch.ones(2, 3) - s = SerializationManager() - serialized, tx_list, tensor_meta = s.serialize(data) - expected_tensor_meta = [ - TensorMeta( - tensor_index=0, - dtype=torch.float32, - shape=torch.Size([2, 3]), - requires_grad=False, - device=device(type="cpu"), - ) - ] - obj = s.deserialize(serialized, tx_list) - self.assertTrue(torch.equal(obj, data)) - self.assertTrue(tensor_meta == expected_tensor_meta) - self.assertTrue(torch.equal(tx_list[0], torch.ones(2, 3))) - except: - update_result({"inference_success": 0}) - raise + data = torch.ones(2, 3) + s = SerializationManager() + serialized, tx_list, tensor_meta = s.serialize(data) + expected_tensor_meta = [ + TensorMeta( + tensor_index=0, + dtype=torch.float32, + shape=torch.Size([2, 3]), + requires_grad=False, + device=device(type="cpu"), + ) + ] + obj = s.deserialize(serialized, tx_list) + self.assertTrue(torch.equal(obj, data)) + self.assertTrue(tensor_meta == expected_tensor_meta) + self.assertTrue(torch.equal(tx_list[0], torch.ones(2, 3))) def test_with_mixed_type(self): class MyClass: pass cls_type = MyClass() - try: - data = { - "a": 1, - "b": [torch.ones([2, 4]), torch.ones([2, 4]), (1, 2)], - "c": (1, 2, torch.tensor(1.0)), - "d": torch.zeros([2, 4]), - "f": cls_type, - } - s = SerializationManager() - serialized, tx_list, tensor_meta = s.serialize(data) - expected_tensor_meta = [ - TensorMeta( - tensor_index=0, - dtype=torch.float32, - shape=torch.Size([2, 4]), - requires_grad=False, - device=device(type="cpu"), - ), - TensorMeta( - tensor_index=1, - dtype=torch.float32, - shape=torch.Size([2, 4]), - requires_grad=False, - device=device(type="cpu"), - ), - TensorMeta( - tensor_index=2, - dtype=torch.float32, - shape=torch.Size([]), - requires_grad=False, - device=device(type="cpu"), - ), - TensorMeta( - tensor_index=3, - dtype=torch.float32, - shape=torch.Size([2, 4]), - requires_grad=False, - device=device(type="cpu"), - ), - ] - expected_tx_list = [torch.ones([2, 4]), torch.ones([2, 4]), torch.tensor(1.0), torch.zeros([2, 4])] - obj = s.deserialize(serialized, tx_list) - self.assertTrue(data == obj) - self.assertTrue(tensor_meta == expected_tensor_meta) - for i in range(4): - self.assertTrue(torch.equal(tx_list[i], expected_tx_list[i])) - except: - update_result({"inference_success": 0}) - raise + data = { + "a": 1, + "b": [torch.ones([2, 4]), torch.ones([2, 4]), (1, 2)], + "c": (1, 2, torch.tensor(1.0)), + "d": torch.zeros([2, 4]), + "f": cls_type, + } + s = SerializationManager() + serialized, tx_list, tensor_meta = s.serialize(data) + expected_tensor_meta = [ + TensorMeta( + tensor_index=0, + dtype=torch.float32, + shape=torch.Size([2, 4]), + requires_grad=False, + device=device(type="cpu"), + ), + TensorMeta( + tensor_index=1, + dtype=torch.float32, + shape=torch.Size([2, 4]), + requires_grad=False, + device=device(type="cpu"), + ), + TensorMeta( + tensor_index=2, + dtype=torch.float32, + shape=torch.Size([]), + requires_grad=False, + device=device(type="cpu"), + ), + TensorMeta( + tensor_index=3, + dtype=torch.float32, + shape=torch.Size([2, 4]), + requires_grad=False, + device=device(type="cpu"), + ), + ] + expected_tx_list = [torch.ones([2, 4]), torch.ones([2, 4]), torch.tensor(1.0), torch.zeros([2, 4])] + obj = s.deserialize(serialized, tx_list) + self.assertTrue(data == obj) + self.assertTrue(tensor_meta == expected_tensor_meta) + for i in range(4): + self.assertTrue(torch.equal(tx_list[i], expected_tx_list[i])) if __name__ == "__main__": diff --git a/test/unit_test/utils/test_tensor_utils.py b/test/unit_test/utils/test_tensor_utils.py index 30c97d0..15942c3 100644 --- a/test/unit_test/utils/test_tensor_utils.py +++ b/test/unit_test/utils/test_tensor_utils.py @@ -8,8 +8,6 @@ from neuronx_distributed.utils.tensor_utils import cumsum -from .. import update_result - def get_cumsum_test_cases(): testcase_tensor_shapes = [ @@ -35,18 +33,14 @@ class TestCumSum(unittest.TestCase): @parameterized.expand(get_cumsum_test_cases()) def test_cumsum(self, tensor_shape, dtype): - try: - # Set random seed for reproducibility - torch.manual_seed(tensor_shape[0] * tensor_shape[1]) - # Generate random 0-1 matrix - ip = torch.randint(high=2, size=tensor_shape, dtype=dtype) - op = cumsum(ip) - op_gt = torch.cumsum(ip, dim=0) - # Check that outputs match - torch.testing.assert_close(op, op_gt) - except: - update_result({"inference_success": 0}) - raise + # Set random seed for reproducibility + torch.manual_seed(tensor_shape[0] * tensor_shape[1]) + # Generate random 0-1 matrix + ip = torch.randint(high=2, size=tensor_shape, dtype=dtype) + op = cumsum(ip) + op_gt = torch.cumsum(ip, dim=0) + # Check that outputs match + torch.testing.assert_close(op, op_gt) if __name__ == "__main__": diff --git a/test/unit_test/wrapper/test_model_wrapper.py b/test/unit_test/wrapper/test_model_wrapper.py index 0ad242a..3f0e796 100644 --- a/test/unit_test/wrapper/test_model_wrapper.py +++ b/test/unit_test/wrapper/test_model_wrapper.py @@ -9,8 +9,6 @@ import neuronx_distributed as nxd -from .. import update_result - def get_model(): seq_len = 512 @@ -64,52 +62,47 @@ class TestModelWrapper(unittest.TestCase): @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) @patch("torch.distributed.get_rank") def test_model_wrapper(self, rank_mock): - try: - pipeline_cuts = [ - "transformer.h.1", - "transformer.h.2", - "transformer.h.3", - "transformer.h.4", - "transformer.h.5", - "transformer.h.6", - "transformer.h.7", - ] - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - pipeline_parallel_size=8, - pipeline_config={ - "transformer_layer_cls": GPT2Block, - "tracer_cls": "hf", - "num_microbatches": 1, - "output_loss_value_spec": True, - "input_names": ["input_ids", "attention_mask", "labels"], - "pipeline_cuts": pipeline_cuts, - "param_init_fn": None, - "leaf_module_cls": ["GPT2Block"], - "use_zero1_optimizer": True, - "use_optimizer_wrapper": True, - }, - optimizer_config={ - "zero_one_enabled": True, - "grad_clipping": True, - "max_grad_norm": 1.0, - }, - sequence_parallel=True, - activation_checkpoint_config="full", - ) - model = nxd.initialize_parallel_model(nxd_config, get_model) - - assert isinstance(model, nxd.trainer.model.NxDModel) - assert model.nxd_config == nxd_config - assert model.pp_enabled - assert model.dtype == torch.float32 - model_str = str(model) - assert "NxDPPModel" in model_str - assert "NxDCheckpointWrapper" in model_str + pipeline_cuts = [ + "transformer.h.1", + "transformer.h.2", + "transformer.h.3", + "transformer.h.4", + "transformer.h.5", + "transformer.h.6", + "transformer.h.7", + ] + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=8, + pipeline_parallel_size=8, + pipeline_config={ + "transformer_layer_cls": GPT2Block, + "tracer_cls": "hf", + "num_microbatches": 1, + "output_loss_value_spec": True, + "input_names": ["input_ids", "attention_mask", "labels"], + "pipeline_cuts": pipeline_cuts, + "param_init_fn": None, + "leaf_module_cls": ["GPT2Block"], + "use_zero1_optimizer": True, + "use_optimizer_wrapper": True, + }, + optimizer_config={ + "zero_one_enabled": True, + "grad_clipping": True, + "max_grad_norm": 1.0, + }, + sequence_parallel=True, + activation_checkpoint_config="full", + ) + model = nxd.initialize_parallel_model(nxd_config, get_model) - except: - update_result({"inference_success": 0}) - raise + assert isinstance(model, nxd.trainer.model.NxDModel) + assert model.nxd_config == nxd_config + assert model.pp_enabled + assert model.dtype == torch.float32 + model_str = str(model) + assert "NxDPPModel" in model_str + assert "NxDCheckpointWrapper" in model_str @patch("neuronx_distributed.parallel_layers.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) @patch( @@ -118,30 +111,25 @@ def test_model_wrapper(self, rank_mock): @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) @patch("torch.distributed.get_rank") def test_model_wrapper_no_pp(self, rank_mock): - try: - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - optimizer_config={ - "zero_one_enabled": True, - "grad_clipping": True, - "max_grad_norm": 1.0, - }, - sequence_parallel=True, - activation_checkpoint_config="full", - ) - model = nxd.initialize_parallel_model(nxd_config, get_model) - - assert isinstance(model, nxd.trainer.model.NxDModel) - assert model.nxd_config == nxd_config - assert not model.pp_enabled - assert model.dtype == torch.float32 - model_str = str(model) - assert "NxDPPModel" not in model_str - assert "NxDCheckpointWrapper" in model_str + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=8, + optimizer_config={ + "zero_one_enabled": True, + "grad_clipping": True, + "max_grad_norm": 1.0, + }, + sequence_parallel=True, + activation_checkpoint_config="full", + ) + model = nxd.initialize_parallel_model(nxd_config, get_model) - except: - update_result({"inference_success": 0}) - raise + assert isinstance(model, nxd.trainer.model.NxDModel) + assert model.nxd_config == nxd_config + assert not model.pp_enabled + assert model.dtype == torch.float32 + model_str = str(model) + assert "NxDPPModel" not in model_str + assert "NxDCheckpointWrapper" in model_str @patch("neuronx_distributed.parallel_layers.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) @patch( @@ -150,26 +138,22 @@ def test_model_wrapper_no_pp(self, rank_mock): @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) @patch("torch.distributed.get_rank") def test_model_wrapper_no_pp_load_state_dict_returns_load_result(self, rank_mock): - try: - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - optimizer_config={ - "zero_one_enabled": True, - "grad_clipping": True, - "max_grad_norm": 1.0, - }, - sequence_parallel=True, - activation_checkpoint_config="full", - ) - model = nxd.initialize_parallel_model(nxd_config, get_model) - state_dict = model.state_dict() - load_result = model.load_state_dict(state_dict) - assert load_result is not None - assert len(load_result.missing_keys) == 0 - assert len(load_result.unexpected_keys) == 0 - except: - update_result({"inference_success": 0}) - raise + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=8, + optimizer_config={ + "zero_one_enabled": True, + "grad_clipping": True, + "max_grad_norm": 1.0, + }, + sequence_parallel=True, + activation_checkpoint_config="full", + ) + model = nxd.initialize_parallel_model(nxd_config, get_model) + state_dict = model.state_dict() + load_result = model.load_state_dict(state_dict) + assert load_result is not None + assert len(load_result.missing_keys) == 0 + assert len(load_result.unexpected_keys) == 0 @patch("neuronx_distributed.pipeline.model.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) @patch( @@ -193,50 +177,45 @@ def test_model_wrapper_no_pp_load_state_dict_returns_load_result(self, rank_mock @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) @patch("torch.distributed.get_rank") def test_model_wrapper_pp_load_state_dict_returns_load_result(self, rank_mock): - try: - pipeline_cuts = [ - "transformer.h.1", - "transformer.h.2", - "transformer.h.3", - "transformer.h.4", - "transformer.h.5", - "transformer.h.6", - "transformer.h.7", - ] - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - pipeline_parallel_size=8, - pipeline_config={ - "transformer_layer_cls": GPT2Block, - "tracer_cls": "hf", - "num_microbatches": 1, - "output_loss_value_spec": True, - "input_names": ["input_ids", "attention_mask", "labels"], - "pipeline_cuts": pipeline_cuts, - "param_init_fn": None, - "leaf_module_cls": ["GPT2Block"], - "use_zero1_optimizer": True, - "use_optimizer_wrapper": True, - }, - optimizer_config={ - "zero_one_enabled": True, - "grad_clipping": True, - "max_grad_norm": 1.0, - }, - sequence_parallel=True, - activation_checkpoint_config="full", - ) - model = nxd.initialize_parallel_model(nxd_config, get_model) + pipeline_cuts = [ + "transformer.h.1", + "transformer.h.2", + "transformer.h.3", + "transformer.h.4", + "transformer.h.5", + "transformer.h.6", + "transformer.h.7", + ] + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=8, + pipeline_parallel_size=8, + pipeline_config={ + "transformer_layer_cls": GPT2Block, + "tracer_cls": "hf", + "num_microbatches": 1, + "output_loss_value_spec": True, + "input_names": ["input_ids", "attention_mask", "labels"], + "pipeline_cuts": pipeline_cuts, + "param_init_fn": None, + "leaf_module_cls": ["GPT2Block"], + "use_zero1_optimizer": True, + "use_optimizer_wrapper": True, + }, + optimizer_config={ + "zero_one_enabled": True, + "grad_clipping": True, + "max_grad_norm": 1.0, + }, + sequence_parallel=True, + activation_checkpoint_config="full", + ) + model = nxd.initialize_parallel_model(nxd_config, get_model) - state_dict = model.state_dict() - load_result = model.load_state_dict(state_dict) - assert load_result is not None - assert len(load_result.missing_keys) == 0 - assert len(load_result.unexpected_keys) == 0 - - except: - update_result({"inference_success": 0}) - raise + state_dict = model.state_dict() + load_result = model.load_state_dict(state_dict) + assert load_result is not None + assert len(load_result.missing_keys) == 0 + assert len(load_result.unexpected_keys) == 0 @patch("neuronx_distributed.pipeline.model.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) @patch( @@ -260,49 +239,44 @@ def test_model_wrapper_pp_load_state_dict_returns_load_result(self, rank_mock): @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) @patch("torch.distributed.get_rank") def test_model_wrapper_return_state_dict_with_prefix(self, rank_mock): - try: - pipeline_cuts = [ - "transformer.h.1", - "transformer.h.2", - "transformer.h.3", - "transformer.h.4", - "transformer.h.5", - "transformer.h.6", - "transformer.h.7", - ] - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - pipeline_parallel_size=8, - pipeline_config={ - "transformer_layer_cls": GPT2Block, - "tracer_cls": "hf", - "num_microbatches": 1, - "output_loss_value_spec": True, - "input_names": ["input_ids", "attention_mask", "labels"], - "pipeline_cuts": pipeline_cuts, - "param_init_fn": None, - "leaf_module_cls": ["GPT2Block"], - "use_zero1_optimizer": True, - "use_optimizer_wrapper": True, - }, - optimizer_config={ - "zero_one_enabled": True, - "grad_clipping": True, - "max_grad_norm": 1.0, - }, - sequence_parallel=True, - activation_checkpoint_config="full", - ) - model = nxd.initialize_parallel_model(nxd_config, get_model) - - prefix = 'model_prefix.' - state_dict = model.state_dict(prefix=prefix) - for key in state_dict: - assert key.startswith(prefix) + pipeline_cuts = [ + "transformer.h.1", + "transformer.h.2", + "transformer.h.3", + "transformer.h.4", + "transformer.h.5", + "transformer.h.6", + "transformer.h.7", + ] + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=8, + pipeline_parallel_size=8, + pipeline_config={ + "transformer_layer_cls": GPT2Block, + "tracer_cls": "hf", + "num_microbatches": 1, + "output_loss_value_spec": True, + "input_names": ["input_ids", "attention_mask", "labels"], + "pipeline_cuts": pipeline_cuts, + "param_init_fn": None, + "leaf_module_cls": ["GPT2Block"], + "use_zero1_optimizer": True, + "use_optimizer_wrapper": True, + }, + optimizer_config={ + "zero_one_enabled": True, + "grad_clipping": True, + "max_grad_norm": 1.0, + }, + sequence_parallel=True, + activation_checkpoint_config="full", + ) + model = nxd.initialize_parallel_model(nxd_config, get_model) - except Exception as e: - update_result({"inference_success": 0}) - raise + prefix = 'model_prefix.' + state_dict = model.state_dict(prefix=prefix) + for key in state_dict: + assert key.startswith(prefix) if __name__ == "__main__": diff --git a/test/unit_test/wrapper/test_nxd_config.py b/test/unit_test/wrapper/test_nxd_config.py index 6ced1bc..f2366cf 100644 --- a/test/unit_test/wrapper/test_nxd_config.py +++ b/test/unit_test/wrapper/test_nxd_config.py @@ -4,8 +4,6 @@ import neuronx_distributed as nxd -from .. import update_result - # Third Party @@ -28,35 +26,31 @@ class TestNxDConfig(unittest.TestCase): ) @patch("torch.distributed.get_rank") def test_neuronx_distributed_config0(self, rank_mock): - try: - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - pipeline_parallel_size=4, - pipeline_config=None, - optimizer_config=None, - activation_checkpoint_config=TestObject, - pad_model=False, - sequence_parallel=False, - model_init_config=None, - ) - - assert nxd_config["optimizer_config"] == { - "zero_one_enabled": False, - "grad_clipping": True, - "max_grad_norm": 1.0, - } - assert nxd_config["model_init_config"] == { - "sequential_move_factor": 11, - "meta_device_init": False, - "param_init_fn": None, - } - assert nxd_config["pipeline_config"] is None - assert nxd_config["activation_checkpoint_config"] == TestObject - assert nxd_config["pad_model"] == False - assert nxd_config["sequence_parallel"] == False - except: - update_result({"inference_success": 0}) - raise + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=8, + pipeline_parallel_size=4, + pipeline_config=None, + optimizer_config=None, + activation_checkpoint_config=TestObject, + pad_model=False, + sequence_parallel=False, + model_init_config=None, + ) + + assert nxd_config["optimizer_config"] == { + "zero_one_enabled": False, + "grad_clipping": True, + "max_grad_norm": 1.0, + } + assert nxd_config["model_init_config"] == { + "sequential_move_factor": 11, + "meta_device_init": False, + "param_init_fn": None, + } + assert nxd_config["pipeline_config"] is None + assert nxd_config["activation_checkpoint_config"] == TestObject + assert nxd_config["pad_model"] is False + assert nxd_config["sequence_parallel"] is False @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_size", MagicMock(return_value=8)) @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_rank", MagicMock(return_value=1)) @@ -72,26 +66,22 @@ def test_neuronx_distributed_config0(self, rank_mock): ) @patch("torch.distributed.get_rank") def test_neuronx_distributed_config1(self, rank_mock): - try: - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - pipeline_parallel_size=4, - pipeline_config=None, - optimizer_config={"zero_one_enabled": True}, - activation_checkpoint_config=None, - pad_model=False, - sequence_parallel=False, - model_init_config=None, - ) - - assert nxd_config["optimizer_config"] == { - "zero_one_enabled": True, - "grad_clipping": True, - "max_grad_norm": 1.0, - } - except: - update_result({"inference_success": 0}) - raise + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=8, + pipeline_parallel_size=4, + pipeline_config=None, + optimizer_config={"zero_one_enabled": True}, + activation_checkpoint_config=None, + pad_model=False, + sequence_parallel=False, + model_init_config=None, + ) + + assert nxd_config["optimizer_config"] == { + "zero_one_enabled": True, + "grad_clipping": True, + "max_grad_norm": 1.0, + } @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_size", MagicMock(return_value=8)) @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_rank", MagicMock(return_value=1)) @@ -107,23 +97,113 @@ def test_neuronx_distributed_config1(self, rank_mock): ) @patch("torch.distributed.get_rank") def test_neuronx_distributed_config2(self, rank_mock): - try: - model_init_config = {"meta_device_init": True, "param_init_fn": lambda x: None} - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - pipeline_parallel_size=4, - pipeline_config=None, - optimizer_config=None, - activation_checkpoint_config=None, - pad_model=False, - sequence_parallel=False, - model_init_config=model_init_config, - ) - - assert nxd_config["model_init_config"] == model_init_config - except: - update_result({"inference_success": 0}) - raise + model_init_config = {"meta_device_init": True, "param_init_fn": lambda x: None} + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=8, + pipeline_parallel_size=4, + pipeline_config=None, + optimizer_config=None, + activation_checkpoint_config=None, + pad_model=False, + sequence_parallel=False, + model_init_config=model_init_config, + ) + + assert nxd_config["model_init_config"] == model_init_config + + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_size", MagicMock(return_value=8)) + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_rank", MagicMock(return_value=1)) + @patch("neuronx_distributed.pipeline.model.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) + @patch( + "neuronx_distributed.pipeline.model.parallel_state.model_parallel_is_initialized", MagicMock(return_value=True) + ) + @patch( + "neuronx_distributed.pipeline.model.parallel_state.get_pipeline_model_parallel_size", MagicMock(return_value=4) + ) + @patch( + "neuronx_distributed.pipeline.model.parallel_state.get_pipeline_model_parallel_rank", MagicMock(return_value=1) + ) + @patch("torch.distributed.get_rank") + def test_neuronx_distributed_config_check_mixed_precision_setting(self, rank_mock): + mixed_precision_config = { + "use_master_weights": True, + "use_fp32_grad_acc": True, + "use_master_weights_in_ckpt": False, + } + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=8, + pipeline_parallel_size=4, + pipeline_config=None, + optimizer_config=None, + activation_checkpoint_config=None, + pad_model=False, + sequence_parallel=False, + mixed_precision_config=mixed_precision_config + ) + + assert nxd_config["mixed_precision_config"] == mixed_precision_config + + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_size", MagicMock(return_value=8)) + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_rank", MagicMock(return_value=1)) + @patch("neuronx_distributed.pipeline.model.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) + @patch( + "neuronx_distributed.pipeline.model.parallel_state.model_parallel_is_initialized", MagicMock(return_value=True) + ) + @patch( + "neuronx_distributed.pipeline.model.parallel_state.get_pipeline_model_parallel_size", MagicMock(return_value=4) + ) + @patch( + "neuronx_distributed.pipeline.model.parallel_state.get_pipeline_model_parallel_rank", MagicMock(return_value=1) + ) + @patch("torch.distributed.get_rank") + def test_neuronx_distributed_config_check_default_mixed_precision_setting_with_optimizer_config_none(self, rank_mock): + mixed_precision_config = { + "use_master_weights": False, + "use_fp32_grad_acc": False, + "use_master_weights_in_ckpt": False, + } + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=8, + pipeline_parallel_size=4, + pipeline_config=None, + optimizer_config=None, + activation_checkpoint_config=None, + pad_model=False, + sequence_parallel=False, + ) + + assert nxd_config["mixed_precision_config"] == mixed_precision_config + + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_size", MagicMock(return_value=8)) + @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_rank", MagicMock(return_value=1)) + @patch("neuronx_distributed.pipeline.model.parallel_state.initialize_model_parallel", MagicMock(return_value=None)) + @patch( + "neuronx_distributed.pipeline.model.parallel_state.model_parallel_is_initialized", MagicMock(return_value=True) + ) + @patch( + "neuronx_distributed.pipeline.model.parallel_state.get_pipeline_model_parallel_size", MagicMock(return_value=4) + ) + @patch( + "neuronx_distributed.pipeline.model.parallel_state.get_pipeline_model_parallel_rank", MagicMock(return_value=1) + ) + @patch("torch.distributed.get_rank") + def test_neuronx_distributed_config_check_default_mixed_precision_setting_with_optimizer_config_has_zero(self, rank_mock): + mixed_precision_config = { + "use_master_weights": True, + "use_fp32_grad_acc": True, + "use_master_weights_in_ckpt": False, + } + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=8, + pipeline_parallel_size=4, + pipeline_config=None, + optimizer_config={"zero_one_enabled": True}, + activation_checkpoint_config=None, + pad_model=False, + sequence_parallel=False, + ) + + assert nxd_config["mixed_precision_config"] == mixed_precision_config @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_size", MagicMock(return_value=8)) @patch("neuronx_distributed.parallel_layers.layers.get_tensor_model_parallel_rank", MagicMock(return_value=1)) @@ -138,27 +218,23 @@ def test_neuronx_distributed_config2(self, rank_mock): "neuronx_distributed.pipeline.model.parallel_state.get_pipeline_model_parallel_rank", MagicMock(return_value=1) ) @patch("torch.distributed.get_rank") - def test_neuronx_distributed_config3(self, rank_mock): - try: - mixed_precision_config = { - "use_master_weights": True, - "use_fp32_grad_acc": True, - "use_master_weights_in_ckpt": False, - } - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - pipeline_parallel_size=4, - pipeline_config=None, - optimizer_config=None, - activation_checkpoint_config=None, - pad_model=False, - sequence_parallel=False, - ) - - assert nxd_config["mixed_precision_config"] == mixed_precision_config - except: - update_result({"inference_success": 0}) - raise + def test_neuronx_distributed_config_check_default_mixed_precision_setting_with_optimizer_config_has_no_zero(self, rank_mock): + mixed_precision_config = { + "use_master_weights": False, + "use_fp32_grad_acc": False, + "use_master_weights_in_ckpt": False, + } + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=8, + pipeline_parallel_size=4, + pipeline_config=None, + optimizer_config={"zero_one_enabled": False}, + activation_checkpoint_config=None, + pad_model=False, + sequence_parallel=False, + ) + + assert nxd_config["mixed_precision_config"] == mixed_precision_config if __name__ == "__main__": diff --git a/test/unit_test/wrapper/test_optimizer_wrapper.py b/test/unit_test/wrapper/test_optimizer_wrapper.py index b764c2e..46a6c4b 100644 --- a/test/unit_test/wrapper/test_optimizer_wrapper.py +++ b/test/unit_test/wrapper/test_optimizer_wrapper.py @@ -9,9 +9,6 @@ import neuronx_distributed as nxd -from .. import update_result - - def get_model(): seq_len = 512 model_config = GPT2Config( @@ -92,56 +89,51 @@ class TestOptimizerWrapper(unittest.TestCase): @patch("neuronx_distributed.utils.model_utils.get_local_world_size", MagicMock(return_value=32)) @patch("torch.distributed.get_rank") def test_optimizer_wrapper(self, rank_mock): - try: - pipeline_cuts = [ - "transformer.h.1", - "transformer.h.2", - "transformer.h.3", - "transformer.h.4", - "transformer.h.5", - "transformer.h.6", - "transformer.h.7", - ] - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - pipeline_parallel_size=8, - pipeline_config={ - "transformer_layer_cls": GPT2Block, - "tracer_cls": "hf", - "num_microbatches": 1, - "output_loss_value_spec": True, - "input_names": ["input_ids", "attention_mask", "labels"], - "pipeline_cuts": pipeline_cuts, - "param_init_fn": None, - "leaf_module_cls": ["GPT2Block"], - "use_zero1_optimizer": True, - "use_optimizer_wrapper": True, - }, - optimizer_config={ - "zero_one_enabled": True, - "grad_clipping": True, - "max_grad_norm": 1.0, - }, - sequence_parallel=True, - activation_checkpoint_config="full", - ) - model = nxd.initialize_parallel_model(nxd_config, get_model) - optimizer = nxd.initialize_parallel_optimizer(nxd_config, torch.optim.AdamW, model.parameters(), lr=1e-3) - - assert optimizer.nxd_config == nxd_config - assert isinstance(optimizer, nxd.trainer.optimizer.NxDOptimizer) - assert isinstance(optimizer.optimizer, nxd.optimizer.NeuronZero1Optimizer) - assert isinstance(optimizer.optimizer.base_optimizer, torch.optim.AdamW) - assert len(list(model.parameters())) == len(optimizer.params) - assert optimizer.grad_norm is None - - for method in ["step", "zero_grad", "state_dict"]: - getattr(optimizer, method)() - assert getattr(nxd.optimizer.zero_redundancy_optimizer.NeuronZero1Optimizer, method).called - - except: - update_result({"inference_success": 0}) - raise + pipeline_cuts = [ + "transformer.h.1", + "transformer.h.2", + "transformer.h.3", + "transformer.h.4", + "transformer.h.5", + "transformer.h.6", + "transformer.h.7", + ] + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=8, + pipeline_parallel_size=8, + pipeline_config={ + "transformer_layer_cls": GPT2Block, + "tracer_cls": "hf", + "num_microbatches": 1, + "output_loss_value_spec": True, + "input_names": ["input_ids", "attention_mask", "labels"], + "pipeline_cuts": pipeline_cuts, + "param_init_fn": None, + "leaf_module_cls": ["GPT2Block"], + "use_zero1_optimizer": True, + "use_optimizer_wrapper": True, + }, + optimizer_config={ + "zero_one_enabled": True, + "grad_clipping": True, + "max_grad_norm": 1.0, + }, + sequence_parallel=True, + activation_checkpoint_config="full", + ) + model = nxd.initialize_parallel_model(nxd_config, get_model) + optimizer = nxd.initialize_parallel_optimizer(nxd_config, torch.optim.AdamW, model.parameters(), lr=1e-3) + + assert optimizer.nxd_config == nxd_config + assert isinstance(optimizer, nxd.trainer.optimizer.NxDOptimizer) + assert isinstance(optimizer.optimizer, nxd.optimizer.NeuronZero1Optimizer) + assert isinstance(optimizer.optimizer.base_optimizer, torch.optim.AdamW) + assert len(list(model.parameters())) == len(optimizer.params) + assert optimizer.grad_norm is None + + for method in ["step", "zero_grad", "state_dict"]: + getattr(optimizer, method)() + assert getattr(nxd.optimizer.zero_redundancy_optimizer.NeuronZero1Optimizer, method).called @patch("torch.optim.AdamW.step", MagicMock(return_value=None)) @patch("torch.optim.AdamW.zero_grad", MagicMock(return_value=None)) @@ -177,55 +169,50 @@ def test_optimizer_wrapper(self, rank_mock): @patch("neuronx_distributed.trainer.optimizer.grads.clip_grad_norm", MagicMock(return_value=None)) @patch("torch.distributed.get_rank") def test_optimizer_wrapper_no_zero1(self, rank_mock): - try: - pipeline_cuts = [ - "transformer.h.1", - "transformer.h.2", - "transformer.h.3", - "transformer.h.4", - "transformer.h.5", - "transformer.h.6", - "transformer.h.7", - ] - nxd_config = nxd.neuronx_distributed_config( - tensor_parallel_size=8, - pipeline_parallel_size=8, - pipeline_config={ - "transformer_layer_cls": GPT2Block, - "tracer_cls": "hf", - "num_microbatches": 1, - "output_loss_value_spec": True, - "input_names": ["input_ids", "attention_mask", "labels"], - "pipeline_cuts": pipeline_cuts, - "param_init_fn": None, - "leaf_module_cls": ["GPT2Block"], - "use_zero1_optimizer": True, - "use_optimizer_wrapper": True, - }, - optimizer_config={ - "zero_one_enabled": False, - "grad_clipping": True, - "max_grad_norm": 1.0, - }, - sequence_parallel=True, - activation_checkpoint_config="full", - ) - model = nxd.initialize_parallel_model(nxd_config, get_model) - optimizer = nxd.initialize_parallel_optimizer(nxd_config, torch.optim.AdamW, model.parameters(), lr=1e-3) - - assert optimizer.nxd_config == nxd_config - assert isinstance(optimizer, nxd.trainer.optimizer.NxDOptimizer) - assert isinstance(optimizer.optimizer, torch.optim.AdamW) - assert len(list(model.parameters())) == len(optimizer.params) - assert optimizer.grad_norm is None - - for method in ["step", "zero_grad", "state_dict"]: - getattr(optimizer, method)() - assert getattr(torch.optim.AdamW, method).called, method - - except: - update_result({"inference_success": 0}) - raise + pipeline_cuts = [ + "transformer.h.1", + "transformer.h.2", + "transformer.h.3", + "transformer.h.4", + "transformer.h.5", + "transformer.h.6", + "transformer.h.7", + ] + nxd_config = nxd.neuronx_distributed_config( + tensor_parallel_size=8, + pipeline_parallel_size=8, + pipeline_config={ + "transformer_layer_cls": GPT2Block, + "tracer_cls": "hf", + "num_microbatches": 1, + "output_loss_value_spec": True, + "input_names": ["input_ids", "attention_mask", "labels"], + "pipeline_cuts": pipeline_cuts, + "param_init_fn": None, + "leaf_module_cls": ["GPT2Block"], + "use_zero1_optimizer": True, + "use_optimizer_wrapper": True, + }, + optimizer_config={ + "zero_one_enabled": False, + "grad_clipping": True, + "max_grad_norm": 1.0, + }, + sequence_parallel=True, + activation_checkpoint_config="full", + ) + model = nxd.initialize_parallel_model(nxd_config, get_model) + optimizer = nxd.initialize_parallel_optimizer(nxd_config, torch.optim.AdamW, model.parameters(), lr=1e-3) + + assert optimizer.nxd_config == nxd_config + assert isinstance(optimizer, nxd.trainer.optimizer.NxDOptimizer) + assert isinstance(optimizer.optimizer, torch.optim.AdamW) + assert len(list(model.parameters())) == len(optimizer.params) + assert optimizer.grad_norm is None + + for method in ["step", "zero_grad", "state_dict"]: + getattr(optimizer, method)() + assert getattr(torch.optim.AdamW, method).called, method if __name__ == "__main__":