In [1]:
import os
import glob
from pathlib import Path
import shutil
import logging

import tensorflow as tf
from tensorflow import keras
import numpy as np
from tqdm import tqdm
import yaml

from rl.network import ResNet
from rl.mcts import MCTS
from rl.buffer import ReplayBuffer, Sample
from rl.game import Game, encode_state

logging.basicConfig(level=logging.INFO)

with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

base_path = "graphs"
index = "20241201"
qubits = config["game_settings"]["N"]
training_settings = config["training_settings"]
network_settings = config["network_settings"]
mcts_settings = config["mcts_settings"]
num_cpus = training_settings["num_cpus"]
num_gpus = training_settings["num_gpus"]
n_episodes = training_settings["n_episodes"]
buffer_size = training_settings["buffer_size"]
batch_size = training_settings["batch_size"]
epochs_per_update = training_settings["epochs_per_update"]
update_period = training_settings["update_period"]
save_period = training_settings["save_period"]
eval_period = training_settings.get("eval_period", 100)


def select_action(policy, valid_actions, prev_action):
    """
    Selects an action given a policy distribution, a set of valid_actions,
    and the previous action. If prev_action is not None, exclude it if needed.
    """
    if prev_action is not None:
        # Ensure valid_actions excludes prev_action if required by game logic
        prob = policy[valid_actions]
    else:
        valid_actions = np.arange(len(policy))
        prob = policy
    prob /= np.sum(prob)
    return np.random.choice(valid_actions, p=prob)



def selfplay(weights, qubits, current_episode, config):
    """
    Runs a single self-play episode with the current network weights and returns
    the record of (state, policy, reward) samples.
    """
    record = []
    game = Game(qubits, config)
    state = game.get_initial_state()
    game.reset_used_columns()

    network = ResNet(action_space=game.action_space, config=config)
    network.set_weights(weights)  # set network weights before prediction
    # Warm-up prediction if needed (comment why if required)
    _ = network.predict(encode_state(state, qubits))

    mcts = MCTS(qubits=qubits, network=network, config=config)
    done = False
    total_score = 0
    step_count = 0
    prev_action = None

    while not done and step_count < game.MAX_STEPS:
        # MCTS search returns a policy distribution over actions
        mcts_policy = mcts.search(
            root_state=state,
            prev_action=prev_action,
            num_simulations=config["mcts_settings"]["num_mcts_simulations"]
        )

        valid_actions = game.get_valid_actions(state, prev_action) if prev_action is not None else np.arange(game.action_space)
        action = select_action(mcts_policy, valid_actions, prev_action)

        record.append(Sample(state.copy(), mcts_policy.copy(), reward=None))
        state, done, action_score = game.step(state, action, prev_action)
        prev_action = action
        total_score += action_score
        step_count += 1

    reward = game.get_reward(state, total_score)
    for sample in record:
        sample.reward = reward
    return record


def evaluate_self_play(qubits, network, config):
    """
    Evaluates the network by running it on a set of predefined states.
    Returns average depth and count metrics.
    """
    pattern = os.path.join("graphs", f"adj_matrix_{qubits}_*.npy")
    file_paths = glob.glob(pattern)

    avg_depth = []
    avg_counts = []
    for file_path in file_paths:
        state = np.load(file_path)
        game = Game(qubits, config)
        swap_pairs = []
        done = False
        step_count = 0
        prev_action = None

        while not done and step_count < game.MAX_STEPS:
            encoded_state = encode_state(state, qubits)
            policy_output, value_output = network.predict(np.expand_dims(encoded_state, axis=0))
            policy = policy_output[0]

            valid_actions = game.get_valid_actions(state, prev_action) if prev_action is not None else np.arange(game.action_space)
            if len(valid_actions) == 0:
                # If no valid actions, break or handle differently
                logging.warning("No valid actions found. Breaking.")
                break

            action = select_action(policy, valid_actions, prev_action)

            if action < len(game.coupling_map):
                selected_action = game.coupling_map[action]
                swap_pairs.append(selected_action)
            else:
                # Clarify why this branch exists. If it's needed:
                indices = action % 2
                extended_pairs = game.coupling_map[indices::2]
                swap_pairs.extend(extended_pairs)

            state, done, _ = game.step(state, action, prev_action)
            prev_action = action
            step_count += 1

        # Determine metrics
        if not done:
            depth = game.MAX_STEPS
            swap_count = game.MAX_STEPS
        else:
            # If game tracks current_layer
            game.current_layer += 1
            depth = game.current_layer
            swap_count = len(swap_pairs)

        logging.info(f"depth: {depth}, count: {swap_count}")
        avg_counts.append(swap_count)
        avg_depth.append(depth)

    return avg_depth, avg_counts

In [2]:
logdir = Path("log")
if logdir.exists():
    shutil.rmtree(logdir)
summary_writer = tf.summary.create_file_writer(str(logdir))

game = Game(qubits, config)
network = ResNet(action_space=game.action_space, config=config)

dummy_state = encode_state(game.state, qubits)
network.predict(encode_state(game.state, qubits))
current_weights = network.get_weights()

optimizer = keras.optimizers.legacy.Adam(
    learning_rate=network_settings["learning_rate"]
)

replay = ReplayBuffer(buffer_size=buffer_size)

n_updates = 0

n = 0
while n < n_episodes:
    # Collect data via self-play
    for _ in tqdm(range(update_period)):
        finished_records = selfplay(current_weights, qubits, n, config)
        replay.add_record(finished_records)
        n += 1

    # Update network if enough samples collected
    if len(replay) >= batch_size:
        num_iters = epochs_per_update * (len(replay) // batch_size)
        value_loss_weight = 0.5
        policy_loss_weight = 1.5

        for i in tqdm(range(num_iters)):
            states, mcts_policy, rewards = replay.get_minibatch(batch_size=batch_size)
            with tf.GradientTape() as tape:
                p_pred, v_pred = network(states, training=True)
                # Compute losses
                value_loss = tf.square(rewards - v_pred)
                # Using a manual cross-entropy; ensure p_pred is a probability:
                policy_loss = -tf.reduce_sum(mcts_policy * tf.math.log(p_pred + 1e-5), axis=1, keepdims=True)
                loss = tf.reduce_mean(value_loss_weight * value_loss + policy_loss_weight * policy_loss)

            grads = tape.gradient(loss, network.trainable_variables)
            grads, _ = tf.clip_by_global_norm(grads, 1.0)
            optimizer.apply_gradients(zip(grads, network.trainable_variables))
            n_updates += 1

            if i % 10 == 0:
                with summary_writer.as_default():
                    tf.summary.scalar("value_loss", tf.reduce_mean(value_loss), step=n_updates)
                    tf.summary.scalar("policy_loss", tf.reduce_mean(policy_loss), step=n_updates)

        current_weights = network.get_weights()

    # Save periodically
    if n % save_period == 0:
        network.save(f"checkpoints/network{qubits}_{index}_{n}", save_format="tf")
        network.save_weights(f"checkpoints/network{qubits}_{index}_{n}.weights.h5")
        logging.info("Model saved.")

    # Evaluate periodically
    if n % eval_period == 0:
        depth, count = evaluate_self_play(qubits, network, config)
        logging.info(f"Episode {n}: SWAP depth = {np.mean(depth)}, SWAP count = {np.mean(count)}")

 95%|█████████▌| 19/20 [04:02<00:13, 13.70s/it]

In [4]:
game = Game(qubits, config)
network = ResNet(action_space=game.action_space, config=config)
network = keras.models.load_model(f"checkpoints/network{qubits}_{index}_100")





In [7]:
depths = []
for _ in range(20):
    depth, count = evaluate_self_play(qubits, network, config)
    depths.append(depth)
min_depth = np.min(np.vstack(depths), axis=0)

ValueError: in user code:

    File "/Users/kento/Develop/venv/ai_transpiler/lib/python3.12/site-packages/tf_keras/src/engine/training.py", line 2436, in predict_function  *
        return step_function(self, iterator)
    File "/Users/kento/Develop/venv/ai_transpiler/lib/python3.12/site-packages/tf_keras/src/engine/training.py", line 2409, in run_step  *
        outputs = model.predict_step(data)
    File "/Users/kento/Develop/venv/ai_transpiler/lib/python3.12/site-packages/tf_keras/src/engine/training.py", line 2377, in predict_step  *
        return self(x, training=False)
    File "/Users/kento/Develop/venv/ai_transpiler/lib/python3.12/site-packages/tf_keras/src/engine/training.py", line 565, in error_handler  *
        del filtered_tb
    File "/Users/kento/Develop/venv/ai_transpiler/lib/python3.12/site-packages/tf_keras/src/engine/training.py", line 588, in __call__  *
        return super().__call__(*args, **kwargs)
    File "/Users/kento/Develop/venv/ai_transpiler/lib/python3.12/site-packages/tf_keras/src/engine/training.py", line 560, in error_handler  *
        filtered_tb = _process_traceback_frames(e.__traceback__)
    File "/Users/kento/Develop/venv/ai_transpiler/lib/python3.12/site-packages/tf_keras/src/engine/base_layer.py", line 1136, in __call__  *
        outputs = call_fn(inputs, *args, **kwargs)
    File "/var/folders/dn/mrzg6ww170s58lpznrdv3h_00000gn/T/__autograph_generated_file9vfy36h2.py", line 56, in error_handler  **
        raise ag__.ld(e)
    File "/var/folders/dn/mrzg6ww170s58lpznrdv3h_00000gn/T/__autograph_generated_file9vfy36h2.py", line 34, in error_handler
        retval__1 = ag__.converted_call(ag__.ld(fn), tuple(ag__.ld(args)), dict(**ag__.ld(kwargs)), fscope_1)
    File "/Users/kento/Develop/venv/ai_transpiler/lib/python3.12/site-packages/tf_keras/src/saving/legacy/saved_model/utils.py", line 65, in return_outputs_and_add_losses  **
        outputs, losses = fn(*args, **kwargs)
    File "/Users/kento/Develop/venv/ai_transpiler/lib/python3.12/site-packages/tf_keras/src/saving/legacy/saved_model/utils.py", line 190, in wrap_with_training_arg
        return control_flow_util.smart_cond(
    File "/Users/kento/Develop/venv/ai_transpiler/lib/python3.12/site-packages/tf_keras/src/utils/control_flow_util.py", line 108, in smart_cond
        return tf.__internal__.smart_cond.smart_cond(
    File "/Users/kento/Develop/venv/ai_transpiler/lib/python3.12/site-packages/tf_keras/src/saving/legacy/saved_model/utils.py", line 193, in <lambda>
        lambda: replace_training_and_call(False),
    File "/Users/kento/Develop/venv/ai_transpiler/lib/python3.12/site-packages/tf_keras/src/saving/legacy/saved_model/utils.py", line 188, in replace_training_and_call
        return wrapped_call(*new_args, **new_kwargs)

    ValueError: Could not find matching concrete function to call loaded from the SavedModel. Got:
      Positional arguments (2 total):
        * <tf.Tensor 'inputs:0' shape=(None, 8, 8, 1) dtype=float32>
        * False
      Keyword arguments: {}
    
     Expected these arguments to match one of the following 2 option(s):
    
    Option 1:
      Positional arguments (2 total):
        * TensorSpec(shape=(None, 9, 9, 1), dtype=tf.float32, name='input_1')
        * True
      Keyword arguments: {}
    
    Option 2:
      Positional arguments (2 total):
        * TensorSpec(shape=(None, 9, 9, 1), dtype=tf.float32, name='input_1')
        * False
      Keyword arguments: {}


In [4]:
for _ in range(20):
    game = Game(qubits, config)
    state = game.state
    ans = []
    done = False
    total_score = 0
    step_count = 0
    prev_action = None
    print(state)
    while not done and step_count < game.MAX_STEPS:
        encoded_state = encode_state(state, qubits)
        input_state = np.expand_dims(encoded_state, axis=0)

        policy_output, value_output = network.predict(input_state)
        policy = policy_output[0]
        if prev_action is not None:
            indices = [i for i in range(game.action_space) if i != prev_action]
            prob = policy[indices]
            action = np.random.choice(indices, p=prob / prob.sum())
        else:
            indices = list(range(game.action_space))
            action = np.random.choice(indices, p=policy)
        selected_action = game.coupling_map[action]
        ans.append(selected_action)
        state, done, _ = game.step(state, action, prev_action)
        prev_action = action
        step_count += 1
    if done:
        print(f"Game finished successfully in {step_count} steps with {ans}")
    else:
        print(f"Game terminated after reaching the maximum steps ({game.MAX_STEPS}).")
        print(f"Total score: {total_score}")

[[0. 0. 0. 1. 1. 1. 0. 1.]
 [0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1.]
 [1. 0. 0. 0. 0. 1. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 1. 0. 0. 0. 1.]
 [0. 1. 0. 0. 0. 0. 0. 0.]
 [1. 0. 1. 0. 0. 1. 0. 0.]]


ValueError: 'a' and 'p' must have same size

In [None]:
min_depth

In [None]:
np.mean(min_depth)

array([1, 5, 6, 4, 6, 7, 7, 2, 6, 3, 2, 5, 3, 7, 5, 7, 6, 3, 5, 2, 4, 9,
       2, 6, 5, 5, 3, 2, 2, 1]) -> 4.366666666666666