In [1]:
from ray.rllib.utils.framework import try_import_tf
tf1, tf, tfv = try_import_tf()

gpus = tf.config.list_physical_devices('GPU')
print(gpus)

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [2]:
import numpy as np

from tqdm.auto import tqdm


from graphenv.examples.tsp.graph_utils import make_complete_planar_graph
from graphenv.graph_env import GraphEnv
from graphenv.examples.tsp.tsp_nfp_state import TSPNFPState
from graphenv.examples.tsp.tsp_nfp_model import TSPGNNModel

In [3]:
N = 40
G = make_complete_planar_graph(N=N, seed=0)

tsp_nfp_state = TSPNFPState(G, max_num_neighbors=5)

## Check the greedy search heuristic baseline

In [4]:
import networkx as nx
from networkx.algorithms.approximation.traveling_salesman import greedy_tsp
tsp_approx = nx.approximation.traveling_salesman_problem

path = tsp_approx(G, cycle=True, method=greedy_tsp)
reward_baseline = -sum([G[path[i]][path[i + 1]]["weight"] for i in range(0, N)])
print(f"Networkx greedy reward: {reward_baseline:1.3f}")

Networkx greedy reward: -5.987


## Initialize a model without any trained weights

In [5]:
env = GraphEnv({
    "state": tsp_nfp_state,
    "max_num_children": G.number_of_nodes(),
})

model = TSPGNNModel._create_base_model(num_messages=1, embed_dim=32)

## Sample from the model's logit value predictions with a softmax

In [6]:
def sample_model():
    
    env.reset()
    obs = env.make_observation()
    done = False
    total_reward = 0

    while not done:
        values = model(obs['vertex_observations'])[0]
        masked_action_values = tf.where(
            obs['action_mask'][1:], values[1:, 0], values.dtype.min
        )
        action_probabilities = tf.nn.softmax(masked_action_values).numpy()
        action = np.random.choice(env.max_num_children, size=1, p=action_probabilities)[0]
        obs, reward, done, info = env.step(action)
        total_reward += reward
        
    return total_reward

In [7]:
[sample_model() for _ in tqdm(range(10))]

  0%|          | 0/10 [00:00<?, ?it/s]

[-7.606531147266801,
 -6.833060707694988,
 -7.946361488918026,
 -7.78510613142895,
 -6.698088958210221,
 -7.476044846218038,
 -6.894692868352494,
 -7.568217203405193,
 -7.955649206270243,
 -8.039583309734674]

## Create rllib agent

In [8]:
from ray.tune.registry import register_env
from ray.rllib.models import ModelCatalog


ModelCatalog.register_custom_model('TSPGNNModel', TSPGNNModel)
register_env('GraphEnv', lambda config: GraphEnv(config))

config = {
    "env": 'GraphEnv',
    "env_config": {
        "state": tsp_nfp_state,
        "max_num_children": G.number_of_nodes(),
    },
    "model": {
        "custom_model": 'TSPGNNModel',
        "custom_model_config": {"num_messages": 1, "embed_dim": 32},
    },
    "num_workers": 1,
    "num_gpus": 0,
    "framework": "tf2",
    "eager_tracing": True,
}


from ray.rllib.agents import ppo



ppo_config = ppo.DEFAULT_CONFIG.copy()
ppo_config.update(config)
agent = ppo.PPOTrainer(config=ppo_config)

2022-05-11 14:51:35,918	INFO services.py:1374 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m
[2m[33m(raylet)[0m   if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"):




2022-05-11 14:51:44,753	INFO trainable.py:125 -- Trainable.setup took 11.866 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.


In [9]:
env = GraphEnv({
        "state": tsp_nfp_state,
        "max_num_children": G.number_of_nodes(),
    })

def sample_ppo_action():
    # run until episode ends
    episode_reward = 0
    done = False
    obs = env.reset()

    while not done:
        action = agent.compute_single_action(obs)
        obs, reward, done, info = env.step(action)
        episode_reward += reward
        
    return episode_reward

In [10]:
[sample_ppo_action() for _ in tqdm(range(10))]

  0%|          | 0/10 [00:00<?, ?it/s]

[-6.7207132466504875,
 -7.145118287632355,
 -7.131609371651683,
 -7.388476603609273,
 -7.731130821575868,
 -8.0216944128351,
 -7.149810385352723,
 -7.355950309062479,
 -6.950356959883898,
 -7.44635546527844]

[2m[33m(raylet)[0m   if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"):
[2m[33m(raylet)[0m   if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"):
[2m[33m(raylet)[0m   if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"):
