In [None]:
import os
import pickle
import numpy as np
import jax
import jax.numpy as jnp
import netket as nk
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from phy_models.shastry_sutherland import shastry_sutherland_lattice, shastry_sutherland_hamiltonian
import flax

# 设置中文字体支持
plt.rcParams['font.sans-serif'] = ['SimHei']  # 使用黑体
plt.rcParams['axes.unicode_minus'] = False    # 解决负号显示问题

def load_quantum_state(pkl_file, L, J1, J2):
    """加载训练好的量子态参数并重构量子态"""
    # 构建模型所需的组件
    lattice = shastry_sutherland_lattice(L, L)
    Q = 1.0 - J2
    ha, hi = shastry_sutherland_hamiltonian(lattice, J1, J2, Q)
    N = lattice.n_nodes
    
    # 配置采样器
    sampler = nk.sampler.MetropolisExchange(
        hilbert=hi, 
        graph=lattice, 
        n_chains=2**10, 
        d_max=2
    )
    
    # 加载GCNN模型参数
    with open(pkl_file, "rb") as f:
        parameters = pickle.load(f)
    
    # 重建模型结构
    # 注意：这里需要与训练时使用相同的模型参数
    local_cluster = jnp.arange(L * L * 4).tolist()
    mask = jnp.zeros(lattice.n_nodes, dtype=bool)
    for i in local_cluster:
        mask = mask.at[i].set(True)
    
    from netket.utils.group.planar import rotation, glide_group
    from netket.utils.group import PointGroup, Identity
    
    # 定义晶格对称性
    nc = 4
    cyclic_4 = PointGroup(
        [Identity()] + [rotation((360 / nc) * i) for i in range(1, nc)],
        ndim=2,
    )
    C4v = glide_group(trans=(1, 1), origin=(0, 0)) @ cyclic_4
    symmetries = lattice.space_group(C4v)
    sgb = lattice.space_group_builder(point_group=C4v)
    
    momentum = [0.0, 0.0]
    chi = sgb.space_group_irreps(momentum)[0]
    
    # 重建GCNN模型
    model = nk.models.GCNN(
        symmetries=symmetries,
        layers=4,  # 需要与训练时相同
        param_dtype=np.complex128,
        features=4,  # 需要与训练时相同
        equal_amplitudes=False,
        parity=1,
        input_mask=mask,
        characters=chi
    )
    
    # 重建变分量子态
    vqs = nk.vqs.MCState(
        sampler=sampler,
        model=model,
        n_samples=2**20,  # 使用更多样本以提高统计精度
        n_discard_per_chain=100,
        chunk_size=2**10,
    )
    
    # 加载保存的参数
    vqs.parameters = parameters
    
    print(f"成功加载量子态：L={L}, J1={J1:.2f}, J2={J2:.2f}")
    return vqs, lattice, hi, ha

def compute_spin_structure_factor(vqs, lattice, L):
    """计算自旋结构因子 S(k)"""
    print("计算自旋结构因子...")
    
    # 定义k点网格
    k_mesh_size = 50
    kx = np.linspace(-np.pi, np.pi, k_mesh_size)
    ky = np.linspace(-np.pi, np.pi, k_mesh_size)
    
    # 获取格点位置
    positions = lattice.sites
    N_sites = lattice.n_nodes
    
    # 初始化结构因子矩阵
    S_k = np.zeros((k_mesh_size, k_mesh_size))
    
    # 计算所有自旋关联函数 <S_i·S_j>
    corr_data = {}
    for i in range(N_sites):
        for j in range(N_sites):
            if i <= j:  # 只计算上三角部分，利用对称性
                # 创建操作符 S_i·S_j = S^x_i·S^x_j + S^y_i·S^y_j + S^z_i·S^z_j
                sx_i_sx_j = nk.operator.spin.sigmax(vqs.hilbert, i) @ nk.operator.spin.sigmax(vqs.hilbert, j)
                sy_i_sy_j = nk.operator.spin.sigmay(vqs.hilbert, i) @ nk.operator.spin.sigmay(vqs.hilbert, j)
                sz_i_sz_j = nk.operator.spin.sigmaz(vqs.hilbert, i) @ nk.operator.spin.sigmaz(vqs.hilbert, j)
                si_sj = sx_i_sx_j + sy_i_sy_j + sz_i_sz_j
                
                # 计算期望值
                corr_value = vqs.expect(si_sj).mean.real
                corr_data[(i, j)] = corr_value
                if i != j:
                    corr_data[(j, i)] = corr_value  # 利用对称性
    
    # 计算结构因子
    for ki in range(k_mesh_size):
        for kj in range(k_mesh_size):
            k_vec = np.array([kx[ki], ky[kj]])
            s_k_val = 0.0
            
            for i in range(N_sites):
                pos_i = np.array(positions[i])
                for j in range(N_sites):
                    pos_j = np.array(positions[j])
                    r_ij = pos_j - pos_i
                    phase = np.exp(1j * np.dot(k_vec, r_ij))
                    s_k_val += corr_data[(i, j)] * phase.real
            
            S_k[ki, kj] = s_k_val / N_sites
    
    # 保存数据
    os.makedirs("analysis_results", exist_ok=True)
    np.savetxt(f"analysis_results/spin_structure_factor_L{L}_J1_{vqs.J1:.2f}_J2_{vqs.J2:.2f}.dat", 
              np.column_stack((kx.reshape(-1, 1), ky.reshape(-1, 1), S_k.reshape(-1, 1))))
              
    return kx, ky, S_k

def compute_plaquette_structure_factor(vqs, lattice, L):
    """计算Plaquette结构因子"""
    print("计算Plaquette结构因子...")
    
    # 定义k点网格
    k_mesh_size = 50
    kx = np.linspace(-np.pi, np.pi, k_mesh_size)
    ky = np.linspace(-np.pi, np.pi, k_mesh_size)
    
    # 获取格点位置
    positions = lattice.sites
    N_sites = lattice.n_nodes
    
    # 初始化结构因子矩阵
    P_k = np.zeros((k_mesh_size, k_mesh_size))
    
    # 定义Plaquette操作符
    plaq_ops = []
    plaq_pos = []
    
    # 处理每个单元格中的plaquette
    for x in range(L):
        for y in range(L):
            base = 4 * (y + x * L)
            site0, site1, site2, site3 = base, base + 1, base + 2, base + 3
            
            # Plaquette操作符 P_r = S_0·S_1 + S_1·S_2 + S_2·S_3 + S_3·S_0
            edges = [(site0, site1), (site1, site2), (site2, site3), (site3, site0)]
            plaq_op = nk.operator.LocalOperator(vqs.hilbert)
            
            for i, j in edges:
                sx_i_sx_j = nk.operator.spin.sigmax(vqs.hilbert, i) @ nk.operator.spin.sigmax(vqs.hilbert, j)
                sy_i_sy_j = nk.operator.spin.sigmay(vqs.hilbert, i) @ nk.operator.spin.sigmay(vqs.hilbert, j)
                sz_i_sz_j = nk.operator.spin.sigmaz(vqs.hilbert, i) @ nk.operator.spin.sigmaz(vqs.hilbert, j)
                plaq_op += sx_i_sx_j + sy_i_sy_j + sz_i_sz_j
            
            plaq_ops.append(plaq_op)
            # 使用plaquette中心作为位置
            plaq_center = np.mean([positions[site0], positions[site1], 
                                   positions[site2], positions[site3]], axis=0)
            plaq_pos.append(plaq_center)
    
    # 计算所有plaquette关联函数 <P_i·P_j>
    N_plaq = len(plaq_ops)
    plaq_corr = np.zeros((N_plaq, N_plaq))
    
    for i in range(N_plaq):
        P_i = plaq_ops[i]
        ev_i = vqs.expect(P_i).mean.real
        for j in range(i, N_plaq):
            P_j = plaq_ops[j]
            ev_j = vqs.expect(P_j).mean.real
            P_i_P_j = P_i @ P_j
            ev_ij = vqs.expect(P_i_P_j).mean.real
            
            # 计算关联函数 C(r) = <P_i·P_j> - <P_i><P_j>
            plaq_corr[i, j] = ev_ij - ev_i * ev_j
            if i != j:
                plaq_corr[j, i] = plaq_corr[i, j]  # 利用对称性
    
    # 计算结构因子
    for ki in range(k_mesh_size):
        for kj in range(k_mesh_size):
            k_vec = np.array([kx[ki], ky[kj]])
            p_k_val = 0.0
            
            for i in range(N_plaq):
                pos_i = np.array(plaq_pos[i])
                for j in range(N_plaq):
                    pos_j = np.array(plaq_pos[j])
                    r_ij = pos_j - pos_i
                    phase = np.exp(1j * np.dot(k_vec, r_ij))
                    p_k_val += plaq_corr[i, j] * phase.real
            
            P_k[ki, kj] = p_k_val / N_plaq
    
    # 保存数据
    np.savetxt(f"analysis_results/plaquette_structure_factor_L{L}_J1_{vqs.J1:.2f}_J2_{vqs.J2:.2f}.dat", 
              np.column_stack((kx.reshape(-1, 1), ky.reshape(-1, 1), P_k.reshape(-1, 1))))
              
    return kx, ky, P_k

def compute_correlation_ratio(kx, ky, S_k, peak_pos, name="Néel", delta_k=1):
    """计算关联比"""
    print(f"计算{name}关联比...")
    
    # 找到k点网格中与peak_pos最接近的索引
    kx_idx = np.argmin(np.abs(kx - peak_pos[0]))
    ky_idx = np.argmin(np.abs(ky - peak_pos[1]))
    
    # 获取峰值强度
    peak_intensity = S_k[kx_idx, ky_idx]
    
    # 获取峰值附近delta_k处的强度
    nearby_intensity = (S_k[kx_idx+delta_k, ky_idx] + 
                       S_k[kx_idx-delta_k, ky_idx] + 
                       S_k[kx_idx, ky_idx+delta_k] + 
                       S_k[kx_idx, ky_idx-delta_k]) / 4.0
    
    # 计算关联比 R = 1 - S(k+δk)/S(k)
    corr_ratio = 1.0 - nearby_intensity / peak_intensity
    
    return corr_ratio

def compute_dimer_dimer_correlation(vqs, lattice, L):
    """计算二聚体-二聚体关联函数"""
    print("计算二聚体-二聚体关联函数...")
    
    # 获取格点位置
    positions = lattice.sites
    N_sites = lattice.n_nodes
    
    # 定义水平二聚体
    dimers = []
    dimer_pos = []
    
    # 找出所有可能的水平二聚体
    edges = lattice.edges()
    for i, j, _ in edges:
        # 检查是否为水平二聚体（x坐标差约为1，y坐标几乎相同）
        pos_i = positions[i]
        pos_j = positions[j]
        dx = abs(pos_i[0] - pos_j[0])
        dy = abs(pos_i[1] - pos_j[1])
        
        if abs(dx - 1.0) < 0.1 and dy < 0.1:
            dimers.append((i, j))
            dimer_pos.append(np.mean([pos_i, pos_j], axis=0))
    
    # 计算所有二聚体-二聚体关联函数
    N_dimers = len(dimers)
    dimer_corr = np.zeros((N_dimers, N_dimers))
    dimer_distances = np.zeros((N_dimers, N_dimers))
    
    for idx1, (i, j) in enumerate(dimers):
        # 二聚体操作符 D_ij = S_i·S_j
        sx_i_sx_j = nk.operator.spin.sigmax(vqs.hilbert, i) @ nk.operator.spin.sigmax(vqs.hilbert, j)
        sy_i_sy_j = nk.operator.spin.sigmay(vqs.hilbert, i) @ nk.operator.spin.sigmay(vqs.hilbert, j)
        sz_i_sz_j = nk.operator.spin.sigmaz(vqs.hilbert, i) @ nk.operator.spin.sigmaz(vqs.hilbert, j)
        D_1 = sx_i_sx_j + sy_i_sy_j + sz_i_sz_j
        
        ev_1 = vqs.expect(D_1).mean.real
        
        for idx2, (k, l) in enumerate(dimers):
            if idx1 <= idx2:  # 只计算上三角部分，利用对称性
                # 二聚体操作符 D_kl = S_k·S_l
                sx_k_sx_l = nk.operator.spin.sigmax(vqs.hilbert, k) @ nk.operator.spin.sigmax(vqs.hilbert, l)
                sy_k_sy_l = nk.operator.spin.sigmay(vqs.hilbert, k) @ nk.operator.spin.sigmay(vqs.hilbert, l)
                sz_k_sz_l = nk.operator.spin.sigmaz(vqs.hilbert, k) @ nk.operator.spin.sigmaz(vqs.hilbert, l)
                D_2 = sx_k_sx_l + sy_k_sy_l + sz_k_sz_l
                
                ev_2 = vqs.expect(D_2).mean.real
                D_1_D_2 = D_1 @ D_2
                ev_12 = vqs.expect(D_1_D_2).mean.real
                
                # 计算关联函数 C_d(r) = <D_1·D_2> - <D_1><D_2>
                dimer_corr[idx1, idx2] = ev_12 - ev_1 * ev_2
                if idx1 != idx2:
                    dimer_corr[idx2, idx1] = dimer_corr[idx1, idx2]  # 利用对称性
                
                # 计算二聚体之间的距离
                dist = np.linalg.norm(np.array(dimer_pos[idx1]) - np.array(dimer_pos[idx2]))
                dimer_distances[idx1, idx2] = dist
                dimer_distances[idx2, idx1] = dist
    
    # 将相似距离的关联函数进行平均
    unique_dists = np.unique(np.round(dimer_distances.flatten(), 1))
    avg_corr = []
    
    for dist in unique_dists:
        mask = (np.abs(dimer_distances - dist) < 0.1)
        avg_val = np.mean(dimer_corr[mask])
        avg_corr.append((dist, avg_val))
    
    avg_corr = np.array(avg_corr)
    
    # 按距离排序
    sort_idx = np.argsort(avg_corr[:, 0])
    avg_corr = avg_corr[sort_idx]
    
    # 保存数据
    np.savetxt(f"analysis_results/dimer_dimer_correlation_L{L}_J1_{vqs.J1:.2f}_J2_{vqs.J2:.2f}.dat", avg_corr)
    
    return avg_corr

def analyze_quantum_state(pkl_file, L, J1, J2):
    """分析量子态并计算序参量"""
    # 加载量子态
    vqs, lattice, hi, ha = load_quantum_state(pkl_file, L, J1, J2)
    
    # 添加参数到量子态对象以便后续使用
    vqs.J1 = J1
    vqs.J2 = J2
    
    # 计算各种序参量
    # 1. 自旋结构因子
    kx, ky, S_k = compute_spin_structure_factor(vqs, lattice, L)
    
    # 2. Plaquette结构因子
    kx_p, ky_p, P_k = compute_plaquette_structure_factor(vqs, lattice, L)
    
    # 3. 计算关联比
    neel_ratio = compute_correlation_ratio(kx, ky, S_k, (np.pi, np.pi), name="Néel")
    plaq_ratio = compute_correlation_ratio(kx_p, ky_p, P_k, (0, np.pi), name="Plaquette")
    
    print(f"Néel关联比: {neel_ratio:.4f}")
    print(f"Plaquette关联比: {plaq_ratio:.4f}")
    
    # 4. 二聚体-二聚体关联函数
    dimer_corr = compute_dimer_dimer_correlation(vqs, lattice, L)
    
    # 保存关联比
    with open(f"analysis_results/correlation_ratios_L{L}_J1_{J1:.2f}_J2_{J2:.2f}.txt", "w") as f:
        f.write(f"Néel关联比: {neel_ratio:.6f}\n")
        f.write(f"Plaquette关联比: {plaq_ratio:.6f}\n")
    
    return kx, ky, S_k, kx_p, ky_p, P_k, neel_ratio, plaq_ratio, dimer_corr

def plot_spin_structure_factor(kx, ky, S_k, L, J1, J2):
    """绘制自旋结构因子热图"""
    plt.figure(figsize=(10, 8))
    
    # 创建一个从蓝色到红色的颜色映射
    colors = [(0, 0, 1), (1, 1, 1), (1, 0, 0)]
    cmap_name = 'blue_white_red'
    cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=256)
    
    # 使用pcolormesh而非imshow以确保轴标签正确
    X, Y = np.meshgrid(kx, ky)
    plt.pcolormesh(X, Y, S_k.T, cmap=cm, shading='auto')
    plt.colorbar(label='S(k)')
    
    # 设置轴标签和标题
    plt.xlabel('$k_x$', fontsize=14)
    plt.ylabel('$k_y$', fontsize=14)
    plt.title(f'Spin Structure Factor - L={L}, J1={J1:.2f}, J2={J2:.2f}', fontsize=16)
    
    # 设置刻度
    plt.xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi], 
              [r'$-\pi$', r'$-\pi/2$', r'$0$', r'$\pi/2$', r'$\pi$'])
    plt.yticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi], 
              [r'$-\pi$', r'$-\pi/2$', r'$0$', r'$\pi/2$', r'$\pi$'])
    
    # 保存图片
    plt.tight_layout()
    plt.savefig(f"analysis_results/spin_structure_factor_L{L}_J1_{J1:.2f}_J2_{J2:.2f}.png", dpi=300)
    plt.close()

def plot_plaquette_structure_factor(kx, ky, P_k, L, J1, J2):
    """绘制Plaquette结构因子热图"""
    plt.figure(figsize=(10, 8))
    
    # 创建颜色映射
    colors = [(0, 0, 1), (1, 1, 1), (1, 0, 0)]
    cmap_name = 'blue_white_red'
    cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=256)
    
    # 绘制热图
    X, Y = np.meshgrid(kx, ky)
    plt.pcolormesh(X, Y, P_k.T, cmap=cm, shading='auto')
    plt.colorbar(label='P(k)')
    
    # 设置轴标签和标题
    plt.xlabel('$k_x$', fontsize=14)
    plt.ylabel('$k_y$', fontsize=14)
    plt.title(f'Plaquette Structure Factor - L={L}, J1={J1:.2f}, J2={J2:.2f}', fontsize=16)
    
    # 设置刻度
    plt.xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi], 
              [r'$-\pi$', r'$-\pi/2$', r'$0$', r'$\pi/2$', r'$\pi$'])
    plt.yticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi], 
              [r'$-\pi$', r'$-\pi/2$', r'$0$', r'$\pi/2$', r'$\pi$'])
    
    # 保存图片
    plt.tight_layout()
    plt.savefig(f"analysis_results/plaquette_structure_factor_L{L}_J1_{J1:.2f}_J2_{J2:.2f}.png", dpi=300)
    plt.close()

def plot_dimer_correlation(dimer_corr, L, J1, J2):
    """绘制二聚体-二聚体关联函数"""
    plt.figure(figsize=(10, 6))
    
    # 绘制关联函数
    plt.plot(dimer_corr[:, 0], dimer_corr[:, 1], 'o-', color='blue', linewidth=2, markersize=8)
    
    # 设置轴标签和标题
    plt.xlabel('Distance', fontsize=14)
    plt.ylabel('Dimer-Dimer Correlation', fontsize=14)
    plt.title(f'Dimer-Dimer Correlation - L={L}, J1={J1:.2f}, J2={J2:.2f}', fontsize=16)
    
    # 添加网格
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # 保存图片
    plt.tight_layout()
    plt.savefig(f"analysis_results/dimer_correlation_L{L}_J1_{J1:.2f}_J2_{J2:.2f}.png", dpi=300)
    plt.close()

# 主执行函数
def main():
    # 设置参数
    L = 5  # 晶格大小
    J2 = 0.05  # J2参数
    J1_list = [0.05, 0.06]  # J1参数列表
    
    for J1 in J1_list:
        # 构建文件路径
        result_dir = f"results/L={L}/J2={J2:.2f}/J1={J1:.2f}"
        pkl_file = os.path.join(result_dir, f"GCNN_L={L}_J2={J2:.2f}_J1={J1:.2f}.pkl")
        
        if not os.path.exists(pkl_file):
            print(f"警告：找不到文件 {pkl_file}")
            continue
        
        print(f"分析量子态：L={L}, J1={J1}, J2={J2}")
        
        # 分析量子态
        kx, ky, S_k, kx_p, ky_p, P_k, neel_ratio, plaq_ratio, dimer_corr = analyze_quantum_state(pkl_file, L, J1, J2)
        
        # 绘制结果
        plot_spin_structure_factor(kx, ky, S_k, L, J1, J2)
        plot_plaquette_structure_factor(kx_p, ky_p, P_k, L, J1, J2)
        plot_dimer_correlation(dimer_corr, L, J1, J2)
        
        print(f"完成 L={L}, J1={J1}, J2={J2} 的分析\n")

if __name__ == "__main__":
    main()


ModuleNotFoundError: No module named 'phy_models'