In [3]:
import sys,os
import numpy as np
from jax import numpy as jnp
import jax
sys.path.append('../reduce_nsp')
from nsp.utils import l2nl

In [5]:
Sz = np.zeros([2,2])
Sz[0,0] = 1/2
Sz[1,1] = -1/2
Sx = np.zeros([2,2])
Sx[1,0] = 1/2
Sx[0,1] = 1/2
Sy = np.zeros([2,2], dtype=np.complex64)
Sy[1,0] = 1j/2
Sy[0,1] = -1j/2

I = np.eye(2)
I4 = np.eye(4)

SzSz = np.kron(Sz,Sz).real.astype(np.float64)
SxSx = np.kron(Sx,Sx).real.astype(np.float64)
SySy = np.kron(Sy,Sy).real.astype(np.float64)
o = np.kron(I, Sz) + np.kron(Sz, I)
Sp = (Sx + 1j*Sy).real
Sm = (Sx - 1j*Sy).real
g = np.kron(Sp, Sm) + np.kron(Sm, Sp)

In [7]:
def proj_symm(x):
    s = int(np.sqrt(x.shape[0]))
    x = x.reshape(s,s,s,s)
    return ((x + np.einsum("ijkl->jilk", x))/2).reshape(s*s, s*s)

In [9]:
np.random.seed(10)
g2 = np.random.rand(4,4)
g2 = (g2 + g2.T)/2
g2 = proj_symm(g2) / 10
# g2 = np.diag(np.diag(g2))
np.save("test5", g2)

g21 = np.random.rand(2,2)
g21 = (g21 + g21.T)/2
g21 /= 10
# g21 = np.diag(np.diag(g21))
np.save("test5_1", g21)

In [12]:
Lx, Ly = [3, 4]
L = Lx * Ly

G = np.zeros((2**L, 2**L))
for bond in [[i, j] for i in range(L) for j in range(L) if i < j]:
    G += l2nl(g, L, bond, sps=2)

G2 = np.zeros((2**L, 2**L))
for bond in [[i, j] for i in range(L) for j in range(L) if i < j]:
    G2 += l2nl(g2, L, bond, sps=2)    
for i in range(L):
    G2 += l2nl(g21, L, [i], sps=2)

M = np.zeros((2**L, 2**L))
for i in range(L):
    M += l2nl(Sz, L, [i], sps=2)


In [13]:
Jz, Jx, Jy, hz, hx = [-1,0.5,0.3,0,0.5]

s = np.arange(L)
x = s%Lx 
y = s//Lx 
T_x = (x+1)%Lx + Lx*y 
T_y = x +Lx*((y+1)%Ly) 

LH = Jz*SzSz + Jx*SxSx + Jy*SySy
LH1 = hz*Sz + hx*Sx
H = np.zeros((2**L, 2**L))
for bond in [[i,T_x[i]] for i in range(L)]+[[i,T_y[i]] for i in range(L)] :
    H += l2nl(LH, L, bond, sps=2)
for i in range(L):
    H += l2nl(-LH1, L, [i], sps=2)
params_dict=dict(Jz=Jz, Jx=Jx,Jy=Jy,hz=hz, hx=hx)


In [14]:
Hjax = jnp.array(H)
Gjax = jnp.array(G)
G2jax = jnp.array(G2)
Mjax = jnp.array(M)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [15]:
E, V = jax.scipy.linalg.eigh(Hjax)

In [16]:
T = 0.5
beta = 1/T
B = jnp.exp(-beta*E)
Z = B.sum()
M_mean = jnp.einsum("n,in,jn,ij->", B, V, V, Mjax, optimize = "greedy") / Z
M2_mean = jnp.einsum("n,in,jn,ij->", B, V, V, Mjax@Mjax, optimize = "greedy") / Z
E_mean = (E*B).sum() / Z
E_square_mean = ((E*E)*B).sum() / Z
G_mean = jnp.einsum("n,in,jn,ij->", B, V, V, Gjax, optimize = "greedy") / Z
G2_mean = jnp.einsum("n,in,jn,ij->", B, V, V, G2jax, optimize = "greedy") / Z

In [17]:
print(f"L = {L}", params_dict)
print(f"T               = {T}")
print(f"E               = {E_mean / L}")
print(f"C               = {(E_square_mean - E_mean**2)*(beta**2)/L}")
print(f"M               = {M_mean}")
print(f"M^2             = {M2_mean}")
print(f"G               = {G_mean}")
print(f"G2              = {G2_mean}")

L = 12 {'Jz': -1, 'Jx': 0.5, 'Jy': 0.3, 'hz': 0, 'hx': 0.5}
T               = 0.5
E               = -0.4478933811187744
C               = 0.5885588526725769
M               = -3.6841049677605042e-06
M^2             = 28.596725463867188
G               = 0.5638363361358643
G2              = 6.6042680740356445
