In [1]:
import tensorflow as tf
print(tf.__version__)

gpus = tf.config.list_physical_devices('GPU')
print(gpus)

for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

2.8.0
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [2]:
from graphenv.examples.tsp.graph_utils import make_complete_planar_graph
from graphenv.graph_env import GraphEnv
from graphenv.examples.tsp.tsp_state import TSPState
from graphenv.examples.tsp.tsp_nfp_state import TSPNFPState

from graphenv.examples.tsp.tsp_model import TSPModel
from graphenv.examples.tsp.tsp_nfp_model import TSPGNNModel

In [3]:
G_fn = lambda : make_complete_planar_graph(N=40)

tsp_state = TSPState(G_fn)
tsp_nfp_state = TSPNFPState(G_fn, max_num_neighbors=10)

In [4]:
%%timeit
tsp_state._make_observation()

38.7 µs ± 2.37 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [5]:
%%timeit
tsp_nfp_state._make_observation()

6.61 µs ± 323 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [6]:
tsp_genv = GraphEnv({'state': tsp_state, 'max_num_children': 40})
tsp_nfp_genv = GraphEnv({'state': tsp_nfp_state, 'max_num_children': 40})

In [7]:
%%timeit
tsp_genv.make_observation()

# Need to reset the state to account for observation caching
tsp_genv.state = tsp_state.new()

1.73 ms ± 1.44 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [8]:
%%timeit
tsp_nfp_genv.make_observation()

# Need to reset the state to account for observation caching
tsp_nfp_genv.state = tsp_nfp_state.new()

589 µs ± 480 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [9]:
input_obs = tsp_genv.make_observation()
input_nfp_obs = tsp_nfp_genv.make_observation()

In [10]:
tsp_model = TSPModel._create_base_model(hidden_dim=256, embed_dim=256, num_nodes=40)
tsp_nfp_model = TSPGNNModel._create_base_model(num_messages=1, embed_dim=256)

In [11]:
%%timeit 

with tf.GradientTape() as tape:
    tsp_model(input_obs['vertex_observations'])

3.76 ms ± 102 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
%%timeit 

with tf.GradientTape() as tape:
    tsp_nfp_model(input_nfp_obs['vertex_observations'])

17.3 ms ± 279 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
