diff --git a/xtuner/v1/loss/__init__.py b/xtuner/v1/loss/__init__.py index 1aa575601..e54e39c54 100644 --- a/xtuner/v1/loss/__init__.py +++ b/xtuner/v1/loss/__init__.py @@ -1,5 +1,5 @@ from .base_loss_ctx import BaseLossConfig, BaseLossContext, BaseLossKwargs -from .ce_loss import CELossConfig, CELossContext +from .ce_loss import CELossConfig, CELossContext, LMHeadLossContext from .chunk_loss import ChunkLoss from .moe_loss import ( BalancingLoss, @@ -28,6 +28,7 @@ "BaseLossConfig", "BaseLossContext", "BaseLossKwargs", + "LMHeadLossContext", ] import torch diff --git a/xtuner/v1/loss/base_loss_ctx.py b/xtuner/v1/loss/base_loss_ctx.py index 860873c41..531a6dfda 100644 --- a/xtuner/v1/loss/base_loss_ctx.py +++ b/xtuner/v1/loss/base_loss_ctx.py @@ -3,17 +3,10 @@ from typing import Annotated, Any, Literal, TypeVar import torch -import torch.distributed as dist import torch.nn as nn from cyclopts import Parameter from pydantic import BaseModel, ConfigDict from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.nn.functional import all_reduce -from typing_extensions import Self - -from xtuner.v1.loss.utils import sp_split - -from .chunk_loss import ChunkLoss # Do loss calibration among dp, sp and grad accumulation: @@ -46,18 +39,13 @@ class BaseLossKwargs(BaseModel): - """Everything needed to compute the loss.""" - - model_config = ConfigDict(title="loss keyword arguments", extra="forbid", arbitrary_types_allowed=True) - shifted_labels: torch.Tensor + """Everything needed to compute the loss. - def sp_split(self, sp_mesh: DeviceMesh) -> Self: - self.shifted_labels = sp_split(self.shifted_labels, sp_mesh=sp_mesh, split_dim=1, padding_value=-100) - return self + Subclasses should implement sp_split() and to() methods if they contain tensors that need to be split across + sequence parallel mesh or moved to device. + """ - def to(self, device: torch.device | str) -> Self: - self.shifted_labels = self.shifted_labels.to(device) - return self + model_config = ConfigDict(title="loss keyword arguments", extra="forbid", arbitrary_types_allowed=True) def chunk(self, chunk_size) -> list["BaseLossKwargs"]: tensor_fields: dict[str, tuple[torch.Tensor, ...]] = {} @@ -114,10 +102,13 @@ class BaseLossConfig(BaseModel): chunk_size: Annotated[int | None, Parameter(help="chunk size when mode is chunk")] = 1024 @property + @abstractmethod def loss_ctx_cls(self) -> type["BaseLossContext"]: raise NotImplementedError + # TODO: private property maybe not a good idea @property + @abstractmethod def _loss_kwargs_cls(self) -> type["BaseLossKwargs"]: raise NotImplementedError @@ -160,72 +151,10 @@ def __init__(self, loss_cfg: BaseLossConfig, loss_kwargs: BaseLossKwargs): self._batch_size = 1 @staticmethod - @abstractmethod - def build_batches(loss_ctx_list: list[_BaseLossContextT], *args, **kwargs) -> list[_BaseLossContextT]: ... - - @abstractmethod - def loss_fn( - self, - hidden_states: torch.Tensor, - head_weight: torch.Tensor, - head_bias: torch.Tensor | None, - loss_kwargs: BaseLossKwargs, - ) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]: - """Step 2.a and 2.b in the loss calculation.""" - ... - - def eager_mode( - self, - hidden_states: torch.Tensor, - head_weight: torch.Tensor, - head_bias: torch.Tensor | None, - loss_kwargs: BaseLossKwargs, - ): - return self.loss_fn(hidden_states, head_weight, head_bias, loss_kwargs) - - def chunk_mode( - self, - hidden_states: torch.Tensor, - head_weight: torch.Tensor, - head_bias: torch.Tensor | None, - loss_kwargs: BaseLossKwargs, - ): - assert self.loss_cfg.chunk_size is not None, "chunk_size must be set in chunk mode" - - chunks = loss_kwargs.chunk(self.loss_cfg.chunk_size) - loss, extra_info = ChunkLoss.apply( - hidden_states, head_weight, head_bias, self.loss_fn, chunks, self.loss_cfg.chunk_size - ) - return loss, (None, extra_info) - - def forward( - self, - hidden_states: torch.Tensor, - head_weight: torch.Tensor, - head_bias: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]: - from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo - - assert self.loss_kwargs is not None, "loss_kwargs must be set before calling forward" - if head_bias is not None: - raise NotImplementedError("Loss does not support head_bias yet.") - - if self.loss_cfg.mode == "eager": - loss, (logits, extra_info) = self.eager_mode(hidden_states, head_weight, head_bias, self.loss_kwargs) - else: - loss, (logits, extra_info) = self.chunk_mode(hidden_states, head_weight, head_bias, self.loss_kwargs) - - # TODO: yanhuida, should be removed - if not isinstance(extra_info, ModelForwardExtraLogInfo): - extra_info = ModelForwardExtraLogInfo(extra_info) - - extra_info["local_base_loss"] = loss.detach().clone() - - # Step 2.c in the loss calculation: reduce the loss over all ranks using all_reduce with autograd support - if dist.is_initialized(): - loss = all_reduce(loss, op=dist.ReduceOp.SUM, group=dist.group.WORLD) - - return loss, (logits, extra_info) + def build_batches(loss_ctx_list: list[_BaseLossContextT], *args, **kwargs) -> list[_BaseLossContextT]: + for ctx in loss_ctx_list: + ctx._batch_size = len(loss_ctx_list) + return loss_ctx_list @classmethod def cat(cls: type[_BaseLossContextT], chunks: list[_BaseLossContextT]) -> _BaseLossContextT: diff --git a/xtuner/v1/loss/ce_loss.py b/xtuner/v1/loss/ce_loss.py index da305783e..4c29f58d1 100644 --- a/xtuner/v1/loss/ce_loss.py +++ b/xtuner/v1/loss/ce_loss.py @@ -6,8 +6,10 @@ import torch.nn.functional as F from cyclopts import Parameter from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.nn.functional import all_reduce from xtuner.v1.loss import BaseLossConfig, BaseLossContext, BaseLossKwargs +from xtuner.v1.loss.chunk_loss import ChunkLoss from xtuner.v1.utils.device import get_device # from xtuner.v1.profiler.prober import ProberList @@ -37,7 +39,11 @@ class CELossConfig(BaseLossConfig): def loss_ctx_cls(self) -> type["CELossContext"]: return CELossContext - def model_post_init(self, __context: Any) -> None: + @property + def _loss_kwargs_cls(self) -> type["CELossKwargs"]: + return CELossKwargs + + def model_post_init(self, _context: Any) -> None: if self.mode == "liger": assert self.loss_reduction == "token", "Currently, cannot use liger kernel with sample or square reduction" @@ -80,8 +86,16 @@ class CELossKwargs(BaseLossKwargs): shifted_labels: torch.Tensor loss_weight: torch.Tensor | None = None + def sp_split(self, sp_mesh: DeviceMesh) -> "CELossKwargs": + self.shifted_labels = sp_split(self.shifted_labels, sp_mesh=sp_mesh, split_dim=1, padding_value=-100) + return self + + def to(self, device: torch.device | str) -> "CELossKwargs": + self.shifted_labels = self.shifted_labels.to(device) + return self + -class CELossContext(BaseLossContext): +class LMHeadLossContext(BaseLossContext): """Cross-entropy loss context for language models. Args: @@ -163,6 +177,7 @@ def build_batches( # type: ignore[override] for loss_ctx in loss_ctx_list: loss_ctx._batch_size = len(loss_ctx_list) + assert loss_ctx.loss_kwargs.loss_weight is not None loss_ctx.loss_kwargs.loss_weight /= global_denominator + 1e-12 return loss_ctx_list @@ -195,15 +210,30 @@ def loss_fn( return loss, (logits, {}) + def eager_mode( + self, + hidden_states: torch.Tensor, + head_weight: torch.Tensor, + head_bias: torch.Tensor | None, + loss_kwargs: CELossKwargs, + ) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]: + return self.loss_fn(hidden_states, head_weight, head_bias, loss_kwargs) + def chunk_mode( self, hidden_states: torch.Tensor, head_weight: torch.Tensor, head_bias: torch.Tensor | None, loss_kwargs: CELossKwargs, - ): + ) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]: if self.loss_cfg.mode == "chunk": - return super().chunk_mode(hidden_states, head_weight, head_bias, loss_kwargs) + assert self.loss_cfg.chunk_size is not None, "chunk_size must be set in chunk mode" + + chunks = loss_kwargs.chunk(self.loss_cfg.chunk_size) + loss, extra_info = ChunkLoss.apply( + hidden_states, head_weight, head_bias, self.loss_fn, chunks, self.loss_cfg.chunk_size + ) + return loss, (None, extra_info) else: assert self.liger_loss_fct is not None, "liger_loss_fct must be initialized in liger mode" shifted_labels = loss_kwargs.shifted_labels # (bs, seq_len) @@ -225,3 +255,36 @@ def chunk_mode( @property def batch_size(self) -> int: return self._batch_size + + def forward( + self, + hidden_states: torch.Tensor, + head_weight: torch.Tensor, + head_bias: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]: + from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo + + assert self.loss_kwargs is not None, "loss_kwargs must be set before calling forward" + if head_bias is not None: + raise NotImplementedError("Loss does not support head_bias yet.") + + if self.loss_cfg.mode == "eager": + loss, (logits, extra_info) = self.eager_mode(hidden_states, head_weight, head_bias, self.loss_kwargs) + else: + loss, (logits, extra_info) = self.chunk_mode(hidden_states, head_weight, head_bias, self.loss_kwargs) + + # TODO: yanhuida, should be removed + if not isinstance(extra_info, ModelForwardExtraLogInfo): + extra_info = ModelForwardExtraLogInfo(extra_info) + + extra_info["local_base_loss"] = loss.detach().clone() + + # Step 2.c in the loss calculation: reduce the loss over all ranks using all_reduce with autograd support + if dist.is_initialized(): + loss = all_reduce(loss, op=dist.ReduceOp.SUM, group=dist.group.WORLD) + + return loss, (logits, extra_info) + + +# Deprecated: Use LMHeadLossContext instead. Will be removed in version 1.1.0 +CELossContext = LMHeadLossContext diff --git a/xtuner/v1/loss/moe_loss.py b/xtuner/v1/loss/moe_loss.py index bd3d776ef..b6cc8ccfb 100644 --- a/xtuner/v1/loss/moe_loss.py +++ b/xtuner/v1/loss/moe_loss.py @@ -1,4 +1,4 @@ -from typing import Annotated, Any, Literal +from typing import Annotated, Literal import torch import torch.nn as nn @@ -6,7 +6,6 @@ from pydantic import BaseModel, ConfigDict from torch import distributed as dist from torch.distributed._functional_collectives import all_reduce -from torch.distributed.device_mesh import DeviceMesh from xtuner.v1.utils.device import get_device @@ -223,7 +222,7 @@ def forward( tokens_per_expert_global = tokens_per_expert.to(router_weights.dtype) # (nlayers, ne) if self.loss_cfg.balancing_loss_global_average and dist.is_initialized(): - tokens_per_expert_global = all_reduce(tokens_per_expert_global, "sum", dist.group.WORLD) + tokens_per_expert_global = all_reduce(tokens_per_expert_global, "sum", dist.group.WORLD) # type: ignore tokens_global = tokens_per_expert_global.sum(-1) # (nlayers, ) seqlen_global = tokens_global // num_experts_per_tok routing_weights_sum_global = all_reduce_autograd(router_weights.sum(dim=1), "sum", dist.group.WORLD) @@ -327,7 +326,9 @@ def forward(self, router_logits: torch.Tensor) -> torch.Tensor: if self.loss_cfg.z_loss_global_average and dist.is_initialized(): unmasked_num = router_logits.shape[1] unmasked_num_rank = torch.tensor(unmasked_num, device=router_logits.device, dtype=torch.int64) - unmasked_num_global = all_reduce(unmasked_num_rank, "sum", dist.group.WORLD) + group = dist.group.WORLD + assert group is not None + unmasked_num_global = all_reduce(unmasked_num_rank, "sum", group) world_size = dist.get_world_size() loss = loss * unmasked_num * world_size / unmasked_num_global diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index cb3eb3ae4..24e2e5ca6 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -8,7 +8,7 @@ from itertools import chain from pathlib import Path from shutil import copy, copytree -from typing import Annotated, Any, Generator, Iterable, Literal, Mapping, Sequence, TypedDict, cast +from typing import Annotated, Any, Generator, Iterable, Literal, Mapping, Sequence, cast import torch import torch.distributed as dist @@ -620,7 +620,8 @@ def build_loss_ctx_batch( if lm_loss_ctx_list is not None: loss_ctx_cls = lm_loss_ctx_list[0].__class__ lm_loss_ctx_list = loss_ctx_cls.build_batches( - lm_loss_ctx_list, cu_seq_lens_list=cu_seq_lens_list, sp_mesh=sp_mesh) + lm_loss_ctx_list, cu_seq_lens_list=cu_seq_lens_list, sp_mesh=sp_mesh + ) if lm_loss_ctx_list is not None: for i, lm_loss_ctx in enumerate(lm_loss_ctx_list): @@ -1714,10 +1715,7 @@ def _collect_full_state_dict(self, module: nn.Module): return ret def _build_loss_ctx( - self, - loss_ctx_cfg: BaseLossConfig | None, - data_batch: list[dict], - sp_mesh: DeviceMesh | None + self, loss_ctx_cfg: BaseLossConfig | None, data_batch: list[dict], sp_mesh: DeviceMesh | None ) -> list[BaseLossContext] | None: if loss_ctx_cfg is None: return None @@ -1728,9 +1726,9 @@ def _build_loss_ctx( if first_loss_ctx is None: return None else: - ret = [first_loss_ctx] + [ - loss_ctx_cfg.build(data=data, sp_mesh=sp_mesh) for data in data_batch[1:]] - return ret + ret = [first_loss_ctx] + [loss_ctx_cfg.build(data=data, sp_mesh=sp_mesh) for data in data_batch[1:]] + return ret # type: ignore[return-value] + # NOTE: Add this overload for inferring the return type for easier type checking and using @overload # type: ignore def __call__( # type: ignore diff --git a/xtuner/v1/model/dense/dense.py b/xtuner/v1/model/dense/dense.py index c0bc66cd6..ee54f0bb2 100644 --- a/xtuner/v1/model/dense/dense.py +++ b/xtuner/v1/model/dense/dense.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from pathlib import Path -from typing import Self, cast, Literal +from typing import Self, cast import torch import torch.distributed as dist @@ -20,7 +20,7 @@ from xtuner.v1.config import FSDPConfig from xtuner.v1.data_proto import SequenceContext from xtuner.v1.float8.float8_handler import Float8Handler -from xtuner.v1.loss import CELossContext, BaseLossContext +from xtuner.v1.loss import BaseLossContext, CELossContext from xtuner.v1.model.base import ( DEFAULT_FLOAT8_CFG, BaseModel, @@ -79,7 +79,7 @@ def __init__(self, config: TransformerConfig): def forward( self, seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch - loss_ctx: dict[Literal["lm"], BaseLossContext] | None = None, + loss_ctx: dict[str, BaseLossContext | list[BaseLossContext]] | None = None, ) -> ModelOutputs: input_ids = seq_ctx.input_ids position_ids = seq_ctx.position_ids @@ -117,7 +117,7 @@ def forward( output["logits"] = logits else: # Training mode - loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx["lm"]) + loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx["lm"]) # type: ignore[call-overload] output["loss"] = loss output["logits"] = logits output["extra_info"] = extra_info diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 1657f4dbe..e9aba4817 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -8,7 +8,6 @@ import torch.distributed as dist import torch.nn.functional as F from cyclopts import Parameter -from pydantic import BaseModel as PydanticBaseModel from pydantic import ConfigDict from torch import nn from torch.distributed._functional_collectives import all_reduce @@ -34,8 +33,6 @@ LMHeadLossContext, ZLossConfig, ZLossContext, - ZLossKwargs, - BaseLossContext, ) from xtuner.v1.model.base import ( DEFAULT_FLOAT8_CFG, @@ -290,11 +287,11 @@ def update_bias(self, total_expert_counts_pre_iter, expected_loads): e_score_correction_bias.add_(updates) - def build_loss_ctx_batch( + def build_loss_ctx_batch( # type: ignore[override] self, data_batch: list["ColateItem"], sp_mesh: DeviceMesh | None = None, - ) -> list[MoELossContextDict]: + ) -> list[MoELossContextDict]: # type: ignore[override] """Build and calibrate loss contexts for MoE model. Args: diff --git a/xtuner/v1/module/lm_head/lm_head.py b/xtuner/v1/module/lm_head/lm_head.py index 92cd96c2a..3e8e48350 100644 --- a/xtuner/v1/module/lm_head/lm_head.py +++ b/xtuner/v1/module/lm_head/lm_head.py @@ -6,7 +6,7 @@ from torch.distributed.tensor import DTensor from typing_extensions import overload -from xtuner.v1.loss import CELossContext +from xtuner.v1.loss import LMHeadLossContext Loss: TypeAlias = torch.Tensor @@ -25,11 +25,11 @@ def forward( @overload # type: ignore[override] def forward( - self, hidden_states: HiddenStates, loss_ctx: CELossContext + self, hidden_states: HiddenStates, loss_ctx: LMHeadLossContext ) -> tuple[Loss, tuple[Logits | None, dict[str, Any]]]: ... def forward( # type: ignore[override] - self, hidden_states: torch.Tensor, loss_ctx: CELossContext | None = None + self, hidden_states: torch.Tensor, loss_ctx: LMHeadLossContext | None = None ) -> tuple[Loss | None, tuple[Logits | None, dict[str, Any]]]: """Forward pass of the language model head.""" if isinstance(self.weight, DTensor): @@ -55,7 +55,7 @@ def __call__( @overload # type: ignore def __call__( - self, hidden_states: HiddenStates, loss_ctx: CELossContext + self, hidden_states: HiddenStates, loss_ctx: LMHeadLossContext ) -> tuple[Loss, tuple[Logits | None, dict[str, Any]]]: ... __call__ = nn.Module.__call__ diff --git a/xtuner/v1/rl/base/loss.py b/xtuner/v1/rl/base/loss.py index 5767e7a76..006f30ca7 100644 --- a/xtuner/v1/rl/base/loss.py +++ b/xtuner/v1/rl/base/loss.py @@ -4,8 +4,7 @@ from torch.distributed.device_mesh import DeviceMesh from typing_extensions import Self -from xtuner.v1.loss import BaseLossConfig, BaseLossKwargs -from xtuner.v1.loss.base_loss_ctx import BaseLossContext +from xtuner.v1.loss.ce_loss import CELossConfig, CELossContext, CELossKwargs from xtuner.v1.loss.utils import sp_gather, sp_split from xtuner.v1.utils.device import get_device @@ -24,7 +23,7 @@ def compute_kl_loss_weight( return kl_loss_weight -class BaseRLLossConfig(BaseLossConfig): +class BaseRLLossConfig(CELossConfig): """Base configuration for reinforcement learning loss functions in XTuner RL. @@ -142,7 +141,7 @@ def build( return LossContext(self, loss_kwargs) -class BaseRLLossKwargs(BaseLossKwargs): +class BaseRLLossKwargs(CELossKwargs): """Keyword arguments for reinforcement learning loss computation. Args: @@ -166,7 +165,9 @@ class BaseRLLossKwargs(BaseLossKwargs): is_weights: torch.Tensor | None = None def sp_split(self, sp_mesh: DeviceMesh) -> Self: - self.shifted_labels = sp_split(self.shifted_labels, sp_mesh=sp_mesh, split_dim=1, padding_value=-100) + # Call parent class to handle shifted_labels + super().sp_split(sp_mesh) + # Handle RL-specific fields self.advantages = sp_split(self.advantages, sp_mesh=sp_mesh, split_dim=1, padding_value=0.0) if self.rollout_logprobs is not None: self.rollout_logprobs = sp_split(self.rollout_logprobs, sp_mesh=sp_mesh, split_dim=1, padding_value=0.0) @@ -180,7 +181,9 @@ def sp_split(self, sp_mesh: DeviceMesh) -> Self: return self def to(self, device: torch.device | str) -> Self: - self.shifted_labels = self.shifted_labels.to(device) + # Call parent class to handle shifted_labels + super().to(device) + # Handle RL-specific fields self.advantages = self.advantages.to(device) if self.old_logprobs is not None: self.old_logprobs = self.old_logprobs.to(device) @@ -199,9 +202,9 @@ def to(self, device: torch.device | str) -> Self: return self -class BaseRLLossContext(BaseLossContext): - loss_cfg: BaseRLLossConfig - loss_kwargs: BaseRLLossKwargs +class BaseRLLossContext(CELossContext): + loss_cfg: BaseRLLossConfig # type: ignore[assignment] + loss_kwargs: BaseRLLossKwargs # type: ignore[assignment] def compute_rollout_is( self, sp_mesh: DeviceMesh, num_tokens: torch.Tensor