In [None]:
import matplotlib.pyplot as plt
import numpy as np
from qualtran.drawing import show_bloq, show_call_graph
from tqdm.notebook import tqdm

from eftqpe.physical_costing import multicircuit_physical_cost, thc
from eftqpe.utils import make_decreasing_function

plt.style.use("figstyle.mplstyle")

## construct the walk operator

In [None]:
ctrl_walk, lambda_thc = thc.walk_and_lambda_from_file("data/thc/h2o_thc_6_4_4_30.npz")

In [None]:
show_bloq(ctrl_walk.decompose_bloq())

## build the call graph

In [None]:
g, sigma = thc.walk_call_graph(ctrl_walk)
print(*sorted(f"{str(k):30s}: {v}" for k, v in sigma.items()), sep="\n")

In [None]:
show_call_graph(g)

## count logical resources

In [None]:
# count the magic

magic_per_walk = thc.magic_from_sigma(sigma)
print(magic_per_walk)

In [None]:
# count the number of qubits

from qualtran.resource_counting import QubitCount, get_cost_value

total_qubits = get_cost_value(ctrl_walk, QubitCount())
print(total_qubits)

## Cost estimates

In [None]:
lambda_thc, total_qubits, magic_per_walk

In [None]:
delta_e = 1e-3
epsilon = delta_e / lambda_thc
gamma = 1e-6
n_factories = 1

multicircuit_physical_cost(
    epsilon, gamma, magic_per_walk, total_qubits, n_factories=n_factories
)

In [None]:
# TODO: switch to dataframe, suppporting multiple input files, and save results

gamma_list = np.logspace(-8, -1, 10)
n_factories = 1

ttot_hr_list = np.zeros_like(gamma_list)
tmax_hr_list = np.zeros_like(gamma_list)
footprint_list = np.zeros_like(gamma_list)

for j, gamma in enumerate(tqdm(gamma_list)):
    cost = multicircuit_physical_cost(
        epsilon=delta_e / lambda_thc,
        gamma=gamma,
        magic_per_unitary=magic_per_walk,
        n_algo_qubits=total_qubits,
        n_factories=n_factories,
    )
    ttot_hr_list[j] = cost["t_tot_hr"]
    tmax_hr_list[j] = cost["t_max_hr"]
    footprint_list[j] = cost["physical_cost"].footprint

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(2*2, 2*3.5), sharex=True, gridspec_kw={'hspace': 0})
ax[0].plot(gamma_list, ttot_hr_list, '-', label=r'$\mathcal{T}_{\mathrm{tot}}$')
ax[0].plot(gamma_list, tmax_hr_list, '--', label=r'$\mathcal{T}_{\mathrm{max}}$')

ax[1].plot(gamma_list, footprint_list, '.')

ax[0].set_ylabel('runtime in hours')
ax[1].set_ylabel(r'\# physical qubits')
ax[1].set_xlabel('$\gamma$, error rate per walk step')

ax[0].set_yscale('log')
#ax[1].set_yscale('log')
ax[1].set_xscale('log')

In [None]:
plt.plot(*make_decreasing_function(footprint_list, ttot_hr_list), 'o-')
plt.axhline(24, color='black', linestyle='dashed')
plt.text(2e5, 1.3*24, "1 day", va='bottom')
plt.axhline(24*30, color='black', linestyle='dashed')
plt.text(2e5, 0.9*24*30, "1 month", va='top')

plt.xlabel(r"\# physical qubits")
plt.grid(axis="x", which="minor")
plt.ylabel("runtime in hours")
plt.yscale("log")
plt.xscale("log")

In [None]:
plt.plot(*make_decreasing_function(footprint_list, ttot_hr_list), 'o-')
plt.axhline(24, color='black', linestyle='dashed')
plt.text(2e5, 1.3*24, "1 day", va='bottom')
plt.axhline(24*30, color='black', linestyle='dashed')
plt.text(2e5, 0.9*24*30, "1 month", va='top')

plt.xlabel(r"\# physical qubits")
plt.grid(axis="x", which="minor")
plt.ylabel("runtime in hours")
plt.yscale("log")
plt.xscale("log")