In [1]:
import numpy as np
import tensorflow as tf
from tqdm.auto import tqdm
print(tf.__version__)

gpus = tf.config.list_physical_devices('GPU')
print(gpus)

for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

    
from graphenv.examples.tsp.graph_utils import make_complete_planar_graph
from graphenv.graph_env import GraphEnv
from graphenv.examples.tsp.tsp_nfp_state import TSPNFPState
from graphenv.examples.tsp.tsp_nfp_model import TSPGNNModel

2.8.0
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [2]:
N = 40
G = make_complete_planar_graph(N=N, seed=0)

tsp_nfp_state = TSPNFPState(G, max_num_neighbors=5)

## Check the greedy search heuristic baseline

In [3]:
import networkx as nx
from networkx.algorithms.approximation.traveling_salesman import greedy_tsp
tsp_approx = nx.approximation.traveling_salesman_problem

path = tsp_approx(G, cycle=True, method=greedy_tsp)
reward_baseline = -sum([G[path[i]][path[i + 1]]["weight"] for i in range(0, N)])
print(f"Networkx greedy reward: {reward_baseline:1.3f}")

Networkx greedy reward: -5.987


## Initialize a model without any trained weights

In [4]:
env = GraphEnv({
    "state": tsp_nfp_state,
    "max_num_children": G.number_of_nodes(),
})

model = TSPGNNModel._create_base_model(num_messages=1, embed_dim=32)

## Sample from the model's logit value predictions with a softmax

In [5]:
def sample_model():
    
    env.reset()
    obs = env.make_observation()
    done = False
    total_reward = 0

    while not done:
        values = model(obs['vertex_observations'])[0]
        masked_action_values = tf.where(
            obs['action_mask'][1:], values[1:, 0], values.dtype.min
        )
        action_probabilities = tf.nn.softmax(masked_action_values).numpy()
        action = np.random.choice(env.max_num_children, size=1, p=action_probabilities)[0]
        obs, reward, done, info = env.step(action)
        total_reward += reward
        
    return total_reward

In [6]:
[sample_model() for _ in tqdm(range(10))]

  0%|          | 0/10 [00:00<?, ?it/s]

[-7.606531147266801,
 -6.833060707694988,
 -7.946361488918026,
 -7.78510613142895,
 -6.698088958210221,
 -7.476044846218038,
 -6.894692868352494,
 -7.568217203405193,
 -7.955649206270243,
 -8.039583309734674]