In [1]:
from flax import nnx
import jax
import jax.numpy as jnp
from typing import Optional, Tuple
import math

In [2]:
# Dynamic Tanh implementation
class DyT(nnx.Module):
    def __init__(self, num_features :jnp.int8, alpha_init_value : jnp.float32 =0.5):
        self.alpha = nnx.Param(
            jnp.ones((1,)) * alpha_init_value,
            name="alpha",
        )
        self.weight = nnx.Param(
            jnp.ones((num_features,)),
            name="weight",
        )
        self.bias = nnx.Param(
            jnp.zeros((num_features,)),
            name="bias",
        )

    def __call__(self, x):
        x = jnp.tanh(self.alpha * x)
        return x * self.weight + self.bias


In [63]:
class BSBRAttention(nnx.Module):
    def __init__(
            self,
            rngs: nnx.Rngs,
            hidden_dim: int,
            num_heads: int,
            chunk_size: int,
            dropout: float = 0.1,
            compression_factor: Optional[int] = None
    ):

        # Fix type annotations from jnp.int8 to int for dimension parameters
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        self.chunk_size = chunk_size
        self.dropout = dropout

        # Standard attention projections
        self.q_proj = nnx.Linear(in_features=hidden_dim, out_features=hidden_dim, rngs=rngs)
        self.k_proj = nnx.Linear(in_features=hidden_dim, out_features=hidden_dim, rngs=rngs)
        self.v_proj = nnx.Linear(in_features=hidden_dim, out_features=hidden_dim, rngs=rngs)
        self.out_proj = nnx.Linear(in_features=hidden_dim, out_features=hidden_dim, rngs=rngs)

        # Meta-query and key projections for chunk-level attention
        self.meta_r_proj = nnx.Linear(in_features=hidden_dim, out_features=hidden_dim, rngs=rngs)
        self.meta_h_proj = nnx.Linear(in_features=hidden_dim, out_features=hidden_dim, rngs=rngs)

        # Optional compression for state vectors
        self.compression_factor = compression_factor
        if compression_factor is not None:
            compressed_dim = hidden_dim // compression_factor
            state_dim = self.head_dim * self.head_dim
            self.compress_proj = nnx.Linear(in_features=state_dim, out_features=compressed_dim, rngs=rngs)
            self.decompress_proj = nnx.Linear(in_features=compressed_dim, out_features=state_dim, rngs=rngs)

        self.dropout_layer = nnx.Dropout(dropout, rngs=rngs)

    def _reshape_for_heads(self, x: jnp.ndarray) -> jnp.ndarray:
        batch_size, seq_len, _ = x.shape
        # Reshape to (batch_size, seq_len, num_heads, head_dim)
        x = x.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        # Transpose to (batch_size, num_heads, seq_len, head_dim)
        return x.transpose(0, 2, 1, 3)

    def _create_masks(self, seq_len: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
        # Calculate number of chunks
        num_chunks = math.ceil(seq_len / self.chunk_size)

        # Create block diagonal mask for within-chunk attention
        m_in = jnp.zeros((seq_len, seq_len))

        for i in range(num_chunks):
            start_idx = i * self.chunk_size
            end_idx = min((i + 1) * self.chunk_size, seq_len)
            chunk_size = end_idx - start_idx
            chunk_mask = jnp.tril(jnp.ones((chunk_size, chunk_size)))
            m_in = m_in.at[start_idx:end_idx, start_idx:end_idx].set(chunk_mask)

        # Create causal mask for between-chunk attention
        m_out = jnp.tril(jnp.ones((num_chunks, num_chunks)))

        return m_in, m_out

    def compute_chunk_states(self, keys: jnp.ndarray, values: jnp.ndarray) -> jnp.ndarray:
        batch_size, num_heads, num_chunks, chunk_size, head_dim = keys.shape

        # Optimize: Avoid reshape by using einsum for batch matrix multiplication
        # [batch, heads, chunks, chunk_size, dim] -> [batch, heads, chunks, dim, dim]
        state = jnp.einsum('bhcsd,bhcse->bhcde', keys, values)

        # Flatten the state matrices to vectors
        state = state.reshape(batch_size, num_heads, num_chunks, -1)

        # Apply compression if specified
        if self.compression_factor is not None:
            state = self.compress_proj(state)

        return state

    def __call__(
            self,
            hidden_states: jnp.ndarray,
            attention_mask: Optional[jnp.ndarray] = None
    ) -> jnp.ndarray:
        batch_size, seq_len, _ = hidden_states.shape
        num_chunks = math.ceil(seq_len / self.chunk_size)

        # Calculate padding (if needed)
        padding = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
        padded_seq_len = seq_len + padding

        # Standard projections for Q, K, V
        q = self._reshape_for_heads(self.q_proj(hidden_states))
        k = self._reshape_for_heads(self.k_proj(hidden_states))
        v = self._reshape_for_heads(self.v_proj(hidden_states))

        # Create masks for within-chunk and between-chunk attention
        m_in, m_out = self._create_masks(seq_len)

        # Apply attention mask if provided
        if attention_mask is not None:
            # Reshape attention_mask: [batch, seq_len] -> [batch, 1, 1, seq_len]
            attention_mask_expanded = attention_mask[:, None, None, :]

            # Reshape m_in: [seq_len, seq_len] -> [1, 1, seq_len, seq_len]
            m_in_expanded = m_in[None, None, :, :]

            # Broadcast and combine masks
            m_in = jnp.broadcast_to(m_in_expanded, (batch_size, self.num_heads, seq_len, seq_len))
            m_in = m_in * attention_mask_expanded

        # Pad sequence length if needed
        if padding > 0:
            q = jnp.pad(q, ((0, 0), (0, 0), (0, padding), (0, 0)))
            k = jnp.pad(k, ((0, 0), (0, 0), (0, padding), (0, 0)))
            v = jnp.pad(v, ((0, 0), (0, 0), (0, padding), (0, 0)))
            padded_hidden = jnp.pad(hidden_states, ((0, 0), (0, padding), (0, 0)))
        else:
            padded_hidden = hidden_states

        # Reshape to chunks [batch, num_heads, num_chunks, chunk_size, head_dim]
        q_chunks = q.reshape(batch_size, self.num_heads, num_chunks, self.chunk_size, self.head_dim)
        k_chunks = k.reshape(batch_size, self.num_heads, num_chunks, self.chunk_size, self.head_dim)
        v_chunks = v.reshape(batch_size, self.num_heads, num_chunks, self.chunk_size, self.head_dim)

        # Reshape hidden states for chunk representation
        chunk_hidden = padded_hidden.reshape(batch_size, num_chunks, self.chunk_size, self.hidden_dim)

        # Get representative token for each chunk (optimized)
        chunk_repr_indices = jnp.array([
                min(self.chunk_size - 1, self.chunk_size - padding - 1 if i == num_chunks - 1 and padding > 0 else self.chunk_size - 1)
                for i in range(num_chunks)
        ])
        chunk_repr = chunk_hidden[:, jnp.arange(num_chunks), chunk_repr_indices]

        # Project to meta queries and keys
        r = self.meta_r_proj(chunk_repr)
        h = self.meta_h_proj(chunk_repr)

        # Reshape to [batch, num_heads, num_chunks, head_dim]
        r = r.reshape(batch_size, num_chunks, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
        h = h.reshape(batch_size, num_chunks, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)

        # Compute chunk states
        f = self.compute_chunk_states(k_chunks, v_chunks)  # [batch, num_heads, num_chunks, state_dim]

        # Between-chunk attention: softmax(R·H^T)·F
        # Calculate R·H^T with scaled dot product
        chunk_attn_scores = jnp.matmul(r, h.transpose(0, 1, 3, 2)) / jnp.sqrt(self.head_dim)

        # Apply causal mask for between-chunk attention
        m_out_expanded = jnp.expand_dims(jnp.expand_dims(m_out, 0), 0)
        m_out_expanded = jnp.broadcast_to(m_out_expanded, (batch_size, self.num_heads, num_chunks, num_chunks))
        chunk_attn_scores = jnp.where(m_out_expanded == 0, -1e9, chunk_attn_scores)

        # Apply softmax and dropout
        chunk_attn_probs = nnx.softmax(chunk_attn_scores, axis=-1)
        chunk_attn_probs = self.dropout_layer(chunk_attn_probs)

        # Calculate retrieved states
        retrieved_states = jnp.matmul(chunk_attn_probs, f)

        # Decompress if needed
        if self.compression_factor is not None:
            retrieved_states = self.decompress_proj(retrieved_states)

        # Reshape retrieved states
        retrieved_states = retrieved_states.reshape(
                batch_size, self.num_heads, num_chunks, self.head_dim, self.head_dim
        )

        # Multiply query with retrieved states using einsum
        long_term_output = jnp.einsum('bhcsd,bhcde->bhcse', q_chunks, retrieved_states)

        # Within-chunk attention
        # Create padded mask
        padded_m_in = jnp.zeros((padded_seq_len, padded_seq_len))
        for i in range(num_chunks):
            start_idx = i * self.chunk_size
            end_idx = (i + 1) * self.chunk_size
            chunk_mask = jnp.tril(jnp.ones((self.chunk_size, self.chunk_size)))
            padded_m_in = padded_m_in.at[start_idx:end_idx, start_idx:end_idx].set(chunk_mask)

        # Calculate local attention scores
        local_attn_scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) / jnp.sqrt(self.head_dim)

        # Apply mask
        expanded_mask = jnp.expand_dims(jnp.expand_dims(padded_m_in, 0), 0)
        expanded_mask = jnp.broadcast_to(expanded_mask, local_attn_scores.shape)
        local_attn_scores = jnp.where(expanded_mask == 0, -1e9, local_attn_scores)

        # Apply softmax and dropout
        local_attn_probs = nnx.softmax(local_attn_scores, axis=-1)
        local_attn_probs = self.dropout_layer(local_attn_probs)

        # Apply attention
        local_output = jnp.matmul(local_attn_probs, v)

        # Reshape long-term output to match local output
        long_term_output = long_term_output.reshape(batch_size, self.num_heads, padded_seq_len, self.head_dim)

        # Combine outputs
        output = long_term_output + local_output

        # Remove padding if added
        if padding > 0:
            output = output[:, :, :seq_len, :]

        # Reshape back to [batch, seq_len, hidden_dim]
        output = output.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, self.hidden_dim)

        # Final projection
        return self.out_proj(output)

In [72]:
class PositionalEncoding(nnx.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        # Create position indices [0, 1, 2, ..., max_len-1]
        position = jnp.arange(0, max_len, dtype=jnp.float32)[:, None]

        # Create division term for frequencies
        div_term = jnp.exp(jnp.arange(0, d_model, 2, dtype=jnp.float32) * (-math.log(10000.0) / d_model))

        # Initialize the positional encoding matrix
        pe = jnp.zeros((max_len, d_model))

        # Compute sine for even indices and cosine for odd indices in one step
        sin_input = position * div_term
        pe = pe.at[:, 0::2].set(jnp.sin(sin_input))
        pe = pe.at[:, 1::2].set(jnp.cos(sin_input))

        # Add batch dimension
        self.pe = pe[None, :, :]  # Shape: (1, max_len, d_model)

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        # x shape: (batch_size, seq_len, d_model)
        return x + self.pe[:, :x.shape[1], :]

In [87]:
class BSBRLayer(nnx.Module):

    def __init__(
            self,
            rngs: nnx.Rngs,
            hidden_dim: jnp.int8,
            num_heads: jnp.int8,
            chunk_size: jnp.int8,
            ff_dim: jnp.int8,
            dropout: jnp.float32 = 0.1,
            compression_factor: Optional[jnp.int8] = None
    ):
        self.attention = BSBRAttention(
                rngs=rngs,
                hidden_dim=hidden_dim,
                num_heads=num_heads,
                chunk_size=chunk_size,
                dropout=dropout,
                compression_factor=compression_factor
        )
        # Layer normalization
        self.dyt1 = DyT(num_features=hidden_dim)
        self.dyt2 = DyT(num_features=hidden_dim)

        # Feedforward network
        self.ff = nnx.Sequential(
                nnx.Linear(in_features=hidden_dim, out_features=ff_dim, rngs=rngs),
                nnx.gelu,
                nnx.Dropout(dropout, rngs=rngs),
                nnx.Linear(in_features=ff_dim, out_features=hidden_dim, rngs=rngs),
                nnx.Dropout(dropout, rngs=rngs),
        )

    def __call__(self, hidden_states: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None) -> jnp.ndarray:
        # Layer normalization
        residual = hidden_states
        hidden_states = self.dyt1(hidden_states)
        hidden_states = self.attention(hidden_states, attention_mask)
        hidden_states = residual + hidden_states

        # Feed-forward network
        residual = hidden_states
        hidden_states = self.dyt2(hidden_states)
        hidden_states = self.ff(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


In [90]:
class BSBRModel(nnx.Module):
    def __init__(self, rngs:nnx.Rngs ,vocab_size:jnp.int8, hidden_dim: jnp.int8,
                 num_layers: jnp.int8,
                 num_heads: jnp.int8,
                 chunk_size: jnp.int8,
                 ff_dim: jnp.int8,
                 dropout: jnp.float32 = 0.1,
                 compression_factor: Optional[jnp.int8] = None):
        self.embedding = nnx.Embed(num_embeddings=vocab_size, features=hidden_dim, rngs=rngs)
        self.positional_encoding = PositionalEncoding(d_model=hidden_dim)

        self.layers = [BSBRLayer(rngs=rngs,
                                 hidden_dim=hidden_dim,
                                 num_heads=num_heads,
                                 chunk_size=chunk_size,
                                 dropout=dropout,
                                 ff_dim=ff_dim,
                                 compression_factor=compression_factor) for _ in range(num_layers)]
        self.dyt = DyT(num_features=hidden_dim)

    def __call__(self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None) -> jnp.ndarray:
        # Embedding and positional encoding
        hidden_states = self.embedding(input_ids)
        hidden_states = self.positional_encoding(hidden_states)

        # Pass through layers
        for layer in self.layers:
            hidden_states = layer(hidden_states, attention_mask)

        # Final layer normalization
        hidden_states = self.dyt(hidden_states)

        return hidden_states

In [91]:
# Model configuration
vocab_size = 10000
hidden_dim = 512
num_layers = 4
num_heads = 8
chunk_size = 128
ff_dim = 2048

In [92]:
model = BSBRModel(
    rngs=nnx.Rngs(0),
    vocab_size=vocab_size,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
    num_heads=num_heads,
    chunk_size=chunk_size,
    ff_dim=ff_dim,
    dropout=0.1,
    compression_factor=4
)

In [93]:
# Create a sample input
batch_size = 2
seq_len = 256
input_ids = jax.random.randint(jax.random.PRNGKey(0), (batch_size, seq_len), 0, vocab_size)
attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32)

In [94]:
output = model(input_ids, attention_mask)

In [95]:
print(f"Input shape: {input_ids.shape}")
print(f"Output shape: {output.shape}")

Input shape: (2, 256)
Output shape: (2, 256, 512)


In [97]:
nnx.display(model)

In [108]:
# Test a longer sequence to demonstrate chunk handling
long_seq_len = 512
long_input_ids = jax.random.randint(jax.random.PRNGKey(0), (batch_size, long_seq_len), 0, vocab_size)
long_attention_mask = jnp.ones((batch_size, long_seq_len), dtype=jnp.float32)

long_output = model(long_input_ids, long_attention_mask)

In [109]:
print(f"Long input shape: {long_input_ids.shape}")
print(f"Long output shape: {long_output.shape}")

Long input shape: (2, 512)
Long output shape: (2, 512, 512)


In [110]:
long_output

Array([[[-1., -1.,  1., ...,  1.,  1.,  1.],
        [-1., -1.,  1., ...,  1., -1.,  1.],
        [ 1., -1., -1., ...,  1., -1.,  1.],
        ...,
        [-1., -1.,  1., ...,  1.,  1.,  1.],
        [-1., -1.,  1., ...,  1.,  1.,  1.],
        [-1., -1.,  1., ...,  1.,  1.,  1.]],

       [[-1.,  1.,  1., ...,  1., -1., -1.],
        [ 1.,  1., -1., ...,  1., -1., -1.],
        [ 1.,  1., -1., ...,  1., -1., -1.],
        ...,
        [ 1.,  1.,  1., ...,  1.,  1.,  1.],
        [ 1.,  1.,  1., ...,  1.,  1.,  1.],
        [ 1.,  1.,  1., ...,  1., -1., -1.]]], dtype=float32)

In [107]:
long_output

Array([[[-0.99601203, -1.        ,  1.        , ...,  1.        ,
         -1.        , -1.        ],
        [ 1.        , -1.        , -1.        , ...,  1.        ,
          0.9999788 , -1.        ],
        [-1.        , -1.        , -1.        , ...,  1.        ,
         -1.        , -1.        ],
        ...,
        [-1.        , -1.        ,  1.        , ...,  1.        ,
          1.        ,  1.        ],
        [ 1.        , -1.        ,  1.        , ...,  1.        ,
          1.        ,  1.        ],
        [-1.        , -1.        ,  1.        , ...,  1.        ,
          1.        ,  1.        ]],

       [[-1.        , -1.        , -1.        , ...,  1.        ,
          1.        ,  1.        ],
        [-1.        , -1.        , -1.        , ...,  1.        ,
          1.        , -1.        ],
        [-1.        ,  1.        , -1.        , ...,  1.        ,
          1.        , -1.        ],
        ...,
        [-1.        , -1.        ,  1.        , ...,  