In [1]:
import os
import jax
import numpy as np
import matplotlib.pyplot as plt
import time
import json
import logging
import sys
from functools import partial
from tqdm.notebook import tqdm

# 设置环境变量
os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=/usr/local/cuda"
os.environ["JAX_PLATFORM_NAME"] = "gpu"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ['NETKET_EXPERIMENTAL_SHARDING'] = '1'

import netket as nk
import netket.nn as nknn
import flax
import flax.linen as nn
import jax.numpy as jnp
import math
from math import pi
from netket.nn.blocks import SymmExpSum
from netket.operator.spin import sigmax, sigmay, sigmaz
from netket.optimizer.qgt import QGTJacobianPyTree, QGTJacobianDense, QGTOnTheFly
from netket.operator import AbstractOperator, LocalOperator as _LocalOperator
from netket.utils.types import DType as _DType
from netket.hilbert import DiscreteHilbert as _DiscreteHilbert
from netket.nn.activation import reim_selu
from netket.jax import logsumexp_cplx
from netket.nn.symmetric_linear import DenseSymmMatrix, DenseEquivariantIrrep
from netket.utils import HashableArray
from einops import rearrange
from netket.utils.group.planar import rotation, reflection_group, D, glide, glide_group, C
from netket.utils.group import PointGroup, Identity, PermutationGroup

# 显示分片模式和可用设备
print("启用分片模式：", nk.config.netket_experimental_sharding)
print("可用设备：", jax.devices())


启用分片模式： True
可用设备： [CudaDevice(id=0), CudaDevice(id=1)]


In [2]:
from jax.nn.initializers import zeros, lecun_normal, normal
from typing import Any

# 默认初始化器
default_gcnn_initializer = lecun_normal(in_axis=1, out_axis=0)
default_attn_initializer = normal(0.02)

class MultiHeadAttention(nn.Module):
    """对称多头注意力机制（无dropout）"""
    features: int
    num_heads: int
    head_dim: int
    irreps: tuple  # 不可约表示矩阵列表
    symmetries: HashableArray  # 对称操作数组
    param_dtype: Any = np.complex128
    
    @nn.compact
    def __call__(self, x, training=True):
        # 输入形状: [batch_size, n_sites, features]
        batch_size, n_sites, _ = x.shape
        
        # 定义 Q、K、V 等变线性变换
        q = DenseEquivariantIrrep(
            irreps=self.irreps,
            features=self.num_heads * self.head_dim,
            param_dtype=self.param_dtype,
            kernel_init=default_attn_initializer,
        )(x)
        
        k = DenseEquivariantIrrep(
            irreps=self.irreps,
            features=self.num_heads * self.head_dim,
            param_dtype=self.param_dtype,
            kernel_init=default_attn_initializer,
        )(x)
        
        v = DenseEquivariantIrrep(
            irreps=self.irreps,
            features=self.num_heads * self.head_dim,
            param_dtype=self.param_dtype,
            kernel_init=default_attn_initializer,
        )(x)
        
        # 重塑为多头格式
        q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
        k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads)
        v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads)
        
        # 计算注意力分数
        scale = 1.0 / jnp.sqrt(self.head_dim)
        attn_weights = jnp.einsum('bhid,bhjd->bhij', q, jnp.conj(k)) * scale
        
        # 使用复值 softmax 保持相位信息
        attn_weights_abs = jnp.abs(attn_weights)
        attn_weights_phase = jnp.exp(1j * jnp.angle(attn_weights))
        attn_weights_softmax = nn.softmax(attn_weights_abs, axis=-1)
        attn_weights = attn_weights_softmax * attn_weights_phase
        
        # 计算注意力输出
        output = jnp.einsum('bhij,bhjd->bhid', attn_weights, v)
        output = rearrange(output, 'b h n d -> b n (h d)')
        
        # 输出投影
        output = DenseEquivariantIrrep(
            irreps=self.irreps,
            features=self.features,
            param_dtype=self.param_dtype,
            kernel_init=default_attn_initializer,
        )(output)
        
        return output


class GCNN_Attention_Irrep(nn.Module):
    """结合注意力机制的群等变神经网络（无dropout）"""
    
    symmetries: HashableArray  # 对称操作数组
    irreps: tuple              # 不可约表示矩阵列表
    layers: int                # 层数 
    features: tuple            # 每层特征数
    characters: HashableArray  # 指定所需对称表示的字符
    attention_layers: int = 2  # 注意力层数量
    num_heads: int = 4         # 注意力头数
    head_dim: int = 32         # 每个头的维度
    parity: int = 1            # 宇称值
    param_dtype: Any = np.complex128  # 参数数据类型
    input_mask: Any = None     # 输入掩码
    equal_amplitudes: bool = False  # 是否强制等幅
    use_bias: bool = True      # 是否使用偏置

    def setup(self):
        # 第一层：对称化线性变换
        self.dense_symm = DenseSymmMatrix(
            symmetries=self.symmetries,
            features=self.features[0],
            param_dtype=self.param_dtype,
            use_bias=self.use_bias,
            kernel_init=default_gcnn_initializer,
            bias_init=zeros,
            mask=self.input_mask,
        )
        
        # GCNN 等变层
        self.equivariant_layers = tuple(
            DenseEquivariantIrrep(
                irreps=self.irreps,
                features=self.features[layer + 1],
                use_bias=self.use_bias,
                param_dtype=self.param_dtype,
                kernel_init=default_gcnn_initializer,
                bias_init=zeros,
            )
            for layer in range(self.layers - 1)
        )
        
        self.equivariant_layers_flip = tuple(
            DenseEquivariantIrrep(
                irreps=self.irreps,
                features=self.features[layer + 1],
                use_bias=self.use_bias,
                param_dtype=self.param_dtype,
                kernel_init=default_gcnn_initializer,
                bias_init=zeros,
            )
            for layer in range(self.layers - 1)
        )
        
        # 注意力模块 – 仅在最后几层中使用
        self.attention_modules = tuple(
            MultiHeadAttention(
                features=self.features[layer + 1],
                num_heads=self.num_heads,
                head_dim=self.head_dim,
                irreps=self.irreps,
                symmetries=self.symmetries,
                param_dtype=self.param_dtype,
            ) if layer >= (self.layers - 1 - self.attention_layers) else None
            for layer in range(self.layers - 1)
        )
        
        # 层归一化
        self.layer_norms = tuple(
            nn.LayerNorm(epsilon=1e-5, param_dtype=self.param_dtype)
            for _ in range(self.layers - 1)
        )

    @nn.compact
    def __call__(self, x, training=True):
        if x.ndim < 3:
            x = jnp.expand_dims(x, -2)  # 添加特征维度

        # 分离正反两条路径
        x_flip = self.dense_symm(-1 * x)
        x = self.dense_symm(x)

        for layer in range(self.layers - 1):
            x = reim_selu(x)
            x_flip = reim_selu(x_flip)

            # 保存残差
            residual_x = x
            residual_x_flip = x_flip
            
            # 卷积路径
            x_conv = (
                self.equivariant_layers[layer](x)
                + self.equivariant_layers_flip[layer](x_flip)
            ) / jnp.sqrt(2)
            x_flip_conv = (
                self.equivariant_layers[layer](x_flip)
                + self.equivariant_layers_flip[layer](x)
            ) / jnp.sqrt(2)
            
            x, x_flip = x_conv, x_flip_conv
            
            # 在末端层应用注意力机制
            if layer >= (self.layers - 1 - self.attention_layers):
                x_attn = self.attention_modules[layer](x, training=training)
                x = x + 0.5 * x_attn
                # 层归一化
                x = self.layer_norms[layer](x)
                x_flip = self.layer_norms[layer](x_flip)
                # 残差连接
                x = x + residual_x
                x_flip = x_flip + residual_x_flip
            
        # 拼接正反两条路径
        x = jnp.concatenate((x, x_flip), -1)
        # 构造宇称为1的字符并应用复值 logsumexp
        par_chars = jnp.expand_dims(
            jnp.concatenate(
                (jnp.array(self.characters), jnp.array(self.characters)), 0
            ),
            (0, 1),
        )
        x = logsumexp_cplx(x, axis=(-2, -1), b=par_chars)
        return x

In [3]:
def create_gcnn_attention(
    symmetries, 
    layers, 
    features, 
    attention_layers=2,
    num_heads=4,
    head_dim=32,
    mask=None, 
    characters=None
):
    """简化创建带注意力机制 GCNN 模型的函数（无dropout）"""
    if isinstance(features, int):
        features = (features,) * layers
    if characters is None:
        characters = HashableArray(np.ones(len(np.asarray(symmetries))))
    else:
        characters = HashableArray(characters)
    
    sym = HashableArray(np.asarray(symmetries))
    irreps = tuple(HashableArray(irrep) for irrep in symmetries.irrep_matrices())
    input_mask = HashableArray(mask) if mask is not None else None
    
    return GCNN_Attention_Irrep(
        symmetries=sym,
        irreps=irreps,
        layers=layers,
        features=features,
        characters=characters,
        attention_layers=attention_layers,
        num_heads=num_heads,
        head_dim=head_dim,
        parity=1,
        param_dtype=np.complex128,
        input_mask=input_mask,
        equal_amplitudes=False,
    )

In [None]:
# 哈密顿量参数
J1 = 0.03
J2 = 0.05
Q = 1-J2  # 四自旋相互作用强度，替换h项

# Shastry-Sutherland晶格定义
Lx = 5
Ly = 5

# 自定义边
custom_edges = [
    (0, 1, [1.0, 0.0], 0),
    (1, 0, [1.0, 0.0], 0),
    (1, 2, [0.0, 1.0], 0),
    (2, 1, [0.0, 1.0], 0),
    (3, 2, [1.0, 0.0], 0),
    (2, 3, [1.0, 0.0], 0),
    (0, 3, [0.0, 1.0], 0),
    (3, 0, [0.0, 1.0], 0),
    (2, 0, [1.0, -1.0], 1),
    (3, 1, [1.0, 1.0], 1),
]

# 创建晶格
lattice = nk.graph.Lattice(
    basis_vectors=[[2.0, 0.0], [0.0, 2.0]],
    extent=(Lx, Ly),
    site_offsets=[[0.5, 0.5], [1.5, 0.5], [1.5, 1.5], [0.5, 1.5]],
    custom_edges=custom_edges,
    pbc=[True, True]
)

# 可视化晶格
# lattice.draw()

In [None]:
# Hilbert空间定义
hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes, total_sz=0)

# 自旋-1/2矩阵
sigmax = jnp.array([[0, 0.5], [0.5, 0]])
sigmay = jnp.array([[0, -0.5j], [0.5j, 0]])
sigmaz = jnp.array([[0.5, 0], [0, -0.5]])
unitm = jnp.array([[1.0, 0.0], [0.0, 1.0]])

# 自旋-自旋相互作用
sxsx = np.kron(sigmax, sigmax)
sysy = np.kron(sigmay, sigmay)
szsz = np.kron(sigmaz, sigmaz)
umum = np.kron(unitm, unitm)
SiSj = sxsx + sysy + szsz

# Q项需要的C_ij算子定义
ProjOp = jnp.array(SiSj) - 0.25 * jnp.array(umum)
ProjOp2 = jnp.kron(ProjOp, ProjOp)

# 构建J1-J2部分的哈密顿量
bond_operator = [
    (J1 * SiSj).tolist(),
    (J2 * SiSj).tolist(),
]
bond_color = [0, 1]

# 创建图哈密顿量 - 不包含Q项
H_J = nk.operator.GraphOperator(hilbert, graph=lattice, bond_ops=bond_operator, bond_ops_colors=bond_color)

# 创建Q项哈密顿量
H_Q = nk.operator.LocalOperator(hilbert, dtype=jnp.complex128)

# 获取晶格尺寸
Lx, Ly = lattice.extent[0], lattice.extent[1]

# 遍历所有单元格
for x in range(Lx):
    for y in range(Ly):
        # 计算当前单元格的基本索引
        base = 4 * (y + x * Ly)
        
        # 当前单元格内的四个格点
        site0 = base      # 左下角 (0.5, 0.5)
        site1 = base + 1  # 右下角 (1.5, 0.5)
        site2 = base + 2  # 右上角 (1.5, 1.5)
        site3 = base + 3  # 左上角 (0.5, 1.5)
        
        # 找到相邻单元格（考虑周期性边界条件）
        right_x = (x + 1) % Lx
        right_base = 4 * (y + right_x * Ly)
        
        left_x = (x - 1 + Lx) % Lx
        left_base = 4 * (y + left_x * Ly)
        
        up_y = (y + 1) % Ly
        up_base = 4 * (up_y + x * Ly)
        
        down_y = (y - 1 + Ly) % Ly
        down_base = 4 * (down_y + x * Ly)
        
        # 1. 单元格内部的水平方向plaquette
        H_Q += nk.operator.LocalOperator(hilbert, [(-Q * ProjOp2).tolist()],
                                        [[site0, site1, site3, site2]])
        
        # 2. 单元格内部的垂直方向plaquette
        H_Q += nk.operator.LocalOperator(hilbert, [(-Q * ProjOp2).tolist()],
                                        [[site0, site3, site1, site2]])
        
        # 3. 与右侧单元格形成的水平plaquette（处理x方向周期性）
        H_Q += nk.operator.LocalOperator(hilbert, [(-Q * ProjOp2).tolist()],
                                        [[site1, right_base, site2, right_base + 3]])
        
        # 4. 与上方单元格形成的垂直plaquette（处理y方向周期性）
        H_Q += nk.operator.LocalOperator(hilbert, [(-Q * ProjOp2).tolist()],
                                        [[site3, up_base, site2, up_base + 1]])

# 合并两部分哈密顿量
ha = H_J + 2*H_Q
ha = ha.to_jax_operator()


In [6]:
# 定义对称群（以 C4v 为例）并获取 irreps
nc = 4
cyclic_4 = PointGroup(
    [Identity()] + [rotation((360 / nc)*i) for i in range(1, nc)],
    ndim=2,
)
C4v = glide_group(trans=(1, 1), origin=(0, 0)) @ cyclic_4
C4v_symmetry = lattice.space_group(C4v)
print(f"Number of symmetry operations: {len(C4v_symmetry)}")

# 构造局部簇及掩码
local_cluster = jnp.arange(lattice.n_nodes).tolist()
mask = jnp.zeros(lattice.n_nodes, dtype=bool)
for i in local_cluster:
    mask = mask.at[i].set(True)

# 获取 irreps (此处示例采用 netket 内置方法获得 irreps)
sgb = lattice.space_group_builder(point_group=C4v)
momentum = [0.0, 0.0]
chi = sgb.space_group_irreps(momentum)[0]

# 构建模型（注意：attention_layers, num_heads, head_dim 参数可根据实际需求调整）
model = create_gcnn_attention(
    symmetries=C4v_symmetry,
    layers=4,
    features=4,
    attention_layers=1,  # 使用 1 个注意力层
    num_heads=2,
    head_dim=4,
    mask=mask,
    characters=chi
)

# 定义采样器与变分量子态
sampler = nk.sampler.MetropolisExchange(hilbert=hilbert, graph=lattice, n_chains=2**12, d_max=2)
vqs = nk.vqs.MCState(
    sampler=sampler,
    model=model,
    n_samples=2**12,
    n_samples_per_rank=None,
    n_discard_per_chain=0,
    chunk_size=2**8,
    training_kwargs={"holomorphic": False}  # 非全纯模型
)

n_params = nk.jax.tree_size(vqs.parameters)
print(f"Number of model parameters: {n_params}")  # 应接近预期参数量

Number of symmetry operations: 72
Number of model parameters: 16472


In [7]:
import jax
from jax import tree_util

# 熵梯度计算函数
def T_logp2(params, inputs, temperature, model_instance):
    variables = {"params": params}
    preds = model_instance.apply(variables, inputs)
    return 2.0 * temperature * jnp.mean(jnp.real(preds)**2)

def T_logp_2(params, inputs, temperature, model_instance):
    variables = {"params": params}
    preds = model_instance.apply(variables, inputs)
    return 2.0 * temperature * (jnp.mean(jnp.real(preds)))**2

# 定义自由能优化驱动，继承 nk.experimental.driver.vmc_srt.VMC_SRt
from netket.experimental.driver.vmc_srt import VMC_SRt

class FreeEnergyVMC_SRt(VMC_SRt):
    def __init__(self, temperature, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.init_temperature = temperature  # 初始温度
        self.temperature = temperature
        self.max_grad_norm = 0.7  # 最大梯度范数

    def _step_with_state(self, state):
        new_state = super()._step_with_state(state)
        params = new_state.parameters
        inputs = new_state.samples

        # 计算熵梯度部分
        mT_grad_S_1 = jax.grad(T_logp2, argnums=0)(params, inputs, self.temperature, self.variational_state.model)
        mT_grad_S_2 = jax.grad(T_logp_2, argnums=0)(params, inputs, self.temperature, self.variational_state.model)
        mT_grad_S = jax.tree_util.tree_map(lambda x, y: x - y, mT_grad_S_1, mT_grad_S_2)

        # 自由能梯度：能量梯度减去熵梯度
        total_grad = jax.tree_util.tree_map(lambda g_e, g_s: g_e - g_s, new_state.gradient, mT_grad_S)

        # 梯度裁剪
        total_grad = tree_util.tree_map(lambda g: jnp.clip(g, -self.clip_norm, self.clip_norm), total_grad)

        new_params = self.optimizer.update(total_grad, params)
        new_state = new_state.replace(parameters=new_params)
        return new_state

# 带进度条与温度递减策略的训练驱动
class CustomFreeEnergyVMC_SRt(FreeEnergyVMC_SRt):
    def __init__(self, reference_energy, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.reference_energy = reference_energy

    def run(self, n_iter, out=None):
        outer_pbar = tqdm(total=n_iter, desc=f"Lx={Lx}, Ly={Ly}")
        for i in range(n_iter):
            # 温度递减
            self.temperature = self.init_temperature * (jnp.exp(-i / 50.0))
            self.advance(1)
            energy_mean = self.energy.mean
            energy_var = self.energy.variance
            energy_error = self.energy.error_of_mean
            relative_error = abs((energy_mean - self.reference_energy) / self.reference_energy) * 100
            outer_pbar.set_postfix({
                'Temp': f'{self.temperature:.4f}',
                'Energy': f'{energy_mean:.6f}',
                'E_var': f'{energy_var:.6f}',
                'E_err': f'{energy_error:.6f}',
                'Rel_err(%)': f'{relative_error:.4f}',
            })
            outer_pbar.update(1)
        outer_pbar.close()
        return self

# 初始化优化器和训练驱动
temperature_original = 1.0  # 初始温度
reference_energy = -16.2631
optimizer = nk.optimizer.Sgd(learning_rate=0.02)   # 或使用 Adam

vmc = CustomFreeEnergyVMC_SRt(
    reference_energy=reference_energy,
    temperature=temperature_original,
    hamiltonian=ha,
    optimizer=optimizer,
    diag_shift=0.05,
    variational_state=vqs
)

start = time.time()
vmc.run(n_iter=1000)
end = time.time()
print(f"Optimization time: {end - start:.2f} seconds")

Lx=3, Ly=3:   0%|          | 0/1000 [00:00<?, ?it/s]

Optimization time: 1557.45 seconds
