In [1]:
# [file name]: mamba1d.py
# [file content begin]
from flax import linen as nn
import jax
import jax.numpy as jnp
from jax.scipy.signal import convolve
import numpy as np

class SSMLayer(nn.Module):
    d_model: int
    d_state: int = 16
    dt_rank: int = 2
    kernel_init: nn.initializers.Initializer = nn.initializers.normal(0.02)
    
    @nn.compact
    def __call__(self, x):
        batch, seq_len, _ = x.shape
        
        # 参数定义（调整卷积核维度）
        A = self.param("A", self.kernel_init, (self.d_state,))
        D = self.param("D", nn.initializers.ones, (self.d_model,))
        dt = self.param("dt", self.kernel_init, (seq_len, self.dt_rank))
        B = self.param("B", self.kernel_init, (self.dt_rank, self.d_state))
        C = self.param("C", self.kernel_init, (self.d_model, self.d_state))  # 调整为 [d_model, d_state]
        
        # 离散化过程
        delta = jnp.exp(dt @ B)  # [seq_len, d_state]
        A_bar = jnp.exp(-delta * jnp.exp(A))  # [seq_len, d_state]
        B_bar = (delta / jnp.exp(A)) * (1 - A_bar)  # [seq_len, d_state]
        
        # 调整卷积核维度为 [d_model, d_state, seq_len]
        B_bar_exp = B_bar.T[None, :, :]  # [1, d_state, seq_len]
        C_exp = C[:, :, None]            # [d_model, d_state, 1]
        kernel = B_bar_exp * C_exp       # [d_model, d_state, seq_len]
        kernel = jnp.moveaxis(kernel, -1, 0)  # [seq_len, d_state, d_model]
        
        # 执行逐通道一维卷积（调整输入和卷积核维度）
        u = x * D  # [batch, seq_len, d_model]
        # 将输入和卷积核转为 [d_model, batch, seq_len] 和 [d_model, d_state, seq_len]
        u_perm = jnp.moveaxis(u, -1, 0)  # [d_model, batch, seq_len]
        kernel_perm = jnp.moveaxis(kernel, -1, 0)  # [d_model, d_state, seq_len]
        
        # 逐通道卷积（每个通道独立处理）
        y = jax.vmap(lambda u_ch, k_ch: convolve(u_ch, k_ch, mode='same'))(u_perm, kernel_perm)
        y = jnp.moveaxis(y, 0, -1)  # 恢复维度 [batch, seq_len, d_model]
        
        # 残差连接
        y = y + u * A_bar[None, :, None]
        return y

class MambaBlock(nn.Module):
    d_model: int
    b: int = 4
    expand: int = 2
    
    @nn.compact
    def __call__(self, x):
        x = x[..., None]  # [batch, L, 1]
        x = x.reshape(x.shape[0], -1, self.b)  # [batch, L//b, b]
        x = jnp.tile(x, (1, 1, self.d_model))  # [batch, L//b, b*d_model]
        x = x.reshape(x.shape[0], -1, self.d_model)  # [batch, seq_len, d_model]
        
        x = nn.Dense(self.d_model * self.expand)(x)
        x = SSMLayer(d_model=self.d_model * self.expand)(x)
        x = nn.silu(x)
        x = nn.Dense(self.d_model)(x)
        return x.reshape(x.shape[0], -1)  # [batch, seq_len]

class Mamba_Enc(nn.Module):
    d_model: int
    b: int = 4
    n_layers: int = 2
    
    @nn.compact
    def __call__(self, x):
        x = x.reshape(x.shape[0], -1)  # [batch, L]
        for _ in range(self.n_layers):
            x = MambaBlock(d_model=self.d_model, b=self.b)(x)
        return jnp.sum(x, axis=1)  # [batch, d_model]
# [file content end]

In [2]:
# [file name]: main_gs.py
# [file content begin]
import jax
import jax.numpy as jnp
import netket as nk
import time
from netket.experimental.driver.vmc_srt import VMC_SRt

seed = 0

# 1D Lattice
L = 100
J2 = 0.8

# Settings optimization
diag_shift = 1e-3
eta = 0.01
N_opt = 10000
N_samples = 3000
N_discard = 0

# Settings wave function
f = 1
d_model = 8  # 直接定义d_model，无需与heads关联
b = 4        # 确保 L % b == 0

#! End input

lattice = nk.graph.Chain(length=L, pbc=True, max_neighbor_order=2)
hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes, total_sz=0)
hamiltonian = nk.operator.Heisenberg(hilbert=hilbert, graph=lattice, J=[1.0, J2])

if L <= 16:
    evals = nk.exact.lanczos_ed(hamiltonian, compute_eigenvectors=False)
    print(f'Exact E0: {evals[0]/(4*L):.5f}')

# 初始化Mamba波函数
wf = Mamba_Enc(d_model=d_model, b=b, n_layers=2)  # 注意移除h参数

key = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(key)
params = wf.init(subkey, jnp.zeros((1, lattice.n_nodes)))

# 采样器与优化器配置
sampler = nk.sampler.MetropolisExchange(
    hilbert=hilbert, graph=lattice, d_max=2, n_chains=N_samples, n_sweeps=lattice.n_nodes
)
vstate = nk.vqs.MCState(
    sampler=sampler, 
    model=wf, 
    sampler_seed=key, 
    n_samples=N_samples, 
    n_discard_per_chain=N_discard,
    variables=params
)
print(f'Parameters: {nk.jax.tree_size(vstate.parameters)}')

optimizer = nk.optimizer.Sgd(learning_rate=eta)
vmc = VMC_SRt(hamiltonian=hamiltonian, optimizer=optimizer, variational_state=vstate, diag_shift=diag_shift)

# 运行优化
start = time.time()
vmc.run(out='Mamba_GS', n_iter=N_opt)
print(f'Time: {time.time()-start:.1f}s')
# [file content end]

ValueError: One input must be smaller than the other in every dimension.