In [14]:
import tensorflow as tf
from tensorflow import keras
import numpy as np

# Set random seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

# Configuration
batch_size = 2
sequence_length = 4
d_model = 64        # Input/output dimension
num_heads = 4       # Number of attention heads
d_head = d_model // num_heads  # Dimension per head
d_latent = 16      # Compressed latent dimension (much smaller than d_model)
d_rotary = 8       # Dimension for rotary component

class MultiHeadLatentAttention(keras.layers.Layer):
    def __init__(self, d_model, num_heads, d_latent, d_rotary):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.d_latent = d_latent
        self.d_rotary = d_rotary
        
        # Compression projections for KV
        self.kv_down = keras.layers.Dense(d_latent)  # Down projection to latent space
        self.k_up = keras.layers.Dense(d_model)      # Up projection for keys
        self.v_up = keras.layers.Dense(d_model)      # Up projection for values
        
        # Rotary component for keys
        self.k_rotary = keras.layers.Dense(num_heads * d_rotary)
        
        # Query projections
        self.q_down = keras.layers.Dense(d_latent)   # Down projection for queries
        self.q_up = keras.layers.Dense(d_model)      # Up projection for queries
        self.q_rotary = keras.layers.Dense(num_heads * d_rotary)  # Rotary for queries
        
        # Output projection
        self.output_linear = keras.layers.Dense(d_model)

    def split_heads(self, x, rotary=False):
        batch_size = tf.shape(x)[0]
        if rotary:
            dim = self.d_rotary
        else:
            dim = self.d_head
        x = tf.reshape(x, (batch_size, -1, self.num_heads, dim))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def apply_rotary_embedding(self, x, seq_length):
        # Simplified rotary embedding for demonstration
        position = tf.range(seq_length, dtype=tf.float32)
        position = tf.expand_dims(position, axis=1)
        return x + position * 0.02  # Simple positional modification

    def call(self, x):
        batch_size = tf.shape(x)[0]
        seq_length = tf.shape(x)[1]
        
        print("\nInput shape:", x.shape)
        
        # 1. Compress KV into latent space
        kv_latent = self.kv_down(x)
        print("\nKV latent shape (compressed):", kv_latent.shape)
        
        # 2. Generate keys and values from latent
        k_content = self.k_up(kv_latent)
        v_content = self.v_up(kv_latent)
        k_rot = self.k_rotary(x)
        print("\nKey content shape:", k_content.shape)
        print("Key rotary shape:", k_rot.shape)
        
        # 3. Process queries
        q_latent = self.q_down(x)
        q_content = self.q_up(q_latent)
        q_rot = self.q_rotary(x)
        
        # 4. Split heads for content and rotary components
        q_content = self.split_heads(q_content)
        k_content = self.split_heads(k_content)
        v_content = self.split_heads(v_content)
        q_rot = self.split_heads(q_rot, rotary=True)
        k_rot = self.split_heads(k_rot, rotary=True)
        
        # 5. Apply rotary embeddings
        q_rot = self.apply_rotary_embedding(q_rot, seq_length)
        k_rot = self.apply_rotary_embedding(k_rot, seq_length)
        
        # 6. Concatenate content and rotary components
        q = tf.concat([q_content, q_rot], axis=-1)
        k = tf.concat([k_content, k_rot], axis=-1)
        
        print("\nFinal Q shape (content + rotary):", q.shape)
        print("Final K shape (content + rotary):", k.shape)
        print("V shape:", v_content.shape)
        
        # 7. Compute attention scores
        scale = tf.math.sqrt(tf.cast(self.d_head + self.d_rotary, tf.float32))
        attention_scores = tf.matmul(q, k, transpose_b=True) / scale
        attention_weights = tf.nn.softmax(attention_scores, axis=-1)
        
        # 8. Apply attention to values
        output = tf.matmul(attention_weights, v_content)
        
        # 9. Combine heads and final projection
        output = tf.transpose(output, perm=[0, 2, 1, 3])
        output = tf.reshape(output, (batch_size, -1, self.d_model))
        output = self.output_linear(output)
        
        print("\nFinal output shape:", output.shape)
        
        return output, attention_weights

# Create and test
dummy_input = tf.random.uniform((batch_size, sequence_length, d_model))
mla = MultiHeadLatentAttention(d_model, num_heads, d_latent, d_rotary)
output, attention_weights = mla(dummy_input)


Input shape: (2, 4, 64)

KV latent shape (compressed): (2, 4, 16)

Key content shape: (2, 4, 64)
Key rotary shape: (2, 4, 32)

Final Q shape (content + rotary): (2, 4, 4, 24)
Final K shape (content + rotary): (2, 4, 4, 24)
V shape: (2, 4, 4, 16)

Final output shape: (2, 4, 64)

Input shape: (2, 4, 64)

KV latent shape (compressed): (2, 4, 16)

Key content shape: (2, 4, 64)
Key rotary shape: (2, 4, 32)

Final Q shape (content + rotary): (2, 4, 4, 24)
Final K shape (content + rotary): (2, 4, 4, 24)
V shape: (2, 4, 4, 16)

Final output shape: (2, 4, 64)


Input (2, 4, 64)
   ↓
Compression to latent space (2, 4, 16)  # Much smaller!
   ↓
Up-projection and head splitting
   ↓
Attention computation
   ↓
Combine heads
   ↓
Final output (2, 4, 64)  # Back to original shape

In [15]:
import tensorflow as tf
import numpy as np

class MoEGating(tf.keras.layers.Layer):
    def __init__(self, num_experts=256, num_selected_experts=8, d_model=64):
        super().__init__()
        self.num_experts = num_experts
        self.num_selected_experts = num_selected_experts
        self.d_model = d_model
        
        # Initialize expert centroids (learnable)
        self.expert_centroids = tf.Variable(
            initial_value=tf.random.normal([num_experts, d_model], stddev=0.02),
            trainable=True,
            name="expert_centroids"
        )
        
        # Initialize balance biases
        self.balance_biases = tf.Variable(
            initial_value=tf.zeros([num_experts]),
            trainable=True,
            name="balance_biases"
        )

    def compute_affinity(self, inputs):
        # inputs shape: [batch_size, seq_len, d_model]
        # expert_centroids shape: [num_experts, d_model]
        
        # Reshape inputs to [batch_size * seq_len, d_model]
        batch_size, seq_len, _ = tf.shape(inputs)
        flat_inputs = tf.reshape(inputs, [-1, self.d_model])
        
        # Compute affinity scores: [batch_size * seq_len, num_experts]
        logits = tf.matmul(flat_inputs, self.expert_centroids, transpose_b=True)
        
        # Add balance biases
        logits = logits + self.balance_biases
        
        # Apply sigmoid
        affinities = tf.sigmoid(logits)
        
        # Reshape back to [batch_size, seq_len, num_experts]
        affinities = tf.reshape(affinities, [batch_size, seq_len, self.num_experts])
        return affinities

    def top_k_gating(self, affinities):
        # Get top-k values and indices
        top_k_values, top_k_indices = tf.math.top_k(
            affinities, k=self.num_selected_experts
        )
        
        # Create a mask for selected experts
        mask = tf.one_hot(top_k_indices, depth=self.num_experts)
        
        # Combine all experts selected for a token
        mask = tf.reduce_sum(mask, axis=-2)
        
        # Mask out non-selected expert affinities
        gating = affinities * tf.cast(mask > 0, tf.float32)
        
        # Normalize gating weights
        normalizer = tf.reduce_sum(gating, axis=-1, keepdims=True)
        gating = gating / (normalizer + 1e-9)
        
        return gating, top_k_indices

    def call(self, inputs):
        # Compute expert affinities
        affinities = self.compute_affinity(inputs)
        print("\nAffinity scores shape:", affinities.shape)
        
        # Compute gating weights and expert selection
        gating, selected_experts = self.top_k_gating(affinities)
        print("\nGating weights shape:", gating.shape)
        print("Selected experts shape:", selected_experts.shape)
        
        return gating, selected_experts

# Test the implementation

num_experts = 256
num_selected_experts = 8

# Create dummy input (as if coming from MHA)
dummy_input = output
print("\nInput shape:", dummy_input.shape)

# Create and apply gating
gating_layer = MoEGating(num_experts, num_selected_experts, d_model)
gating_weights, selected_experts = gating_layer(dummy_input)

# Print example output for first sequence in batch
print("\nExample for first sequence in batch:")
print("First token's top experts:", selected_experts[0, 0])
print("Corresponding gating weights:", 
      [float(gating_weights[0, 0, idx]) for idx in selected_experts[0, 0]])

# Verify sum of gating weights is approximately 1
print("\nSum of gating weights for first token:", 
      float(tf.reduce_sum(gating_weights[0, 0])))


Input shape: (2, 4, 64)

Affinity scores shape: (2, 4, 256)

Gating weights shape: (2, 4, 256)
Selected experts shape: (2, 4, 8)

Example for first sequence in batch:
First token's top experts: tf.Tensor([225 255  79 178 159 118 252  76], shape=(8,), dtype=int32)
Corresponding gating weights: [0.12765584886074066, 0.12545745074748993, 0.12526875734329224, 0.12515246868133545, 0.12500016391277313, 0.12431211024522781, 0.123655766248703, 0.12349744141101837]

Sum of gating weights for first token: 1.0


In [16]:
class MoELayer(tf.keras.layers.Layer):
    def __init__(self, num_experts=256, d_model=64, d_ff=256):
        super().__init__()
        self.num_experts = num_experts
        self.d_model = d_model
        self.d_ff = d_ff
        
        # Shared expert (always used)
        self.shared_expert = self._create_ffn()
        
        # Create routed experts
        self.routed_experts = [self._create_ffn() for _ in range(num_experts)]
    
    def _create_ffn(self):
        return tf.keras.Sequential([
            tf.keras.layers.Dense(self.d_ff, activation='gelu'),
            tf.keras.layers.Dense(self.d_model)
        ])
    
    def call(self, inputs, gating_weights, selected_experts):
        batch_size, seq_len, _ = tf.shape(inputs)
        
        # 1. Apply shared expert to all tokens
        shared_output = self.shared_expert(inputs)
        
        # 2. Initialize expert outputs
        expert_outputs = tf.zeros_like(inputs)
        
        # Process each expert
        for expert_idx in range(self.num_experts):
            # Find where this expert is selected
            expert_mask = tf.reduce_any(tf.equal(selected_experts, expert_idx), axis=-1)
            expert_mask = tf.cast(expert_mask, tf.float32)
            
            # Get corresponding gates
            expert_gates = gating_weights[..., expert_idx]
            
            # Process tokens through this expert where it's selected
            expert_output = self.routed_experts[expert_idx](inputs)
            
            # Add weighted output to total
            expert_outputs += expert_output * tf.expand_dims(expert_gates * expert_mask, -1)
        
        # 3. Combine shared and routed outputs with residual connection
        final_output = inputs + shared_output + expert_outputs
        print("\nFinal MoE output shape:", final_output.shape)
        
        return final_output

# Create and test MoE layer
moe_layer = MoELayer(num_experts=num_experts, d_model=d_model)
moe_output = moe_layer(dummy_input, gating_weights, selected_experts)

print("\nMoE layer test:")
print("Input shape:", dummy_input.shape)
print("Output shape:", moe_output.shape)


Final MoE output shape: (2, 4, 64)

MoE layer test:
Input shape: (2, 4, 64)
Output shape: (2, 4, 64)


In [17]:
class RMSNorm(tf.keras.layers.Layer):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps
        
    def build(self, input_shape):
        self.weight = self.add_weight(
            shape=(input_shape[-1],),
            initializer='ones',
            trainable=True,
            name='weight'
        )
        
    def call(self, x):
        # Calculate RMS
        mean_square = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
        x_norm = x * tf.math.rsqrt(mean_square + self.eps)
        return self.weight * x_norm

class OutputHead(tf.keras.layers.Layer):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.vocab_size = vocab_size
        self.dense = tf.keras.layers.Dense(vocab_size)
        
    def call(self, x):
        # Apply final projection to vocabulary size
        logits = self.dense(x)
        # Apply softmax to get probabilities
        probs = tf.nn.softmax(logits, axis=-1)
        return probs

# Complete the pipeline with previous components
# Assuming we have moe_output from previous step
vocab_size = 32000  # Example vocabulary size

# 1. Apply RMSNorm
rms_norm = RMSNorm()
normalized_output = rms_norm(moe_output)
print("\nAfter RMSNorm shape:", normalized_output.shape)

# 2. Apply output head (only at the end of all transformer blocks)
output_head = OutputHead(vocab_size, d_model)
final_output = output_head(normalized_output)
print("\nFinal output shape:", final_output.shape)

# Print sample probabilities for first token
print("\nSample token probabilities (first token):")
print("Sum of probabilities:", float(tf.reduce_sum(final_output[0, 0])))  # Should be close to 1
print("Top 5 token probabilities:", sorted(final_output[0, 0].numpy())[-5:])


After RMSNorm shape: (2, 4, 64)

Final output shape: (2, 4, 32000)

Sample token probabilities (first token):
Sum of probabilities: 1.0
Top 5 token probabilities: [3.9051996e-05, 3.9068873e-05, 3.9444534e-05, 4.0344134e-05, 4.0591844e-05]


In [19]:
class MoEGating(tf.keras.layers.Layer):
    def __init__(self, num_experts=256, num_selected_experts=8, d_model=64):
        super().__init__()
        self.num_experts = num_experts
        self.num_selected_experts = num_selected_experts
        self.d_model = d_model
        
        # Initialize expert centroids (learnable)
        self.expert_centroids = self.add_weight(
            shape=(num_experts, d_model),
            initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
            trainable=True,
            name="expert_centroids"
        )
        
        # Initialize balance biases
        self.balance_biases = self.add_weight(
            shape=(num_experts,),
            initializer='zeros',
            trainable=True,
            name="balance_biases"
        )

    def call(self, inputs):
        # inputs shape: [batch_size, seq_len, d_model]
        
        # Compute logits
        logits = tf.einsum('bsd,ed->bse', inputs, self.expert_centroids)
        logits = logits + self.balance_biases
        
        # Apply sigmoid
        affinities = tf.sigmoid(logits)
        
        # Get top-k values and indices
        top_k_values, selected_experts = tf.math.top_k(
            affinities, k=self.num_selected_experts
        )
        
        # Create gating weights
        mask = tf.one_hot(selected_experts, depth=self.num_experts)
        mask = tf.reduce_sum(mask, axis=-2)
        gating = affinities * tf.cast(mask > 0, tf.float32)
        gating = gating / (tf.reduce_sum(gating, axis=-1, keepdims=True) + 1e-9)
        
        return gating, selected_experts

class MoELayer(tf.keras.layers.Layer):
    def __init__(self, num_experts=256, d_model=64, d_ff=256):
        super().__init__()
        self.num_experts = num_experts
        self.d_model = d_model
        self.d_ff = d_ff
        
        # Shared expert
        self.shared_expert = tf.keras.Sequential([
            tf.keras.layers.Dense(d_ff, activation='gelu'),
            tf.keras.layers.Dense(d_model)
        ])
        
        # Create routed experts
        self.routed_experts = [
            tf.keras.Sequential([
                tf.keras.layers.Dense(d_ff, activation='gelu'),
                tf.keras.layers.Dense(d_model)
            ]) for _ in range(num_experts)
        ]
    
    def call(self, inputs, gating_weights, selected_experts):
        # Apply shared expert
        shared_output = self.shared_expert(inputs)
        expert_outputs = tf.zeros_like(inputs)
        
        # Process through experts
        for i in range(self.num_experts):
            expert_mask = tf.reduce_any(tf.equal(selected_experts, i), axis=-1)
            expert_mask = tf.cast(expert_mask, tf.float32)[..., tf.newaxis]
            expert_gates = gating_weights[..., i:i+1]
            
            expert_output = self.routed_experts[i](inputs)
            expert_outputs += expert_output * expert_gates * expert_mask
        
        return inputs + shared_output + expert_outputs

class TransformerBlock(tf.keras.layers.Layer):
    def __init__(self, num_experts=256, num_selected_experts=8, d_model=64, d_ff=256):
        super().__init__()
        self.gating = MoEGating(num_experts, num_selected_experts, d_model)
        self.moe = MoELayer(num_experts, d_model, d_ff)
        self.norm = tf.keras.layers.LayerNormalization()
        self.attention = tf.keras.layers.MultiHeadAttention(
            num_heads=4, key_dim=d_model // 4
        )
        
    def call(self, inputs):
        # Attention
        attn_output = self.attention(inputs, inputs, inputs)
        x = self.norm(inputs + attn_output)
        
        # Gating
        gating_weights, selected_experts = self.gating(x)
        
        # MoE
        x = self.moe(x, gating_weights, selected_experts)
        
        # Norm
        return self.norm(x)

class DeepSeekModel(tf.keras.Model):
    def __init__(self, num_blocks=2, vocab_size=32000, d_model=64):
        super().__init__()
        self.embedding = tf.keras.layers.Embedding(vocab_size, d_model)
        self.blocks = [TransformerBlock(d_model=d_model) for _ in range(num_blocks)]
        self.final_layer = tf.keras.layers.Dense(vocab_size, activation='softmax')
        
    def call(self, inputs):
        x = self.embedding(inputs)
        
        for block in self.blocks:
            x = block(x)
            
        return self.final_layer(x)

# Create model
inputs = tf.keras.Input(shape=(64,))  # sequence length of 4
model = DeepSeekModel(num_blocks=2, vocab_size=32000, d_model=64)
outputs = model(inputs)

# Create the Keras model
model = tf.keras.Model(inputs=inputs, outputs=outputs)

# Print model summary
model.summary()