In [None]:
#H.8.1a
def exact_result_XandZ(alpha, beta, time):
    root = np.sqrt(alpha**2 + beta**2)
    c_0 = np.cos(root*time) - (alpha/root)*np.sin(root*time)*1.j
    c_1 = -(beta/root)*np.sin(root*time)*1.j
    return np.array([c_0, c_1])
    
@qml.qnode(dev)
def trotter_XandZ(alpha, beta, time, n):
    t = time
    for _ in range(n):
	    qml.PauliRot(2 * t/n * alpha, "Z",wires=0)
	    qml.PauliRot(2 * t/n * beta, "X",wires=0)

    return qml.state()

def trotter_error_XandZ(alpha, beta, time, n):
    diff = np.abs(trotter_XandZ(alpha, beta, time, n) - exact_result_XandZ(alpha, beta, time))
    return np.sqrt(sum(diff*diff))

In [None]:
#H.8.1b
@qml.qnode(dev)
def trotter_2_XandZ(alpha, beta, time, n):
    t = time
    for _ in range(n):
	    qml.PauliRot(2 * t/(2*n) * alpha, "Z", wires=0)
	    qml.PauliRot(2 * t/(2*n) * beta, "X", wires=0)
	    qml.PauliRot(2 * t/(2*n) * beta, "X", wires=0)
	    qml.PauliRot(2 * t/(2*n) * alpha, "Z", wires=0)
    return qml.state()

def trotter_2_error_XandZ(alpha, beta, time, n):
    diff = np.abs(trotter_2_XandZ(alpha, beta, time, n) - exact_result_XandZ(alpha, beta, time))
    return np.sqrt(sum(diff*diff))

In [None]:
#H.8.1c
@qml.qnode(dev)
def trotter_k_XandZ(alpha, beta, time, n, k):
    def U(alpha, beta, time, n, k):
        if k == 1:
            qml.RZ(alpha*time/n, wires=[0])
            qml.RX(2*beta*time/n, wires=[0])
            qml.RZ(alpha*time/n, wires=[0])
        else:
            s = 1 / (4 - 4 ** (1 / (2 * k -1)))
            U(alpha, beta, s * time, n, k -1)
            U(alpha, beta, s * time, n, k -1)
            U(alpha, beta, (1 - 4 * s) * time, n , k-1)
            U(alpha, beta, s * time, n ,k-1)
            U(alpha, beta, s * time, n, k -1)
            pass
            
    for _ in range(n):
        U(alpha, beta, time, n, k)
    return qml.state()

def trotter_k_error_XandZ(alpha, beta, time, n, k):
    diff = np.abs(trotter_k_XandZ(alpha, beta, time, n, k) - exact_result_XandZ(alpha, beta, time))
    return np.sqrt(sum(diff*diff))

In [None]:
#H.8.1d
def trotter_steps_XandZ(alpha, beta, time, error, k):
    n = 1
    while True:
	    e = trotter_k_error_XandZ(alpha, beta, time,n,k)
	    if e <= error:
		    break
	    n+= 1
    return n

error = 1e-6
optimal_k = 3
n = trotter_steps_XandZ(1, 1, 1, error, optimal_k)
depth = qml.specs(trotter_k_XandZ)(1, 1, 1, n, optimal_k)['depth']
print("The Trotter circuit of order", 2*optimal_k, "uses a circuit of depth", depth, "gates to achieve error ε =", error, ".")

In [None]:
#H.8.2a
def truncation_XandZ(alpha, beta, time, K_bits):
    root = np.sqrt(alpha**2 + beta**2)
    coeff_list = [0]*2**K_bits
    U_list = [0]*2**K_bits
    V = (alpha*qml.PauliZ(wires=0).compute_matrix() + beta*qml.PauliX(wires=0).compute_matrix())/root

    for k in range(2**(K_bits-1)):
        coeff_list[2*k] = ((time * np.sqrt(alpha**2 + beta**2))**(2*k))/(fact(2*k))
        coeff_list[2*k + 1] = ((time * np.sqrt(alpha**2 + beta**2))**(2*k + 1)) / (fact(2 * k +1))
        U_list[2*k] = np.eye(2) * (-1)**k
        U_list[2*k + 1] = (-1)**k * (-1j) * V
    return [coeff_list, U_list]

In [None]:
#H.8.2b
def LCU_XandZ(alpha, beta, time, K_bits):
    aux = range(K_bits)
    main = range(K_bits, K_bits + 1)
    dev2 = qml.device("default.qubit", wires=K_bits + 1, shots=None)
    [coeff_list, U_list] = truncation_XandZ(alpha, beta, time, K_bits)
    
    @qml.qnode(dev2)
    def LCU_circuit():
        print(PREPARE_matrix(coeff_list))
        qml.QubitUnitary(PREPARE_matrix(coeff_list),wires=aux)
        SELECT(U_list)
        qml.QubitUnitary(PREPARE_matrix(coeff_list).conj().T, wires=aux)
        return qml.state()

    unnormed = LCU_circuit()[:2] # Unnormalized state of main qubit
    normed = unnormed/np.sqrt(sum(np.conjugate(unnormed)*unnormed)) # Normalize!
    return normed

In [None]:
#H.8.2c
def LCU_error_XandZ(alpha, beta, time, K_bits):
    diff = np.abs(LCU_XandZ(alpha, beta, time, K_bits) - exact_result_XandZ(alpha, beta,time))
    return np.sqrt(sum(diff*diff))

In [None]:
#H.8.3
alpha, error = 3, 0.003

print("For α =", alpha, "and error ε =", error, 
      "the optimal Trotter circuit has depth",  trotter_depth(alpha, error),
      "and the optimal LCU circuit depth", LCU_depth(alpha, error), ".")

alpha_trotter, error_trotter = 0, 1
alpha_LCU, error_LCU = 3, 0.003