In [3]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from jax import grad, jit, value_and_grad
from functools import partial
from tqdm.notebook import tqdm
import netket as nk
import flax.linen as nn

# 设置环境变量
os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=/usr/local/cuda"
os.environ["JAX_PLATFORM_NAME"] = "gpu"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ['NETKET_EXPERIMENTAL_SHARDING'] = '1'
print("Available devices:", jax.devices())

###########################
# 构造物理晶格和 Hilbert 空间
###########################
Lx = 3
Ly = 3
# 注意：custom_edges 格式为 (i, j, distance_vector, color)
custom_edges = [
    (0, 1, [1.0, 0.0], 0),
    (1, 0, [1.0, 0.0], 0),
    (1, 2, [0.0, 1.0], 0),
    (2, 1, [0.0, 1.0], 0),
    (2, 3, [1.0, 0.0], 0),
    (3, 2, [1.0, 0.0], 0),
    (0, 3, [0.0, 1.0], 0),
    (3, 0, [0.0, 1.0], 0),
    (2, 0, [1.0, -1.0], 1),
    (3, 1, [1.0, 1.0], 1),
]
# 构建晶格，其中 custom_edges 参数必须为数值型
lattice = nk.graph.Lattice(
    basis_vectors = [[2.0, 0.0], [0.0, 2.0]],
    extent=(Lx, Ly),
    site_offsets=[[0.5, 0.5], [1.5, 0.5], [1.5, 1.5], [0.5, 1.5]],
    custom_edges=custom_edges,
    pbc=[True, True]
)
# 将自定义边信息保存到其他属性（避免覆盖内部 edges 方法）
lattice.edge_info = custom_edges

# Hilbert 空间，定义自旋1/2和总磁量为0的空间
hi = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes, total_sz=0)

#####################################
# 构造联合神经网络：model_joint
#####################################
class JointNet(nn.Module):
    d_model: int
    num_layers: int
    patch_size: int = 1
    n_sites: int = 4
    
    @nn.compact
    def __call__(self, x):
        # 将输入 x reshape 成 [batch, n_sites, 1]
        x = x.reshape(x.shape[0], self.n_sites, 1)
        x = nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(),
                     dtype=jnp.float64)(x)
        # 多层简单残差结构
        for _ in range(self.num_layers):
            x_ln = nn.LayerNorm(dtype=jnp.float64)(x)
            x = x + nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(),
                             dtype=jnp.float64)(x_ln)
        # 输出 Dense 层映射到 2 个通道：
        # 通道 0 对应 J1, 通道 1 对应 J2
        x = nn.Dense(2, kernel_init=nn.initializers.xavier_uniform(),
                     dtype=jnp.float64)(x)
        return x  # 输出 shape 为 [batch, n_sites, 2]

# 初始化联合模型
d_model = 48
num_layers = 4
model_joint = JointNet(d_model=d_model, num_layers=num_layers, n_sites=lattice.n_nodes)
dummy_input = jnp.zeros((1, lattice.n_nodes))
key = jax.random.PRNGKey(42)
params_joint = model_joint.init(key, dummy_input)["params"]

#########################################
# 构造哈密顿量 (J1, J2 部分及完整系统)
#########################################
# 构造自旋-1/2算子
sigmax = jnp.array([[0, 0.5], [0.5, 0]])
sigmay = jnp.array([[0, -0.5j], [0.5j, 0]])
sigmaz = jnp.array([[0.5, 0], [0, -0.5]])
sxsx = np.kron(sigmax, sigmax)
sysy = np.kron(sigmay, sigmay)
szsz = np.kron(sigmaz, sigmaz)
SiSj = sxsx + sysy + szsz

# J1 部分哈密顿量
bond_operator_J1 = [(0.8 * SiSj).tolist()]
bond_color_J1 = [0]
H_J1 = nk.operator.GraphOperator(hi, graph=lattice,
                                 bond_ops=bond_operator_J1,
                                 bond_ops_colors=bond_color_J1)

# J2 部分哈密顿量
bond_operator_J2 = [(1.0 * SiSj).tolist()]
bond_color_J2 = [1]
H_J2 = nk.operator.GraphOperator(hi, graph=lattice,
                                 bond_ops=bond_operator_J2,
                                 bond_ops_colors=bond_color_J2)

# 完整系统哈密顿量（合并 J1 和 J2）
bond_operator = [
    (0.8 * SiSj).tolist(),
    (1.0 * SiSj).tolist()
]
bond_color = [0, 1]
H_full = nk.operator.GraphOperator(hi, graph=lattice,
                                   bond_ops=bond_operator,
                                   bond_ops_colors=bond_color)

# 转换为 JAX operator
ha_J1 = H_J1.to_jax_operator()
ha_J2 = H_J2.to_jax_operator()
ha_full = H_full.to_jax_operator()

#####################################
# 构造联合变分态 vqs_joint
#####################################
sampler_joint = nk.sampler.MetropolisExchange(hilbert=hi, graph=lattice,
                                              n_chains=2**12, d_max=2)
vqs_joint = nk.vqs.MCState(
    sampler=sampler_joint,
    model=lambda x: model_joint.apply({"params": params_joint}, x),
    n_samples=2**12,
    n_samples_per_rank=None,
    n_discard_per_chain=0,
    chunk_size=2**8,
    training_kwargs={"holomorphic": False}
)
# 注意：vqs_joint 保存的是联合模型参数

#########################################
# 定义联合网络损失与局部耦合损失
#########################################
gamma = 0.1           # coupling strength
target_coupling = 0.0 # 目标局部交叉相关（根据物理要求调整）

# 联合模型调用接口
def joint_model_apply(params, samples):
    return model_joint.apply({"params": params}, samples)

# 定义局部耦合损失函数
def local_coupling_loss(outputs, lattice):
    # outputs shape 为 [batch, n_sites, 2]
    loss = 0.0
    count = 0
    # 这里使用 lattice.edge_info 而非 lattice.edges
    for (i, j, _) in lattice.edge_info:
        psi1 = outputs[:, i, 0]
        psi2 = outputs[:, j, 1]
        psi1b = outputs[:, j, 0]
        psi2b = outputs[:, i, 1]
        local_err = jnp.mean((jnp.real(psi1 * jnp.conjugate(psi2)) - target_coupling)**2) \
                    + jnp.mean((jnp.real(psi1b * jnp.conjugate(psi2b)) - target_coupling)**2)
        loss += local_err
        count += 1
    return loss / (count + 1e-8)

# 定义联合 Loss 函数
# 注意：下面 expect_branch 与 expect_state 接口为伪代码，需要根据实际情况实现
def joint_loss(params, samples, lattice):
    outputs = joint_model_apply(params, samples)  # 输出 shape [batch, n_sites, 2]
    psi_J1 = outputs[..., 0]
    psi_J2 = outputs[..., 1]
    # 伪接口：请根据 NetKet 实现期望能量的接口
    E_J1_obj = vqs_joint.expect_branch(psi_J1, ha_J1)  # 返回含 .mean 属性的对象
    E_J2_obj = vqs_joint.expect_branch(psi_J2, ha_J2)
    L_coupling = local_coupling_loss(outputs, lattice)
    loss1 = E_J1_obj.mean + gamma * L_coupling
    loss2 = E_J2_obj.mean + gamma * L_coupling
    total_loss = loss1 + loss2
    return total_loss, (E_J1_obj.mean, E_J2_obj.mean, L_coupling)

# jit 加速联合梯度计算
joint_loss_and_grad = jit(value_and_grad(joint_loss, has_aux=True))

#########################################
# 定义联合优化器及 QuantumGAN 类
#########################################
optimizer = nk.optimizer.Sgd(learning_rate=0.05)

class QuantumGAN:
    def __init__(self, vqs_joint, ha_J1, ha_J2, ha_full, optimizer, lattice,
                 diag_shift=0.01, temperature=1.0, reference_energy=-16.2618):
        self.vqs_joint = vqs_joint
        self.ha_J1 = ha_J1
        self.ha_J2 = ha_J2
        self.ha_full = ha_full
        self.optimizer = optimizer
        self.lattice = lattice
        self.diag_shift = diag_shift
        self.temperature = temperature
        self.init_temperature = temperature
        self.reference_energy = reference_energy
        
        self.energy_history = []
        self.J1_energy_history = []
        self.J2_energy_history = []
        self.coupling_history = []
    
    def _update_temperature(self, iteration):
        self.temperature = self.init_temperature * (jnp.exp(-iteration / 50.0) / 2.0)
        return self.temperature
    
    def _evaluate_full_energy(self):
        # 使用联合模型输出两个通道合成的波函数计算完整哈密顿量的期望
        outputs = joint_model_apply(self.vqs_joint.parameters, self.vqs_joint.samples)
        psi = (outputs[...,0] + outputs[...,1]) / 2.0
        E_full_obj = vqs_joint.expect_state(psi, self.ha_full)
        return E_full_obj.mean
    
    def run(self, n_iter):
        outer_pbar = tqdm(total=n_iter, desc=f"Quantum GAN Training: Lattice extent {self.lattice.extent}")
        params = self.vqs_joint.parameters
        
        for i in range(n_iter):
            self._update_temperature(i)
            samples = self.vqs_joint.samples
            (loss, (E1, E2, Lcoup)), grads = joint_loss_and_grad(params, samples, self.lattice)
            params = self.optimizer.update(grads, params)
            self.vqs_joint.replace(parameters=params)
            
            self.energy_history.append(np.real(self._evaluate_full_energy()))
            self.J1_energy_history.append(np.real(E1))
            self.J2_energy_history.append(np.real(E2))
            self.coupling_history.append(np.real(Lcoup))
            
            relative_error = abs((np.real(self.energy_history[-1]) - self.reference_energy) / self.reference_energy) * 100
            outer_pbar.set_postfix({
                'E_full': f'{np.real(self.energy_history[-1]):.6f}',
                'J1_E': f'{np.real(E1):.6f}',
                'J2_E': f'{np.real(E2):.6f}',
                'Coupling': f'{np.real(Lcoup):.6f}',
                'RelErr(%)': f'{relative_error:.4f}'
            })
            outer_pbar.update(1)
        outer_pbar.close()
        return self

#########################################
# 初始化并运行 QuantumGAN 优化
#########################################
quantum_gan = QuantumGAN(
    vqs_joint = vqs_joint,
    ha_J1 = ha_J1,
    ha_J2 = ha_J2,
    ha_full = ha_full,
    optimizer = optimizer,
    lattice = lattice,
    diag_shift = 0.01,
    temperature = 1.0,
    reference_energy = -16.2618
)

import time
start = time.time()
quantum_gan.run(n_iter=1000)
end = time.time()
print(f"Total time: {end - start:.2f} sec")

Available devices: [CudaDevice(id=0), CudaDevice(id=1)]


TypeError: cannot unpack non-iterable function object