In [None]:
import torch

class HalkoLayer(torch.nn.Module):
    """Minimal 3D tensor processing with dynamic routing."""
    def __init__(self, depth: int):
        super().__init__()
        self.depth = depth
        self.micro_centers = torch.nn.ModuleList([
            torch.nn.Linear(128, 128) for _ in range(depth)
        ])
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x shape: [batch, depth, seq, features]
        for d in range(self.depth):
            x[:, d] = self.micro_centers[d](x[:, d])
        return x

# Example usage
ha = HalkoLayer(depth=3)
input = torch.randn(2, 3, 10, 128)  # [batch=2, depth=3, seq=10, features=128]
output = ha(input)
print(output.shape)  # torch.Size([2, 3, 10, 128])