In [1]:
# import numpy as np
import jax
import jax.numpy as jnp
import itertools
import stim

In [2]:
circuit = """
H 0 1 2 3 4 5 6 7 8 9 10 11\nCZ 0 4 0 5 1 3 1 5 2 3 2 4 6 10 6 11 7 9 7 11 8 9 8 10 3 6 4 7 5 8    
"""

print(circuit)


H 0 1 2 3 4 5 6 7 8 9 10 11
CZ 0 4 0 5 1 3 1 5 2 3 2 4 6 10 6 11 7 9 7 11 8 9 8 10 3 6 4 7 5 8    



In [3]:
key = jax.random.key(0)

In [4]:
def generate_random_rgs(key: jnp.ndarray, num_rows: int, num_cols: int):
    nodes = jnp.arange(num_rows * num_cols).reshape((num_rows, num_cols))
    nodes = nodes.tolist()
    is_connected = jax.random.bernoulli(key, shape=(int(num_rows**2) * (num_cols - 1),))
    edges = itertools.chain(
        *[itertools.product(nodes[i], nodes[i + 1]) for i in range(num_rows - 1)]
    )

    circuit = "H " + " ".join([str(i) for i in range(int(num_cols * num_rows))]) + '\nCZ '

    circuit += ' '.join([ f'{i} {j}' for is_c, (i, j) in zip(is_connected, edges) if is_c])

    return circuit

def generate_random_rgs_v2(key: jnp.ndarray, num_rows: int, num_cols: int):
    nodes = jnp.arange(num_rows * num_cols).reshape((num_rows, num_cols))
    nodes = nodes.tolist()
    is_connected = jax.random.bernoulli(key, shape=(int(num_cols ** 2) * 2,))

    edges = itertools.chain(
        *[itertools.product(nodes[i], nodes[i + 1]) for i in [0, num_rows - 2]]
    )

    circuit = "H " + " ".join([str(i) for i in range(int(num_cols * num_rows))]) + '\nCZ '
    circuit += ' '.join([ f'{i} {j}' for is_c, (i, j) in zip(is_connected, edges) if is_c])
    circuit += ' ' + '3 6 4 7 5 8'

    return circuit

def gen_all_bell():

    stabilzers = []
    for i, j in itertools.product(range(0, 3), range(9, 12)):
        stab = ["_"] * 12
        stab[i] = 'X'
        stab[j] = 'Z'
        stabilzers.append("+" + "".join(stab))

        stab = ["_"] * 12
        stab[j] = 'X'
        stab[i] = 'Z'
        stabilzers.append("+" + "".join(stab))

    return stabilzers

In [5]:
all_bell_stabs = gen_all_bell()

In [6]:
all_bell_stabs

['+X________Z__',
 '+Z________X__',
 '+X_________Z_',
 '+Z_________X_',
 '+X__________Z',
 '+Z__________X',
 '+_X_______Z__',
 '+_Z_______X__',
 '+_X________Z_',
 '+_Z________X_',
 '+_X_________Z',
 '+_Z_________X',
 '+__X______Z__',
 '+__Z______X__',
 '+__X_______Z_',
 '+__Z_______X_',
 '+__X________Z',
 '+__Z________X']

In [20]:
count = 0
for idx in range(2000):
    key, graph_key, meas_1_key, meas_2_key = jax.random.split(key, 4)
    circuit = generate_random_rgs_v2(graph_key, 4, 3)

    s = stim.TableauSimulator()
    s.do(stim.Circuit(circuit))

    for i in range(3, 9):
        s.postselect_x(i, desired_value=False)

    node_1 = jax.random.choice(meas_1_key, jnp.array([0, 1, 2]))
    node_2 = jax.random.choice(meas_1_key, jnp.array([9, 10, 11]))
    for i in [node_1, node_2]:
        s.postselect_z(i, desired_value=False)

    r = s.canonical_stabilizers()
    set_r = set(map(str, filter(lambda x: x.weight == 2, r)))

    is_contained_bell_pair = set_r.issubset(
        set(all_bell_stabs)
    )

    if is_contained_bell_pair:
        count += len(set_r) / 2

count

1515.0

In [19]:
all_bell_stabs

['+X________Z__',
 '+Z________X__',
 '+X_________Z_',
 '+Z_________X_',
 '+X__________Z',
 '+Z__________X',
 '+_X_______Z__',
 '+_Z_______X__',
 '+_X________Z_',
 '+_Z________X_',
 '+_X_________Z',
 '+_Z_________X',
 '+__X______Z__',
 '+__Z______X__',
 '+__X_______Z_',
 '+__Z_______X_',
 '+__X________Z',
 '+__Z________X']

In [17]:
set_r

{'+__X_______Z_', '+__Z_______X_'}

In [174]:
key, subkey = jax.random.split(key)
circuit = generate_random_rgs(key, 4, 3)

s = stim.TableauSimulator()
s.do(stim.Circuit(circuit))

s.canonical_stabilizers()

for i in range(3, 9):
    s.postselect_x(i, desired_value=False)

for i in [1, 10]:
    s.postselect_z(i, desired_value=False)

r = s.canonical_stabilizers()
r

[stim.PauliString("+X__________Z"),
 stim.PauliString("+Z__________X"),
 stim.PauliString("+_Z__________"),
 stim.PauliString("+__X_________"),
 stim.PauliString("+___X________"),
 stim.PauliString("+____X_______"),
 stim.PauliString("+_____X______"),
 stim.PauliString("+______X_____"),
 stim.PauliString("+_______X____"),
 stim.PauliString("+________X___"),
 stim.PauliString("+_________X__"),
 stim.PauliString("+__________Z_")]

TypeError: peek_observable_expectation(): incompatible function arguments. The following argument types are supported:
    1. (self: stim._stim_polyfill.TableauSimulator, observable: stim._stim_polyfill.PauliString) -> int

Invoked with: <stim._stim_polyfill.TableauSimulator object at 0x11ff365f0>, 'XX'

In [175]:
circuit

'H 0 1 2 3 4 5 6 7 8 9 10 11\nCZ 0 4 1 3 1 4 1 5 2 3 2 4 3 6 3 8 4 6 4 7 4 8 5 8 7 11 8 9 8 10'

In [184]:
set(map(str ,filter(lambda x: x.weight == 2, r))).issubset(set(all_bell_stabs))

True

In [178]:
str(r[1])

'+Z__________X'

In [87]:
r[1].weight

3

In [89]:
r[1].pauli_indices()

[1, 10, 11]

In [147]:
for i, j in itertools.product(range(0, 3), range(9, 12)):
    stab = ["_"] * 12
    stab[i] = 'X'
    stab[j] = 'Z'
    stab_str = "".join(stab)
    print(stab_str)

    stab = ["_"] * 12
    stab[j] = 'X'
    stab[i] = 'Z'
    stab_str = "".join(stab)
    print(stab_str)

X________Z__
Z________X__
X_________Z_
Z_________X_
X__________Z
Z__________X
_X_______Z__
_Z_______X__
_X________Z_
_Z________X_
_X_________Z
_Z_________X
__X______Z__
__Z______X__
__X_______Z_
__Z_______X_
__X________Z
__Z________X


In [None]:
stab_str.translate(str.maketrans({"X": "Z", "Z": "X"}))

'__Z________X'