diff --git a/src/parallax/models/kimi_k2.py b/src/parallax/models/deepseek_v3.py similarity index 95% rename from src/parallax/models/kimi_k2.py rename to src/parallax/models/deepseek_v3.py index d3b3804..c48cfec 100644 --- a/src/parallax/models/kimi_k2.py +++ b/src/parallax/models/deepseek_v3.py @@ -11,7 +11,7 @@ from mlx_lm.models.deepseek_v3 import ModelArgs -class ParallaxKimiK2Attention(MLXDeepseekV3Attention): +class ParallaxDeepSeekV3Attention(MLXDeepseekV3Attention): """A custom attention module for Parallax, extending the DeepseekV3 Attention class. We apply explicit KV cache handling and passing in `offset` directly from Request. @@ -95,14 +95,14 @@ def __call__( return self.o_proj(output), (mx.concatenate([k_nope, k_pe], axis=-1), values) -class ParallaxKimiK2Block(MLXDeepseekV3Block): +class ParallaxDeepSeekV3Block(MLXDeepseekV3Block): """A custom transformer block for Parallax, extending the Qwen3 Block class. This version handles the KV cache explicitly and returns new K and V states. """ def __init__(self, args: ModelArgs, layer_idx: int): super().__init__(args, layer_idx=layer_idx) - self.self_attn = ParallaxKimiK2Attention(args) + self.self_attn = ParallaxDeepSeekV3Attention(args) def __call__( self, @@ -124,4 +124,4 @@ def get_architecture(cls): return "DeepseekV3ForCausalLM" -EntryClass = ParallaxKimiK2Block +EntryClass = ParallaxDeepSeekV3Block diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index dd0b819..86bd3b2 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -143,6 +143,7 @@ def __init__( self.finished_batch = [] self.start_layer = start_layer self.end_layer = end_layer + self.is_first_peer = start_layer == 0 self.is_last_peer = end_layer == self.config.get("num_hidden_layers") self.num_shard_layers = end_layer - start_layer