Verify that the implementation of the qDRIFT subroutine is correct. 
Compute error statistics for the qDRIFT subroutine.

In [2]:
import sys
import os
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from qiskit.quantum_info import Pauli, SparsePauliOp, Operator
# Adjust the path to the directory containing 'scripts/algo'
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..', 'scripts')))
from algo.qft_qpe_qdrift import *
from utils.scalable_numerical_tests import *

In [None]:

# ------------------------------------ test ---------------------------------
import numpy as np
import plotly.graph_objects as go
# Define parameters
times = [10, 1, 0.1, 0.01, 0.001]
for time in times:
    num_qubits = 3
    num_terms = 100
    theoretical_errors = []
    unitary_errors = []
    unitary_errors_inf = []
    diamond_distances = []
    nums_samples = np.arange(400, 10000, 100)
    # Generate a random Hamiltonian
    hamiltonian = generate_random_hamiltonian(num_qubits, num_terms)
    print(f"Hamiltonian: {hamiltonian}")

    # Compute exact unitary evolution
    U_exact = exact_unitary_evolution(hamiltonian, time)

    for num_samples in nums_samples:
        # Compute qDRIFT approximation
        sampled_unitaries, labels = qdrift_sample(hamiltonian, time, num_samples)
        # multiply the sampled unitaries in the order they were sampled
        U_qdrift = sampled_unitaries[0]
        for unitary in sampled_unitaries[1:]:
            U_qdrift = U_qdrift @ unitary
        U_qdrift = Operator(U_qdrift)
        
        # Compute errors
        error = unitary_error_2_norm(U_exact, U_qdrift)
        error_2 = unitary_error_inf_norm(U_exact, U_qdrift)
        diamond_dist = compute_diamond_distance(U_exact, U_qdrift)
        lam = sum(abs(term[0]) for term in hamiltonian)
        theoretical_error = estimate_theoretical_qdrift_errror(num_samples=num_samples, lam=lam, time=time)
        # print(f"Theoretical Error\t\tError\t\tDiamond Distance")
        # print(f"   {theoretical_error}\t    {error} \t    {diamond_dist}")
        theoretical_errors.append(2 * theoretical_error)
        unitary_errors.append(error)
        diamond_distances.append(diamond_dist)
        unitary_errors_inf.append(error_2)

    # graph error vs num sampleswith plotly
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=nums_samples, y=theoretical_errors, mode='lines', name='Theoretical Error'))
    fig.add_trace(go.Scatter(x=nums_samples, y=unitary_errors, mode='lines', name='Unitary Error 2-norm'))
    fig.add_trace(go.Scatter(x=nums_samples, y=unitary_errors_inf, mode='lines', name='Unitary Error inf-norm'))
    fig.add_trace(go.Scatter(x=nums_samples, y=diamond_distances, mode='lines', name='Diamond Distance'))
    fig.update_layout(title=f'Error vs Number of Samples for time={time}', xaxis_title='Number of Samples', yaxis_title='Error')
    fig.show()
    

Hamiltonian: [(0.7085835219099342, Pauli('XXI')), (0.6998280412093427, Pauli('ZXI')), (0.28510793999779616, Pauli('IZX')), (0.48926134088772644, Pauli('ZII')), (0.8045667004965198, Pauli('ZZZ')), (0.5600563657465342, Pauli('ZXX')), (0.49159425030962167, Pauli('XXZ')), (0.14619962190035374, Pauli('XZI')), (0.9361345439394804, Pauli('XZZ')), (0.9579431076075867, Pauli('XIX')), (0.9795908005431186, Pauli('IZI')), (0.9462328671524697, Pauli('XXI')), (0.7375611608873655, Pauli('XXI')), (0.2968391216253773, Pauli('IXI')), (0.6394453136835669, Pauli('ZXZ')), (0.027311319560118452, Pauli('IIX')), (0.8365182267664694, Pauli('IXI')), (0.11560618441185533, Pauli('ZXZ')), (0.1225348866474083, Pauli('IIX')), (0.21720586420104304, Pauli('IZX')), (0.7705810151554098, Pauli('IXZ')), (0.7094408417804973, Pauli('IXZ')), (0.645981115489179, Pauli('IXI')), (0.4095893913176323, Pauli('III')), (0.6557740583167097, Pauli('XXX')), (0.4929289882207646, Pauli('ZII')), (0.9303857789722326, Pauli('ZIX')), (0.1773

Hamiltonian: [(0.22881981610920288, Pauli('IXI')), (0.08568885510414603, Pauli('ZIZ')), (0.9780417827751423, Pauli('XZI')), (0.9438925698182191, Pauli('XIZ')), (0.8312566993886173, Pauli('XIX')), (0.033018770134606745, Pauli('XXZ')), (0.567655808297048, Pauli('XIZ')), (0.5374859767473626, Pauli('XXZ')), (0.6201775318598606, Pauli('XIZ')), (0.6155052444882491, Pauli('ZXX')), (0.27116288311697956, Pauli('XXZ')), (0.07642160171712731, Pauli('IZZ')), (0.0874544804026326, Pauli('IZX')), (0.8483798366974515, Pauli('XXI')), (0.6816487373624669, Pauli('ZIZ')), (0.7499553269008541, Pauli('ZZZ')), (0.42700242833052104, Pauli('XIX')), (0.06454025039351097, Pauli('XZX')), (0.7561598841128296, Pauli('ZXI')), (0.954658457447389, Pauli('XXZ')), (0.5496731623219319, Pauli('XIZ')), (0.9465091136376764, Pauli('ZZX')), (0.5087629799601083, Pauli('XXX')), (0.13312103154693333, Pauli('ZIX')), (0.15960889420292612, Pauli('XZX')), (0.08718765821676622, Pauli('ZIX')), (0.7659296350477676, Pauli('IZI')), (0.27

Hamiltonian: [(0.6822471578882037, Pauli('IZI')), (0.6252746233146389, Pauli('XII')), (0.278889135670619, Pauli('XIZ')), (0.17141234413897044, Pauli('XXZ')), (0.349220226504276, Pauli('ZII')), (0.7019087909622977, Pauli('XII')), (0.3384163722911042, Pauli('XIX')), (0.03794836407381863, Pauli('XXZ')), (0.5779101496175315, Pauli('XXX')), (0.765632450172375, Pauli('ZIZ')), (0.6857544793286314, Pauli('IXI')), (0.25700154441279033, Pauli('ZZX')), (0.8594956529045921, Pauli('XIZ')), (0.8072405606535435, Pauli('III')), (0.7722454423017627, Pauli('ZZI')), (0.3891183033222487, Pauli('XZZ')), (0.4217420792348099, Pauli('XXX')), (0.9321376226207483, Pauli('XZI')), (0.3688178522914083, Pauli('XXX')), (0.16764434507785886, Pauli('ZIX')), (0.04584512513707384, Pauli('ZXZ')), (0.3497027489248664, Pauli('ZIX')), (0.8870255938590337, Pauli('IXI')), (0.33937565789127466, Pauli('ZXX')), (0.2064445186848165, Pauli('ZIX')), (0.8322569556254616, Pauli('ZXX')), (0.5392233853602701, Pauli('ZIX')), (0.69315385

Hamiltonian: [(0.001982575209876658, Pauli('IXZ')), (0.29491510811577126, Pauli('ZII')), (0.17046287331354004, Pauli('ZXZ')), (0.28811483557956485, Pauli('XII')), (0.5704044070047506, Pauli('XXI')), (0.6015996551545847, Pauli('III')), (0.5554351402393628, Pauli('ZIZ')), (0.2782669643538547, Pauli('XXI')), (0.46710142716226044, Pauli('XIX')), (0.8263936101746265, Pauli('ZXX')), (0.0899051574852775, Pauli('ZIX')), (0.43521311159413, Pauli('XIZ')), (0.26544720947728795, Pauli('ZXZ')), (0.25212314182771645, Pauli('ZXX')), (0.8488177630313891, Pauli('IZI')), (0.2649043545154147, Pauli('XIZ')), (0.555597252396558, Pauli('IZZ')), (0.3374348919295673, Pauli('XXX')), (0.5513077273714191, Pauli('IIZ')), (0.034785988994769146, Pauli('ZXZ')), (0.5932191188259796, Pauli('XXZ')), (0.8123036029817249, Pauli('ZXI')), (0.25006428580883766, Pauli('ZXX')), (0.35174648087367677, Pauli('ZZX')), (0.7888637280293658, Pauli('ZXI')), (0.611032958039719, Pauli('XIZ')), (0.26146280796773347, Pauli('XIZ')), (0.85

Hamiltonian: [(0.8957573420773478, Pauli('IXZ')), (0.39884622302402606, Pauli('XZX')), (0.33034074130989155, Pauli('ZII')), (0.8553364518151559, Pauli('ZXZ')), (0.26541188137074645, Pauli('IZZ')), (0.9167800763193148, Pauli('XXZ')), (0.8091102207049914, Pauli('XII')), (0.31672691647937234, Pauli('ZXZ')), (0.4805693205214948, Pauli('XII')), (0.26013438594147265, Pauli('IZZ')), (0.10548844303524363, Pauli('III')), (0.18914744845963538, Pauli('IZI')), (0.8554991211775981, Pauli('XIZ')), (0.818364726178464, Pauli('IXZ')), (0.45862384014618063, Pauli('ZXZ')), (0.11953084820735593, Pauli('IZI')), (0.30478357597024075, Pauli('IXX')), (0.7856631696935065, Pauli('IIZ')), (0.7746116759477443, Pauli('IIZ')), (0.38608277411783687, Pauli('XIZ')), (0.9281462608939228, Pauli('XZX')), (0.35827910725111745, Pauli('ZIZ')), (0.17779967274228392, Pauli('IIZ')), (0.09820685663866069, Pauli('XIX')), (0.6744475455972274, Pauli('XIZ')), (0.6537999235813242, Pauli('ZII')), (0.9400893542743448, Pauli('ZXZ')), (

In [12]:
from joblib import Parallel, delayed
import numpy as np
import plotly.graph_objects as go
import multiprocessing
import time

# Define parameters
times = [10, 1, 0.1, 0.01, 0.001]
num_qubits = 3
num_terms = 100
nums_samples = np.arange(400, 100000, 100)
num_trials = 10  # Number of qDRIFT trials to average over

def run_qdrift_trial(hamiltonian, time, num_samples):
    """Runs a single qDRIFT trial and returns the error metrics."""
    sampled_unitaries, labels = qdrift_sample(hamiltonian, time, num_samples)
    
    # Compute qDRIFT evolution
    U_qdrift = Operator(reduce(lambda a, b: a @ b, sampled_unitaries))
    
    # Compute errors
    error = unitary_error_2_norm(U_exact, U_qdrift)
    error_2 = unitary_error_inf_norm(U_exact, U_qdrift)
    diamond_dist = compute_diamond_distance(U_exact, U_qdrift)
    
    return error, error_2, diamond_dist
start = time.time()
for time in times:
    # Generate a random Hamiltonian
    hamiltonian = generate_random_hamiltonian(num_qubits, num_terms)
    lam = sum(abs(term[0]) for term in hamiltonian)  # Compute λ once
    U_exact = exact_unitary_evolution(hamiltonian, time)  # Compute exact unitary once

    # Run in parallel
    results = Parallel(n_jobs=multiprocessing.cpu_count())(
        delayed(run_qdrift_trial)(hamiltonian, time, num_samples)
        for num_samples in nums_samples for _ in range(num_trials)
    )

    # Reshape results: Convert from (num_samples * num_trials) → (num_samples, num_trials)
    results = np.array(results).reshape(len(nums_samples), num_trials, 3)

    # Compute mean errors over trials
    unitary_errors = results[:, :, 0].mean(axis=1)
    unitary_errors_inf = results[:, :, 1].mean(axis=1)
    diamond_distances = results[:, :, 2].mean(axis=1)
    theoretical_errors = [
        2 * estimate_theoretical_qdrift_errror(num_samples, lam, time) for num_samples in nums_samples
    ]

    # Plot results
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=nums_samples, y=theoretical_errors, mode='lines', name='Theoretical Error'))
    fig.add_trace(go.Scatter(x=nums_samples, y=unitary_errors, mode='lines', name='Unitary Error 2-norm'))
    fig.add_trace(go.Scatter(x=nums_samples, y=unitary_errors_inf, mode='lines', name='Unitary Error inf-norm'))
    fig.add_trace(go.Scatter(x=nums_samples, y=diamond_distances, mode='lines', name='Diamond Distance'))
    fig.update_layout(title=f'Error vs Number of Samples for time={time}', xaxis_title='Number of Samples', yaxis_title='Error')
    fig.show()
# ------------------------------------ test ---------------------------------