From eba4e3fbb4c568a4b6d657bede93b8fc2bda7eba Mon Sep 17 00:00:00 2001 From: SRE Date: Mon, 22 Sep 2025 20:51:14 +0800 Subject: [PATCH 01/13] feat: add build image workflows (#86) --- .github/workflows/build-images.yaml | 71 +++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 .github/workflows/build-images.yaml diff --git a/.github/workflows/build-images.yaml b/.github/workflows/build-images.yaml new file mode 100644 index 00000000..52d6e31a --- /dev/null +++ b/.github/workflows/build-images.yaml @@ -0,0 +1,71 @@ +name: Build Images + +on: + workflow_dispatch: + push: + branches: + - main + - develop + paths: + - 'docker/**' + - 'src/**' + - 'pyproject.toml' + - 'README.md' + + +env: + IMAGE_NAME: ${{ github.repository }} + NAMESPACE: gradient + +jobs: + build: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + strategy: + matrix: + variant: [blackwell, hopper] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to Container Registry + uses: docker/login-action@v3 + with: + registry: ${{ secrets.ALIYUN_ACR }} + username: ${{ secrets.ALIYUN_ACR_USERNAME }} + password: ${{ secrets.ALIYUN_ACR_PASSWORD }} + + - name: Extract metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ secrets.ALIYUN_ACR }}/${{ env.NAMESPACE }}/${{ env.IMAGE_NAME }} + tags: | + type=ref,event=branch,suffix=-${{ matrix.variant }} + type=ref,event=pr,suffix=-${{ matrix.variant }} + type=raw,value=latest-${{ matrix.variant }},enable={{is_default_branch}} + + - name: Build and push Docker image + uses: docker/build-push-action@v5 + with: + context: . + file: ./docker/Dockerfile.${{ matrix.variant }} + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max + platforms: linux/amd64 + + - name: Generate artifact attestation + uses: actions/attest-build-provenance@v1 + with: + subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + subject-digest: ${{ steps.build.outputs.digest }} + push-to-registry: true From 018b69b0c234ed1dd823f1a071b626f69107355b Mon Sep 17 00:00:00 2001 From: "simont@gradient.network" Date: Tue, 23 Sep 2025 14:40:17 +0800 Subject: [PATCH 02/13] feat: add buildkit --- .github/workflows/build-images.yaml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build-images.yaml b/.github/workflows/build-images.yaml index 52d6e31a..715248ba 100644 --- a/.github/workflows/build-images.yaml +++ b/.github/workflows/build-images.yaml @@ -19,7 +19,7 @@ env: jobs: build: - runs-on: ubuntu-latest + runs-on: arc-runner-set-parallax permissions: contents: read packages: write @@ -33,6 +33,9 @@ jobs: - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 + with: + endpoint: tcp://buildkit-buildkit-service:1234 + - name: Log in to Container Registry uses: docker/login-action@v3 From 057be8d343d41c8943146ecc452636862f328a86 Mon Sep 17 00:00:00 2001 From: "simont@gradient.network" Date: Tue, 23 Sep 2025 14:46:59 +0800 Subject: [PATCH 03/13] feat: update endpoint --- .github/workflows/build-images.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build-images.yaml b/.github/workflows/build-images.yaml index 715248ba..d1123e74 100644 --- a/.github/workflows/build-images.yaml +++ b/.github/workflows/build-images.yaml @@ -34,7 +34,8 @@ jobs: - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 with: - endpoint: tcp://buildkit-buildkit-service:1234 + driver: remote + endpoint: tcp://buildkit-buildkit-service.arc-systems:1234 - name: Log in to Container Registry From afabc735d1dec1274ed72cbfef8515c6e73ae296 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Thu, 25 Sep 2025 17:28:47 +0800 Subject: [PATCH 04/13] add kimi k2 init --- src/parallax/models/kimi_k2.py | 125 +++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 src/parallax/models/kimi_k2.py diff --git a/src/parallax/models/kimi_k2.py b/src/parallax/models/kimi_k2.py new file mode 100644 index 00000000..3cc51bfb --- /dev/null +++ b/src/parallax/models/kimi_k2.py @@ -0,0 +1,125 @@ +""" +hidden_dimefines the Qwen3 model. +""" + +from typing import Optional, Tuple + +import mlx.core as mx +from mlx_lm.models.base import scaled_dot_product_attention +from mlx_lm.models.deepseek_v3 import DeepseekV3Attention as MLXDeepseekV3Attention +from mlx_lm.models.deepseek_v3 import ModelArgs +from mlx_lm.models.deepseek_v3 import DeepseekV3DecoderLayer as MLXDeepseeklock + + +class ParallaxKimiK2Attention(MLXDeepseekV3Attention): + """A custom attention module for Parallax, extending the DeepseekV3 Attention class. + + We apply explicit KV cache handling and passing in `offset` directly from Request. + This version returns the new K and V states for external caching. + """ + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + offset: int = 0, + lengths: Optional[mx.array] = None, + ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: + """ + Attention forward pass with explicit KV cache handling. + + Args: + x: (batch, target_len, hidden_dim) - Input hidden states for the current query segment. + mask: (batch, n_q_heads, target_len, source_len) + cache: Optional tuple (past_k, past_v). + shape: (batch, n_kv_heads, S_past_padded, head_dim) + offset: source_len_padded (scalar, used for RoPE calculation). + + Returns: + output_h: (batch, target_len, hidden_dim) - Output hidden states. + new_k: (batch, n_kv_heads, target_len, head_dim) - New keys for this segment. + new_v: (batch, n_kv_heads, target_len, head_dim) - New values for this segment. + """ + B, L, D = x.shape + if self.q_lora_rank is None: + q = self.q_proj(x) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) + + q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3) + q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1) + compressed_kv = self.kv_a_proj_with_mqa(x) + compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1) + k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) + + k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) + + q_pe = self.rope(q_pe, offset=offset) + k_pe = self.rope(k_pe, offset=offset) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) + queries = mx.concatenate([q_nope, q_pe], axis=-1) + + if cache is not None: + past_k, past_v = cache + if past_k is not None and past_v is not None: + if past_k.shape[2] != offset: + raise ValueError( + f"ParallaxAttention: Expected past_k sequence length {past_k.shape[2]} " + f"to match RoPE offset {offset} (S_past_padded)." + ) + final_keys_for_attn = mx.concatenate( + [past_k, mx.concatenate([k_nope, k_pe], axis=-1)], axis=2 + ) + final_values_for_attn = mx.concatenate([past_v, values], axis=2) + else: + raise ValueError("cache was provided but one of k/v was None.") + else: + final_keys_for_attn = mx.concatenate([k_nope, k_pe], axis=-1) + final_values_for_attn = values + + output = scaled_dot_product_attention( + queries, + final_keys_for_attn, + final_values_for_attn, + scale=self.scale, + mask=mask, + cache=None, + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (mx.concatenate([k_nope, k_pe], axis=-1), values) + + +class ParallaxKimiK2Block(MLXDeepseeklock): + """A custom transformer block for Parallax, extending the Qwen3 Block class. + This version handles the KV cache explicitly and returns new K and V states. + """ + + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__(args) + self.self_attn = ParallaxKimiK2Attention(args) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + offset: int = 0, + lengths: Optional[mx.array] = None, + ): + r, (k_cache, v_cache) = self.self_attn(self.input_layernorm(x), mask, cache, offset=offset) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, (k_cache, v_cache) + + @classmethod + def get_architecture(cls): + """Get the architecture name for the block.""" + return "DeepseekV3ForCausalLM" + + +EntryClass = ParallaxKimiK2Block From 733c31025b032f27e09e7a607de297bc32f7b6fe Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Fri, 26 Sep 2025 12:57:22 +0800 Subject: [PATCH 05/13] update deepseekv2 --- src/parallax/models/deepseek_v2.py | 125 +++++++++++++++++++++++++++++ src/parallax/models/kimi_k2.py | 6 +- 2 files changed, 128 insertions(+), 3 deletions(-) create mode 100644 src/parallax/models/deepseek_v2.py diff --git a/src/parallax/models/deepseek_v2.py b/src/parallax/models/deepseek_v2.py new file mode 100644 index 00000000..2a239ba2 --- /dev/null +++ b/src/parallax/models/deepseek_v2.py @@ -0,0 +1,125 @@ +""" +hidden_dimefines the Qwen3 model. +""" + +from typing import Optional, Tuple + +import mlx.core as mx +from mlx_lm.models.base import scaled_dot_product_attention +from mlx_lm.models.deepseek_v2 import DeepseekV2Attention as MLXDeepseekV2Attention +from mlx_lm.models.deepseek_v2 import ModelArgs +from mlx_lm.models.deepseek_v2 import DeepseekV2DecoderLayer as MLXDeepseekV2Block + + +class ParallaxDeepSeekV2Attention(MLXDeepseekV2Attention): + """A custom attention module for Parallax, extending the DeepseekV2 Attention class. + + We apply explicit KV cache handling and passing in `offset` directly from Request. + This version returns the new K and V states for external caching. + """ + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + offset: int = 0, + lengths: Optional[mx.array] = None, + ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: + """ + Attention forward pass with explicit KV cache handling. + + Args: + x: (batch, target_len, hidden_dim) - Input hidden states for the current query segment. + mask: (batch, n_q_heads, target_len, source_len) + cache: Optional tuple (past_k, past_v). + shape: (batch, n_kv_heads, S_past_padded, head_dim) + offset: source_len_padded (scalar, used for RoPE calculation). + + Returns: + output_h: (batch, target_len, hidden_dim) - Output hidden states. + new_k: (batch, n_kv_heads, target_len, head_dim) - New keys for this segment. + new_v: (batch, n_kv_heads, target_len, head_dim) - New values for this segment. + """ + B, L, D = x.shape + if self.q_lora_rank is None: + q = self.q_proj(x) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) + + q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3) + q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1) + compressed_kv = self.kv_a_proj_with_mqa(x) + compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1) + k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) + + k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) + + q_pe = self.rope(q_pe, offset=offset) + k_pe = self.rope(k_pe, offset=offset) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) + queries = mx.concatenate([q_nope, q_pe], axis=-1) + + if cache is not None: + past_k, past_v = cache + if past_k is not None and past_v is not None: + if past_k.shape[2] != offset: + raise ValueError( + f"ParallaxAttention: Expected past_k sequence length {past_k.shape[2]} " + f"to match RoPE offset {offset} (S_past_padded)." + ) + final_keys_for_attn = mx.concatenate( + [past_k, mx.concatenate([k_nope, k_pe], axis=-1)], axis=2 + ) + final_values_for_attn = mx.concatenate([past_v, values], axis=2) + else: + raise ValueError("cache was provided but one of k/v was None.") + else: + final_keys_for_attn = mx.concatenate([k_nope, k_pe], axis=-1) + final_values_for_attn = values + + output = scaled_dot_product_attention( + queries, + final_keys_for_attn, + final_values_for_attn, + scale=self.scale, + mask=mask, + cache=None, + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (mx.concatenate([k_nope, k_pe], axis=-1), values) + + +class ParallaxDeepSeekV2Block(MLXDeepseekV2Block): + """A custom transformer block for Parallax, extending the Qwen3 Block class. + This version handles the KV cache explicitly and returns new K and V states. + """ + + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__(args, layer_idx=layer_idx) + self.self_attn = ParallaxDeepSeekV2Attention(args) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + offset: int = 0, + lengths: Optional[mx.array] = None, + ): + r, (k_cache, v_cache) = self.self_attn(self.input_layernorm(x), mask, cache, offset=offset) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, (k_cache, v_cache) + + @classmethod + def get_architecture(cls): + """Get the architecture name for the block.""" + return "DeepseekV2ForCausalLM" + + +EntryClass = ParallaxDeepSeekV2Block diff --git a/src/parallax/models/kimi_k2.py b/src/parallax/models/kimi_k2.py index 3cc51bfb..794d7963 100644 --- a/src/parallax/models/kimi_k2.py +++ b/src/parallax/models/kimi_k2.py @@ -8,7 +8,7 @@ from mlx_lm.models.base import scaled_dot_product_attention from mlx_lm.models.deepseek_v3 import DeepseekV3Attention as MLXDeepseekV3Attention from mlx_lm.models.deepseek_v3 import ModelArgs -from mlx_lm.models.deepseek_v3 import DeepseekV3DecoderLayer as MLXDeepseeklock +from mlx_lm.models.deepseek_v3 import DeepseekV3DecoderLayer as MLXDeepseekV3Block class ParallaxKimiK2Attention(MLXDeepseekV3Attention): @@ -93,13 +93,13 @@ def __call__( return self.o_proj(output), (mx.concatenate([k_nope, k_pe], axis=-1), values) -class ParallaxKimiK2Block(MLXDeepseeklock): +class ParallaxKimiK2Block(MLXDeepseekV3Block): """A custom transformer block for Parallax, extending the Qwen3 Block class. This version handles the KV cache explicitly and returns new K and V states. """ def __init__(self, args: ModelArgs, layer_idx: int): - super().__init__(args) + super().__init__(args, layer_idx=layer_idx) self.self_attn = ParallaxKimiK2Attention(args) def __call__( From a5a9ed143d1f8f06a296639e1404795742dfda5d Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Fri, 26 Sep 2025 14:49:07 +0800 Subject: [PATCH 06/13] support deepseek & kimi --- src/parallax/models/deepseek_v2.py | 5 +++ src/parallax/models/kimi_k2.py | 2 ++ src/parallax/server/executor.py | 4 +++ src/parallax/server/kv_cache.py | 52 +++++++++++++++++++++--------- src/parallax_utils/utils.py | 11 +++++-- 5 files changed, 56 insertions(+), 18 deletions(-) diff --git a/src/parallax/models/deepseek_v2.py b/src/parallax/models/deepseek_v2.py index 2a239ba2..f5b5bf59 100644 --- a/src/parallax/models/deepseek_v2.py +++ b/src/parallax/models/deepseek_v2.py @@ -80,6 +80,8 @@ def __call__( final_keys_for_attn = mx.concatenate([k_nope, k_pe], axis=-1) final_values_for_attn = values + if mask is not None: + mask = mx.array(mask, dtype=queries.dtype) # Ensure mask is the same dtype as queries output = scaled_dot_product_attention( queries, final_keys_for_attn, @@ -90,6 +92,9 @@ def __call__( ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + + # print(f"Output values shape: {values.shape}") + # print(f"Output k_nope shape: {(mx.concatenate([k_nope, k_pe], axis=-1)).shape}") return self.o_proj(output), (mx.concatenate([k_nope, k_pe], axis=-1), values) diff --git a/src/parallax/models/kimi_k2.py b/src/parallax/models/kimi_k2.py index 794d7963..a5648988 100644 --- a/src/parallax/models/kimi_k2.py +++ b/src/parallax/models/kimi_k2.py @@ -80,6 +80,8 @@ def __call__( final_keys_for_attn = mx.concatenate([k_nope, k_pe], axis=-1) final_values_for_attn = values + if mask is not None: + mask = mx.array(mask, dtype=queries.dtype) output = scaled_dot_product_attention( queries, final_keys_for_attn, diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index d7b267ce..66a81586 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -160,6 +160,8 @@ def __init__( self.head_dim = self.config.get("head_dim") or self.config.get( "hidden_size" ) // self.config.get("num_attention_heads") + self.qk_nope_head_dim = self.config.get("qk_nope_head_dim", None) + self.qk_rope_head_dim = self.config.get("qk_rope_head_dim", None) self.enable_prefix_cache = enable_prefix_cache self.linear_key_head_dim = self.config.get("linear_key_head_dim", None) self.linear_value_head_dim = self.config.get("linear_value_head_dim", None) @@ -209,6 +211,8 @@ def __init__( linear_v_dim=self.linear_value_head_dim, linear_num_k_heads=self.linear_num_key_heads, linear_num_v_heads=self.linear_num_value_heads, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, max_num_tokens=max_tokens_in_kv_pool, ) mx.set_wired_limit(mx.metal.device_info()["max_recommended_working_set_size"]) diff --git a/src/parallax/server/kv_cache.py b/src/parallax/server/kv_cache.py index da19c337..8a2e5ffa 100644 --- a/src/parallax/server/kv_cache.py +++ b/src/parallax/server/kv_cache.py @@ -37,7 +37,8 @@ class KVCache: def __init__( self, num_kv_heads: int, - head_dim: int, + head_dim_k: int, + head_dim_v: int, num_layers: int, dtype: mx.Dtype, block_size: int = 64, @@ -47,6 +48,8 @@ def __init__( linear_v_dim: Optional[int] = None, linear_num_k_heads: Optional[int] = None, linear_num_v_heads: Optional[int] = None, + qk_nope_head_dim: Optional[int] = None, + qk_rope_head_dim: Optional[int] = None, num_initial_tokens: int = 0, ): """ @@ -59,7 +62,6 @@ def __init__( num_initial_tokens: The number of tokens to initialize the cache with. """ self.num_kv_heads = num_kv_heads - self.head_dim = head_dim self.dtype = dtype self.block_size = block_size self.conv_dim = conv_dim @@ -68,11 +70,18 @@ def __init__( self.linear_v_dim = linear_v_dim self.linear_num_k_heads = linear_num_k_heads self.linear_num_v_heads = linear_num_v_heads + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.head_dim_v = head_dim_v + self.head_dim_k = head_dim_k num_initial_tokens = self.round_up_to_step(num_initial_tokens) # (num_layers, num_kv_heads, seq_len, head_dim) - self.keys = mx.zeros((num_layers, num_kv_heads, num_initial_tokens, head_dim), dtype) - self.values = mx.zeros((num_layers, num_kv_heads, num_initial_tokens, head_dim), dtype) + + self.keys = mx.zeros((num_layers, num_kv_heads, num_initial_tokens, self.head_dim_k), dtype) + self.values = mx.zeros( + (num_layers, num_kv_heads, num_initial_tokens, self.head_dim_v), dtype + ) self.state0 = ( mx.zeros((num_layers, conv_kernel_size - 1, conv_dim), dtype) if conv_dim else None ) @@ -115,8 +124,8 @@ def update( Updates the cache with new key-value pairs. Args: - keys: New keys to add, shape (num_layers, num_kv_heads, target_len, head_dim) - values: New values to add, shape (num_layers, num_kv_heads, target_len, head_dim) + keys: New keys to add, shape (num_layers, num_kv_heads, target_len, head_dim_k) + values: New values to add, shape (num_layers, num_kv_heads, target_len, head_dim_v) """ if state0 is not None and self.state0 is not None: self.state0 = state0 @@ -128,10 +137,11 @@ def update( prev_tokens = self.num_tokens # Grow the cache based on the block_size size if self.needs_grow(seq_len): - num_layers, num_kv_heads, _, head_dim = keys.shape + num_layers, num_kv_heads, _, head_dim_k = keys.shape + _, _, _, head_dim_v = values.shape n_steps = (self.block_size + seq_len - 1) // self.block_size - k_shape = (num_layers, num_kv_heads, n_steps * self.block_size, head_dim) - v_shape = (num_layers, num_kv_heads, n_steps * self.block_size, head_dim) + k_shape = (num_layers, num_kv_heads, n_steps * self.block_size, head_dim_k) + v_shape = (num_layers, num_kv_heads, n_steps * self.block_size, head_dim_v) new_k = mx.zeros(k_shape, keys.dtype) new_v = mx.zeros(v_shape, values.dtype) @@ -167,6 +177,8 @@ def __init__( linear_v_dim: Optional[int] = None, linear_num_k_heads: Optional[int] = None, linear_num_v_heads: Optional[int] = None, + qk_nope_head_dim: Optional[int] = None, + qk_rope_head_dim: Optional[int] = None, ): """ Args: @@ -179,7 +191,6 @@ def __init__( cache_memory_fraction: The fraction of memory to use for the cache. """ self.num_kv_heads = num_kv_heads - self.head_dim = head_dim self.num_layers = num_layers self.dtype = dtype self.block_size = block_size @@ -189,6 +200,13 @@ def __init__( self.linear_v_dim = linear_v_dim self.linear_num_k_heads = linear_num_k_heads self.linear_num_v_heads = linear_num_v_heads + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + if qk_nope_head_dim and qk_rope_head_dim: + self.head_dim_k = qk_nope_head_dim + qk_rope_head_dim + else: + self.head_dim_k = head_dim + self.head_dim_v = head_dim self.request_caches: Dict[str, KVCache] = {} self.tokens_in_cache = 0 @@ -198,7 +216,8 @@ def __init__( kv_cache_memory_fraction=cache_memory_fraction, num_shard_layers=num_layers, num_key_value_heads=num_kv_heads, - head_dim=head_dim, + head_dim_k=self.head_dim_k, + head_dim_v=self.head_dim_v, elem_bytes=dtype.size, ) if max_num_tokens is not None: @@ -264,7 +283,8 @@ def add_request(self, request: Request, num_tokens: int = 128) -> bool: self.request_caches[request.request_id] = KVCache( num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, + head_dim_k=self.head_dim_k, + head_dim_v=self.head_dim_v, num_layers=self.num_layers, dtype=self.dtype, block_size=self.block_size, @@ -311,16 +331,18 @@ def update_requests( Returns: True if requests are updated. """ - batch_size, num_layers, n_kv_heads, _, head_dim = keys.shape + batch_size, num_layers, n_kv_heads, _, head_dim_k = keys.shape + _, _, _, _, head_dim_v = values.shape # Validate - assert keys.shape == values.shape, "key and value must have the same shape" + # assert keys.shape == values.shape, "key and value must have the same shape" assert num_layers == self.num_layers, "key and value must have the same number of layers" assert batch_size == len(requests), "key and value must have the same batch size" assert len(lengths) == batch_size, "lengths must have the same batch size as requests" assert ( n_kv_heads == self.num_kv_heads ), "key and value must have the same number of key-value heads" - assert head_dim == self.head_dim, "key and value must have the same head dimension" + assert head_dim_k == self.head_dim_k, "key and value must have the same head dimension" + assert head_dim_v == self.head_dim_v, "key and value must have the same head dimension" # TODO: Use vmap for better performance for request, key, value, length, state0, state1 in zip( requests, keys, values, lengths, states0, states1 diff --git a/src/parallax_utils/utils.py b/src/parallax_utils/utils.py index 53049c2f..a314df16 100644 --- a/src/parallax_utils/utils.py +++ b/src/parallax_utils/utils.py @@ -47,7 +47,8 @@ def compute_max_tokens_in_cache( kv_cache_memory_fraction: float, num_shard_layers: int, num_key_value_heads: int, - head_dim: int, + head_dim_k: int, + head_dim_v: int, elem_bytes: int, available_cache_bytes: Optional[int] = None, ) -> int: @@ -65,7 +66,9 @@ def compute_max_tokens_in_cache( hw = HardwareInfo.detect() used = mx.get_active_memory() if mx is not None else 0 available_cache_size = int((hw.total_ram_gb * 1024**3 - used) * kv_cache_memory_fraction) - per_token_cache_size = num_shard_layers * num_key_value_heads * head_dim * 2 * elem_bytes + per_token_cache_size = ( + num_shard_layers * num_key_value_heads * (head_dim_k + head_dim_v) * elem_bytes + ) return max(0, available_cache_size // per_token_cache_size) @@ -110,12 +113,14 @@ def compute_max_batch_size( available_cache_bytes = None if memory_gb is not None: available_cache_bytes = int(memory_gb * 1024**3 * kv_cache_memory_fraction) + ## This is an Error due to kv may have different head_dim max_tokens = compute_max_tokens_in_cache( device=device or "", # empty means non-cuda path kv_cache_memory_fraction=kv_cache_memory_fraction, num_shard_layers=num_shard_layers, num_key_value_heads=num_key_value_heads, - head_dim=head_dim, + head_dim_k=head_dim, + head_dim_v=head_dim, elem_bytes=eb, available_cache_bytes=available_cache_bytes, ) From b92e862ae06f9191e49a1066d73c4b9855ca1a2e Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Fri, 26 Sep 2025 14:51:45 +0800 Subject: [PATCH 07/13] update --- src/parallax/models/deepseek_v2.py | 2 +- src/parallax/models/kimi_k2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/parallax/models/deepseek_v2.py b/src/parallax/models/deepseek_v2.py index f5b5bf59..a9bd2254 100644 --- a/src/parallax/models/deepseek_v2.py +++ b/src/parallax/models/deepseek_v2.py @@ -7,8 +7,8 @@ import mlx.core as mx from mlx_lm.models.base import scaled_dot_product_attention from mlx_lm.models.deepseek_v2 import DeepseekV2Attention as MLXDeepseekV2Attention -from mlx_lm.models.deepseek_v2 import ModelArgs from mlx_lm.models.deepseek_v2 import DeepseekV2DecoderLayer as MLXDeepseekV2Block +from mlx_lm.models.deepseek_v2 import ModelArgs class ParallaxDeepSeekV2Attention(MLXDeepseekV2Attention): diff --git a/src/parallax/models/kimi_k2.py b/src/parallax/models/kimi_k2.py index a5648988..d3b38043 100644 --- a/src/parallax/models/kimi_k2.py +++ b/src/parallax/models/kimi_k2.py @@ -7,8 +7,8 @@ import mlx.core as mx from mlx_lm.models.base import scaled_dot_product_attention from mlx_lm.models.deepseek_v3 import DeepseekV3Attention as MLXDeepseekV3Attention -from mlx_lm.models.deepseek_v3 import ModelArgs from mlx_lm.models.deepseek_v3 import DeepseekV3DecoderLayer as MLXDeepseekV3Block +from mlx_lm.models.deepseek_v3 import ModelArgs class ParallaxKimiK2Attention(MLXDeepseekV3Attention): From 75672cc6e38b36fc0b23724525be7406eef0121d Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Fri, 26 Sep 2025 19:18:56 +0800 Subject: [PATCH 08/13] modify modelInfo for different k_head_dim from v_head_dim --- src/backend/server/static_config.py | 32 +++++++++++++++++++++++++++++ src/parallax_utils/utils.py | 6 ++++-- src/scheduling/model_info.py | 25 ++++++++++++++++------ src/scheduling/node.py | 2 ++ 4 files changed, 57 insertions(+), 8 deletions(-) diff --git a/src/backend/server/static_config.py b/src/backend/server/static_config.py index c486119f..c343a6e2 100644 --- a/src/backend/server/static_config.py +++ b/src/backend/server/static_config.py @@ -106,6 +106,38 @@ num_local_experts=128, num_experts_per_tok=8, ), + "deepseek-ai/DeepSeek-V2-Lite": ModelInfo( + model_name="deepseek-ai/DeepSeek-V2-Lite", + head_size=128, + hidden_dim=2048, + intermediate_dim=10944, + num_attention_heads=16, + num_kv_heads=16, + vocab_size=151936, + num_layers=27, + ffn_num_projections=3, + param_bytes_per_element=2, + cache_bytes_per_element=2, + embedding_bytes_per_element=2, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + ), + "moonshotai/Kimi-K2-Instruct": ModelInfo( + model_name="moonshotai/Kimi-K2-Instruct", + head_size=128, + hidden_dim=7168, + intermediate_dim=18432, + num_attention_heads=64, + num_kv_heads=64, + vocab_size=163840, + num_layers=61, + ffn_num_projections=3, + param_bytes_per_element=2, + cache_bytes_per_element=2, + embedding_bytes_per_element=2, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + ), } # Supported model list diff --git a/src/parallax_utils/utils.py b/src/parallax_utils/utils.py index a314df16..ea712654 100644 --- a/src/parallax_utils/utils.py +++ b/src/parallax_utils/utils.py @@ -104,6 +104,8 @@ def compute_max_batch_size( dtype=None, elem_bytes: Optional[int] = None, memory_gb: Optional[float] = None, + head_dim_k: Optional[int] = None, + head_dim_v: Optional[int] = None, ) -> int: """Compute final max_batch_size by chaining dtype->elem_bytes, KV capacity, and clamping. @@ -119,8 +121,8 @@ def compute_max_batch_size( kv_cache_memory_fraction=kv_cache_memory_fraction, num_shard_layers=num_shard_layers, num_key_value_heads=num_key_value_heads, - head_dim_k=head_dim, - head_dim_v=head_dim, + head_dim_k=head_dim_k if head_dim_k is not None else head_dim, + head_dim_v=head_dim_v if head_dim_v is not None else head_dim, elem_bytes=eb, available_cache_bytes=available_cache_bytes, ) diff --git a/src/scheduling/model_info.py b/src/scheduling/model_info.py index a6c14680..78f79319 100644 --- a/src/scheduling/model_info.py +++ b/src/scheduling/model_info.py @@ -35,10 +35,23 @@ class ModelInfo: cache_bytes_per_element: int = 1 embedding_bytes_per_element: int = 1 + qk_nope_head_dim: Optional[int] = None + qk_rope_head_dim: Optional[int] = None + if qk_nope_head_dim is not None and qk_rope_head_dim is not None: + head_size_k: int = qk_nope_head_dim + qk_rope_head_dim + else: + head_size_k: int = head_size + head_size_v: int = head_size + @property - def kv_dim(self) -> int: + def v_dim(self) -> int: """Return key and value head dim.""" - return self.num_kv_heads * self.head_size + return self.num_kv_heads * self.head_size_v + + @property + def k_dim(self) -> int: + """Return key head dim.""" + return self.num_kv_heads * self.head_size_k @property def embedding_io_bytes(self) -> int: @@ -48,7 +61,7 @@ def embedding_io_bytes(self) -> int: @property def per_token_per_layer_kv_size(self) -> int: """Return bytes per token for KV cache.""" - return 2 * self.cache_bytes_per_element * self.kv_dim + return self.cache_bytes_per_element * (self.k_dim + self.v_dim) def per_layer_kv_cache_size(self, *, batch_size: int = 1, source_seq_len: int = 256) -> int: """Return size of KV cache in bytes for given request dimensions.""" @@ -81,7 +94,7 @@ def decoder_layer_flops( # Q/O projections: (T, hidden_dim) @ (hidden_dim, hidden_dim) qo_flops = 2 * 2 * target_seq_len * self.hidden_dim * self.hidden_dim # K/V projections: (T, hidden_dim) @ (hidden_dim, kv_dim) - kv_flops = 2 * 2 * target_seq_len * self.hidden_dim * self.kv_dim + kv_flops = 2 * target_seq_len * self.hidden_dim * (self.k_dim + self.v_dim) projection_flops = qo_flops + kv_flops # 'roof' estimation for GQA @@ -123,8 +136,8 @@ def decoder_layer_io_bytes( source_seq_len: Source sequence length (prompt tokens) """ # Attention params - qo_params = self.param_bytes_per_element * self.hidden_dim * self.hidden_dim - kv_params = self.param_bytes_per_element * self.hidden_dim * self.kv_dim + qo_params = self.param_bytes_per_element * self.hidden_dim * self.hidden_dim * 2 + kv_params = self.param_bytes_per_element * self.hidden_dim * (self.k_dim + self.v_dim) attention_params = qo_params + kv_params # FFN params diff --git a/src/scheduling/node.py b/src/scheduling/node.py index 13a7a9f3..6e40b7eb 100644 --- a/src/scheduling/node.py +++ b/src/scheduling/node.py @@ -222,6 +222,8 @@ def max_requests(self) -> int: head_dim=self.model_info.head_size, elem_bytes=elem_bytes, memory_gb=self.hardware.memory_gb, + head_dim_k=self.model_info.head_size_k, + head_dim_v=self.model_info.head_size_v, ) if derived_max <= 0: raise ValueError( From bc6f03bd106026528107738e3fdd80d92e613403 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Sun, 28 Sep 2025 09:10:41 +0800 Subject: [PATCH 09/13] update --- src/backend/server/static_config.py | 2 +- src/scheduling/model_info.py | 18 +++++++++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/backend/server/static_config.py b/src/backend/server/static_config.py index 272ca9fd..e874214e 100644 --- a/src/backend/server/static_config.py +++ b/src/backend/server/static_config.py @@ -1,8 +1,8 @@ import json from huggingface_hub import hf_hub_download -from scheduling.model_info import ModelInfo +from scheduling.model_info import ModelInfo # Supported model list MODEL_LIST = [ diff --git a/src/scheduling/model_info.py b/src/scheduling/model_info.py index 78f79319..b96cdd80 100644 --- a/src/scheduling/model_info.py +++ b/src/scheduling/model_info.py @@ -37,11 +37,19 @@ class ModelInfo: qk_nope_head_dim: Optional[int] = None qk_rope_head_dim: Optional[int] = None - if qk_nope_head_dim is not None and qk_rope_head_dim is not None: - head_size_k: int = qk_nope_head_dim + qk_rope_head_dim - else: - head_size_k: int = head_size - head_size_v: int = head_size + head_size_k: int = None # 将在 __init__ 中设置 + head_size_v: int = None # 将在 __init__ 中设置 + + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + # 设置 head_size_k 和 head_size_v + if self.qk_nope_head_dim is not None and self.qk_rope_head_dim is not None: + self.head_size_k = self.qk_nope_head_dim + self.qk_rope_head_dim + else: + self.head_size_k = self.head_size + self.head_size_v = self.head_size @property def v_dim(self) -> int: From c172f5a1f9f85f95d9283d8bcd7939ab87b6fb9c Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Sun, 28 Sep 2025 17:37:15 +0800 Subject: [PATCH 10/13] add "trust remote code = True" --- src/parallax/server/http_server.py | 4 ++-- src/parallax/server/shard_loader.py | 4 +++- src/parallax/sglang/model_runner.py | 3 ++- src/parallax/utils/tokenizer_utils.py | 26 ++++++++++++++++++++++++++ tests/test_executor.py | 2 +- tests/test_model.py | 2 +- 6 files changed, 35 insertions(+), 6 deletions(-) diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index 39363730..bfcf15c5 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -29,12 +29,12 @@ import zmq import zmq.asyncio from fastapi.responses import ORJSONResponse, StreamingResponse -from mlx_lm.tokenizer_utils import StreamingDetokenizer, load_tokenizer +from mlx_lm.tokenizer_utils import StreamingDetokenizer from mlx_lm.utils import get_model_path, load_config from pydantic import BaseModel from starlette.datastructures import State -from parallax.utils.tokenizer_utils import load_detokenizer +from parallax.utils.tokenizer_utils import load_detokenizer, load_tokenizer from parallax.utils.utils import get_zmq_socket from parallax_utils.logging_config import get_logger diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index 7b4f12b6..f15250d6 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -10,10 +10,10 @@ import mlx.core as mx import safetensors from mlx import nn -from mlx_lm.tokenizer_utils import load_tokenizer from mlx_lm.utils import get_model_path, load_config from parallax.server.model import ShardedModel +from parallax.utils.tokenizer_utils import load_tokenizer from parallax_utils.logging_config import get_logger logger = get_logger(__name__) @@ -115,6 +115,8 @@ def load( # We need the model object to know its structure and which layers it owns. # This part mirrors the logic from the provided utils.py to get model_args. model_type = config.get("model_type") + if model_type == "kimi_k2": + model_type = "deepseek_v3" if not model_type: raise ValueError("model_type not found in config.json") try: diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index b043bdef..63df8ab0 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -13,7 +13,6 @@ import sglang import sglang.srt.distributed.parallel_state import torch -from mlx_lm.tokenizer_utils import load_tokenizer from mlx_lm.utils import get_model_path, load_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import ( @@ -44,6 +43,8 @@ ) from torch.distributed import Backend +from parallax.utils.tokenizer_utils import load_tokenizer + # from parallax.sglang.monkey_patch.model_runner import ModelRunner as SGLModelRunner logger = logging.getLogger(__name__) diff --git a/src/parallax/utils/tokenizer_utils.py b/src/parallax/utils/tokenizer_utils.py index 463ef59d..faefef22 100755 --- a/src/parallax/utils/tokenizer_utils.py +++ b/src/parallax/utils/tokenizer_utils.py @@ -14,6 +14,7 @@ _is_spm_decoder, _is_spm_decoder_no_space, ) +from mlx_lm.tokenizer_utils import load_tokenizer as _mlx_load_tokenizer class ParallaxNaiveStreamingDetokenizer(NaiveStreamingDetokenizer): @@ -97,3 +98,28 @@ def load_detokenizer(model_path, tokenizer): tokenmap = _get_bpe_tokenmap(tokenizer) return detokenizer_class, tokenmap + + +def load_tokenizer(model_path, trust_remote_code=True, tokenizer_config_extra=None, **kwargs): + """ + Wrapper function for MLX load_tokenizer that defaults trust_remote_code to True. + This is needed for models like Kimi-K2 that contain custom code. + + Args: + model_path: Path to the model + trust_remote_code: Whether to trust remote code (defaults to True) + tokenizer_config_extra: Extra config to pass to AutoTokenizer.from_pretrained + **kwargs: Additional arguments to pass to the original load_tokenizer + + Returns: + The loaded tokenizer + """ + if tokenizer_config_extra is None: + tokenizer_config_extra = {} + + # Add trust_remote_code to the tokenizer config + if trust_remote_code: + tokenizer_config_extra = tokenizer_config_extra.copy() + tokenizer_config_extra["trust_remote_code"] = True + + return _mlx_load_tokenizer(model_path, tokenizer_config_extra=tokenizer_config_extra, **kwargs) diff --git a/tests/test_executor.py b/tests/test_executor.py index fe66c18d..1dd16093 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -4,11 +4,11 @@ import pytest from mlx_lm.generate import generate -from mlx_lm.tokenizer_utils import load_tokenizer from mlx_lm.utils import get_model_path, load_model from parallax.server.executor import Executor from parallax.server.request import InitialRequest +from parallax.utils.tokenizer_utils import load_tokenizer MODEL_REPO = "mlx-community/Qwen3-0.6B-bf16" diff --git a/tests/test_model.py b/tests/test_model.py index c839e91d..a62dd2db 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -7,11 +7,11 @@ import mlx.core as mx import pytest from mlx_lm.models.base import create_attention_mask -from mlx_lm.tokenizer_utils import load_tokenizer from mlx_lm.utils import get_model_path, load_model from parallax.server.server_info import ShardedModelInfo from parallax.server.shard_loader import MLXModelLoader +from parallax.utils.tokenizer_utils import load_tokenizer from parallax.utils.utils import pad_inputs REPO_ID = "mlx-community/Qwen3-0.6B-bf16" From edeb96963f3796f4d8abe54091d086df926d4451 Mon Sep 17 00:00:00 2001 From: gufengc Date: Tue, 30 Sep 2025 12:13:08 +0800 Subject: [PATCH 11/13] update --- src/backend/server/scheduler_manage.py | 1 + src/backend/server/static_config.py | 4 ++++ src/scheduling/model_info.py | 9 ++++----- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/backend/server/scheduler_manage.py b/src/backend/server/scheduler_manage.py index da3f3d60..76820fda 100644 --- a/src/backend/server/scheduler_manage.py +++ b/src/backend/server/scheduler_manage.py @@ -111,6 +111,7 @@ def _start_scheduler(self, model_name, init_nodes_num): self.init_nodes_num = init_nodes_num model_info = get_model_info(model_name) + logger.info(f"Model info: {model_info}") self.scheduler = Scheduler(model_info, [], min_nodes_bootstrapping=init_nodes_num) # Run the scheduler's event/dispatch loops in background so the process diff --git a/src/backend/server/static_config.py b/src/backend/server/static_config.py index 85427cab..dae8cec4 100644 --- a/src/backend/server/static_config.py +++ b/src/backend/server/static_config.py @@ -27,6 +27,8 @@ "nvidia/Llama-3.3-70B-Instruct-FP8", "nvidia/Llama-3.1-70B-Instruct-FP8", "nvidia/Llama-3.1-8B-Instruct-FP8", + "moonshotai/Kimi-K2-Instruct", + "moonshotai/Kimi-K2-Instruct-0905", ] NODE_JOIN_COMMAND_LOCAL_NETWORK = """parallax join""" @@ -56,6 +58,8 @@ def get_model_info(model_name): model_info = ModelInfo( model_name=model_name, head_size=config.get("head_dim", 128), + qk_nope_head_dim=config.get("qk_nope_head_dim", None), + qk_rope_head_dim=config.get("qk_rope_head_dim", None), hidden_dim=config.get("hidden_size", 0), intermediate_dim=config.get("intermediate_size", 0), num_attention_heads=config.get("num_attention_heads", 0), diff --git a/src/scheduling/model_info.py b/src/scheduling/model_info.py index b96cdd80..bd49e060 100644 --- a/src/scheduling/model_info.py +++ b/src/scheduling/model_info.py @@ -37,14 +37,13 @@ class ModelInfo: qk_nope_head_dim: Optional[int] = None qk_rope_head_dim: Optional[int] = None - head_size_k: int = None # 将在 __init__ 中设置 - head_size_v: int = None # 将在 __init__ 中设置 + head_size_k: int = None + head_size_v: int = None def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) - # 设置 head_size_k 和 head_size_v if self.qk_nope_head_dim is not None and self.qk_rope_head_dim is not None: self.head_size_k = self.qk_nope_head_dim + self.qk_rope_head_dim else: @@ -144,8 +143,8 @@ def decoder_layer_io_bytes( source_seq_len: Source sequence length (prompt tokens) """ # Attention params - qo_params = self.param_bytes_per_element * self.hidden_dim * self.hidden_dim * 2 - kv_params = self.param_bytes_per_element * self.hidden_dim * (self.k_dim + self.v_dim) + qo_params = self.param_bytes_per_element * self.hidden_dim * self.hidden_dim + kv_params = self.param_bytes_per_element * self.hidden_dim * (self.k_dim + self.v_dim) // 2 attention_params = qo_params + kv_params # FFN params From 5d1cd49c5f526e67ab6990db5cb9e3232b5d189f Mon Sep 17 00:00:00 2001 From: gufengc Date: Tue, 30 Sep 2025 12:20:51 +0800 Subject: [PATCH 12/13] update --- src/backend/server/static_config.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/backend/server/static_config.py b/src/backend/server/static_config.py index dae8cec4..004cf3b9 100644 --- a/src/backend/server/static_config.py +++ b/src/backend/server/static_config.py @@ -7,6 +7,12 @@ # Supported model list MODEL_LIST = [ "Qwen/Qwen3-0.6B", + "openai/gpt-oss-20b", + "openai/gpt-oss-120b", + "moonshotai/Kimi-K2-Instruct", + "moonshotai/Kimi-K2-Instruct-0905", + "Qwen/Qwen3-Next-80B-A3B-Instruct", + "Qwen/Qwen3-Next-80B-A3B-Thinking", # "Qwen/Qwen3-8B", # "Qwen/Qwen3-8B-FP8", "Qwen/Qwen3-32B", @@ -16,19 +22,13 @@ # "Qwen/Qwen3-30B-A3B-Thinking-2507-FP8", "Qwen/Qwen3-235B-A22B-Instruct-2507-FP8", "Qwen/Qwen3-235B-A22B-Thinking-2507-FP8", - "Qwen/Qwen3-Next-80B-A3B-Instruct", - "Qwen/Qwen3-Next-80B-A3B-Thinking", # "Qwen/Qwen2.5-3B-Instruct", # "Qwen/Qwen2.5-7B-Instruct", # "Qwen/Qwen2.5-14B-Instruct", "Qwen/Qwen2.5-72B-Instruct", - "openai/gpt-oss-20b", - "openai/gpt-oss-120b", "nvidia/Llama-3.3-70B-Instruct-FP8", "nvidia/Llama-3.1-70B-Instruct-FP8", "nvidia/Llama-3.1-8B-Instruct-FP8", - "moonshotai/Kimi-K2-Instruct", - "moonshotai/Kimi-K2-Instruct-0905", ] NODE_JOIN_COMMAND_LOCAL_NETWORK = """parallax join""" From ec62e7df290d7675abe8c7acfcd2c4cbb47b2a16 Mon Sep 17 00:00:00 2001 From: gufengc Date: Tue, 30 Sep 2025 12:21:20 +0800 Subject: [PATCH 13/13] update --- src/backend/server/scheduler_manage.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/backend/server/scheduler_manage.py b/src/backend/server/scheduler_manage.py index 76820fda..da3f3d60 100644 --- a/src/backend/server/scheduler_manage.py +++ b/src/backend/server/scheduler_manage.py @@ -111,7 +111,6 @@ def _start_scheduler(self, model_name, init_nodes_num): self.init_nodes_num = init_nodes_num model_info = get_model_info(model_name) - logger.info(f"Model info: {model_info}") self.scheduler = Scheduler(model_info, [], min_nodes_bootstrapping=init_nodes_num) # Run the scheduler's event/dispatch loops in background so the process