In [1]:
# Allows us to import files from the parent folder
import sys
import os
# Get the parent directory of the notebook's folder
base_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
sys.path.append(base_dir)

from neural_network import CNNPure
from agent import DQN
from deformation_handler import DeformationManager

from jax import random, lax
import jax.numpy as jnp

key = random.key(42)

In [2]:
CONV_LAYERS = [(6,3,1,1) for _ in range(5)] # 5 Layers of convolution

CODE_DISTANCE = 3

In [3]:
dm = DeformationManager(jnp.arange(6), CODE_DISTANCE)

model = CNNPure(
    input_shape=(6,CODE_DISTANCE,CODE_DISTANCE),
    conv_layers=CONV_LAYERS
)
online_net_params = model.init(key)

agent = DQN(
    model=model,
    discount=.8,
    n_actions=6*CODE_DISTANCE*CODE_DISTANCE
)

Input image of size 3 by 3 with 6 channels
to image of size 3 by 3 with 6 channels
to image of size 3 by 3 with 6 channels
to image of size 3 by 3 with 6 channels
to image of size 3 by 3 with 6 channels
to image of size 3 by 3 with 6 channels


In [4]:
deformation = jnp.zeros(CODE_DISTANCE**2, dtype=jnp.int32)
key = random.key(42)

for _ in range(10):
    img = dm.deformation_image(deformation)
    action, done, key = agent.act(
        key,
        online_net_params,
        img,
        disallowed_actions = jnp.zeros(agent.n_actions).at[CODE_DISTANCE**2*deformation + jnp.arange(CODE_DISTANCE**2)].set(True),
        epsilon = 0
    )
    d_idx, q_idx = jnp.unravel_index(action, (6,CODE_DISTANCE*CODE_DISTANCE))
    print(deformation, "->", action, "=", f"deformation {d_idx} on data qubit {q_idx}")
    deformation = deformation.at[q_idx].set(d_idx)

[0 0 0 0 0 0 0 0 0] -> 49 = deformation 5 on data qubit 4
[0 0 0 0 5 0 0 0 0] -> 48 = deformation 5 on data qubit 3
[0 0 0 5 5 0 0 0 0] -> 44 = deformation 4 on data qubit 8
[0 0 0 5 5 0 0 0 4] -> 51 = deformation 5 on data qubit 6
[0 0 0 5 5 0 5 0 4] -> 45 = deformation 5 on data qubit 0
[5 0 0 5 5 0 5 0 4] -> 52 = deformation 5 on data qubit 7
[5 0 0 5 5 0 5 5 4] -> 7 = deformation 0 on data qubit 7
[5 0 0 5 5 0 5 0 4] -> 52 = deformation 5 on data qubit 7
[5 0 0 5 5 0 5 5 4] -> 7 = deformation 0 on data qubit 7
[5 0 0 5 5 0 5 0 4] -> 52 = deformation 5 on data qubit 7
