In [1]:
import time

from absl import app
import haiku as hk
import jax
import jax.numpy as jnp
import jraph
import numpy as np

In [2]:
def conway_mlp(x):
  """Implements a MLP representing Conway's game of life rules."""
  w = jnp.array([[0.0, -1.0], [0.0, 1.0], [0.0, 1.0],
                 [0, -1.0], [1.0, 1.0], [1.0, 1.0]])
  b = jnp.array([3.5, -3.5, -1.5, 1.5, -2.5, -3.5])
  h = jnp.maximum(jnp.dot(w, x) + b, 0.)
  w = jnp.array([[2.0, -4.0, 2.0, -4.0, 2.0, -4.0]])
  b = jnp.array([-4.0])
  y = jnp.maximum(jnp.dot(w, h) + b, 0.0)
  return y

In [3]:
def conway_graph(size) -> jraph.GraphsTuple:
  """Returns a graph representing the game field of conway's game of life."""
  # Creates nodes: each node represents a cell in the game.
  n_node = size**2
  nodes = np.zeros((n_node, 1))
  node_indices = jnp.arange(n_node)
  # Creates edges, senders and receivers:
  # the senders represent the connections to the 8 neighboring fields.
  n_edge = 8 * n_node
  edges = jnp.zeros((n_edge, 1))
  senders = jnp.vstack(
      [node_indices - size - 1, node_indices - size, node_indices - size + 1,
       node_indices - 1, node_indices + 1,
       node_indices + size - 1, node_indices + size, node_indices + size + 1])
  senders = senders.T.reshape(-1)
  senders = (senders + size**2) % size**2
  receivers = jnp.repeat(node_indices, 8)
  # Adds a glider to the game
  nodes[0, 0] = 1.0
  nodes[1, 0] = 1.0
  nodes[2, 0] = 1.0
  nodes[2 + size, 0] = 1.0
  nodes[1 + 2 * size, 0] = 1.0
  return jraph.GraphsTuple(n_node=jnp.array([n_node]),
                           n_edge=jnp.array([n_edge]),
                           nodes=jnp.asarray(nodes),
                           edges=edges,
                           globals=None,
                           senders=senders,
                           receivers=receivers)


In [4]:
def display_graph(graph: jraph.GraphsTuple):
  """Prints the nodes of the graph representing Conway's game of life."""
  size = int(np.sqrt(np.sum(graph.n_node)))


In [21]:
def display_graph(graph, size):
    def _display_node(node):
        if node == 1.0:
            return 'x'
        else:
            return ' '

    nodes = graph.nodes.copy()
    output = ''.join(_display_node(nodes[i * size + j]) for j in range(size) for i in range(size))
    print('-' * size + '\n' + output)

# Call the display_graph function with your graph and size parameters
# For example: display_graph(your_graph, your_size)


In [17]:
def main(_):

  def net_fn(graph: jraph.GraphsTuple):
    unf = jraph.concatenated_args(conway_mlp)
    net = jraph.InteractionNetwork(
        update_edge_fn=lambda e, n_s, n_r: n_s,
        update_node_fn=jax.vmap(unf))
    return net(graph)

  net = hk.without_apply_rng(hk.transform(net_fn))

  cg = conway_graph(size=20)
  params = net.init(jax.random.PRNGKey(42), cg)
  for _ in range(100):
    time.sleep(0.05)
    cg = jax.jit(net.apply)(params, cg)
    display_graph(cg)

In [22]:
if __name__ == '__main__':
  app.run(main)

FATAL Flags parsing error: Unknown command line flag 'f'
Pass --helpshort or --helpfull to see help on flags.


AttributeError: 'tuple' object has no attribute 'tb_frame'