In [1]:
import sys,os
import numpy as np
from jax import numpy as jnp
import jax
sys.path.append('../reduce_nsp')
from nsp.utils import l2nl, sum_ham

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
Sz = np.zeros([2,2])
Sz[0,0] = 1/2
Sz[1,1] = -1/2
Sx = np.zeros([2,2])
Sx[1,0] = 1/2
Sx[0,1] = 1/2
Sy = np.zeros([2,2], dtype=np.complex64)
Sy[1,0] = 1j/2
Sy[0,1] = -1j/2

I = np.eye(2)
I4 = np.eye(4)

SzSz = np.kron(Sz,Sz).real.astype(np.float64)
SxSx = np.kron(Sx,Sx).real.astype(np.float64)
SySy = np.kron(Sy,Sy).real.astype(np.float64)
o = np.kron(I, Sz) + np.kron(Sz, I)
Sp = (Sx + 1j*Sy).real
Sm = (Sx - 1j*Sy).real
g = np.kron(Sp, Sm) + np.kron(Sm, Sp)

In [3]:
def proj_symm(x):
    s = int(np.sqrt(x.shape[0]))
    x = x.reshape(s,s,s,s)
    return ((x + np.einsum("ijkl->jilk", x))/2).reshape(s*s, s*s)

In [4]:
np.random.seed(10)
g2 = np.random.rand(4,4)
g2 = (g2 + g2.T)/2
g2 = proj_symm(g2) / 10
# g2 = np.diag(np.diag(g2))
np.save("test5", g2)

g21 = np.random.rand(2,2)
g21 = (g21 + g21.T)/2
g21 /= 10
# g21 = np.diag(np.diag(g21))
np.save("test5_1", g21)

In [5]:
L = 10

G = np.zeros((2**L, 2**L))
for bond in [[i, j] for i in range(L) for j in range(L) if i < j]:
    G += l2nl(g, L, bond, sps=2)

G2 = np.zeros((2**L, 2**L))
for bond in [[i, j] for i in range(L) for j in range(L) if i < j]:
    G2 += l2nl(g2, L, bond, sps=2)    
for i in range(L):
    G2 += l2nl(g21, L, [i], sps=2)

M = np.zeros((2**L, 2**L))
for i in range(L):
    M += l2nl(Sz, L, [i], sps=2)


In [6]:
Jz, Jx, Jy, hz, hx = [-1,0.5,0.3,0,0.5]
LH = Jz*SzSz + Jx*SxSx + Jy*SySy
LH1 = hz*Sz + hx*Sx
H = np.zeros((2**L, 2**L))
for bond in [[i,(i+1)%L] for i in range(L)]:
    H += l2nl(LH, L, bond, sps=2)
for i in range(L):
    H += l2nl(-LH1, L, [i], sps=2)
params_dict=dict(Jz=Jz, Jx=Jx,Jy=Jy,hz=hz, hx=hx)


In [7]:
Hjax = jnp.array(H)
Gjax = jnp.array(G)
G2jax = jnp.array(G2)
Mjax = jnp.array(M)

In [8]:
E, V = jax.scipy.linalg.eigh(Hjax)

In [9]:
T = 0.5
beta = 1/T
B = jnp.exp(-beta*E)
Z = B.sum()
M_mean = jnp.einsum("n,in,jn,ij->", B, V, V, Mjax, optimize = "greedy") / Z
M2_mean = jnp.einsum("n,in,jn,ij->", B, V, V, Mjax@Mjax, optimize = "greedy") / Z
E_mean = (E*B).sum() / Z
E_square_mean = ((E*E)*B).sum() / Z
G_mean = jnp.einsum("n,in,jn,ij->", B, V, V, Gjax, optimize = "greedy") / Z
G2_mean = jnp.einsum("n,in,jn,ij->", B, V, V, G2jax, optimize = "greedy") / Z

In [10]:
print(f"L = {L}", params_dict)
print(f"T               = {T}")
print(f"E               = {E_mean / L}")
print(f"C               = {(E_square_mean - E_mean**2)*(beta**2)/L}")
print(f"M               = {M_mean}")
print(f"M^2             = {M2_mean}")
print(f"G               = {G_mean}")
print(f"G2              = {G2_mean}")

L = 10 {'Jz': -1, 'Jx': 0.5, 'Jy': 0.3, 'hz': 0, 'hx': 0.5}
T               = 0.5
E               = -0.17920081317424774
C               = 0.19045209884643555
M               = -6.4285181906598154e-06
M^2             = 6.039064884185791
G               = 1.106979489326477
G2              = 5.013575077056885


### 2sites versions

In [11]:
LL = int(L/2)

In [67]:
Jz, Jx, Jy, hz, hx = [-1,-0.5,-0.3,0,0.5]
LH = Jz*SzSz + Jx*SxSx + Jy*SySy
LH1 = hz*Sz + hx*Sx
H_ = sum_ham(LH, [[1,2]], 4, 2)
H_ += sum_ham(LH/2, [[0,1],[2,3]], 4, 2)
LH = H_
LH1_2 = np.kron(LH1, I) + np.kron(I, LH1)
H2 = np.zeros((4**LL, 4**LL))
# for bond in [[i,(i+1)%L] for i in range(L)]:
H2 += sum_ham(LH, [[i,(i+1)%LL] for i in range(LL)],LL, 4)
H2 += sum_ham(-LH1_2, [[i] for i in range(LL)],LL, 4)


In [68]:
g_2site = sum_ham(g, [[0,2],[1,2],[0,3],[1,3]], 4, 2)
np.save("2site", g_2site)
g_2site_single = g
np.save("2site_single", g_2site_single)

In [92]:
G = np.zeros((4**LL, 4**LL))
G = sum_ham(g_2site, [[i, j] for i in range(LL) for j in range(LL) if i < j], LL, 4)
G += sum_ham(g_2site_single, [[i] for i in range(LL)], LL, 4)

In [119]:
np.random.seed(10)
g2 = np.random.rand(16,16)
g2 = (g2 + g2.T)/2
g2 = proj_symm(g2) / 10
np.save("2site_2", g2)

g2_single = np.random.rand(4,4)
g2_single = (g2_single + g2_single.T)/2
np.save("2site_2_single", g2_single)


In [114]:
G2 = np.zeros((4**LL, 4**LL))
G2 = sum_ham(g2, [[i, j] for i in range(LL) for j in range(LL) if i < j], LL, 4)
G2 += sum_ham(g2_single, [[i] for i in range(LL)], LL, 4)

In [115]:
Hjax = jnp.array(H2)
Gjax = jnp.array(G)
G2jax = jnp.array(G2)

# G2jax = jnp.array(G2)
# Mjax = jnp.array(M)

In [116]:
E, V = jax.scipy.linalg.eigh(Hjax)

In [117]:
T = 0.5
beta = 1/T
B = jnp.exp(-beta*E)
Z = B.sum()
E_mean = (E*B).sum() / Z
E_square_mean = ((E*E)*B).sum() / Z
G_mean = jnp.einsum("n,in,jn,ij->", B, V, V, Gjax, optimize = "greedy") / Z
G2_mean = jnp.einsum("n,in,jn,ij->", B, V, V, G2jax, optimize = "greedy") / Z

In [118]:
print(f"L = {L}", params_dict)
print(f"T               = {T}")
print(f"E               = {E_mean / L}")
print(f"C               = {(E_square_mean - E_mean**2)*(beta**2)/L}")
# print(f"M               = {M_mean}")
# print(f"M^2             = {M2_mean}")
print(f"G               = {G_mean}")
print(f"G2              = {G2_mean}")

L = 10 {'Jz': -1, 'Jx': 0.5, 'Jy': 0.3, 'hz': 0, 'hx': 0.5}
T               = 0.5
E               = -0.2736138701438904
C               = 0.3662477433681488
G               = 7.666713714599609
G2              = 10.180668830871582
