Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
fbd6c7c
0815-temp
SangChengC Aug 15, 2025
4acd7e7
0815-add-visual-only
SangChengC Aug 15, 2025
68dd163
0820-add-llm-only
SangChengC Aug 20, 2025
9aaf63b
0820
SangChengC Aug 20, 2025
66d0c10
0820-del-metric
SangChengC Aug 20, 2025
1f46fd2
add redis server for vit/llm disaggaggregation
shihaobai Aug 26, 2025
3561a17
remove unused code of http manager
shihaobai Aug 26, 2025
0ef48cb
[0826]modify visual server
SangChengC Aug 26, 2025
d4de040
merge main
shihaobai Aug 26, 2025
27ef8f3
add vit mananger for vit-llm disaggr
shihaobai Aug 27, 2025
70bc956
[0827]temp
SangChengC Aug 27, 2025
ded28b7
update visual server mananger
shihaobai Aug 27, 2025
3a89cf0
add visual start
shihaobai Aug 27, 2025
630d3ee
rename
shihaobai Aug 28, 2025
4407040
add vit register loop
shihaobai Aug 28, 2025
a566580
[0828]temp
SangChengC Aug 28, 2025
1ae9cd3
[0828]temp
SangChengC Aug 28, 2025
c99bb46
fix vit manager
shihaobai Aug 28, 2025
81cbc03
merge
shihaobai Aug 28, 2025
00b3b53
fix llm remote vit init
shihaobai Aug 28, 2025
62f80c4
[0828]temp
SangChengC Aug 28, 2025
b673a36
fix vit transfer
shihaobai Aug 28, 2025
2eaa709
Merge branch 'visual_only2' of https://github.com/ModelTC/lightllm in…
shihaobai Aug 28, 2025
67a3c38
fix connection bug
shihaobai Aug 28, 2025
676215e
add wait for embed for llm
shihaobai Aug 28, 2025
33923b9
[0828]fix vit embed
SangChengC Aug 29, 2025
fcac8e5
[0828]temp
SangChengC Aug 29, 2025
06f7817
[0828]temp
SangChengC Aug 29, 2025
daf1318
[0829]add free_afs
SangChengC Aug 29, 2025
1c16903
[support]add get_image_embedding
SangChengC Sep 3, 2025
c1d98eb
0903
SangChengC Sep 3, 2025
cffa0a0
0909
SangChengC Sep 9, 2025
0a296a1
0911
SangChengC Sep 11, 2025
6b95156
0911-add-other-multimodal's vit dispatch
SangChengC Sep 11, 2025
4853561
[fix]0915-fix-rpyc-cost
SangChengC Sep 16, 2025
ffe2f6b
[fix]fix redis
SangChengC Sep 19, 2025
e723c40
[fix]clean redis before start
SangChengC Sep 23, 2025
efd1213
merge main
Oct 13, 2025
d53a924
merge main
Oct 13, 2025
7d5b9a6
fix other vlm
Oct 13, 2025
b477506
fix other vlm
Oct 13, 2025
ac67fcc
fix other vlm
Oct 13, 2025
40f8c6a
fix other vlm
Oct 13, 2025
d9cb8c3
merge main
Oct 23, 2025
0ed4dc9
Merge branch 'main' into visual_only3
Nov 5, 2025
9838d89
fix1124
Nov 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions lightllm/models/internvl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,22 @@ def init_imageitem_extral_params(
img.extra_params["image_patch_max_num"] = 6
elif num_images > 6:
img.extra_params["image_patch_max_num"] = 0
img.patch_num = self.get_image_patch(img)
return

def init_audioitem_extral_params(
self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams
):
return

def get_image_token_length(self, img: ImageItem):
return (
self.get_image_patch_func(
img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True
)
* self.image_length
def get_image_patch(self, img: ImageItem):
return self.get_image_patch_func(
img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True
)

def get_image_token_length(self, img: ImageItem):
return self.get_image_patch(img) * self.image_length

def get_audio_token_length(self, audio: AudioItem):
L = audio.audio_length
L = L if L <= 480000 else 480000 # max_length < 30s
Expand Down
92 changes: 66 additions & 26 deletions lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import math
import torch
import triton
import triton.language as tl
import torch


@triton.jit
Expand All @@ -17,57 +16,94 @@ def rotary_kernel(
stride_cos_d,
stride_sin_l,
stride_sin_d,
D: tl.constexpr,
HALF_D: tl.constexpr,
L,
H,
D,
BLOCK_SEQ: tl.constexpr,
BLOCK_HEAD: tl.constexpr,
BLOCK_D: tl.constexpr,
):
pid_h = tl.program_id(0).to(tl.int64)
pid_l = tl.program_id(1).to(tl.int64)
pid_blk = tl.program_id(2).to(tl.int64)
pid_head_blk = tl.program_id(0)
pid_seq_blk = tl.program_id(1)

offs_h = pid_head_blk * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
offs_l = pid_seq_blk * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
offs_d = tl.arange(0, BLOCK_D)
d = pid_blk * BLOCK_D + offs_d
mask = d < D

base = pid_l * stride_l + pid_h * stride_h
offs_h = offs_h.to(tl.int64)
offs_l = offs_l.to(tl.int64)
offs_d = offs_d.to(tl.int64)

mask_h = offs_h < H
mask_l = offs_l < L
mask_d = offs_d < D

HALF_D = D // 2

l_b = offs_l[:, None, None]
h_b = offs_h[None, :, None]
d_b = offs_d[None, None, :]

mask = mask_l[:, None, None] & mask_h[None, :, None] & mask_d[None, None, :]

base = l_b * stride_l + h_b * stride_h + d_b * stride_d
x = tl.load(inp_ptr + base, mask=mask, other=0.0)

in_ptr = inp_ptr + base + d * stride_d
cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d
sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d
cos_base_2d = offs_l[:, None] * stride_cos_l + offs_d[None, :] * stride_cos_d
sin_base_2d = offs_l[:, None] * stride_sin_l + offs_d[None, :] * stride_sin_d
mask_ld = mask_l[:, None] & mask_d[None, :]

x = tl.load(in_ptr, mask=mask)
cos = tl.load(cos_ptr_, mask=mask)
sin = tl.load(sin_ptr_, mask=mask)
cos_2d = tl.load(cos_ptr + cos_base_2d, mask=mask_ld, other=0.0)
sin_2d = tl.load(sin_ptr + sin_base_2d, mask=mask_ld, other=0.0)

partner_d = tl.where(d < HALF_D, d + HALF_D, d - HALF_D)
partner_ptr = inp_ptr + base + partner_d * stride_d
partner_val = tl.load(partner_ptr, mask=mask)
rotated = tl.where(d < HALF_D, -partner_val, partner_val)
cos = cos_2d[:, None, :]
sin = sin_2d[:, None, :]

partner_d = tl.where(offs_d < HALF_D, offs_d + HALF_D, offs_d - HALF_D)
partner_d_b = partner_d[None, None, :]

partner_base = l_b * stride_l + h_b * stride_h + partner_d_b * stride_d
partner_val = tl.load(inp_ptr + partner_base, mask=mask, other=0.0)

rotated = tl.where(d_b < HALF_D, -partner_val, partner_val)

y = x * cos + rotated * sin

out_ptr_ = out_ptr + base + d
tl.store(out_ptr_, y, mask=mask)
tl.store(out_ptr + base, y, mask=mask)


def apply_rotary_pos_emb_triton(
tensor: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, BLOCK_D: int = 128
tensor: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
assert tensor.is_cuda and cos.is_cuda and sin.is_cuda
assert cos.is_contiguous() and sin.is_contiguous()
if tensor.ndim != 3:
raise RuntimeError("tensor shape should be [L, H, D]")

orig_dtype = tensor.dtype
x = tensor.float()

cos = cos.repeat(1, 2).view(cos.size(0), -1).contiguous().float()
sin = sin.repeat(1, 2).view(sin.size(0), -1).contiguous().float()

L, H, D = x.shape
HALF_D = D // 2
y = torch.empty_like(x)

grid = (H, L, triton.cdiv(D, BLOCK_D))
BLOCK_SEQ = 16
BLOCK_HEAD = 4
BLOCK_D = triton.next_power_of_2(D)

if D >= 128:
num_warps = 8
else:
num_warps = 4

grid = (
triton.cdiv(H, BLOCK_HEAD),
triton.cdiv(L, BLOCK_SEQ),
)

rotary_kernel[grid](
inp_ptr=x,
Expand All @@ -81,9 +117,13 @@ def apply_rotary_pos_emb_triton(
stride_cos_d=cos.stride(1),
stride_sin_l=sin.stride(0),
stride_sin_d=sin.stride(1),
L=L,
H=H,
D=D,
HALF_D=HALF_D,
BLOCK_SEQ=BLOCK_SEQ,
BLOCK_HEAD=BLOCK_HEAD,
BLOCK_D=BLOCK_D,
num_warps=num_warps,
)

return y.to(orig_dtype)
47 changes: 38 additions & 9 deletions lightllm/models/qwen2_vl/vision_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
)
from torchvision.transforms.v2 import functional as F

from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


IMAGE_FACTOR = 28
MIN_PIXELS = 4 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
Expand Down Expand Up @@ -160,9 +165,19 @@ def rescale_and_normalize(

return images

@torch.inference_mode()
def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
try:
return self._preprocess_bydevice(image, device="cuda")
except Exception as e:
logger.warning(f"Exception during image preprocessing on CUDA: {str(e)}")
torch.cuda.current_stream().synchronize()
return self._preprocess_bydevice(image, device="cpu")

def _preprocess_bydevice(self, image, device="cuda") -> Tuple[torch.Tensor, torch.Tensor]:
image_arr = np.asarray(image, dtype=np.uint8)
image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to("cuda", non_blocking=True)
image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to(device=device, non_blocking=True)

grouped_images, grouped_images_index = group_images_by_shape(
[image_data], disable_grouping=self.disable_grouping
)
Expand All @@ -183,27 +198,39 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
interpolation=self.interpolation,
)
resized_images_grouped[shape] = stacked_images

grouped_images = None
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
resized_images_grouped = None

# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(
resized_images, disable_grouping=self.disable_grouping
)
resized_images = None

processed_images_grouped = {}
processed_grids = {}

for shape, stacked_images in grouped_images.items():
stacked_images = stacked_images.to("cuda", non_blocking=True)

resized_height, resized_width = stacked_images.shape[-2:]
# Fused rescale and normalize

patches = self.rescale_and_normalize(
stacked_images, self.do_rescale, self.rescale_factor, self.do_normalize, self.image_mean, self.image_std
stacked_images,
self.do_rescale,
self.rescale_factor,
self.do_normalize,
self.image_mean,
self.image_std,
)
if patches.ndim == 4:
# add a temporal dimension if we have images
patches = patches.unsqueeze(1)

if patches.shape[1] % self.temporal_patch_size != 0:
repeats = patches[:, -1:].repeat(1, self.temporal_patch_size - 1, 1, 1, 1)
patches = torch.cat([patches, repeats], dim=1)

batch_size, grid_t, channel = patches.shape[:3]
grid_t = grid_t // self.temporal_patch_size
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
Expand All @@ -224,8 +251,7 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
.contiguous()
)
# Reorder dimensions to group grid and patch information for subsequent flattening.
# (batch, grid_t, grid_h, grid_w, merge_h, merge_w, channel, temp_patch_size, patch_h, patch_w)

flatten_patches = patches.view(
batch_size,
grid_t * grid_h * grid_w,
Expand All @@ -235,9 +261,12 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
processed_images_grouped[shape] = flatten_patches
processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size

grouped_images = None

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_grids = reorder_images(processed_grids, grouped_images_index)
pixel_values = torch.cat(processed_images, dim=0) # (num_patches_total, C*T*ps*ps)

pixel_values = torch.cat(processed_images, dim=0)
image_grid_thw = torch.as_tensor(processed_grids)

return pixel_values, image_grid_thw
18 changes: 14 additions & 4 deletions lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import rpyc
import socket
import torch
import torch.distributed as dist

Expand All @@ -6,9 +8,10 @@

from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
from lightllm.utils.infer_utils import mark_cost_time
from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed
from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed, read_afs
from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb
from lightllm.distributed.communication_op import all_reduce
from lightllm.utils.envs_utils import get_env_start_args


"""
Expand All @@ -29,6 +32,9 @@
class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer):
def __init__(self, network_config, mode):
super().__init__(network_config, mode)
self.args = get_env_start_args()
self.cache_client = rpyc.connect("localhost", self.args.cache_port, config={"allow_pickle": True})
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
return

def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight):
Expand All @@ -50,9 +56,13 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
# skip the same image
if img["token_id"] in img_start_token_ids or img["_prefill_"] is False:
continue
# pull the img_embeds by uid from shm
data = read_shm(get_shm_name_embed(img["uuid"]))
img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1))
# pull the img_embeds by uid from shm or afs
if self.args.enable_remote_vit:
embed = read_afs(get_shm_name_embed(img["uuid"]), self.args.image_embed_dir)
else:
embed = read_shm(get_shm_name_embed(img["uuid"]))
self.cache_client.root.release([img["uuid"]])
img_weight.append(bytes2tensor(embed).cuda().reshape(img["token_num"], -1))
img_start_token_ids.append(img["token_id"])
img_token_lens.append(img["token_num"])
img_start_locs.append(img_start_loc)
Expand Down
37 changes: 36 additions & 1 deletion lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--run_mode",
type=str,
choices=["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"],
choices=["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server", "visual"],
default="normal",
help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode,
config_server is for pd split mode used to register pd_master node, and get pd_master node list,
Expand Down Expand Up @@ -516,6 +516,41 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=0.03,
help="""The interval of the schedule time, default is 30ms.""",
)
parser.add_argument(
"--image_embed_dir",
type=str,
default=None,
help="path for vit embed",
)
parser.add_argument(
"--enable_remote_vit",
action="store_true",
help="Whether to enable remote vit for multimodal service.",
)
parser.add_argument(
"--remote_vit_port",
type=int,
default=12346,
help="The port number for the remote vit service.",
)
# redis for vit llm disaggregation
parser.add_argument(
"--redis_port",
type=int,
default=6379,
help="The port number for the redis service in config_server mode.",
)
parser.add_argument(
"--redis_evict_fraction",
type=float,
default=0.3,
help="The evict fraction for the redis service in config_server mode.",
)
parser.add_argument(
"--start_redis",
action="store_true",
help="Whether to start the redis service in config_server mode.",
)
parser.add_argument(
"--enable_cpu_cache",
action="store_true",
Expand Down
Loading