In [1]:
import sys
from functools import reduce
import numpy as np
import cirq
from stabilizer_states import StabilizerStates
from stabilizer_toolkit.decompositions import rank2, validate_decompositions
from stabilizer_toolkit.magic_states import enumerate_ccz, enumerate_t
from stabilizer_toolkit.helpers.unitary import get_tensored_unitary

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
np.set_printoptions(precision=3, linewidth=sys.maxsize, edgeitems=4, threshold=1024, suppress=True) 

While it is possible to load the full five qubit stabilizer state dataset, there are many more states than we actually need.

In [6]:
StabilizerStates.count(5)

2423520

In [None]:
S5 = StabilizerStates(5, 'real')
_, state, _, _ = next(enumerate_ccz(5))

So, we would be looking over ~2.4M states, but luckily we only need to search within the real stabilizer states, which is roughly 147k.

In [None]:
S5 = StabilizerStates(5, 'ternary')
print(len(S5))
_, state, _, _ = next(enumerate_ccz(5))
decompositions, coeffs = rank2.search_all_stabilizer_states(state, S5, num_cpus=8)

146880
[19] [1 0 0 1 1]
[11] [0 1 0 1 1]
[7] [0 0 1 1 1]


2023-04-20 23:47:32,540	INFO worker.py:1544 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
  9%|█████▏                                                    | 957791864/10786793760 [9:59:28<106:05:46, 25733.94it/s]

In [29]:
_, state, _, _ = next(enumerate_ccz(5))
decompositions, coeffs = rank2.ternary_search(state, S5, debug=False)

[19] [1 0 0 1 1]
[11] [0 1 0 1 1]
[7] [0 0 1 1 1]


2023-04-20 23:36:56,904	INFO worker.py:1544 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
100%|██████████████████████████████████████████████████████████████████████████| 146880/146880 [05:51<00:00, 417.44it/s]


In [30]:
validate_decompositions(state, decompositions, coeffs)

12 decompositions
|ψ〉	= [ 0.177  0.177  0.177  0.177  0.177  0.177  0.177 -0.177  0.177  0.177  0.177 -0.177  0.177  0.177  0.177  0.177  0.177  0.177  0.177 -0.177  0.177  0.177  0.177  0.177  0.177  0.177  0.177  0.177  0.177  0.177  0.177 -0.177]

✅	= [0.707] * [ 0.    0.    0.25  0.25  0.    0.    0.25 -0.25  0.    0.    0.25 -0.25  0.    0.    0.25  0.25  0.    0.    0.25 -0.25  0.    0.    0.25  0.25  0.    0.    0.25  0.25  0.    0.    0.25 -0.25]
	+ [0.707] * [0.25 0.25 0.   0.   0.25 0.25 0.   0.   0.25 0.25 0.   0.   0.25 0.25 0.   0.   0.25 0.25 0.   0.   0.25 0.25 0.   0.   0.25 0.25 0.   0.   0.25 0.25 0.   0.  ]

✅	= [0.707] * [ 0.    0.25  0.25  0.    0.25  0.    0.   -0.25  0.25  0.    0.   -0.25  0.    0.25  0.25  0.    0.25  0.    0.   -0.25  0.    0.25  0.25  0.    0.    0.25  0.25  0.    0.25  0.    0.   -0.25]
	+ [0.707] * [0.25 0.   0.   0.25 0.   0.25 0.25 0.   0.   0.25 0.25 0.   0.25 0.   0.   0.25 0.   0.25 0.25 0.   0.25 0.   0.   0.25 0.25 0.   0.   0.25 0. 

True

In [28]:
_, state, D, circuit = next(enumerate_ccz(5))
print(circuit)
print(np.where(state < 0))
print(np.diag(D.astype(np.int8)))
print("[ 1  1  1  1  1  1  1 -1  1  1  1 -1  1  1  1  1  1  1  1 -1  1  1  1  1  1  1  1  1  1  1  1 -1]")

[19] [1 0 0 1 1]
[11] [0 1 0 1 1]
[7] [0 0 1 1 1]
[[1 0 0]
 [0 1 0]
 [0 0 1]
 [1 1 1]
 [1 1 1]]
(array([ 7, 11, 19, 31]),)
[ 1  1  1  1  1  1  1 -1  1  1  1 -1  1  1  1  1  1  1  1 -1  1  1  1  1  1  1  1  1  1  1  1 -1]
[ 1  1  1  1  1  1  1 -1  1  1  1 -1  1  1  1  1  1  1  1 -1  1  1  1  1  1  1  1  1  1  1  1 -1]


There are 29 distinct CCZ circuits and corresponding magic states, so at roughly ~6 mins per search that will take 2 hours. You can run the next cell to perform a rank-2 decomposition search for each of these magic states.

In [None]:
for index, state, D, circuit in enumerate_ccz(5):
    decompositions, coeffs = rank2.ternary_search(state, S5)
    print()
    print(f"Distinct circuit index {index}")
    print(f"|ψ〉= {state}")
    print(f" D = diag({np.diag(D)})")
    print(circuit)
    valid = validate_decompositions(state, decompositions, coeffs, show=False)
    status = "✅" if valid else "❌"
    print(f"All {len(decompositions)} decomposition(s) rank-2: {status}")
    print()