diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py index 5cbac1ccc8..e83c507d63 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py @@ -9,7 +9,10 @@ from lightllm.distributed import dist_group_manager from lightllm.common.fused_moe.topk_select import select_experts from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank -from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8 +from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import ( + per_token_group_quant_fp8, + tma_align_input_scale, +) from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather from lightllm.utils.log_utils import init_logger @@ -152,7 +155,7 @@ def low_latency_dispatch( ) return recv_x, masked_m, topk_idx, topk_weights, handle, hook - def dispatch( + def select_experts_and_quant_input( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -168,9 +171,6 @@ def dispatch( num_expert_group=self.n_group, scoring_func=self.scoring_func, ) - topk_idx = topk_idx.to(torch.long) - buffer = dist_group_manager.ep_buffer - num_experts = self.n_routed_experts M, K = hidden_states.shape w1, w1_scale = self.w1 block_size_k = 0 @@ -180,7 +180,17 @@ def dispatch( input_scale = torch.empty((M, K // block_size_k), dtype=torch.float32, device=hidden_states.device) qinput_tensor = torch.empty((M, K), dtype=w1.dtype, device=hidden_states.device) per_token_group_quant_fp8(hidden_states, block_size_k, qinput_tensor, input_scale) + return topk_weights, topk_idx.to(torch.long), (qinput_tensor, input_scale) + def dispatch( + self, + qinput_tensor: Tuple[torch.Tensor], + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + overlap_event: Optional[Any] = None, + ): + buffer = dist_group_manager.ep_buffer + num_experts = self.n_routed_experts # get_dispatch_layout ( num_tokens_per_rank, @@ -189,16 +199,10 @@ def dispatch( is_token_in_rank, previous_event, ) = buffer.get_dispatch_layout( - topk_idx, num_experts, previous_event=None, async_finish=True, allocate_on_comm_stream=False + topk_idx, num_experts, previous_event=overlap_event, async_finish=True, allocate_on_comm_stream=True ) - - # normal dispatch - # recv_x [recive_num_tokens, hidden] recv_x_scale [recive_num_tokens, hidden // block_size] - # recv_topk_idx [recive_num_tokens, topk_num] - # recv_topk_weights [recive_num_tokens, topk_num] - # num_recv_tokens_per_expert_list list [cur_node_expert_num] padding with expert_alignment=128 recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = buffer.dispatch( - (qinput_tensor, input_scale), + qinput_tensor, topk_idx=topk_idx, topk_weights=topk_weights, num_tokens_per_rank=num_tokens_per_rank, @@ -207,14 +211,14 @@ def dispatch( num_tokens_per_expert=num_tokens_per_expert, previous_event=previous_event, async_finish=True, - allocate_on_comm_stream=False, + allocate_on_comm_stream=True, expert_alignment=128, ) def hook(): event.current_stream_wait() - return qinput_tensor, recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, hook + return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, hook def masked_group_gemm( self, recv_x: Tuple[torch.Tensor], masked_m: torch.Tensor, dtype: torch.dtype, expected_m: int @@ -228,35 +232,35 @@ def prefilled_group_gemm( num_recv_tokens_per_expert_list, recv_x: Tuple[torch.Tensor], recv_topk_idx: torch.Tensor, - hidden_states: torch.Tensor, - qinput_tensor: torch.Tensor, recv_topk_weights: torch.Tensor, + hidden_dtype=torch.bfloat16, ): + device = recv_x[0].device w1, w1_scale = self.w1 w2, w2_scale = self.w2 - _, K = hidden_states.shape + _, K = recv_x[0].shape _, N, _ = w1.shape # scatter all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums. # gather_out shape [recive_num_tokens, hidden] - gather_out = torch.empty_like(recv_x[0], device=hidden_states.device, dtype=hidden_states.dtype) + gather_out = torch.empty_like(recv_x[0], device=device, dtype=hidden_dtype) if all_tokens > 0: - input_tensor = ( - torch.empty((all_tokens, K), device=hidden_states.device, dtype=qinput_tensor.dtype), - torch.empty((all_tokens, K // 128), device=hidden_states.device, dtype=torch.float32), - ) + input_tensor = [ + torch.empty((all_tokens, K), device=device, dtype=recv_x[0].dtype), + torch.empty((all_tokens, K // 128), device=device, dtype=torch.float32), + ] # when m_indices is filled ok. # m_indices show token use which expert, example, [0, 0, 0, 0, .... 1, 1, 1, 1,...., cur_expert_num - 1, ..] # the count of 0 is num_recv_tokens_per_expert_list[0], the count of 1 is num_recv_tokens_per_expert_list[1] # ... - m_indices = torch.empty(all_tokens, device=hidden_states.device, dtype=torch.int32) + m_indices = torch.empty(all_tokens, device=device, dtype=torch.int32) # output_index shape [recive_num_tokens, topk_num] # output_index use to show the token index in input_tensor output_index = torch.empty_like(recv_topk_idx) num_recv_tokens_per_expert = torch.tensor( - num_recv_tokens_per_expert_list, device=hidden_states.device, dtype=torch.int32 - ) + num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu" + ).cuda(non_blocking=True) expert_start_loc = torch.empty_like(num_recv_tokens_per_expert) @@ -271,26 +275,25 @@ def prefilled_group_gemm( m_indices, output_index, ) - + input_tensor[1] = tma_align_input_scale(input_tensor[1]) # groupgemm (contiguous layout) - gemm_out_a = torch.empty((all_tokens, N), device=hidden_states.device, dtype=hidden_states.dtype) + gemm_out_a = torch.empty((all_tokens, N), device=device, dtype=hidden_dtype) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices) # silu_and_mul_fwd + qaunt # TODO fused kernel - silu_out = torch.empty((all_tokens, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) + silu_out = torch.empty((all_tokens, N // 2), device=device, dtype=hidden_dtype) silu_and_mul_fwd(gemm_out_a.view(-1, N), silu_out) qsilu_out, qsilu_out_scale = tma_aligned_quantize(silu_out) # groupgemm (contiguous layout) - gemm_out_b = torch.empty((all_tokens, K), device=hidden_states.device, dtype=hidden_states.dtype) + gemm_out_b = torch.empty((all_tokens, K), device=device, dtype=hidden_dtype) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices ) - # gather and local reduce ep_gather(gemm_out_b, recv_topk_idx, recv_topk_weights, output_index, gather_out) @@ -312,6 +315,7 @@ def combine( self, gemm_out_b: torch.Tensor, handle: Any, + overlap_event: Optional[Any] = None, ): # normal combine combined_x, _, event = dist_group_manager.ep_buffer.combine( @@ -319,8 +323,8 @@ def combine( handle, topk_weights=None, async_finish=True, - previous_event=None, - allocate_on_comm_stream=False, + previous_event=overlap_event, + allocate_on_comm_stream=True, ) def hook(): diff --git a/lightllm/common/basemodel/triton_kernel/add_in_place.py b/lightllm/common/basemodel/triton_kernel/add_in_place.py new file mode 100644 index 0000000000..770e69fc51 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/add_in_place.py @@ -0,0 +1,38 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _add_in_place( + input_ptr, + other_ptr, + n_elements, + alpha, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(input_ptr + offsets, mask=mask) + y = tl.load(other_ptr + offsets, mask=mask) + x = x + y * alpha + tl.store(input_ptr + offsets, x, mask=mask) + + +@torch.no_grad() +def add_in_place(input: torch.Tensor, other: torch.Tensor, *, alpha=1): + assert input.is_contiguous(), "input tensor must be contiguous" + assert other.is_contiguous(), "other tensor must be contiguous" + n_elements = input.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + _add_in_place[grid]( + input, + other, + n_elements, + alpha, + BLOCK_SIZE=1024, + ) + return input diff --git a/lightllm/common/fused_moe/deepep_scatter_gather.py b/lightllm/common/fused_moe/deepep_scatter_gather.py index a87c6275a1..0e2ff1280d 100644 --- a/lightllm/common/fused_moe/deepep_scatter_gather.py +++ b/lightllm/common/fused_moe/deepep_scatter_gather.py @@ -48,7 +48,6 @@ def _fwd_kernel_ep_scatter_2( recv_topk, recv_topk_stride0, recv_topk_stride1, - num_recv_tokens_per_expert, output_tensor, output_tensor_stride0, output_tensor_stride1, @@ -58,39 +57,30 @@ def _fwd_kernel_ep_scatter_2( output_index, output_index_stride0, output_index_stride1, - m_indices, topk_num: tl.constexpr, - num_experts: tl.constexpr, HIDDEN_SIZE: tl.constexpr, HIDDEN_SIZE_PAD: tl.constexpr, SCALE_HIDDEN_SIZE: tl.constexpr, SCALE_HIDDEN_SIZE_PAD: tl.constexpr, - BLOCK_E: tl.constexpr, - BLOCK_D: tl.constexpr, ): start_token_id = tl.program_id(0) grid_num = tl.num_programs(0) - for token_id in range(start_token_id, total_token_num, grid_num): - for topk_index in tl.range(0, topk_num, 1, num_stages=4): - expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index) - if expert_id >= 0: - dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1) - tl.store(output_index + token_id * output_index_stride0 + topk_index, dest_token_index) - - offset_in = tl.arange(0, HIDDEN_SIZE_PAD) - mask = offset_in < HIDDEN_SIZE + offset_in = tl.arange(0, HIDDEN_SIZE_PAD) + mask = offset_in < HIDDEN_SIZE - offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD) - mask_s = offset_in_s < SCALE_HIDDEN_SIZE + offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD) + mask_s = offset_in_s < SCALE_HIDDEN_SIZE + for token_id in range(start_token_id, total_token_num, grid_num): to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask) to_copy_s = tl.load(recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s) - for topk_index in range(0, topk_num): + for topk_index in tl.range(0, topk_num, 1, num_stages=4): expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index) if expert_id >= 0: - dest_token_index = tl.load(output_index + token_id * output_index_stride0 + topk_index) + dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1) + tl.store(output_index + token_id * output_index_stride0 + topk_index, dest_token_index) output_tensor_ptr = output_tensor + dest_token_index * output_tensor_stride0 output_tensor_scale_ptr = output_tensor_scale + dest_token_index * output_tensor_scale_stride0 tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask) @@ -143,7 +133,6 @@ def ep_scatter( recv_topk, recv_topk.stride(0), recv_topk.stride(1), - num_recv_tokens_per_expert, output_tensor, output_tensor.stride(0), output_tensor.stride(1), @@ -153,16 +142,12 @@ def ep_scatter( output_index, output_index.stride(0), output_index.stride(1), - m_indices, topk_num=recv_topk.shape[1], - num_experts=num_experts, num_warps=num_warps, HIDDEN_SIZE=hidden_size, HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size), SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D, SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D), - BLOCK_E=BLOCK_E, - BLOCK_D=BLOCK_D, ) return @@ -217,11 +202,12 @@ def ep_gather( input_index: torch.Tensor, output_tensor: torch.Tensor, ): - BLOCK_D = 128 # block size of quantization - num_warps = 4 + BLOCK_D = 1024 # block size of quantization + num_warps = 2 num_tokens = output_tensor.shape[0] hidden_size = input_tensor.shape[1] - grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 16 * 1024)) + assert hidden_size % BLOCK_D == 0 + grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024)) _fwd_kernel_ep_gather[grid]( num_tokens, input_tensor, diff --git a/lightllm/common/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/fused_moe/grouped_fused_moe_ep.py index 7d55b447e5..3b5cc6b91d 100644 --- a/lightllm/common/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/fused_moe/grouped_fused_moe_ep.py @@ -8,8 +8,10 @@ from lightllm.utils.log_utils import init_logger from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.common.fused_moe.moe_silu_and_mul_mix_quant_ep import silu_and_mul_masked_post_quant_fwd -from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8 -from lightllm.common.quantization.deepgemm_quant import get_tma_aligned_size +from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import ( + per_token_group_quant_fp8, + tma_align_input_scale, +) from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank import numpy as np @@ -20,8 +22,6 @@ from deep_ep import Buffer, EventOverlap import deep_gemm - # Set the number of SMs to use - Buffer.set_num_sms(20) except: logger.warning("no deepep or deep_gemm") @@ -30,12 +30,10 @@ def tma_aligned_quantize( input_tensor: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn ) -> Tuple[torch.Tensor, torch.Tensor]: m, k = input_tensor.shape - padded_m = get_tma_aligned_size(m, 4) # the dtype of input_scale is torch.float32 - input_scale = torch.empty((k // block_size, padded_m), dtype=torch.float32, device=input_tensor.device).t() + input_scale = torch.empty((m, k // block_size), dtype=torch.float32, device=input_tensor.device) qinput_tensor = torch.empty((m, k), dtype=dtype, device=input_tensor.device) per_token_group_quant_fp8(input_tensor, block_size, qinput_tensor, input_scale) - input_scale = input_scale[:m, :] - + input_scale = tma_align_input_scale(input_scale) return qinput_tensor, input_scale @@ -147,10 +145,10 @@ def fused_experts_impl( # gather_out shape [recive_num_tokens, hidden] gather_out = torch.empty_like(recv_x[0], device=hidden_states.device, dtype=hidden_states.dtype) if all_tokens > 0: - input_tensor = ( + input_tensor = [ torch.empty((all_tokens, K), device=hidden_states.device, dtype=qinput_tensor.dtype), torch.empty((all_tokens, K // 128), device=hidden_states.device, dtype=torch.float32), - ) + ] # when m_indices is filled ok. # m_indices show token use which expert, example, [0, 0, 0, 0, .... 1, 1, 1, 1,...., cur_expert_num - 1, ..] # the count of 0 is num_recv_tokens_per_expert_list[0], the count of 1 is num_recv_tokens_per_expert_list[1] @@ -161,8 +159,8 @@ def fused_experts_impl( output_index = torch.empty_like(recv_topk_idx) num_recv_tokens_per_expert = torch.tensor( - num_recv_tokens_per_expert_list, device=hidden_states.device, dtype=torch.int32 - ) + num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu" + ).cuda(non_blocking=True) expert_start_loc = torch.empty_like(num_recv_tokens_per_expert) @@ -180,7 +178,7 @@ def fused_experts_impl( # groupgemm (contiguous layout) gemm_out_a = torch.empty((all_tokens, N), device=hidden_states.device, dtype=hidden_states.dtype) - + input_tensor[1] = tma_align_input_scale(input_tensor[1]) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices) # silu_and_mul_fwd + qaunt diff --git a/lightllm/common/quantization/deepgemm_quant.py b/lightllm/common/quantization/deepgemm_quant.py index 74b9fb968e..622a9711c3 100644 --- a/lightllm/common/quantization/deepgemm_quant.py +++ b/lightllm/common/quantization/deepgemm_quant.py @@ -3,38 +3,18 @@ from .quantize_method import QuantizationMethod from .registry import QUANTMETHODS import torch.nn.functional as F -from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8 -from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_gemm_kernel import w8a8_block_fp8_matmul +from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import ( + per_token_group_quant_fp8, + tma_align_input_scale, +) try: HAS_DEEPGEMM = True import deep_gemm - from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor except: HAS_DEEPGEMM = False -# copy from -# https://github.com/deepseek-ai/DeepGEMM/blob/bd2a77552886b98c205af12f8d7d2d61247c4b27/deep_gemm/jit_kernels/utils.py#L58 -def get_tma_aligned_size(x: int, element_size: int) -> int: - """ - Global memory address of TMA must be 16-byte aligned. - Since we use column-major layout for the LHS scaling tensor, - the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes. - - Arguments: - x: original M-axis shape of the LHS scaling tensor. - element_size: element size of the LHS scaling tensor. - - Returns: - M-axis shape of the LHS scaling tensor after padding. - """ - tma_alignment_bytes = 16 - assert tma_alignment_bytes % element_size == 0 - alignment = tma_alignment_bytes // element_size - return ceil_div(x, alignment) * alignment - - class DeepGEMMBaseQuantizationMethod(QuantizationMethod): def __init__(self): super().__init__() @@ -69,15 +49,13 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ m, k = input_tensor.shape n = weights[0].shape[1] if input_scale is None: - padded_m = get_tma_aligned_size(m, 4) # the dtype of input_scale is torch.float32 - input_scale = torch.empty( - (k // self.block_size, padded_m), dtype=torch.float32, device=input_tensor.device - ).t() + input_scale = torch.empty((m, k // self.block_size), dtype=torch.float32, device=input_tensor.device) qinput_tensor = self.cache_manager.alloc_tensor( (m, k), qweight.dtype, device=qweight.device, is_graph_out=False ) per_token_group_quant_fp8(input_tensor, self.block_size, qinput_tensor, input_scale) - input_scale = input_scale[:m, :] + input_scale = tma_align_input_scale(input_scale) + if out is None: if use_custom_tensor_mananger: out = self.cache_manager.alloc_tensor( @@ -85,16 +63,5 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ ) else: out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device) - if n % 128 != 0: - w8a8_block_fp8_matmul( - qinput_tensor, - qweight, - input_scale, - weight_scale, - out, - (self.block_size, self.block_size), - dtype=input_tensor.dtype, - ) - else: - deep_gemm.gemm_fp8_fp8_bf16_nt([qinput_tensor, input_scale], [qweight.t(), weight_scale.t()], out) + deep_gemm.gemm_fp8_fp8_bf16_nt([qinput_tensor, input_scale], [qweight.t(), weight_scale.t()], out) return out diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py b/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py index e1a377ebbe..952891ff9f 100644 --- a/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py +++ b/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py @@ -7,6 +7,17 @@ from functools import lru_cache from typing import Any, Dict, List, Optional, Tuple +try: + HAS_SGLANG_KERNEL = True + from sgl_kernel import sgl_per_token_group_quant_fp8 +except: + HAS_SGLANG_KERNEL = False + +try: + from deep_gemm import ceil_div +except: + pass + # Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py @triton.jit @@ -47,7 +58,7 @@ def _per_token_group_quant_fp8( tl.store(y_s_ptr, y_s) -def per_token_group_quant_fp8( +def lightllm_per_token_group_quant_fp8( x: torch.Tensor, group_size: int, x_q: torch.Tensor, @@ -96,10 +107,93 @@ def per_token_group_quant_fp8( num_warps=num_warps, num_stages=num_stages, ) - return +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + x_q: torch.Tensor, + x_s: torch.Tensor, + eps: float = 1e-10, + dtype: torch.dtype = torch.float8_e4m3fn, +): + if HAS_SGLANG_KERNEL: + finfo = torch.finfo(dtype) + fp8_max, fp8_min = finfo.max, finfo.min + sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, 1e-10, fp8_min, fp8_max) + else: + lightllm_per_token_group_quant_fp8(x, group_size, x_q, x_s, eps=1e-10, dtype=torch.float8_e4m3fn) + + +# copy from +# https://github.com/deepseek-ai/DeepGEMM/blob/bd2a77552886b98c205af12f8d7d2d61247c4b27/deep_gemm/jit_kernels/utils.py#L58 +def get_tma_aligned_size(x: int, element_size: int) -> int: + """ + Global memory address of TMA must be 16-byte aligned. + Since we use column-major layout for the LHS scaling tensor, + the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes. + + Arguments: + x: original M-axis shape of the LHS scaling tensor. + element_size: element size of the LHS scaling tensor. + + Returns: + M-axis shape of the LHS scaling tensor after padding. + """ + tma_alignment_bytes = 16 + assert tma_alignment_bytes % element_size == 0 + alignment = tma_alignment_bytes // element_size + return ceil_div(x, alignment) * alignment + + +@triton.jit +def _tma_align_input_scale_kernel( + input_scale_ptr, + output_ptr, + m, + k_div_block_size, + input_scale_stride_m, + input_scale_stride_k, + output_stride_m, + output_stride_k, + BLOCK_SIZE_K: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + grid_m = tl.num_programs(0) + k_offsets = tl.arange(0, BLOCK_SIZE_K) + + for m_base in range(pid_m, m, grid_m): + input_offset = input_scale_ptr + m_base * input_scale_stride_m + k_offsets * input_scale_stride_k + input_data = tl.load(input_offset, mask=k_offsets < k_div_block_size) + + output_offset = output_ptr + k_offsets * output_stride_k + m_base * output_stride_m + tl.store(output_offset, input_data, mask=k_offsets < k_div_block_size) + + +def tma_align_input_scale(input_scale: torch.Tensor): + assert input_scale.dim() == 2 + m, k_div_block_size = input_scale.shape + padd_m = get_tma_aligned_size(m, input_scale.element_size()) + output = torch.empty((k_div_block_size, padd_m), dtype=input_scale.dtype, device=input_scale.device) + + grid_m = min(m, 8192) + BLOCK_SIZE_K = triton.next_power_of_2(k_div_block_size) + + _tma_align_input_scale_kernel[(grid_m,)]( + input_scale_ptr=input_scale, + output_ptr=output, + m=m, + k_div_block_size=k_div_block_size, + input_scale_stride_m=input_scale.stride(0), + input_scale_stride_k=input_scale.stride(1), + output_stride_m=output.stride(1), # Note: these are swapped + output_stride_k=output.stride(0), # for column-major + BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + return output.t()[:m] + + def torch_quant(x, group_size, dtype=torch.float8_e4m3fn): M, N = x.shape x_q = torch.randn((M, N)).cuda().to(torch.float8_e4m3fn) @@ -115,7 +209,26 @@ def torch_quant(x, group_size, dtype=torch.float8_e4m3fn): return x_q.reshape(M, N), x_s -if __name__ == "__main__": +def test_tma_align(): + m = 576 + k = 8192 + x = torch.randn((m, k // 128), dtype=torch.float32).cuda() + for _ in range(10): + x_padded = tma_align_input_scale(x) + print(x_padded.shape) + import time + + torch.cuda.synchronize() + start = time.time() + for _ in range(100): + x_padded = tma_align_input_scale(x) + torch.cuda.synchronize() + print("Time:", time.time() - start) + x_padded = tma_align_input_scale(x) + print(torch.abs(x_padded - x).max()) + + +def test_per_token_group_quant_fp8(): group_size = 128 x = torch.randn((1024, 8192), dtype=torch.bfloat16).cuda() @@ -127,3 +240,7 @@ def torch_quant(x, group_size, dtype=torch.float8_e4m3fn): th_x_q, th_x_s = torch_quant(x, group_size) print("th_x_s - x_s", torch.abs(th_x_s - x_s.reshape(-1)).max()) print("th_x_q - x_q", torch.abs(th_x_q.to(torch.float32) - x_q.to(torch.float32)).max()) + + +if __name__ == "__main__": + test_tma_align() diff --git a/lightllm/distributed/communication_op.py b/lightllm/distributed/communication_op.py index d091375fc0..d170ce1155 100644 --- a/lightllm/distributed/communication_op.py +++ b/lightllm/distributed/communication_op.py @@ -34,6 +34,7 @@ get_global_rank, get_current_rank_in_dp, ) +from lightllm.utils.device_utils import get_device_sm_count from contextlib import nullcontext, contextmanager logger = init_logger(__name__) @@ -54,7 +55,12 @@ try: import deep_ep + from deep_gemm.jit_kernels.utils import set_num_sms + deepep_sms = int(os.getenv("DEEPEP_SMS", deep_ep.Buffer.num_sms)) + device_sms = get_device_sm_count() + deep_ep.Buffer.set_num_sms(deepep_sms) + set_num_sms(device_sms - deepep_sms) HAS_DEEPEP = True except: HAS_DEEPEP = False diff --git a/lightllm/models/deepseek2/infer_struct.py b/lightllm/models/deepseek2/infer_struct.py index c7c884e022..b247c8e98c 100644 --- a/lightllm/models/deepseek2/infer_struct.py +++ b/lightllm/models/deepseek2/infer_struct.py @@ -19,5 +19,6 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if self.is_prefill: self.b_kv_start_loc = self.b_seq_len.cumsum(dim=0) - self.b_seq_len + self.max_value_in_b_seq_len = self.b_seq_len.max().item() return diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 70bdd35345..ff0a3ed541 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -29,7 +29,6 @@ from lightllm.distributed.communication_op import all_gather, all_gather_into_tensor, all_reduce, reduce_scatter_tensor from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.dist_utils import get_global_world_size -from lightllm.utils.custom_kernel_utis import torch_cat_3 class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer): @@ -254,6 +253,7 @@ def _decompress_kv( compressed_kv, k_rope, infer_state.b_req_idx, + infer_state.max_value_in_b_seq_len, infer_state.b_seq_len, infer_state.req_manager.req_to_token_indexs, infer_state.b_kv_start_loc, @@ -290,10 +290,7 @@ def _context_attention_flashinfer_kernel_with_CC( o_tensor = ( self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out ) - repeat_k_rope = self.alloc_tensor((k_rope.shape[0], self.tp_q_head_num_, k_rope.shape[2]), dtype=k_rope.dtype) - repeat_rope(repeat_k_rope, k_rope) - - k = torch_cat_3([k_nope, repeat_k_rope], dim=-1) + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1) infer_state.prefill_wrapper.run(q, k, v, out=o_tensor) return o_tensor @@ -309,9 +306,7 @@ def _context_attention_flashinfer_kernel_with_CC_fp8( o_tensor = ( self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out ) - repeat_k_rope = self.alloc_tensor((k_rope.shape[0], self.tp_q_head_num_, k_rope.shape[2]), dtype=k_rope.dtype) - repeat_rope(repeat_k_rope, k_rope) - k = torch_cat_3([k_nope, repeat_k_rope], dim=-1) + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1) infer_state.prefill_wrapper.run(q, k, v, out=o_tensor) return o_tensor @@ -620,12 +615,15 @@ def overlap_tpsp_token_forward( _0_o = None _0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) _0_router_logits = layer_weight.moe_gate.mm(_0_input1) - # 1 hook if getattr(infer_state1, "hook", None) is not None: infer_state1.hook() infer_state1.hook = None + # 0 shared expert + if self.n_shared_experts is not None: + _0_shared_output = LlamaTransformerLayerInfer._ffn(self, _0_input1, infer_state, layer_weight) + # 0 dispatch ( _0_recv_x, @@ -637,10 +635,6 @@ def overlap_tpsp_token_forward( ) = layer_weight.experts.low_latency_dispatch(_0_input1, _0_router_logits) infer_state.hook = _0_hook - # 0 shared expert - if self.n_shared_experts is not None: - _0_shared_output = LlamaTransformerLayerInfer._ffn(self, _0_input1, infer_state, layer_weight) - # 1 attention _1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight) _1_cache_kv = self._pre_cache_kv(infer_state1, layer_weight) @@ -656,12 +650,15 @@ def overlap_tpsp_token_forward( # to do gate and disptatch _1_router_logits = layer_weight.moe_gate.mm(_1_input1) - # 0 hook if getattr(infer_state, "hook", None) is not None: infer_state.hook() infer_state.hook = None + # 1 shared expert + if self.n_shared_experts is not None: + _1_shared_output = LlamaTransformerLayerInfer._ffn(self, _1_input1, infer_state1, layer_weight) + # 1 dispatch ( _1_recv_x, @@ -673,10 +670,6 @@ def overlap_tpsp_token_forward( ) = layer_weight.experts.low_latency_dispatch(_1_input1, _1_router_logits) infer_state1.hook = _1_hook - # 1 shared expert - if self.n_shared_experts is not None: - _1_shared_output = LlamaTransformerLayerInfer._ffn(self, _1_input1, infer_state1, layer_weight) - # moe calu expected_m = triton.cdiv( input_embdings.shape[0] * get_global_world_size() * self.num_experts_per_tok, self.n_routed_experts @@ -756,21 +749,12 @@ def overlap_tpsp_context_forward( infer_state1.hook() infer_state1.hook = None - # 0 dispatch execute - ( - _0_qinput_tensor, - _0_recv_x, - _0_topk_idx, - _0_topk_weight, - _0_num_recv_tokens_per_expert_list, - _0_handle, - _0_hook, - ) = layer_weight.experts.dispatch(_0_input1, _0_router_logits) - infer_state.hook = _0_hook + _0_topk_weight, _0_topk_idx, _0_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( + _0_input1, _0_router_logits + ) + from deep_ep import Buffer - # 0 shared expert - if self.n_shared_experts is not None: - _0_shared_output = LlamaTransformerLayerInfer._ffn(self, _0_input1, infer_state, layer_weight) + _0_overlap_event = Buffer.capture() # 1 attention _1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight) @@ -788,22 +772,31 @@ def overlap_tpsp_context_forward( _1_router_logits = layer_weight.moe_gate.mm(_1_input1) + # 0 dispatch execute + ( + _0_recv_x, + _0_recv_topk_idx, + _0_recv_topk_weight, + _0_num_recv_tokens_per_expert_list, + _0_handle, + _0_hook, + ) = layer_weight.experts.dispatch(_0_qinput_tensor, _0_topk_idx, _0_topk_weight, overlap_event=_0_overlap_event) + infer_state.hook = _0_hook + # wait 0 dispatch if getattr(infer_state, "hook", None) is not None: infer_state.hook() infer_state.hook = None - # 1 dispatch execute - ( - _1_qinput_tensor, - _1_recv_x, - _1_topk_idx, - _1_topk_weight, - _1_num_recv_tokens_per_expert_list, - _1_handle, - _1_hook, - ) = layer_weight.experts.dispatch(_1_input1, _1_router_logits) - infer_state1.hook = _1_hook + _1_topk_weight, _1_topk_idx, _1_qinput_tensor = layer_weight.experts.select_experts_and_quant_input( + _1_input1, _1_router_logits + ) + + _1_overlap_event = Buffer.capture() + + # 0 shared expert + if self.n_shared_experts is not None: + _0_shared_output = LlamaTransformerLayerInfer._ffn(self, _0_input1, infer_state, layer_weight) # 1 shared expert if self.n_shared_experts is not None: @@ -811,21 +804,33 @@ def overlap_tpsp_context_forward( # 0 moe calu _0_moe_out = layer_weight.experts.prefilled_group_gemm( - _0_num_recv_tokens_per_expert_list, _0_recv_x, _0_topk_idx, _0_input1, _0_qinput_tensor, _0_topk_weight + _0_num_recv_tokens_per_expert_list, _0_recv_x, _0_recv_topk_idx, _0_recv_topk_weight ) + # 1 dispatch execute + ( + _1_recv_x, + _1_recv_topk_idx, + _1_recv_topk_weight, + _1_num_recv_tokens_per_expert_list, + _1_handle, + _1_hook, + ) = layer_weight.experts.dispatch(_1_qinput_tensor, _1_topk_idx, _1_topk_weight, overlap_event=_1_overlap_event) + infer_state1.hook = _1_hook + # wait 1 dispatch if getattr(infer_state1, "hook", None) is not None: infer_state1.hook() infer_state1.hook = None + _0_combine_event = Buffer.capture() # 0 combine execute - _0_ffn_out, _0_hook = layer_weight.experts.combine(_0_moe_out, _0_handle) + _0_ffn_out, _0_hook = layer_weight.experts.combine(_0_moe_out, _0_handle, _0_combine_event) infer_state.hook = _0_hook # 1 moe calc _1_moe_out = layer_weight.experts.prefilled_group_gemm( - _1_num_recv_tokens_per_expert_list, _1_recv_x, _1_topk_idx, _1_input1, _1_qinput_tensor, _1_topk_weight + _1_num_recv_tokens_per_expert_list, _1_recv_x, _1_recv_topk_idx, _1_recv_topk_weight ) # wait 0 combine @@ -833,13 +838,15 @@ def overlap_tpsp_context_forward( infer_state.hook() infer_state.hook = None + _1_combine_event = Buffer.capture() + _0_ffn_out *= self.routed_scaling_factor if self.n_shared_experts is not None: _0_ffn_out.add_(_0_shared_output) input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_)) # 1 combine execute - _1_ffn_out, _1_hook = layer_weight.experts.combine(_1_moe_out, _1_handle) + _1_ffn_out, _1_hook = layer_weight.experts.combine(_1_moe_out, _1_handle, _1_combine_event) def _1_hook_post(): _1_hook() diff --git a/lightllm/models/deepseek2/triton_kernel/repeat_rope.py b/lightllm/models/deepseek2/triton_kernel/repeat_rope.py index 386de3d4ec..dc70b7f972 100644 --- a/lightllm/models/deepseek2/triton_kernel/repeat_rope.py +++ b/lightllm/models/deepseek2/triton_kernel/repeat_rope.py @@ -16,19 +16,30 @@ def _repeat_rope_tensor( copy_head_num, head_dim, total_len, + NEED_MASK: tl.constexpr, + BLOCK_HEAD: tl.constexpr, BLOCK_N: tl.constexpr, ): start_index = tl.program_id(0) grid_num = tl.num_programs(0) offs_d = tl.arange(0, BLOCK_N) + offs_head = tl.arange(0, BLOCK_HEAD) for cur_index in range(start_index, total_len, step=grid_num): - in_tensor = tl.load( - in_ptr + in_stride_0 * cur_index + in_stride_1 * 0 + offs_d, mask=offs_d < head_dim, other=0 - ) - for cur_head in tl.range(copy_head_num, num_stages=3): + if NEED_MASK: + in_tensor = tl.load( + in_ptr + in_stride_0 * cur_index + in_stride_1 * 0 + offs_d, mask=offs_d < head_dim, other=0 + ) + tl.store( + out_ptr + out_stride_0 * cur_index + out_stride_1 * offs_head[:, None] + offs_d[None, :], + in_tensor[None, :], + mask=(offs_head[:, None] < copy_head_num) & (offs_d[None, :] < head_dim), + ) + else: + in_tensor = tl.load(in_ptr + in_stride_0 * cur_index + in_stride_1 * 0 + offs_d) tl.store( - out_ptr + out_stride_0 * cur_index + out_stride_1 * cur_head + offs_d, in_tensor, mask=offs_d < head_dim + out_ptr + out_stride_0 * cur_index + out_stride_1 * offs_head[:, None] + offs_d[None, :], + in_tensor[None, :], ) return @@ -41,18 +52,23 @@ def repeat_rope(dest_tensor: torch.Tensor, source_tensor: torch.Tensor): assert head_num == 1 BLOCK_N = triton.next_power_of_2(head_dim) + BLOCK_HEAD = triton.next_power_of_2(repeat_head_num) if BLOCK_N <= 256: num_warps = 1 elif BLOCK_N <= 1024: + num_warps = 2 + elif BLOCK_N <= 2048: num_warps = 4 else: num_warps = 8 - if seq_len <= 16 * 1024: + if seq_len <= 8 * 1024: grid = (seq_len,) else: - grid = (16 * 1024,) + grid = (8 * 1024,) + + NEED_MASK = (BLOCK_N != head_dim) or (BLOCK_HEAD != repeat_head_num) _repeat_rope_tensor[grid]( source_tensor, @@ -62,8 +78,9 @@ def repeat_rope(dest_tensor: torch.Tensor, source_tensor: torch.Tensor): repeat_head_num, head_dim, seq_len, + NEED_MASK=NEED_MASK, + BLOCK_HEAD=BLOCK_HEAD, BLOCK_N=BLOCK_N, num_warps=num_warps, - num_stages=3, ) return diff --git a/lightllm/models/deepseek2/triton_kernel/sample_kv.py b/lightllm/models/deepseek2/triton_kernel/sample_kv.py index 31339bf370..cf42f15551 100644 --- a/lightllm/models/deepseek2/triton_kernel/sample_kv.py +++ b/lightllm/models/deepseek2/triton_kernel/sample_kv.py @@ -70,6 +70,7 @@ def sample_kv( kv_nope, kv_rope, b_req_idx, + max_value_in_b_seq_len, b_seq_len, req_to_token_indexs, b_kv_start_loc, @@ -87,7 +88,7 @@ def sample_kv( batch = b_seq_len.shape[0] - max_input_len = b_seq_len.max() + max_input_len = max_value_in_b_seq_len grid = ( batch, triton.cdiv(max_input_len, BLOCK), diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index f6f4e67e77..5d3deaee68 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -163,6 +163,12 @@ def make_argument_parser() -> argparse.ArgumentParser: default=6, help="schedule new requests after every router_max_wait_tokens decode steps.", ) + parser.add_argument( + "--disable_aggressive_schedule", + action="store_true", + help="""aggressive schedule can lead to frequent prefill interruptions during decode. + disabling it allows the router_max_wait_tokens parameter to work more effectively.""", + ) parser.add_argument("--use_dynamic_prompt_cache", action="store_true", help="use_dynamic_prompt_cache test") diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 8c02dcda20..a879450461 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -89,7 +89,7 @@ def normal_or_p_d_start(args): args.token_healing_mode, args.use_reward_model, args.return_all_prompt_logprobs, - args.first_token_constraint_mode, + args.output_constraint_mode != "none", ].count(True) <= 1 # 部分模式目前还无法与dynamic_prompt_cache一起跑,to do。 if args.use_dynamic_prompt_cache: diff --git a/lightllm/server/api_tgi.py b/lightllm/server/api_tgi.py index 81d0cf4d6f..675fd35391 100755 --- a/lightllm/server/api_tgi.py +++ b/lightllm/server/api_tgi.py @@ -126,6 +126,7 @@ async def tgi_generate_impl(request: Request, httpserver_manager: HttpServerMana "generated_tokens": count_output_tokens_dict[sub_id], "finish_reason": finish_status_dict[sub_id].get_finish_reason(), "tokens": tokens_dict[sub_id], + "prompt_tokens": tokens_dict[sub_id][0]["prompt_tokens"], } if prompt_token_ids is not None: ret["prompt_token_ids"] = prompt_token_ids diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index d67b3e2bea..e38aad0e59 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -36,6 +36,7 @@ class StartArgs: router_token_ratio: float = field(default=0.0) router_max_new_token_len: int = field(default=1024) router_max_wait_tokens: int = field(default=6) + disable_aggressive_schedule: bool = field(default=False) use_dynamic_prompt_cache: bool = field(default=False) chunked_prefill_size: int = field(default=8192) disable_chunked_prefill: bool = field(default=False) diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index ef68eceb30..51411c3011 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -26,7 +26,7 @@ from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient from .stats import Stats from .pause_strategy import Fcfs, select_paused_reqs -from lightllm.utils.log_utils import init_logger +from lightllm.utils.log_utils import init_logger, log_time_ready from lightllm.server.router.token_load import TokenLoad from lightllm.server.metrics.manager import MetricClient from lightllm.common.basemodel.infer_lock import g_router_lock @@ -266,9 +266,11 @@ async def loop_for_fwd( self.metric_client.gauge_set("lightllm_batch_pause_size", 0.0) self.metric_client.gauge_set("lightllm_queue_size", 0.0) self.metric_client.gauge_set("lightllm_batch_current_max_tokens", 0.0) - for dp_i in range(self.dp_size_in_node): - frozen_token_num = self.shared_token_load.get_frozened_token_count(dp_i) - logger.debug(f"dp_i {dp_i} frozen token num: {frozen_token_num} \n") + # 60s print once + if log_time_ready("frozen_info", 60): + for dp_i in range(self.dp_size_in_node): + frozen_token_num = self.shared_token_load.get_frozened_token_count(dp_i) + logger.debug(f"dp_i {dp_i} frozen token num: {frozen_token_num} \n") if self.running_batch is None: await asyncio.sleep(0.01) # 10ms @@ -318,7 +320,11 @@ async def _step(self): self.running_batch = new_batch await self._prefill_batch(self.running_batch) self._filter_runing_batch() - self.has_wait_tokens = self.max_wait_tokens + + # 激进调度控制 + if not self.args.disable_aggressive_schedule: + self.has_wait_tokens = self.max_wait_tokens + elif self.is_multinode_and_multidp: # 在多节点多 dp 的模式下,如果当前 running_batch 为None, 也需要不断的调用 decode 操作, # 因为其他节点上的dp可能存在运行的请求,所以本节点也需要调用decode,推理后端的backend会 @@ -333,7 +339,11 @@ async def _step(self): new_mini_batch = await self.get_schedule_result(self.running_batch) self.has_wait_tokens = 0 if new_mini_batch is not None: - self.has_wait_tokens = self.max_wait_tokens + + # 激进调度控制 + if not self.args.disable_aggressive_schedule: + self.has_wait_tokens = self.max_wait_tokens + self.stats_tool.count_prompt_tokens(new_mini_batch) await self._prefill_batch(new_mini_batch) if not new_mini_batch.is_clear(): diff --git a/lightllm/utils/custom_kernel_utis.py b/lightllm/utils/custom_kernel_utis.py index bf5a65719d..f72dfabfcb 100644 --- a/lightllm/utils/custom_kernel_utis.py +++ b/lightllm/utils/custom_kernel_utis.py @@ -19,7 +19,7 @@ def custom_cat(tensors): start_loc = 0 for t, size in zip(tensors, sizes): - out_tensor[start_loc : (start_loc + size)].copy_(t, non_blocking=True) + out_tensor[start_loc : (start_loc + size)].copy_(t) start_loc += size return out_tensor diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index c030a17437..d4f975c0f2 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -1,9 +1,11 @@ import os import json +import torch from easydict import EasyDict from functools import lru_cache from lightllm.utils.log_utils import init_logger + logger = init_logger(__name__) @@ -18,7 +20,15 @@ def get_unique_server_name(): return service_uni_name +def set_cuda_arch(args): + if args.enable_flashinfer_prefill or args.enable_flashinfer_decode: + capability = torch.cuda.get_device_capability() + arch = f"{capability[0]}.{capability[1]}" + os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}" + + def set_env_start_args(args): + set_cuda_arch(args) if not isinstance(args, dict): args = vars(args) os.environ["LIGHTLLM_START_ARGS"] = json.dumps(args) diff --git a/lightllm/utils/log_utils.py b/lightllm/utils/log_utils.py index d3d007b31c..f15309d5cf 100644 --- a/lightllm/utils/log_utils.py +++ b/lightllm/utils/log_utils.py @@ -4,6 +4,7 @@ import logging import sys import os +import time from typing import Optional _FORMAT = "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s" @@ -13,6 +14,7 @@ _LOG_LEVEL = getattr(logging, _LOG_LEVEL.upper(), 0) _LOG_DIR = os.environ.get("LIGHTLLM_LOG_DIR", None) + class NewLineFormatter(logging.Formatter): """Adds logging prefix to newlines to align multi-line messages.""" @@ -44,14 +46,14 @@ def _setup_logger(): _default_handler.flush = sys.stdout.flush # type: ignore _default_handler.setLevel(_LOG_LEVEL) _root_logger.addHandler(_default_handler) - + if _default_file_handler is None and _LOG_DIR is not None: if not os.path.exists(_LOG_DIR): try: os.makedirs(_LOG_DIR) except OSError as e: _root_logger.warn(f"Error creating directory {_LOG_DIR} : {e}") - _default_file_handler = logging.FileHandler(_LOG_DIR + '/default.log') + _default_file_handler = logging.FileHandler(_LOG_DIR + "/default.log") _default_file_handler.setLevel(_LOG_LEVEL) _default_file_handler.setFormatter(fmt) _root_logger.addHandler(_default_file_handler) @@ -61,6 +63,7 @@ def _setup_logger(): # being propagated to the parent logger. _root_logger.propagate = False + # The logger is initialized when the module is imported. # This is thread-safe as the module is only imported once, # guaranteed by the Python GIL. @@ -91,3 +94,24 @@ def init_logger(name: str): logger.addHandler(_inference_log_file_handler[pid]) logger.propagate = False return logger + + +_log_time_mark_dict = {} + + +def log_time_ready(mark_name, time_count: int): + """ + time_count 间隔时间超过多少s调用该函数会返回True,否则返回False + 用于控制一些日志输出的频率 + """ + global _log_time_mark_dict + + if mark_name not in _log_time_mark_dict: + _log_time_mark_dict[mark_name] = time.time() + return False + cur_time_mark = time.time() + if cur_time_mark - _log_time_mark_dict[mark_name] >= time_count: + _log_time_mark_dict[mark_name] = cur_time_mark + return True + else: + return False diff --git a/requirements.txt b/requirements.txt index fedbdb1a9c..10e3f3046e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -78,7 +78,6 @@ tiktoken==0.7.0 matplotlib==3.8.2 psutil==5.9.4 prometheus_client==0.20.0 -outlines==0.2.1 cchardet==2.1.7 ujson==5.10.0 frozendict==2.4.6 @@ -87,3 +86,5 @@ easydict==1.13 gunicorn==23.0.0 vllm==0.7.2 flashinfer-python==0.2.2 +sglang-kernel==0.0.5.post4 +outlines==0.2.1 diff --git a/setup.py b/setup.py index 9180e7711f..1fcaa7ac0e 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,6 @@ install_requires=[ "pyzmq", "uvloop", - "torch", "transformers", "einops", "packaging", diff --git a/test/model/model_infer.py b/test/model/model_infer.py index 77bc01c684..a466206f63 100644 --- a/test/model/model_infer.py +++ b/test/model/model_infer.py @@ -7,7 +7,7 @@ from lightllm.utils.dist_utils import init_distributed_env from lightllm.utils.envs_utils import get_env_start_args from lightllm.models.deepseek2.model import Deepseek2TpPartModel -from lightllm.common.basemodel.microbatch_overlap_objs import DecodeMicroBatch +from lightllm.common.basemodel.microbatch_overlap_objs import DecodeMicroBatch, PrefillMicroBatch from torch.profiler import profile, record_function, ProfilerActivity @@ -53,6 +53,66 @@ def test_model_inference(args, model_class): return +def overlap_prefill( + model_part, + batch_size, + max_len_in_batch, + input_ids, + mem_indexes, + b_req_idx, + b_start_loc, + b_seq_len, + total_token_num, + b_ready_cache_len, +): + + _0_batch_size = batch_size // 2 + _0_total_token_num = total_token_num // 2 + _0_max_len_in_batch = max_len_in_batch + _0_input_ids = input_ids[: total_token_num // 2] + _0_mem_indexes = mem_indexes[: total_token_num // 2] + _0_b_req_idx = b_req_idx[: batch_size // 2] + _0_b_seq_len = b_seq_len[: batch_size // 2] + _0_b_start_loc = b_start_loc[: batch_size // 2] + _o_b_ready_cache_len = b_ready_cache_len[: batch_size // 2] + micro_batch1 = PrefillMicroBatch( + _0_batch_size, + _0_total_token_num, + _0_max_len_in_batch, + _0_input_ids, + _0_mem_indexes, + _0_b_req_idx, + _0_b_start_loc, + _0_b_seq_len, + _o_b_ready_cache_len, + ) + + _1_batch_size = batch_size - batch_size // 2 + _1_total_token_num = total_token_num - total_token_num // 2 + _1_max_len_in_batch = max_len_in_batch + _1_input_ids = input_ids[total_token_num // 2 :] + _1_mem_indexes = mem_indexes[total_token_num // 2 :] + _1_b_req_idx = b_req_idx[batch_size // 2 :] + _1_b_seq_len = b_seq_len[batch_size // 2 :] + _1_b_start_loc = b_start_loc[: batch_size // 2] + _1_b_ready_cache_len = b_ready_cache_len[batch_size // 2 :] + + micro_batch2 = PrefillMicroBatch( + _1_batch_size, + _1_total_token_num, + _1_max_len_in_batch, + _1_input_ids, + _1_mem_indexes, + _1_b_req_idx, + _1_b_start_loc, + _1_b_seq_len, + _1_b_ready_cache_len, + ) + + logits, logits1 = model_part.microbatch_overlap_prefill(micro_batch1, micro_batch2) + return torch.cat((logits, logits1), dim=0) + + def overlap_decode( model_part, batch_size, max_len_in_batch, input_ids, mem_indexes, b_req_idx, b_start_loc, b_seq_len, total_token_num ): @@ -121,7 +181,7 @@ def torch_profile(fn, log_dir=None): with profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=False, - on_trace_ready=torch.profiler.tensorboard_trace_handler(log_dir) + on_trace_ready=torch.profiler.tensorboard_trace_handler(log_dir), ) as prof: fn() print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) @@ -138,7 +198,7 @@ def tppart_model_infer(args, model_class, model_kvargs, batch_size, input_len, o enable_decode_overlap = args.enable_decode_microbatch_overlap group_size = 1 - if enable_decode_overlap: + if enable_decode_overlap or args.enable_prefill_microbatch_overlap: assert batch_size % 2 == 0, "batch size must be even number" group_size = 2 init_distributed_env(model_kvargs) @@ -171,19 +231,32 @@ def tppart_model_infer(args, model_class, model_kvargs, batch_size, input_len, o total_token_num = input_len * batch_size mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0]).cuda() - - logics = model_part.forward( - batch_size, - total_token_num, - input_len, - test_data, - mem_indexes, - b_req_idx, - b_start_loc, - b_seq_len, - b_ready_cache_len=b_ready_cache_len, - is_prefill=True, - ) + if args.enable_prefill_microbatch_overlap: + logics = overlap_prefill( + model_part, + batch_size, + input_len, + test_data, + mem_indexes, + b_req_idx, + b_start_loc, + b_seq_len, + total_token_num, + b_ready_cache_len, + ) + else: + logics = model_part.forward( + batch_size, + total_token_num, + input_len, + test_data, + mem_indexes, + b_req_idx, + b_start_loc, + b_seq_len, + b_ready_cache_len=b_ready_cache_len, + is_prefill=True, + ) prob_out = torch.softmax(logics, dim=-1) predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) predict_ids = predict_ids.detach().cpu().numpy() @@ -255,40 +328,32 @@ def tppart_model_infer(args, model_class, model_kvargs, batch_size, input_len, o if args.profile: proton.start(name="forward_prefill", context="python") - if args.torch_profile: - print("Profile Prefill") - try: - torch_profile( - lambda: model_part.forward( - batch_size, - total_token_num, - input_len, - test_data, - mem_indexes, - b_req_idx, - b_start_loc, - b_seq_len, - b_ready_cache_len=b_ready_cache_len, - is_prefill=True, - ), - log_dir=f"./logs_decode_overlap/forward_prefill_{model_kvargs['rank_id']}", - ) - except Exception as e: - print(str(e)) - raise - - logics = model_part.forward( - batch_size, - total_token_num, - input_len, - test_data, - mem_indexes, - b_req_idx, - b_start_loc, - b_seq_len, - b_ready_cache_len=b_ready_cache_len, - is_prefill=True, - ) + if args.enable_prefill_microbatch_overlap: + logics = overlap_prefill( + model_part, + batch_size, + input_len, + test_data, + mem_indexes, + b_req_idx, + b_start_loc, + b_seq_len, + total_token_num, + b_ready_cache_len, + ) + else: + logics = model_part.forward( + batch_size, + total_token_num, + input_len, + test_data, + mem_indexes, + b_req_idx, + b_start_loc, + b_seq_len, + b_ready_cache_len=b_ready_cache_len, + is_prefill=True, + ) prob_out = torch.softmax(logics, dim=-1) predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) predict_ids = predict_ids.detach().cpu().numpy() @@ -300,6 +365,45 @@ def tppart_model_infer(args, model_class, model_kvargs, batch_size, input_len, o proton.finalize() print("prefill time cost:", (time.time() - prefill_start_time) * 1000) + if args.torch_profile: + print("Profile Prefill") + try: + if args.enable_prefill_microbatch_overlap: + torch_profile( + lambda: overlap_prefill( + model_part, + batch_size, + input_len, + test_data, + mem_indexes, + b_req_idx, + b_start_loc, + b_seq_len, + total_token_num, + b_ready_cache_len, + ), + log_dir=f"./logs_sglang_4k/forward_prefill_{model_kvargs['rank_id']}", + ) + else: + torch_profile( + lambda: model_part.forward( + batch_size, + total_token_num, + input_len, + test_data, + mem_indexes, + b_req_idx, + b_start_loc, + b_seq_len, + b_ready_cache_len=b_ready_cache_len, + is_prefill=True, + ), + log_dir=f"./logs_sglang_4k/forward_prefill_{model_kvargs['rank_id']}", + ) + except Exception as e: + print(str(e)) + raise + if rank_id == 0: if args.profile: proton.start(name="forward_decode", context="python") @@ -337,7 +441,7 @@ def tppart_model_infer(args, model_class, model_kvargs, batch_size, input_len, o b_seq_len, total_token_num, ), - log_dir=f"./logs_decode_overlap/forward_decode_{model_kvargs['rank_id']}", + log_dir=f"./logs_sglang_4k/forward_decode_{model_kvargs['rank_id']}", ) else: logits = decode( @@ -351,7 +455,7 @@ def tppart_model_infer(args, model_class, model_kvargs, batch_size, input_len, o b_seq_len, total_token_num, ) - if i ==0 and args.torch_profile: + if i == 0 and args.torch_profile: torch_profile( lambda: decode( model_part, @@ -364,7 +468,7 @@ def tppart_model_infer(args, model_class, model_kvargs, batch_size, input_len, o b_seq_len, total_token_num, ), - log_dir=f"./logs_decode_overlap/forward_decode_{model_kvargs['rank_id']}", + log_dir=f"./logs_sglang_4k/forward_decode_{model_kvargs['rank_id']}", ) prob_out = torch.softmax(logits, dim=-1) diff --git a/test/model/test_model.py b/test/model/test_model.py index 3308968645..ffae00297f 100644 --- a/test/model/test_model.py +++ b/test/model/test_model.py @@ -83,8 +83,8 @@ def test_model_infer(self): import torch parser = make_argument_parser() - parser.add_argument("--batch_size", type=int, default=128, help="batch size") - parser.add_argument("--input_len", type=int, default=64, help="input sequence length") + parser.add_argument("--batch_size", type=int, default=2, help="batch size") + parser.add_argument("--input_len", type=int, default=4096, help="input sequence length") parser.add_argument("--output_len", type=int, default=128, help="output sequence length") parser.add_argument( "--profile", diff --git a/unit_tests/common/basemodel/triton_kernel/test_add_in_place.py b/unit_tests/common/basemodel/triton_kernel/test_add_in_place.py new file mode 100644 index 0000000000..7e9a55a775 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/test_add_in_place.py @@ -0,0 +1,31 @@ +import torch +import time +import pytest +from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy +from lightllm.common.basemodel.triton_kernel.add_in_place import add_in_place +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +@pytest.mark.parametrize( + "dim1, dim2, alpha", + [ + (dim1, dim2, alpha) + for dim1 in range(1, 1024, 100) + for dim2 in range(1, 1024, 100) + for alpha in [0.1, 0.3, 0.5, 0.7, 0.1] + ], +) +def test_add_in_place(dim1, dim2, alpha): + input = torch.rand((dim1, dim2), device="cuda") + other = torch.rand((dim1, dim2), device="cuda") + + output = input + other * alpha + add_in_place(input, other, alpha=alpha) + rlt = torch.allclose(input, output, atol=1e-5, rtol=0) + assert rlt + + +if __name__ == "__main__": + pytest.main() diff --git a/unit_tests/models/deepseek2/test_rope_repeat.py b/unit_tests/models/deepseek2/test_rope_repeat.py index 15a4fb9734..2f578e03eb 100644 --- a/unit_tests/models/deepseek2/test_rope_repeat.py +++ b/unit_tests/models/deepseek2/test_rope_repeat.py @@ -8,6 +8,13 @@ def test_torch_cat(): source = torch.randn((100, 1, 1077), device="cuda") dest = torch.randn((100, 7, 1077), device="cuda") + repeat_rope(dest, source) + torch.equal(dest[:, 0, :], source) + torch.equal(dest[:, -1, :], source) + + source = torch.randn((100, 1, 128), device="cuda") + dest = torch.randn((100, 64, 128), device="cuda") + repeat_rope(dest, source) torch.equal(dest[:, 0, :], source) torch.equal(dest[:, -1, :], source)