In [1]:
from ray.rllib.utils.framework import try_import_tf
tf1, tf, tfv = try_import_tf()

gpus = tf.config.list_physical_devices('GPU')
print(gpus)

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [2]:
import numpy as np

from tqdm.auto import tqdm


from graphenv.examples.tsp.graph_utils import make_complete_planar_graph
from graphenv.graph_env import GraphEnv
from graphenv.examples.tsp.tsp_nfp_state import TSPNFPState
from graphenv.examples.tsp.tsp_nfp_model import TSPGNNModel

In [3]:
N = 40

tsp_nfp_state = TSPNFPState(lambda: make_complete_planar_graph(N), max_num_neighbors=5)

## Check the greedy search heuristic baseline

In [4]:
import networkx as nx
from networkx.algorithms.approximation.traveling_salesman import greedy_tsp
tsp_approx = nx.approximation.traveling_salesman_problem

G = make_complete_planar_graph(N)
path = tsp_approx(G, cycle=True, method=greedy_tsp)
reward_baseline = -sum([G[path[i]][path[i + 1]]["weight"] for i in range(0, N)])
print(f"Networkx greedy reward: {reward_baseline:1.3f}")

Networkx greedy reward: -6.065


## Initialize a model without any trained weights

In [5]:
env = GraphEnv({
    "state": tsp_nfp_state,
    "max_num_children": G.number_of_nodes(),
})

model = TSPGNNModel._create_base_model(num_messages=1, embed_dim=32)

## Sample from the model's logit value predictions with a softmax

In [6]:
def sample_model():
    
    env.reset()
    obs = env.make_observation()
    done = False
    total_reward = 0

    while not done:
        values = model(obs['vertex_observations'])[0]
        masked_action_values = tf.where(
            obs['action_mask'][1:], values[1:, 0], values.dtype.min
        )
        action_probabilities = tf.nn.softmax(masked_action_values).numpy()
        action = np.random.choice(env.max_num_children, size=1, p=action_probabilities)[0]
        obs, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        total_reward += reward
        
    return total_reward

In [7]:
[sample_model() for _ in tqdm(range(10))]

  0%|          | 0/10 [00:00<?, ?it/s]

[-7.670711860667094,
 -8.255356356799277,
 -8.231026465888045,
 -7.671346259475466,
 -8.335099720680878,
 -7.620007661195499,
 -7.433247461280753,
 -8.283621356137186,
 -7.445996696252394,
 -7.3782104001350906]

## Create rllib agent

In [8]:
obs, info = env.reset()
env.observation_space.contains(obs)

True

In [9]:
env.observation_space['action_mask'].contains(obs['action_mask'])

True

In [10]:
for key in env.observation_space['vertex_observations'].keys():
    assert env.observation_space['vertex_observations'][key].contains(obs['vertex_observations'][key]), key

In [11]:
env.observation_space['vertex_observations']['connectivity'].shape

(41, 200, 2)

In [12]:
obs['vertex_observations']['connectivity'].shape

(41, 200, 2)

In [13]:
from ray.tune.registry import register_env
from ray.rllib.models import ModelCatalog


ModelCatalog.register_custom_model('TSPGNNModel', TSPGNNModel)
register_env('GraphEnv', lambda config: GraphEnv(config))

config = {
    "env": 'GraphEnv',
    "env_config": {
        "state": tsp_nfp_state,
        "max_num_children": G.number_of_nodes(),
    },
    "model": {
        "custom_model": 'TSPGNNModel',
        "custom_model_config": {"num_messages": 1, "embed_dim": 32},
    },
    "num_workers": 1,
    "num_gpus": 0,
    "framework": "tf2",
    "eager_tracing": True,
}


from ray.rllib.agents import ppo



ppo_config = ppo.DEFAULT_CONFIG.copy()
ppo_config.update(config)
agent = ppo.PPOTrainer(config=ppo_config)

2022-05-18 08:15:42,192	INFO services.py:1374 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m
[2m[33m(raylet)[0m   if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"):




2022-05-18 08:15:51,353	INFO trainable.py:125 -- Trainable.setup took 12.115 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.


In [14]:
env = GraphEnv({
        "state": tsp_nfp_state,
        "max_num_children": G.number_of_nodes(),
    })

def sample_ppo_action():
    # run until episode ends
    episode_reward = 0
    done = False
    obs, info = env.reset()

    while not done:
        action = agent.compute_single_action(obs)
        obs, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        episode_reward += reward
        
    return episode_reward

In [15]:
[sample_ppo_action() for _ in tqdm(range(10))]

  0%|          | 0/10 [00:00<?, ?it/s]

[-7.470961114058433,
 -6.770144662767127,
 -7.407155430808889,
 -8.720809002028735,
 -7.265972881233817,
 -7.404054532323574,
 -8.361671640332627,
 -7.584063264209201,
 -8.228413557010217,
 -6.924827091584273]

[2m[33m(raylet)[0m   if LooseVersion(aiohttp.__version__) < LooseVersion("4.0.0"):
2022-05-18 08:16:02,630	ERROR worker.py:488 -- print_logs: Connection closed by server.
2022-05-18 08:16:02,631	ERROR import_thread.py:83 -- ImportThread: Connection closed by server.
2022-05-18 08:16:02,639	ERROR worker.py:1259 -- listen_error_messages_raylet: Connection closed by server.
*** SIGTERM received at time=1652883364 on cpu 5 ***
PC: @     0x7f175c8f1eb3  (unknown)  epoll_wait
    @     0x7f175d4e1630  (unknown)  (unknown)
[2022-05-18 08:16:04,917 E 27227 27227] logging.cc:317: *** SIGTERM received at time=1652883364 on cpu 5 ***
[2022-05-18 08:16:04,917 E 27227 27227] logging.cc:317: PC: @     0x7f175c8f1eb3  (unknown)  epoll_wait
[2022-05-18 08:16:04,917 E 27227 27227] logging.cc:317:     @     0x7f175d4e1630  (unknown)  (unknown)
