Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
77dcd8f
add max_value_in_b_seq_len in deepseekv2 inferstate.
hiworldwzj Mar 31, 2025
c3618ce
feat: add a triton version for add_in_place
Mar 31, 2025
86b2e3e
fix bug for release infer req.
hiworldwzj Mar 31, 2025
8265f47
tma align kernel
shihaobai Mar 31, 2025
05d2bcb
Merge branch 'prefill_overlap' of https://github.com/ModelTC/lightllm…
shihaobai Mar 31, 2025
1581c13
remove
shihaobai Mar 31, 2025
34e3e74
update tma
shihaobai Apr 1, 2025
7b9fafb
fix: add assert in add_in_place
Apr 1, 2025
8c62384
import repeat copy.
hiworldwzj Apr 1, 2025
2cb6219
revert cat & repeat
shihaobai Apr 1, 2025
4c23310
update overlap
shihaobai Apr 1, 2025
7434b39
update overlap
shihaobai Apr 1, 2025
bef9e7b
feat: disable add_in_place
Apr 2, 2025
ed94d6f
opt: coordinate dispatch position
Apr 2, 2025
6ff1049
update test
shihaobai Apr 2, 2025
e800cc0
tune deep_gemm sms
shihaobai Apr 2, 2025
3522e93
update dependcy
shihaobai Apr 2, 2025
faf076a
update dependcy
shihaobai Apr 2, 2025
6b9d985
update
shihaobai Apr 2, 2025
428f2b5
fix
shihaobai Apr 2, 2025
1b3d4ff
back move ffn.
Apr 2, 2025
9f9a7d3
Merge remote-tracking branch 'origin/main' into prefill_overlap
Apr 2, 2025
5ca9695
fix
hiworldwzj Apr 2, 2025
456921f
add disable_aggressive_schedule cli.
Apr 2, 2025
d6b63ca
update api
shihaobai Apr 2, 2025
a20fdd0
Merge branch 'prefill_overlap' of https://github.com/ModelTC/lightllm…
shihaobai Apr 2, 2025
a2139b8
fix start check.
Apr 2, 2025
94d7045
Merge remote-tracking branch 'origin/prefill_overlap' into prefill_ov…
Apr 2, 2025
f670748
fix typing.
Apr 2, 2025
3d19cbd
fix
Apr 2, 2025
631b6a8
fix
hiworldwzj Apr 2, 2025
48ff875
fix
shihaobai Apr 2, 2025
43553be
improve scatter_2.
Apr 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from lightllm.distributed import dist_group_manager
from lightllm.common.fused_moe.topk_select import select_experts
from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank
from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8
from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import (
per_token_group_quant_fp8,
tma_align_input_scale,
)
from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
from lightllm.utils.log_utils import init_logger

Expand Down Expand Up @@ -152,7 +155,7 @@ def low_latency_dispatch(
)
return recv_x, masked_m, topk_idx, topk_weights, handle, hook

def dispatch(
def select_experts_and_quant_input(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
Expand All @@ -168,9 +171,6 @@ def dispatch(
num_expert_group=self.n_group,
scoring_func=self.scoring_func,
)
topk_idx = topk_idx.to(torch.long)
buffer = dist_group_manager.ep_buffer
num_experts = self.n_routed_experts
M, K = hidden_states.shape
w1, w1_scale = self.w1
block_size_k = 0
Expand All @@ -180,7 +180,17 @@ def dispatch(
input_scale = torch.empty((M, K // block_size_k), dtype=torch.float32, device=hidden_states.device)
qinput_tensor = torch.empty((M, K), dtype=w1.dtype, device=hidden_states.device)
per_token_group_quant_fp8(hidden_states, block_size_k, qinput_tensor, input_scale)
return topk_weights, topk_idx.to(torch.long), (qinput_tensor, input_scale)

def dispatch(
self,
qinput_tensor: Tuple[torch.Tensor],
topk_idx: torch.Tensor,
topk_weights: torch.Tensor,
overlap_event: Optional[Any] = None,
):
buffer = dist_group_manager.ep_buffer
num_experts = self.n_routed_experts
# get_dispatch_layout
(
num_tokens_per_rank,
Expand All @@ -189,16 +199,10 @@ def dispatch(
is_token_in_rank,
previous_event,
) = buffer.get_dispatch_layout(
topk_idx, num_experts, previous_event=None, async_finish=True, allocate_on_comm_stream=False
topk_idx, num_experts, previous_event=overlap_event, async_finish=True, allocate_on_comm_stream=True
)

# normal dispatch
# recv_x [recive_num_tokens, hidden] recv_x_scale [recive_num_tokens, hidden // block_size]
# recv_topk_idx [recive_num_tokens, topk_num]
# recv_topk_weights [recive_num_tokens, topk_num]
# num_recv_tokens_per_expert_list list [cur_node_expert_num] padding with expert_alignment=128
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch(
(qinput_tensor, input_scale),
qinput_tensor,
topk_idx=topk_idx,
topk_weights=topk_weights,
num_tokens_per_rank=num_tokens_per_rank,
Expand All @@ -207,14 +211,14 @@ def dispatch(
num_tokens_per_expert=num_tokens_per_expert,
previous_event=previous_event,
async_finish=True,
allocate_on_comm_stream=False,
allocate_on_comm_stream=True,
expert_alignment=128,
)

def hook():
event.current_stream_wait()

return qinput_tensor, recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, hook
return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, hook

def masked_group_gemm(
self, recv_x: Tuple[torch.Tensor], masked_m: torch.Tensor, dtype: torch.dtype, expected_m: int
Expand All @@ -228,35 +232,35 @@ def prefilled_group_gemm(
num_recv_tokens_per_expert_list,
recv_x: Tuple[torch.Tensor],
recv_topk_idx: torch.Tensor,
hidden_states: torch.Tensor,
qinput_tensor: torch.Tensor,
recv_topk_weights: torch.Tensor,
hidden_dtype=torch.bfloat16,
):
device = recv_x[0].device
w1, w1_scale = self.w1
w2, w2_scale = self.w2
_, K = hidden_states.shape
_, K = recv_x[0].shape
_, N, _ = w1.shape
# scatter
all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums.
# gather_out shape [recive_num_tokens, hidden]
gather_out = torch.empty_like(recv_x[0], device=hidden_states.device, dtype=hidden_states.dtype)
gather_out = torch.empty_like(recv_x[0], device=device, dtype=hidden_dtype)
if all_tokens > 0:
input_tensor = (
torch.empty((all_tokens, K), device=hidden_states.device, dtype=qinput_tensor.dtype),
torch.empty((all_tokens, K // 128), device=hidden_states.device, dtype=torch.float32),
)
input_tensor = [
torch.empty((all_tokens, K), device=device, dtype=recv_x[0].dtype),
torch.empty((all_tokens, K // 128), device=device, dtype=torch.float32),
]
# when m_indices is filled ok.
# m_indices show token use which expert, example, [0, 0, 0, 0, .... 1, 1, 1, 1,...., cur_expert_num - 1, ..]
# the count of 0 is num_recv_tokens_per_expert_list[0], the count of 1 is num_recv_tokens_per_expert_list[1]
# ...
m_indices = torch.empty(all_tokens, device=hidden_states.device, dtype=torch.int32)
m_indices = torch.empty(all_tokens, device=device, dtype=torch.int32)
# output_index shape [recive_num_tokens, topk_num]
# output_index use to show the token index in input_tensor
output_index = torch.empty_like(recv_topk_idx)

num_recv_tokens_per_expert = torch.tensor(
num_recv_tokens_per_expert_list, device=hidden_states.device, dtype=torch.int32
)
num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu"
).cuda(non_blocking=True)

expert_start_loc = torch.empty_like(num_recv_tokens_per_expert)

Expand All @@ -271,26 +275,25 @@ def prefilled_group_gemm(
m_indices,
output_index,
)

input_tensor[1] = tma_align_input_scale(input_tensor[1])
# groupgemm (contiguous layout)
gemm_out_a = torch.empty((all_tokens, N), device=hidden_states.device, dtype=hidden_states.dtype)
gemm_out_a = torch.empty((all_tokens, N), device=device, dtype=hidden_dtype)

deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices)

# silu_and_mul_fwd + qaunt
# TODO fused kernel
silu_out = torch.empty((all_tokens, N // 2), device=hidden_states.device, dtype=hidden_states.dtype)
silu_out = torch.empty((all_tokens, N // 2), device=device, dtype=hidden_dtype)

silu_and_mul_fwd(gemm_out_a.view(-1, N), silu_out)
qsilu_out, qsilu_out_scale = tma_aligned_quantize(silu_out)

# groupgemm (contiguous layout)
gemm_out_b = torch.empty((all_tokens, K), device=hidden_states.device, dtype=hidden_states.dtype)
gemm_out_b = torch.empty((all_tokens, K), device=device, dtype=hidden_dtype)

deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices
)

# gather and local reduce
ep_gather(gemm_out_b, recv_topk_idx, recv_topk_weights, output_index, gather_out)

Expand All @@ -312,15 +315,16 @@ def combine(
self,
gemm_out_b: torch.Tensor,
handle: Any,
overlap_event: Optional[Any] = None,
):
# normal combine
combined_x, _, event = dist_group_manager.ep_buffer.combine(
gemm_out_b,
handle,
topk_weights=None,
async_finish=True,
previous_event=None,
allocate_on_comm_stream=False,
previous_event=overlap_event,
allocate_on_comm_stream=True,
)

def hook():
Expand Down
38 changes: 38 additions & 0 deletions lightllm/common/basemodel/triton_kernel/add_in_place.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch

import triton
import triton.language as tl


@triton.jit
def _add_in_place(
input_ptr,
other_ptr,
n_elements,
alpha,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(input_ptr + offsets, mask=mask)
y = tl.load(other_ptr + offsets, mask=mask)
x = x + y * alpha
tl.store(input_ptr + offsets, x, mask=mask)


@torch.no_grad()
def add_in_place(input: torch.Tensor, other: torch.Tensor, *, alpha=1):
assert input.is_contiguous(), "input tensor must be contiguous"
assert other.is_contiguous(), "other tensor must be contiguous"
n_elements = input.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
_add_in_place[grid](
input,
other,
n_elements,
alpha,
BLOCK_SIZE=1024,
)
return input
38 changes: 12 additions & 26 deletions lightllm/common/fused_moe/deepep_scatter_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def _fwd_kernel_ep_scatter_2(
recv_topk,
recv_topk_stride0,
recv_topk_stride1,
num_recv_tokens_per_expert,
output_tensor,
output_tensor_stride0,
output_tensor_stride1,
Expand All @@ -58,39 +57,30 @@ def _fwd_kernel_ep_scatter_2(
output_index,
output_index_stride0,
output_index_stride1,
m_indices,
topk_num: tl.constexpr,
num_experts: tl.constexpr,
HIDDEN_SIZE: tl.constexpr,
HIDDEN_SIZE_PAD: tl.constexpr,
SCALE_HIDDEN_SIZE: tl.constexpr,
SCALE_HIDDEN_SIZE_PAD: tl.constexpr,
BLOCK_E: tl.constexpr,
BLOCK_D: tl.constexpr,
):
start_token_id = tl.program_id(0)
grid_num = tl.num_programs(0)

for token_id in range(start_token_id, total_token_num, grid_num):
for topk_index in tl.range(0, topk_num, 1, num_stages=4):
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
if expert_id >= 0:
dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
tl.store(output_index + token_id * output_index_stride0 + topk_index, dest_token_index)

offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
mask = offset_in < HIDDEN_SIZE
offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
mask = offset_in < HIDDEN_SIZE

offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
mask_s = offset_in_s < SCALE_HIDDEN_SIZE
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
mask_s = offset_in_s < SCALE_HIDDEN_SIZE

for token_id in range(start_token_id, total_token_num, grid_num):
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
to_copy_s = tl.load(recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s)

for topk_index in range(0, topk_num):
for topk_index in tl.range(0, topk_num, 1, num_stages=4):
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
if expert_id >= 0:
dest_token_index = tl.load(output_index + token_id * output_index_stride0 + topk_index)
dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
tl.store(output_index + token_id * output_index_stride0 + topk_index, dest_token_index)
output_tensor_ptr = output_tensor + dest_token_index * output_tensor_stride0
output_tensor_scale_ptr = output_tensor_scale + dest_token_index * output_tensor_scale_stride0
tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
Expand Down Expand Up @@ -143,7 +133,6 @@ def ep_scatter(
recv_topk,
recv_topk.stride(0),
recv_topk.stride(1),
num_recv_tokens_per_expert,
output_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
Expand All @@ -153,16 +142,12 @@ def ep_scatter(
output_index,
output_index.stride(0),
output_index.stride(1),
m_indices,
topk_num=recv_topk.shape[1],
num_experts=num_experts,
num_warps=num_warps,
HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
BLOCK_E=BLOCK_E,
BLOCK_D=BLOCK_D,
)
return

Expand Down Expand Up @@ -217,11 +202,12 @@ def ep_gather(
input_index: torch.Tensor,
output_tensor: torch.Tensor,
):
BLOCK_D = 128 # block size of quantization
num_warps = 4
BLOCK_D = 1024 # block size of quantization
num_warps = 2
num_tokens = output_tensor.shape[0]
hidden_size = input_tensor.shape[1]
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 16 * 1024))
assert hidden_size % BLOCK_D == 0
grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024))
_fwd_kernel_ep_gather[grid](
num_tokens,
input_tensor,
Expand Down
24 changes: 11 additions & 13 deletions lightllm/common/fused_moe/grouped_fused_moe_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from lightllm.utils.log_utils import init_logger
from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd
from lightllm.common.fused_moe.moe_silu_and_mul_mix_quant_ep import silu_and_mul_masked_post_quant_fwd
from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8
from lightllm.common.quantization.deepgemm_quant import get_tma_aligned_size
from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import (
per_token_group_quant_fp8,
tma_align_input_scale,
)
from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank
import numpy as np
Expand All @@ -20,8 +22,6 @@
from deep_ep import Buffer, EventOverlap
import deep_gemm

# Set the number of SMs to use
Buffer.set_num_sms(20)
except:
logger.warning("no deepep or deep_gemm")

Expand All @@ -30,12 +30,10 @@ def tma_aligned_quantize(
input_tensor: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn
) -> Tuple[torch.Tensor, torch.Tensor]:
m, k = input_tensor.shape
padded_m = get_tma_aligned_size(m, 4) # the dtype of input_scale is torch.float32
input_scale = torch.empty((k // block_size, padded_m), dtype=torch.float32, device=input_tensor.device).t()
input_scale = torch.empty((m, k // block_size), dtype=torch.float32, device=input_tensor.device)
qinput_tensor = torch.empty((m, k), dtype=dtype, device=input_tensor.device)
per_token_group_quant_fp8(input_tensor, block_size, qinput_tensor, input_scale)
input_scale = input_scale[:m, :]

input_scale = tma_align_input_scale(input_scale)
return qinput_tensor, input_scale


Expand Down Expand Up @@ -147,10 +145,10 @@ def fused_experts_impl(
# gather_out shape [recive_num_tokens, hidden]
gather_out = torch.empty_like(recv_x[0], device=hidden_states.device, dtype=hidden_states.dtype)
if all_tokens > 0:
input_tensor = (
input_tensor = [
torch.empty((all_tokens, K), device=hidden_states.device, dtype=qinput_tensor.dtype),
torch.empty((all_tokens, K // 128), device=hidden_states.device, dtype=torch.float32),
)
]
# when m_indices is filled ok.
# m_indices show token use which expert, example, [0, 0, 0, 0, .... 1, 1, 1, 1,...., cur_expert_num - 1, ..]
# the count of 0 is num_recv_tokens_per_expert_list[0], the count of 1 is num_recv_tokens_per_expert_list[1]
Expand All @@ -161,8 +159,8 @@ def fused_experts_impl(
output_index = torch.empty_like(recv_topk_idx)

num_recv_tokens_per_expert = torch.tensor(
num_recv_tokens_per_expert_list, device=hidden_states.device, dtype=torch.int32
)
num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu"
).cuda(non_blocking=True)

expert_start_loc = torch.empty_like(num_recv_tokens_per_expert)

Expand All @@ -180,7 +178,7 @@ def fused_experts_impl(

# groupgemm (contiguous layout)
gemm_out_a = torch.empty((all_tokens, N), device=hidden_states.device, dtype=hidden_states.dtype)

input_tensor[1] = tma_align_input_scale(input_tensor[1])
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices)

# silu_and_mul_fwd + qaunt
Expand Down
Loading