In [1]:
# TODO:
# Use NetKet to generate groundstate of simple spin model as function of parameter
# Use NetKet to either measure observables, or use other representation of the model (weights) as input

# Use 03_ notebook to implement the 'train on extremes and predict in the middle'

In [2]:
# import jax.numpy as jnp
# import netket as nk
# import flax
# import flax.linen as nn
# from netket.operator.spin import sigmax,sigmaz 

# N=20

# hi = nk.hilbert.Spin(s=1 / 2, N=N)

# Gamma=-1
# H = sum([Gamma*sigmax(hi,i) for i in range(N)])

# V=-1
# H += sum([V*sigmaz(hi,i)*sigmaz(hi,(i+1)%N) for i in range(N)])

# class FFN(nn.Module):
    
#     alpha : int = 1
            
#     @nn.compact
#     def __call__(self, x):

#         # the first layer is a dense layer
#         dense = nn.Dense(features=self.alpha*x.shape[-1])
        
#         y=dense(x)

#         # the non-linearity is a simple ReLu    
#         y=nn.relu(y)
        
#         # sum the output
#         return jnp.sum(y,axis=-1)
    
# model=FFN(alpha=1)

# sampler = nk.sampler.MetropolisLocal(hi)

# vstate = nk.vqs.MCState(sampler, model, n_samples=1000)

# optimizer = nk.optimizer.Sgd(learning_rate=0.1)

# gs = nk.driver.VMC(H, optimizer, variational_state=vstate,preconditioner=nk.optimizer.SR(diag_shift=0.1))

# log=nk.logging.RuntimeLog()
# gs.run(n_iter=300,out=log)

# ffn_energy=vstate.expect(H)

In [3]:
import jax.numpy as jnp
import netket as nk
import flax
import flax.linen as nn
from netket.operator.spin import sigmax,sigmaz 

N=20

hi = nk.hilbert.Spin(s=1 / 2, N=N)

In [4]:
import netket.nn as nknn

graph=nk.graph.Chain(length=N,pbc=True)

class SymmModel(nknn.Module):
    alpha: int

    @nknn.compact
    def __call__(self, x):
        
        x = nknn.DenseSymm(symmetries=graph.translation_group(),
                           features=self.alpha,
                           kernel_init=nk.nn.initializers.normal(stddev=0.01))(x)
        x = nn.relu(x)

        # sum the output
        return jnp.sum(x,axis=(-1,-2))

In [5]:
model=SymmModel(alpha=4)

In [6]:
sampler = nk.sampler.MetropolisLocal(hi)

In [7]:
def get_gs_sample(H, sampler, model, n_samples):
    
    vstate = nk.vqs.MCState(sampler, model, n_samples=1000)

    optimizer = nk.optimizer.Sgd(learning_rate=0.1)

    gs = nk.driver.VMC(H, optimizer, variational_state=vstate,preconditioner=nk.optimizer.SR(diag_shift=0.1))

    log=nk.logging.RuntimeLog()
    gs.run(n_iter=300,out=log)
    
    s = vstate.sample()
    s = s.reshape(-1, s.shape[2])[:n_samples,:]
    
    return s

In [8]:
all_data = {}

for g in jnp.arange(0.1,2,0.1):
    H = sum([-g*sigmax(hi,i) for i in range(N)])

    J=-1
    H += sum([J*sigmaz(hi,i)*sigmaz(hi,(i+1)%N) for i in range(N)])

    s=get_gs_sample(H, sampler, model, 100)

    all_data[str(g)] = s

100%|██████████| 300/300 [00:09<00:00, 32.74it/s, Energy=-20.04999607 ± 0.00000013 [σ²=0.00000000, R̂=1.0092]]
100%|██████████| 300/300 [00:09<00:00, 31.86it/s, Energy=-20.2005811 ± 0.0000035 [σ²=0.0000000, R̂=1.2790]]
100%|██████████| 300/300 [00:10<00:00, 28.87it/s, Energy=-20.45280 ± 0.00030 [σ²=0.00009, R̂=1.0020]]   
100%|██████████| 300/300 [00:09<00:00, 30.91it/s, Energy=-20.80826 ± 0.00018 [σ²=0.00003, R̂=0.9996]]   
100%|██████████| 300/300 [00:10<00:00, 28.48it/s, Energy=-21.2693 ± 0.0010 [σ²=0.0010, R̂=0.9974]]   
100%|██████████| 300/300 [00:10<00:00, 27.60it/s, Energy=-21.84390 ± 0.00045 [σ²=0.00020, R̂=1.0035]]
100%|██████████| 300/300 [00:11<00:00, 27.20it/s, Energy=-22.53651 ± 0.00067 [σ²=0.00045, R̂=0.9991]]
100%|██████████| 300/300 [00:11<00:00, 26.15it/s, Energy=-23.3559 ± 0.0016 [σ²=0.0026, R̂=0.9992]]   
100%|██████████| 300/300 [00:11<00:00, 25.17it/s, Energy=-24.3291 ± 0.0018 [σ²=0.0034, R̂=0.9997]]
100%|██████████| 300/300 [00:12<00:00, 24.34it/s, Energy=-25.490

# Supervised learning

In [9]:
def get_training_data(all_data, Ts, Tc=2.7, train_fraction=0.8):
    # Lists to store the raw data
    raw_T = []
    raw_x = []
    raw_y = [] 
    
    for T in Ts:      
        raw_x.append(all_data['%.3f'%(T)])
        n = len(all_data['%.3f'%(T)])
        label = [1,0] if T < Tc else [0,1]
        raw_y.append(np.array([label] * n))
        raw_T.append(np.array([T]*n))
        
    raw_T = np.concatenate(raw_T)
    raw_x = np.concatenate(raw_x, axis=0)
    raw_y = np.concatenate(raw_y, axis=0)
    
    # Shuffle
    indices = np.random.permutation(len(raw_x))
    all_T = raw_T[indices]
    all_x = raw_x[indices]
    all_y = raw_y[indices]

    # Split into train and test sets
    train_split = int(train_fraction * len(all_x))
    train_T = jnp.array(all_T[:train_split])
    train_x = jnp.array(all_x[:train_split])
    train_y = jnp.array(all_y[:train_split])
    test_T = jnp.array(all_T[train_split:])
    test_x = jnp.array(all_x[train_split:])
    test_y = jnp.array(all_y[train_split:])
    
    return [raw_T, raw_x, raw_y], [train_T, train_x, train_y], [test_T, test_x, test_y]