Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xtuner/v1/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .base_loss_ctx import BaseLossConfig, BaseLossContext, BaseLossKwargs
from .ce_loss import CELossConfig, CELossContext
from .ce_loss import CELossConfig, CELossContext, LMHeadLossContext
from .chunk_loss import ChunkLoss
from .moe_loss import (
BalancingLoss,
Expand Down Expand Up @@ -28,6 +28,7 @@
"BaseLossConfig",
"BaseLossContext",
"BaseLossKwargs",
"LMHeadLossContext",
]

import torch
Expand Down
95 changes: 12 additions & 83 deletions xtuner/v1/loss/base_loss_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,10 @@
from typing import Annotated, Any, Literal, TypeVar

import torch
import torch.distributed as dist
import torch.nn as nn
from cyclopts import Parameter
from pydantic import BaseModel, ConfigDict
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.nn.functional import all_reduce
from typing_extensions import Self

from xtuner.v1.loss.utils import sp_split

from .chunk_loss import ChunkLoss


# Do loss calibration among dp, sp and grad accumulation:
Expand Down Expand Up @@ -46,18 +39,13 @@


class BaseLossKwargs(BaseModel):
"""Everything needed to compute the loss."""

model_config = ConfigDict(title="loss keyword arguments", extra="forbid", arbitrary_types_allowed=True)
shifted_labels: torch.Tensor
"""Everything needed to compute the loss.

def sp_split(self, sp_mesh: DeviceMesh) -> Self:
self.shifted_labels = sp_split(self.shifted_labels, sp_mesh=sp_mesh, split_dim=1, padding_value=-100)
return self
Subclasses should implement sp_split() and to() methods if they contain tensors that need to be split across
sequence parallel mesh or moved to device.
"""

def to(self, device: torch.device | str) -> Self:
self.shifted_labels = self.shifted_labels.to(device)
return self
model_config = ConfigDict(title="loss keyword arguments", extra="forbid", arbitrary_types_allowed=True)

def chunk(self, chunk_size) -> list["BaseLossKwargs"]:
tensor_fields: dict[str, tuple[torch.Tensor, ...]] = {}
Expand Down Expand Up @@ -114,10 +102,13 @@ class BaseLossConfig(BaseModel):
chunk_size: Annotated[int | None, Parameter(help="chunk size when mode is chunk")] = 1024

@property
@abstractmethod
def loss_ctx_cls(self) -> type["BaseLossContext"]:
raise NotImplementedError

# TODO: private property maybe not a good idea
@property
@abstractmethod
def _loss_kwargs_cls(self) -> type["BaseLossKwargs"]:
raise NotImplementedError

Expand Down Expand Up @@ -160,72 +151,10 @@ def __init__(self, loss_cfg: BaseLossConfig, loss_kwargs: BaseLossKwargs):
self._batch_size = 1

@staticmethod
@abstractmethod
def build_batches(loss_ctx_list: list[_BaseLossContextT], *args, **kwargs) -> list[_BaseLossContextT]: ...

@abstractmethod
def loss_fn(
self,
hidden_states: torch.Tensor,
head_weight: torch.Tensor,
head_bias: torch.Tensor | None,
loss_kwargs: BaseLossKwargs,
) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]:
"""Step 2.a and 2.b in the loss calculation."""
...

def eager_mode(
self,
hidden_states: torch.Tensor,
head_weight: torch.Tensor,
head_bias: torch.Tensor | None,
loss_kwargs: BaseLossKwargs,
):
return self.loss_fn(hidden_states, head_weight, head_bias, loss_kwargs)

def chunk_mode(
self,
hidden_states: torch.Tensor,
head_weight: torch.Tensor,
head_bias: torch.Tensor | None,
loss_kwargs: BaseLossKwargs,
):
assert self.loss_cfg.chunk_size is not None, "chunk_size must be set in chunk mode"

chunks = loss_kwargs.chunk(self.loss_cfg.chunk_size)
loss, extra_info = ChunkLoss.apply(
hidden_states, head_weight, head_bias, self.loss_fn, chunks, self.loss_cfg.chunk_size
)
return loss, (None, extra_info)

def forward(
self,
hidden_states: torch.Tensor,
head_weight: torch.Tensor,
head_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]:
from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo

assert self.loss_kwargs is not None, "loss_kwargs must be set before calling forward"
if head_bias is not None:
raise NotImplementedError("Loss does not support head_bias yet.")

if self.loss_cfg.mode == "eager":
loss, (logits, extra_info) = self.eager_mode(hidden_states, head_weight, head_bias, self.loss_kwargs)
else:
loss, (logits, extra_info) = self.chunk_mode(hidden_states, head_weight, head_bias, self.loss_kwargs)

# TODO: yanhuida, should be removed
if not isinstance(extra_info, ModelForwardExtraLogInfo):
extra_info = ModelForwardExtraLogInfo(extra_info)

extra_info["local_base_loss"] = loss.detach().clone()

# Step 2.c in the loss calculation: reduce the loss over all ranks using all_reduce with autograd support
if dist.is_initialized():
loss = all_reduce(loss, op=dist.ReduceOp.SUM, group=dist.group.WORLD)

return loss, (logits, extra_info)
def build_batches(loss_ctx_list: list[_BaseLossContextT], *args, **kwargs) -> list[_BaseLossContextT]:
for ctx in loss_ctx_list:
ctx._batch_size = len(loss_ctx_list)
return loss_ctx_list

@classmethod
def cat(cls: type[_BaseLossContextT], chunks: list[_BaseLossContextT]) -> _BaseLossContextT:
Expand Down
71 changes: 67 additions & 4 deletions xtuner/v1/loss/ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import torch.nn.functional as F
from cyclopts import Parameter
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.nn.functional import all_reduce

from xtuner.v1.loss import BaseLossConfig, BaseLossContext, BaseLossKwargs
from xtuner.v1.loss.chunk_loss import ChunkLoss
from xtuner.v1.utils.device import get_device

# from xtuner.v1.profiler.prober import ProberList
Expand Down Expand Up @@ -37,7 +39,11 @@ class CELossConfig(BaseLossConfig):
def loss_ctx_cls(self) -> type["CELossContext"]:
return CELossContext

def model_post_init(self, __context: Any) -> None:
@property
def _loss_kwargs_cls(self) -> type["CELossKwargs"]:
return CELossKwargs

def model_post_init(self, _context: Any) -> None:
if self.mode == "liger":
assert self.loss_reduction == "token", "Currently, cannot use liger kernel with sample or square reduction"

Expand Down Expand Up @@ -80,8 +86,16 @@ class CELossKwargs(BaseLossKwargs):
shifted_labels: torch.Tensor
loss_weight: torch.Tensor | None = None

def sp_split(self, sp_mesh: DeviceMesh) -> "CELossKwargs":
self.shifted_labels = sp_split(self.shifted_labels, sp_mesh=sp_mesh, split_dim=1, padding_value=-100)
return self

def to(self, device: torch.device | str) -> "CELossKwargs":
self.shifted_labels = self.shifted_labels.to(device)
return self


class CELossContext(BaseLossContext):
class LMHeadLossContext(BaseLossContext):
"""Cross-entropy loss context for language models.

Args:
Expand Down Expand Up @@ -163,6 +177,7 @@ def build_batches( # type: ignore[override]

for loss_ctx in loss_ctx_list:
loss_ctx._batch_size = len(loss_ctx_list)
assert loss_ctx.loss_kwargs.loss_weight is not None
loss_ctx.loss_kwargs.loss_weight /= global_denominator + 1e-12
return loss_ctx_list

Expand Down Expand Up @@ -195,15 +210,30 @@ def loss_fn(

return loss, (logits, {})

def eager_mode(
self,
hidden_states: torch.Tensor,
head_weight: torch.Tensor,
head_bias: torch.Tensor | None,
loss_kwargs: CELossKwargs,
) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]:
return self.loss_fn(hidden_states, head_weight, head_bias, loss_kwargs)

def chunk_mode(
self,
hidden_states: torch.Tensor,
head_weight: torch.Tensor,
head_bias: torch.Tensor | None,
loss_kwargs: CELossKwargs,
):
) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]:
if self.loss_cfg.mode == "chunk":
return super().chunk_mode(hidden_states, head_weight, head_bias, loss_kwargs)
assert self.loss_cfg.chunk_size is not None, "chunk_size must be set in chunk mode"

chunks = loss_kwargs.chunk(self.loss_cfg.chunk_size)
loss, extra_info = ChunkLoss.apply(
hidden_states, head_weight, head_bias, self.loss_fn, chunks, self.loss_cfg.chunk_size
)
return loss, (None, extra_info)
else:
assert self.liger_loss_fct is not None, "liger_loss_fct must be initialized in liger mode"
shifted_labels = loss_kwargs.shifted_labels # (bs, seq_len)
Expand All @@ -225,3 +255,36 @@ def chunk_mode(
@property
def batch_size(self) -> int:
return self._batch_size

def forward(
self,
hidden_states: torch.Tensor,
head_weight: torch.Tensor,
head_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, tuple[torch.Tensor | None, dict[str, Any]]]:
from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo

assert self.loss_kwargs is not None, "loss_kwargs must be set before calling forward"
if head_bias is not None:
raise NotImplementedError("Loss does not support head_bias yet.")

if self.loss_cfg.mode == "eager":
loss, (logits, extra_info) = self.eager_mode(hidden_states, head_weight, head_bias, self.loss_kwargs)
else:
loss, (logits, extra_info) = self.chunk_mode(hidden_states, head_weight, head_bias, self.loss_kwargs)

# TODO: yanhuida, should be removed
if not isinstance(extra_info, ModelForwardExtraLogInfo):
extra_info = ModelForwardExtraLogInfo(extra_info)

extra_info["local_base_loss"] = loss.detach().clone()

# Step 2.c in the loss calculation: reduce the loss over all ranks using all_reduce with autograd support
if dist.is_initialized():
loss = all_reduce(loss, op=dist.ReduceOp.SUM, group=dist.group.WORLD)

return loss, (logits, extra_info)


# Deprecated: Use LMHeadLossContext instead. Will be removed in version 1.1.0
CELossContext = LMHeadLossContext
9 changes: 5 additions & 4 deletions xtuner/v1/loss/moe_loss.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import Annotated, Any, Literal
from typing import Annotated, Literal

import torch
import torch.nn as nn
from cyclopts import Parameter
from pydantic import BaseModel, ConfigDict
from torch import distributed as dist
from torch.distributed._functional_collectives import all_reduce
from torch.distributed.device_mesh import DeviceMesh

from xtuner.v1.utils.device import get_device

Expand Down Expand Up @@ -223,7 +222,7 @@ def forward(

tokens_per_expert_global = tokens_per_expert.to(router_weights.dtype) # (nlayers, ne)
if self.loss_cfg.balancing_loss_global_average and dist.is_initialized():
tokens_per_expert_global = all_reduce(tokens_per_expert_global, "sum", dist.group.WORLD)
tokens_per_expert_global = all_reduce(tokens_per_expert_global, "sum", dist.group.WORLD) # type: ignore
tokens_global = tokens_per_expert_global.sum(-1) # (nlayers, )
seqlen_global = tokens_global // num_experts_per_tok
routing_weights_sum_global = all_reduce_autograd(router_weights.sum(dim=1), "sum", dist.group.WORLD)
Expand Down Expand Up @@ -327,7 +326,9 @@ def forward(self, router_logits: torch.Tensor) -> torch.Tensor:
if self.loss_cfg.z_loss_global_average and dist.is_initialized():
unmasked_num = router_logits.shape[1]
unmasked_num_rank = torch.tensor(unmasked_num, device=router_logits.device, dtype=torch.int64)
unmasked_num_global = all_reduce(unmasked_num_rank, "sum", dist.group.WORLD)
group = dist.group.WORLD
assert group is not None
unmasked_num_global = all_reduce(unmasked_num_rank, "sum", group)
world_size = dist.get_world_size()
loss = loss * unmasked_num * world_size / unmasked_num_global

Expand Down
16 changes: 7 additions & 9 deletions xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from itertools import chain
from pathlib import Path
from shutil import copy, copytree
from typing import Annotated, Any, Generator, Iterable, Literal, Mapping, Sequence, TypedDict, cast
from typing import Annotated, Any, Generator, Iterable, Literal, Mapping, Sequence, cast

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -620,7 +620,8 @@ def build_loss_ctx_batch(
if lm_loss_ctx_list is not None:
loss_ctx_cls = lm_loss_ctx_list[0].__class__
lm_loss_ctx_list = loss_ctx_cls.build_batches(
lm_loss_ctx_list, cu_seq_lens_list=cu_seq_lens_list, sp_mesh=sp_mesh)
lm_loss_ctx_list, cu_seq_lens_list=cu_seq_lens_list, sp_mesh=sp_mesh
)

if lm_loss_ctx_list is not None:
for i, lm_loss_ctx in enumerate(lm_loss_ctx_list):
Expand Down Expand Up @@ -1714,10 +1715,7 @@ def _collect_full_state_dict(self, module: nn.Module):
return ret

def _build_loss_ctx(
self,
loss_ctx_cfg: BaseLossConfig | None,
data_batch: list[dict],
sp_mesh: DeviceMesh | None
self, loss_ctx_cfg: BaseLossConfig | None, data_batch: list[dict], sp_mesh: DeviceMesh | None
) -> list[BaseLossContext] | None:
if loss_ctx_cfg is None:
return None
Expand All @@ -1728,9 +1726,9 @@ def _build_loss_ctx(
if first_loss_ctx is None:
return None
else:
ret = [first_loss_ctx] + [
loss_ctx_cfg.build(data=data, sp_mesh=sp_mesh) for data in data_batch[1:]]
return ret
ret = [first_loss_ctx] + [loss_ctx_cfg.build(data=data, sp_mesh=sp_mesh) for data in data_batch[1:]]
return ret # type: ignore[return-value]

# NOTE: Add this overload for inferring the return type for easier type checking and using
@overload # type: ignore
def __call__( # type: ignore
Expand Down
8 changes: 4 additions & 4 deletions xtuner/v1/model/dense/dense.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from pathlib import Path
from typing import Self, cast, Literal
from typing import Self, cast

import torch
import torch.distributed as dist
Expand All @@ -20,7 +20,7 @@
from xtuner.v1.config import FSDPConfig
from xtuner.v1.data_proto import SequenceContext
from xtuner.v1.float8.float8_handler import Float8Handler
from xtuner.v1.loss import CELossContext, BaseLossContext
from xtuner.v1.loss import BaseLossContext, CELossContext
from xtuner.v1.model.base import (
DEFAULT_FLOAT8_CFG,
BaseModel,
Expand Down Expand Up @@ -79,7 +79,7 @@ def __init__(self, config: TransformerConfig):
def forward(
self,
seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch
loss_ctx: dict[Literal["lm"], BaseLossContext] | None = None,
loss_ctx: dict[str, BaseLossContext | list[BaseLossContext]] | None = None,
) -> ModelOutputs:
input_ids = seq_ctx.input_ids
position_ids = seq_ctx.position_ids
Expand Down Expand Up @@ -117,7 +117,7 @@ def forward(
output["logits"] = logits
else:
# Training mode
loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx["lm"])
loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx["lm"]) # type: ignore[call-overload]
output["loss"] = loss
output["logits"] = logits
output["extra_info"] = extra_info
Expand Down
Loading
Loading