In [None]:
import sys
sys.path.append('../../pyzx_copy')
import pyzx_copy as zx_copy

from zxreinforce.Resetters import Resetter_Circuit
from zxreinforce.own_constants import encode_phase, N_EDGE_ACTIONS, N_NODE_ACTIONS
from zxreinforce.zx_env_circuit import ZXCalculus
from zxreinforce.zx_gym_wrapper import ZXGymWrapper
from zxreinforce.ppo import PPOLightning
from pyzx_copy.utils import heuristic_fixed_pairs_node_disjoint, broken_paths_for_unrouted_pairs, to_networkx_graph

import torch

def decode_action(a, wrapper):
    """Human-readable action decoding."""
    n_na = N_NODE_ACTIONS
    n_ea = N_EDGE_ACTIONS
    max_nodes = wrapper.max_nodes
    max_edges = wrapper.max_edges
    stop_idx = 1 + max_nodes * n_na + max_edges * n_ea - 1

    if a == stop_idx:
        return ("STOP", None, None)

    if a < max_nodes * n_na:
        node_idx = a // n_na
        act_idx = a % n_na
        act_name = ["select_node", "unfuse_rule", "color_change_rule", "split_hadamard", "pi_rule"][act_idx]
        return ("node", act_name, int(node_idx))
    else:
        a2 = a - max_nodes * n_na
        edge_idx = a2 // n_ea
        act_idx = a2 % n_ea
        act_name = ["select_edge", "fuse_rule", "bialgebra_rule"][act_idx]
        return ("edge", act_name, int(edge_idx))

def greedy_action(model, obs):
    with torch.no_grad():
        node_emb, edge_emb, node_mask, edge_mask, context, action_mask = model._obs_to_tensors([obs])
        logits = model.policy(node_emb, edge_emb, node_mask, edge_mask)
        masked = logits.masked_fill(action_mask == 0, -1e9)
        return int(masked.argmax(dim=1).item())


num_qubits_min=2
num_qubits_max=3
min_gates=3
max_gates=6
p_t=0.2
p_h=0.2

max_epochs=1000
steps_per_epoch=100
max_steps=10000
num_envs=16
max_nodes=128
max_edges=256
max_qubits=5
env_max_steps=300
step_penalty=0
resetter_seed=-1
count_down_from=20
dont_allow_stop=True

def build_eval_env():
    resetter = Resetter_Circuit(
        num_qubits_min=num_qubits_min,
        num_qubits_max=num_qubits_max,
        min_gates=min_gates,
        max_gates=max_gates,
        p_t=p_t,
        p_h=p_h,
        seed=None if resetter_seed < 0 else resetter_seed,
    )
    env = ZXCalculus(
        max_steps=env_max_steps,
        add_reward_per_step=step_penalty,
        resetter=resetter,
        count_down_from=count_down_from,
        dont_allow_stop=dont_allow_stop,
        extra_state_info=False
    )
    wrapped = ZXGymWrapper(
        env,
        max_nodes=max_nodes,
        max_edges=max_edges,
        max_qubits=max_qubits,
    )
    return wrapped

def dummy_env_fn():
        # Shapes must match the values passed to PPOLightning ctor
        return build_eval_env()


In [None]:

env = build_eval_env()
obs = env.reset()

# Run until STOP or env done
stop_idx = env.action_space.n - 1
print("stop_idx", stop_idx)
step = 0
max_steps = 200
g = env.env.graph.copy()
zx_copy.draw(g)
# for v in g.vertices():
#     print(v, g.type(v), g.neighbors(v))
# for e in g.edges():
#     print(e)

In [None]:
# ckpt = "runs/ZX-PPO_nq[2-3]_gates[3-6]_N128_E256_envs16_rs2048_mb512_pe4_lr0.0003_20250919-154808/checkpoints/last.ckpt"
ckpt = "runs/ZX-PPO_nq[2-3]_gates[3-6]_N128_E256_envs16_rs2048_mb512_pe4_lr0.0003_20250919-194011/checkpoints/last.ckpt"
# ckpt = "runs/ZX-PPO_nq[2-3]_gates[3-6]_N128_E256_envs16_rs2048_mb512_pe4_lr0.0003_20250919-154808/checkpoints/step=6400-retmean=0.000.ckpt"
model = PPOLightning.load_from_checkpoint(
    ckpt,
    env_fn=dummy_env_fn,
    max_nodes=max_nodes,
    max_edges=max_edges,
    gamma=0.995,
    gae_lambda=0.95,
    clip_eps=0.2,
    ent_coef=0.02,
    vf_coef=0.5,
    lr=3e-4,
    num_envs=num_envs,
)
model.eval()
model.freeze()


# Run until STOP or env done
stop_idx = env.action_space.n - 1
step = 0
max_steps = 200
while step < max_steps:
    a = greedy_action(model, obs)
    kind, name, idx = decode_action(a, env)
    print(f"[{step:03d}] action={kind}:{name} idx={idx}")

    # If the policy picked STOP, don't call env.step to avoid auto-reset; we want the final graph
    if a == stop_idx:
        print("Policy chose STOP.")
        break

    obs, reward, done, _ = env.step(a)
    # Optional: inspect intermediate size
    print(f"   reward={reward:.3f}, |V|={env.env.n_spiders}, |E|={env.env.n_edges}")
    if done:
        print("Env returned done=True")
        break
    step += 1

# Get final ZX graph
final_graph = env.env.graph.copy()
print(f"Final ZX: |V|={len(list(final_graph.vertices()))}, |E|={len(list(final_graph.edges()))}")

