In [1]:
from neural_network import CNN
from quantum_error_correction_code import SurfaceCode
from agent import DQN

from jax import random, lax
import jax.numpy as jnp

key = random.key(42)

In [2]:
CONV_LAYERS = [(6,3,1,1) for _ in range(5)] # 5 Layers of convolution

CODE_DISTANCE = 5

In [3]:
code = SurfaceCode(CODE_DISTANCE)

model = CNN(
    input_shape=(6,CODE_DISTANCE,CODE_DISTANCE),
    conv_layers=CONV_LAYERS
)
print(model.layer_sizes)

agent = DQN(
    model=model,
    discount=.8,
    num_data_qubits=code.num_data_qubits
)

[(6, 5, 5), (6, 5, 5), (6, 5, 5), (6, 5, 5), (6, 5, 5), (6, 5, 5)]


In [7]:
action = jnp.zeros((6,CODE_DISTANCE,CODE_DISTANCE)).at[3,2,1].set(1).argmax()
print(action)
deformation_action_idx, data_qubit_action_idx = agent.split_action(action)
print(deformation_action_idx, data_qubit_action_idx)
action_ = agent.merge_action(deformation_action_idx, data_qubit_action_idx)
print(action_)
assert action == action_

86
3 11
86


In [5]:
deformation = jnp.zeros(CODE_DISTANCE**2, dtype=jnp.int32)
agent.merge_action(
    deformation_action_idx=deformation[data_qubit_action_idx],
    data_qubit_action_idx=data_qubit_action_idx
)

Array(11, dtype=int32)

In [6]:
subkey, key = random.split(key)
online_net_params = model.init(subkey)

deformation = jnp.zeros(CODE_DISTANCE**2, dtype=jnp.int32)

for _ in range(10):
    img = code.deformation_image(deformation)
    action, done, key = agent.act(
        key,
        online_net_params,
        img,
        disallowed_actions = jnp.zeros(6*CODE_DISTANCE**2).at[CODE_DISTANCE**2*deformation + jnp.arange(CODE_DISTANCE**2)].set(True),
        epsilon = 0
    )
    d_idx, q_idx = agent.split_action(action)
    print(deformation, "->", action, "=", f"deformation {d_idx} on data qubit {q_idx}")
    deformation = deformation.at[q_idx].set(d_idx)

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] -> 62 = deformation 2 on data qubit 12
[0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0] -> 67 = deformation 2 on data qubit 17
[0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 2 0 0 0 0 0 0 0] -> 72 = deformation 2 on data qubit 22
[0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 2 0 0 0 0 2 0 0] -> 68 = deformation 2 on data qubit 18
[0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 2 2 0 0 0 2 0 0] -> 38 = deformation 1 on data qubit 13
[0 0 0 0 0 0 0 0 0 0 0 0 2 1 0 0 0 2 2 0 0 0 2 0 0] -> 66 = deformation 2 on data qubit 16
[0 0 0 0 0 0 0 0 0 0 0 0 2 1 0 0 2 2 2 0 0 0 2 0 0] -> 63 = deformation 2 on data qubit 13
[0 0 0 0 0 0 0 0 0 0 0 0 2 2 0 0 2 2 2 0 0 0 2 0 0] -> 38 = deformation 1 on data qubit 13
[0 0 0 0 0 0 0 0 0 0 0 0 2 1 0 0 2 2 2 0 0 0 2 0 0] -> 63 = deformation 2 on data qubit 13
[0 0 0 0 0 0 0 0 0 0 0 0 2 2 0 0 2 2 2 0 0 0 2 0 0] -> 38 = deformation 1 on data qubit 13
