In [186]:
### IMPORTS ###
# Quantum libraries:
import pennylane as qml
from pennylane import numpy as np
import jax
import jax.numpy as jnp
from jax import jit
from functools import partial

# Plotting
from matplotlib import pyplot as plt

# Other
import copy
from tqdm.notebook import tqdm # Pretty progress bars
from IPython.display import Markdown, display # Better prints
import joblib # Writing and loading
##############


In [187]:
N = 5
dev = qml.device("default.mixed", wires = 5)
dev_jax = qml.device("default.qubit.jax", wires = 5)
J = 1
l_steps = 100

lams = np.linspace(0,2*J,l_steps)

In [298]:
###################
## VQE FUNCTIONS ##
###################

def qml_build_H(N, lam, J):
    '''
    Set up Hamiltonian: 
            H = lam*Σsigma^i_z - J*Σsigma^i_x*sigma^{i+1}_ 
    '''
    # Interaction of spins with magnetic field
    H = lam * qml.PauliZ(0)
    for i in range(1,N):
        H = H + lam * qml.PauliZ(i)
        
    # Interaction between spins:
    for i in range(0,N-1):
        H = H + J*(-1)*( qml.PauliX(i) @ qml.PauliX(i+1) )
    
    return H

def vqe_circuit(N, param, H, shift_invariance = 0):
    for spin in range(N):
        qml.RY(param[spin], wires = spin)
    
    return qml.expval(H)


In [309]:
qcircuit = qml.QNode(vqe_circuit, dev, interface = "jax")
vcircuit = jax.vmap(qcircuit)

In [310]:
Y = np.arange(N)

In [311]:
drawer = qml.draw(qcircuit)
print(drawer(N, Y, qml_build_H(N, 0, J), 0) )

0: ──RY(0.00)─┤ ╭<𝓗>
1: ──RY(1.00)─┤ ├<𝓗>
2: ──RY(2.00)─┤ ├<𝓗>
3: ──RY(3.00)─┤ ├<𝓗>
4: ──RY(4.00)─┤ ╰<𝓗>


In [312]:
fn = lambda x: qcircuit(N, x, qml_build_H(N, 0, J), 0)

In [313]:
a = fn(Y)

In [314]:
a

DeviceArray(-0.78666747, dtype=float32)

In [315]:
v_fn = jax.vmap(fn)

In [316]:
Xs = []

for i in range(2):
    X = jnp.array(np.random.rand(N))
    Xs.append(X)
    
Xs = jnp.array(Xs)

In [317]:
v_fn(Xs)

NotImplementedError: batching rules are implemented only for id_tap, not for call.