In [1]:
from jax_ARNN_helper import *

In [2]:
import netket as nk

from netket import jax as nkjax
import jax.numpy as jnp

from tqdm.notebook import tqdm

import jax

import pennylane as qml

import numpy as np

import matplotlib.pyplot as plt

import pandas as pd

In [3]:
%config Completer.use_jedi = False

In [4]:
num_qubits = 2
depth = 4

## Define TFIM Hamiltonian and obtain exact GS energy

In [6]:
def TFIM_Hamiltonian(coeffs_ZZ, coeffs_X):
    Z = np.array([[1,0],[0, -1]])
    X = np.array([[0,1],[1,0]])
    one = np.eye(2)
    dim = 2**(len(coeffs_X))
    n_qubit = len(coeffs_X)
    H = np.zeros((dim, dim))
    for i in range(len(coeffs_X)):
        op1 = [Z]*((i+1)%n_qubit == 0) + [one]*(min(i, n_qubit-2)) + [Z] + [Z]*((i+1)%n_qubit != 0) + [one]*(n_qubit - i -2) 
        op2 = [one]*i + [X] + [one]*(n_qubit-i-1)

        M = 1
        for O in op1:
            M = np.kron(M, O)
        H += M*coeffs_ZZ[i]
            
        M = 1
        for O in op2:
            M = np.kron(M, O)
        H += M*coeffs_X[i]
        
    return H
        

coeffs_ZZ = [1.0]*num_qubits
coeffs_X = [1.0]*num_qubits
obs_ZZ = [qml.PauliZ(i)@qml.PauliZ((i+1)%num_qubits) for i in range(num_qubits)]
obs_X = [qml.PauliX(i) for i in range(num_qubits)]

H_qml = qml.Hamiltonian(coeffs_ZZ + coeffs_X, obs_ZZ + obs_X)
H = TFIM_Hamiltonian(coeffs_ZZ, coeffs_X)  

eig, v = np.linalg.eig(H)
E0 = eig.min()

print("min energy: ", E0)

min energy:  -2.8284271247461894


## Define QC in pennylane with TFIM Hamiltonian

In [7]:
dev = qml.device("default.qubit", wires=num_qubits)
gate_set = [qml.RX, qml.RY, qml.RZ, qml.CNOT, 0]


def circuit(action_sequence, num_qubits=None, depth = None, params=None, wires=0):
    """
    For now no parameters implemented
    """
    
    params = [np.pi/4] # For the moment params are fixed
    
    for i in range(depth):
        for j in range(num_qubits):
            action = action_sequence[0][j][i]
            if gate_set[action] != 0: # Skip identity
                if action == 3:
                    """This is CNOT (2 qubit gate)"""
                    gate_set[action](wires=[j%num_qubits, (j+1)%num_qubits])
                else:
                    gate_set[action](params[0], wires= j%num_qubits)

cost_fn = qml.ExpvalCost(circuit, H_qml, dev)

## Define loss function 

`include_energy` can be used to turn off the energy part of the loss function. If set to `0.0` we only maximize the energy and end up in an arbitrary "uniform" superposition state.

In [14]:
def loss(variables, samples, measure, T=1., include_energy = 1.0):
    log_p = vs._apply_fun(variables, samples)
    avg = log_p*jnp.array(measure)
    return include_energy*avg.mean() + (T*log_p).mean()

def get_sample(vs = None):
    S = vs.sample()
    S = S.reshape(-1, L)
    return jax.lax.stop_gradient(S)

In [15]:
L = num_qubits*depth

hi = nk.hilbert.Spin(s=2, N=L)
ma = ARNNConv1D(hilbert=hi, layers=2, features=10, kernel_size=10) #NN model
sa = nk.sampler.ARDirectSampler(hi) # Sampler
op = nk.optimizer.Sgd(learning_rate=0.1) # Optimizer

vs = nk.vqs.MCState(sa, ma, n_samples=100) # Variational State

grad = nkjax.grad(loss)

df = pd.DataFrame(columns=["Temp_factor", 'unique_samples', 'good_solutions', "unique_good_solutions", "samples", "measures"])

alpha = .050 # Learning rate
Temp = 0.1
include_energy = 1.0

ma = ARNNConv1D(hilbert=hi, layers=2, features=10, kernel_size=10) #NN model
vs = nk.vqs.MCState(sa, ma, n_samples=100) # Variational State

for i in tqdm(range(101)):
    s = get_sample(vs = vs)
    S = (s + 4)/2
    S = S.astype(int)
    measure = [cost_fn(S[i].reshape(1, num_qubits, depth), num_qubits = num_qubits, depth = depth).item() for i in range(len(S))]
    grads = grad(vs.variables, s, measure, T=Temp, include_energy=include_energy)
    vs.variables = jax.tree_multimap(lambda p, g: p - alpha * g,
                            vs.variables, grads)


s = get_sample(vs = vs)
S = (s + 4)/2
S = S.astype(int)

measure = [cost_fn(S[i].reshape(1, num_qubits, depth), num_qubits = num_qubits, depth=depth).item() for i in range(len(S))]
(idx,) = np.where(np.array(measure) < -1.99)

new_row = pd.Series(data={"Temp_factor":Temp, 'unique_samples': np.unique(S, axis=0).shape[0], 'good_solutions':idx, "unique_good_solutions": np.unique(S[idx], axis=0).shape[0], "samples": s, "measures":measure}, name='{}'.format(j))
df = df.append(new_row, ignore_index= False)

  0%|          | 0/101 [00:00<?, ?it/s]

0 0.1


In [16]:
df

Unnamed: 0,Temp_factor,unique_samples,good_solutions,unique_good_solutions,samples,measures
0,0.1,1,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",1,"[[-4.0, -4.0, -4.0, -4.0, 4.0, 4.0, 0.0, 2.0],...","[-2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2.0, -2...."
