Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MoE layer example #303

Closed
wants to merge 13 commits into from
Closed
4 changes: 2 additions & 2 deletions thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def expand(a: TensorLike, *shape: int) -> TensorLike:
offset_idx = idx + offset
requested_length = shape[offset_idx]
utils.check(
requested_length == x or x == 1 or requested_length == -1,
requested_length == x or x == 1 or requested_length == -1 or x == -1,
lambda: f"expand: attempting to expand a dimension of length {x}!",
)

Expand Down Expand Up @@ -1166,7 +1166,7 @@ def compute_broadcast_shape(*_shapes):
common_shape[idx] = shape[idx]

utils.check(
(shape[idx] == 1) or (common_shape[idx] == shape[idx]),
(shape[idx] == 1) or (common_shape[idx] == shape[idx]) or (common_shape[idx] == -1),
lambda: f"Attempting to broadcast a dimension of length {shape[idx]}!",
)

Expand Down
2 changes: 1 addition & 1 deletion thunder/core/baseutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def check_valid_length(length: int):
"""Validates that an object represents a valid dimension length."""

check_type(length, int)
check(length >= 0, lambda: f"Found invalid length {length}!")
check(length >= -1, lambda: f"Found invalid length {length}!")


def check_valid_shape(shape: tuple[int, ...] | list[int]):
Expand Down
20 changes: 19 additions & 1 deletion thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def register_method(method_name: str, method: Callable, /) -> None:


class PrimIDs(Enum):
NONZERO_TUPLE = auto()
# Unpacking and input validation prims
ASSERT_TENSOR_METADATA = auto()
CHECK_TENSOR_SHAPE_AND_METADATA = auto()
Expand Down Expand Up @@ -2833,7 +2834,7 @@ def broadcast_in_dim_meta(a: TensorProxy, /, shape: Sequence[int], broadcast_dim
lambda: f"One of the broadcast_dimensions={broadcast_dimensions} was {idx}, which is out-of-bounds for a tensor with {len(shape)} dimensions",
)
utils.check(
original_length == 1 or shape[idx] == original_length,
original_length == 1 or shape[idx] == original_length or original_length == -1,
lambda: f"A dimension of length {original_length} cannot be broadcast to a dimension of length {shape[idx]}",
)

Expand Down Expand Up @@ -3557,6 +3558,23 @@ def matmul_meta(a: TensorProxy, b: TensorProxy, /) -> TensorProxy:

matmul = make_prim(PrimIDs.MATMUL, "matmul", meta=matmul_meta)


def nonzero_tuple_meta(
a: TensorProxy,
/,
) -> tuple[TensorProxy, ...]:
# Checks types
utils.check_type(a, TensorProxy)

# Output shape is data dependent
output_shape = (-1,)

# Returns the output tensor
return tuple(TensorProxy(like=a, shape=output_shape, dtype=dtypes.int64) for _ in range(a.ndim))


nonzero_tuple = make_prim(PrimIDs.NONZERO_TUPLE, "nonzero_tuple", meta=nonzero_tuple_meta)

#
# NN prims
#
Expand Down
27 changes: 27 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,26 @@ def _take_prim_grad(a: TensorProxy, index: TensorProxy, dim: int) -> TensorProxy
register_grad(pids.TAKE, _take_prim_grad)


def _index_add_prim_grad(a: TensorProxy, /, index: TensorProxy, value: TensorProxy, dim: int) -> TensorProxy:
fwd = prims.index_add(a, index, value, dim)

g = get_grad(fwd)
put_grad(a, g)

if value.ndim > 0:
g = clang.take(g, index, dim)
g = clang.expand(g, value.shape)
else:
g = clang.take(g, clang.squeeze(index, 0), dim)

put_grad(value, g)

return fwd


register_grad(pids.INDEX_ADD, _index_add_prim_grad)


@torchctx
def _gather_prim_grad(a: TensorProxy, index: TensorProxy, dim: int) -> TensorProxy:
fwd = prims.gather(a, index, dim)
Expand Down Expand Up @@ -1120,6 +1140,13 @@ def _where_prim_grad(pred: Number | TensorProxy, a: Number | TensorProxy, b: Num

register_grad(pids.WHERE, _where_prim_grad)


def _nonzero_tuple_prim_grad(a: TensorProxy) -> tuple[TensorProxy, ...]:
return prims.nonzero_tuple(a)


register_grad(pids.NONZERO_TUPLE, _nonzero_tuple_prim_grad)

#
# Reduction operator grads
#
Expand Down
6 changes: 6 additions & 0 deletions thunder/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,11 @@ def is_numbertensor(t):
# TODO: maybe generalize to *args like check_same_dtype
# TODO: change to check_same_shape or add check_same_shape variant and make check_same_dtype use the same pattern
def same_shape(a: Sequence[int], b: Sequence[int], /) -> bool:
# Allow for -1 in the shape to represent an unknown dimension
if -1 in a:
a = tuple(x if x != -1 else b[i] for i, x in enumerate(a))
if -1 in b:
b = tuple(x if x != -1 else a[i] for i, x in enumerate(b))
return tuple(a) == tuple(b)


Expand Down Expand Up @@ -1048,6 +1053,7 @@ def find_producer_symbols(trace: TraceCtx, proxies: Sequence[Proxy], stop_proxie
(__b = ltorch.sub(x, y)
# __b = prims.sub(x, y),)
"""
stop_proxies = filter(lambda x: isinstance(x, Proxy), stop_proxies)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
stop_proxies = filter(lambda x: isinstance(x, Proxy), stop_proxies)
stop_proxies = tuple(filter(lambda x: isinstance(x, Proxy), stop_proxies))

trace_producers = producers(trace)
result = set()
queue = list(proxies)
Expand Down
10 changes: 10 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,7 @@ def _addcmul_transform(a: TensorLike, b: TensorLike, c: TensorLike, /, *, value:

clamp = _register_torch_operation("clamp")
where = _register_torch_operation("where")
nonzero = _register_torch_operation("nonzero")
masked_fill = _register_torch_operation("masked_fill", module=torch.Tensor)
tril = _register_torch_operation("tril")

Expand Down Expand Up @@ -1001,6 +1002,15 @@ def _tril_transform(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | N

_register_implementation(prims.where, where, checker=_where_prim_checker)


def _nonzero_tuple_exec_transform(a: torch.Tensor) -> tuple[torch.Tensor, ...]:
return nonzero(a, as_tuple=True)


_register_implementation(
prims.nonzero_tuple, nonzero, checker=_always_executable, execution_transform=_nonzero_tuple_exec_transform
)

_register_implementation(ltorch.clamp, clamp, checker=_always_executable)
_register_implementation(ltorch.masked_fill, masked_fill, checker=_masked_fill_checker)
_register_implementation(ltorch.tril, checker=_tril_checker, execution_transform=_tril_transform)
Expand Down
50 changes: 50 additions & 0 deletions thunder/tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,53 @@ def test_hf_bart_self_attn():
tom = thunder.jit(model)
thunder_result = tom(inp, None)
assert_close(torch_result, thunder_result)


@instantiate(dtypes=(thunder.float32,))
def test_llama_moe(executor, device, dtype):
# This test is a modified version of the LLaMAMoE from LitGPT:
# https://github.com/Lightning-AI/litgpt/blob/96836008be96fa2fe5e6909a3fd7a112cc57716e/litgpt/model.py#L325-L349
class Test(nn.Module):
def __init__(self) -> None:
super().__init__()
self.n_expert = 8
self.n_expert_per_token = 2
self.C = 2
self.gate = nn.Linear(self.C, self.n_expert, bias=False)
self.experts = nn.ModuleList(nn.Linear(2, 2, bias=False) for _ in range(self.n_expert))

def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
x = x.view(-1, C) # (B*T, C)
router = self.gate(x) # (B*T, n_expert)
probs, indices = torch.topk(router, self.n_expert_per_token) # (B*T, n_expert_per_token)
probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype)
masks = indices.unsqueeze(-1) == torch.arange(self.n_expert, device=x.device)
masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token)
y = torch.zeros_like(x) # (B*T, C)
for i in range(self.n_expert):
# NOTE: zip is not working
# See https://github.com/Lightning-AI/lightning-thunder/issues/284
# for (mask, expert) in zip(masks, self.experts):
mask = masks[i]
expert = self.experts[i]
token_idx, expert_idx = torch.where(mask)
# NOTE: probs[token_idx, expert_idx, None] is not working
pprobs = probs[token_idx, expert_idx]
pprobs = pprobs.unsqueeze(-1)
eexpert = expert(x[token_idx])
# NOTE: The following line uses += instead of torch.index_add in the original code
y = torch.index_add(y, 0, token_idx, pprobs * eexpert)
return y.view(B, T, C)

model = Test().to(device=device, dtype=ttorch.to_torch_dtype(dtype))
model = thunder.jit(model, executors=executor.executors_list())

x = torch.randn(2, 3, 2, device=device, dtype=ttorch.to_torch_dtype(dtype))
y = model(x)
assert y.shape == (2, 3, 2)

y.backward(torch.randn_like(y))
assert all(p.grad is not None for p in model.parameters())
print(thunder.last_backward_traces(model)[-1])
print(thunder.last_traces(model)[-1])
9 changes: 9 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1708,6 +1708,8 @@ def tril(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | Number = Non
def where(
pred: TensorLike, a: None | Number | TensorLike = None, b: None | Number | TensorLike = None, /
) -> TensorLike:
if a is None and b is None:
return prims.nonzero_tuple(pred)
utils.check(
isinstance(a, (Number, TensorProxy)) and isinstance(b, (Number, TensorProxy)),
lambda: f"torch.where() does not support only specifying a condition",
Expand Down Expand Up @@ -1781,6 +1783,13 @@ def convert(x, if_none):
return result


@torchsymbol(torch.nonzero, is_method=True)
def nonzero(a: TensorLike, /, as_tuple: bool = False) -> TensorLike | tuple[TensorLike, ...]:
if as_tuple:
return prims.nonzero_tuple(a)
raise NotImplementedError("torch.nonzero() only supports as_tuple=True")


#
# Reduction operations
#
Expand Down
Loading