# Understanding H-Nets from the ground up

H-Net's core idea is a novel dynamic chunking (DC) mechanism which interfaces between the main  network and the encoder/decoder networks, learning how to segment data while using standard differentiable optimization.

The input is a sequence of bytes. The Encoder is a neural network (i.e. SSM Mamba) that encodes the input sequence into a sequence of latent vectors. The Decoder is another neural network (i.e. SSM Mamba) that decodes the latent vectors back into the original sequence.

Two main modules
- routing module + downsampler (encoding)
- smoothing module + upsampler (decoding)

$$ x^0 \in \mathbb{R}^{L^0 \times D^0} $$
$$ p^0 \in [0, 1]^{L^0 } $$

Encoder:
$$ \hat{x}^0 = \mathcal{E}^0 (x^0) $$
Chunker:
$$ (x^{1}, p^0) = \operatorname{Chunk}(\hat{x}^0) $$
Main Neural Network:
$$ \hat{z}^0 = \mathcal{M}(x^0) $$
Dechunker:
$$ z^0 = \operatorname{Dechunk}(\hat{z}^{1}, p^0) + \operatorname{Linear}(\hat{x}^0) $$
Decoder:
$$ \hat{z}^0 = \mathcal{D}^0 (z^0) $$


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Chunking Layer

The Chunking Layer contains a routing module that computes the chunking probabilities and a downsampler that applies these probabilities to the input data.

$$ q_t = W_q \hat{x}_t, \quad k_t = W_k \hat{x}_t, \quad p_t = \frac{1}{2} \left( 1 - \frac{q_t^T
 k_{t-1}}{\| q_t \| \| k_{t-1} \|} \right) \in [0,1], \quad b_t = 1_{\{ p_t \geq 0.5 \}}. $$

In [None]:
class SimilarityRouter(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.W_q = nn.Linear(dim, dim, bias=False)
        self.W_k = nn.Linear(dim, dim, bias=False)

    def forward(self, x_hat):
        # x_hat: (batch size, sequence length, embedding dimension) (B, L, D)
        q = self.W_q(x_hat)  # (B, L, D)
        k = self.W_k(x_hat)  # (B, L, D)

        # Shift k to compare with previous
        # k_{t-1} for t=1 is k_1 (paper sets p_1=1)
        k_shifted = torch.cat([k[:, :1, :],
                               k[:, :-1, :]],
                               dim=1)

        # Cosine similarity: q_t^T k_{t-1} / (|q_t| |k_{t-1}|)
        dot = torch.sum(q * k_shifted, dim=-1)  # (B, L)
        norm_q = torch.norm(q, dim=-1) + 1e-8
        norm_k = torch.norm(k_shifted, dim=-1) + 1e-8
        cos_sim = dot / (norm_q * norm_k)

        # p_t = 1/2 (1 - cos_sim)
        p = 0.5 * (1 - cos_sim)  # (B, L)
        p[:, 0] = 1.0  # Force p_1 = 1.0

        # Boundary indicators b_t = 1 if p_t >= 0.5
        b = (p >= 0.5).float()  # (B, L)

        return p, b

In [None]:
def downsample(x_hat, b, p):
    # x_hat: (B, L, D), b: (B, L), p: (B, L)
    # Find positions where b_t == 1
    mask = b.bool()  # (B, L)

    # Compressed x_next: gather elements where mask is True
    # For each batch, collect non-zero indices
    x_next = []
    P_next = []
    for i in range(x_hat.shape[0]):
        idx = mask[i].nonzero(as_tuple=False).squeeze(-1)  # Indices where b=1
        x_next.append(x_hat[i, idx, :])
        P_next.append(p[i, idx])

    # Pad to max length for batching (in practice, use variable length or padding)
    max_len = max(len(x) for x in x_next)
    x_next_padded = torch.stack([F.pad(x, (0, 0, 0, max_len - len(x))) for x in x_next])
    P_next_padded = torch.stack([F.pad(P, (0, max_len - len(P))) for P in P_next])

    return x_next_padded, P_next_padded

## Dechunker 

The Dechunker consists of a smoothing module:

$$ z^s = \operatorname{Dechunk}(\hat{z}^{s+1}, p^s) + \operatorname{Linear}(\hat{x}^s). $$

The critical challenge in training a dynamic chunking module lies in the *discrete nature* of chunk  boundaries, which impedes gradient flow during backpropagation.

In [None]:
def smoothing_module(z_hat, P):
    # z_hat: (B, L_compressed, D), P: (B, L_compressed)
    # exponential moving average smoothing
    bar_z = torch.zeros_like(z_hat)
    for b in range(z_hat.shape[0]):
        bar_z[b, 0] = z_hat[b, 0]
        for t in range(1, z_hat.shape[1]):
            bar_z[b, t] = P[b, t] * z_hat[b, t] + (1 - P[b, t]) * bar_z[b, t-1]
    return bar_z

### Dechunker 2/2: Upsampler. 

We carefully design the upsampler that decompresses $\hat{z}^{s+1}$ to match 
the original resolution of inputs in the previous stage $z^s$ with the 
following definition:

$$
c_t = p_t^{b_t} (1 - p_t)^{1 - b_t} = 
\begin{cases} 
p_t & \text{if } b_t = 1, \\ 
1 - p_t & \text{otherwise},
\end{cases}
$$

$$
\text{STE}(c_t) = c_t + \text{stopgradient}(1 - c_t),
$$

$$
\tilde{z}_t = \tilde{z}_{\sum_{k=1}^t b_k},
$$

$$
\text{Upsampler}(\tilde{z}, c)_t = \text{STE} (c_t ) \cdot \tilde{z}_t.
$$

In [None]:
def upsample(bar_z, b, p, original_L, D):
    # bar_z: (B, L_compressed, D), b: (B, original_L), p: (B, original_L)
    z = torch.zeros(bar_z.shape[0], original_L, D, device=bar_z.device)

    for i in range(bar_z.shape[0]):
        chunk_idx = 0
        for t in range(original_L):
            z[i, t] = bar_z[i, chunk_idx]
            if b[i, t] == 1 and t < original_L - 1:  # Move to next chunk at boundary
                chunk_idx += 1

    # Confidence adjustment (from paper: c_t = p_t^{b_t} (1-p_t)^{1-b_t})
    c = p.clone()
    c[b == 1] = p[b == 1]
    c[b == 0] = 1 - p[b == 0]
    # Not directly used in dechunk, perhaps for weighting; here we skip for simplicity

    return z

In [None]:
class HNet(nn.Module):
    def __init__(self, dim, num_stages=1):
        super().__init__()
        self.num_stages = num_stages
        self.router = SimilarityRouter(dim)
        self.encoder = nn.Linear(dim, dim)  # Placeholder for E_s (e.g., Mamba layer)
        self.main = nn.Linear(dim, dim)     # Placeholder for M
        self.decoder = nn.Linear(dim, dim)  # Placeholder for D_s
        self.linear_skip = nn.Linear(dim, dim)  # For skip connection

    def forward(self, x0):
        # x0: (B, L0, D0)
        xs = [x0]
        ps = []

        # Forward: Encode and Chunk (compress)
        for s in range(self.num_stages):
            x_hat = self.encoder(xs[-1])
            p, b = self.router(x_hat)
            x_next, P_next = downsample(x_hat, b, p)
            xs.append(x_next)
            ps.append((p, b, P_next))

        # Main network at final stage
        z_hat_S = self.main(xs[-1])
        zs = [z_hat_S]

        # Backward: Dechunk and Decode (decompress)
        for s in reversed(range(self.num_stages)):
            bar_z = ema_smooth(zs[-1], ps[s][2])  # P_next
            z_s = upsample(bar_z, ps[s][1], ps[s][0], xs[s].shape[1], xs[s].shape[2])
            z_s += self.linear_skip(self.encoder(xs[s]))  # Skip connection
            z_hat_s = self.decoder(z_s)
            zs.append(z_hat_s)

        return zs[-1]  # Final output

In [6]:
# Example
dim = 512
model = HNet(dim, num_stages=1)
x = torch.randn(1, 1024, dim)  # Batch 1, seq len 1024, dim 512
output = model(x)
print(output.shape)  # Should be (1, 1024, 512)

IndexError: index 525 is out of bounds for dimension 1 with size 525