Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mlx_lm.models.deepseek_v3 import ModelArgs


class ParallaxKimiK2Attention(MLXDeepseekV3Attention):
class ParallaxDeepSeekV3Attention(MLXDeepseekV3Attention):
"""A custom attention module for Parallax, extending the DeepseekV3 Attention class.

We apply explicit KV cache handling and passing in `offset` directly from Request.
Expand Down Expand Up @@ -95,14 +95,14 @@ def __call__(
return self.o_proj(output), (mx.concatenate([k_nope, k_pe], axis=-1), values)


class ParallaxKimiK2Block(MLXDeepseekV3Block):
class ParallaxDeepSeekV3Block(MLXDeepseekV3Block):
"""A custom transformer block for Parallax, extending the Qwen3 Block class.
This version handles the KV cache explicitly and returns new K and V states.
"""

def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__(args, layer_idx=layer_idx)
self.self_attn = ParallaxKimiK2Attention(args)
self.self_attn = ParallaxDeepSeekV3Attention(args)

def __call__(
self,
Expand All @@ -124,4 +124,4 @@ def get_architecture(cls):
return "DeepseekV3ForCausalLM"


EntryClass = ParallaxKimiK2Block
EntryClass = ParallaxDeepSeekV3Block
1 change: 1 addition & 0 deletions src/parallax/server/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def __init__(
self.finished_batch = []
self.start_layer = start_layer
self.end_layer = end_layer

self.is_first_peer = start_layer == 0
self.is_last_peer = end_layer == self.config.get("num_hidden_layers")
self.num_shard_layers = end_layer - start_layer
Expand Down