In [0]:
#@title ##### License
# Copyright 2018 The GraphNets Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

# Physical dynamics of a mass-spring system
This notebook and the accompanying code demonstrates how to use the Graph Nets library to learn to predict the motion of a set of masses connected by springs.

The network is trained to predict the behaviour of a chain of five masses, connected by identical springs. The first and last masses are fixed; the others are subject to gravity.

After training, the network's prediction ability is illustrated by comparing its output to the true behaviour of the structure. Then the network's ability to generalise is tested, by using it to predict the behaviour of a similar but more complicated mass/spring structure.

In [2]:
#@title ### Install the Graph Nets library on this Colaboratory runtime  { form-width: "60%", run: "auto"}
#@markdown <br>1. Connect to a local or hosted Colaboratory runtime by clicking the **Connect** button at the top-right.<br>2. Choose "Yes" below to install the Graph Nets library on the runtime machine with:<br> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;```pip install graph_nets```<br> Note, this works both with local and hosted Colaboratory runtimes.

install_graph_nets_library = "Yes"  #@param ["Yes", "No"]

if install_graph_nets_library.lower() == "yes":
  print("Installing Graph Nets library with:")
  print("  $ pip install graph_nets\n")
  print("Output message from command:\n")
  !pip install graph_nets
else:
  print("Skipping installation of Graph Nets library")

Installing Graph Nets library with:
  $ pip install graph_nets

Output message from command:



### Install dependencies locally

If you are running this notebook locally (i.e., not through Colaboratory), you will also need to install a few more dependencies. Run the following on the command line to install the graph networks library, as well as a few other dependencies:

```
pip install graph_nets matplotlib scipy
```

# Code

In [3]:
#@title Imports  { form-width: "30%" }

# The demo dependencies are not installed with the library, but you can install
# them with:
#
# $ pip install jupyter matplotlib scipy
#
# Run the demo with:
#
# $ jupyter notebook <path>/<to>/<demos>/shortest_path.ipynb

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time

from graph_nets import blocks
from graph_nets import utils_tf
from graph_nets.demos import models
from matplotlib import pyplot as plt
import numpy as np
import sonnet as snt
import tensorflow as tf

try:
  import seaborn as sns
except ImportError:
  pass
else:
  sns.reset_orig()

# SEED = 1
# np.random.seed(SEED)
# tf.set_random_seed(SEED)


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.



In [0]:
def base_graph(n, d):
  nodes = np.zeros((n, 5), dtype=np.float32)
  half_width = d * n / 2.0
  nodes[:, 0] = np.linspace(
      -half_width, half_width, num=n, endpoint=False, dtype=np.float32)
  # indicate that the first and last masses are fixed
  nodes[(0, -1), -1] = 1.

  # Edges.
  edges, senders, receivers = [], [], []
  for i in range(n - 1):
    left_node = i
    right_node = i + 1
    # The 'if' statements prevent incoming edges to fixed ends of the string.
    if right_node < n - 1:
      # Left incoming edge.
      edges.append([50., d])
      senders.append(left_node)
      receivers.append(right_node)
    if left_node > 0:
      # Right incoming edge.
      edges.append([50., d])
      senders.append(right_node)
      receivers.append(left_node)

  
  return {
      "globals": [0., -10.],
      "nodes": nodes,
      "edges": edges,
      "receivers": list(receivers),
      "senders": list(senders)
  }


def permute_graphs(graphs):
  
  # Permutated graphs
  perms = []
  
  # Generate n permutations
  for g in graphs:
    # Get number of nodes
    n = len(g["nodes"])
    # Generate random permutation
    p = np.array(np.random.permutation(n))
    # Append permutated graph to list
    perms.append({
      "globals": [0., -10.],
      "nodes": g["nodes"][p],
      "edges": g["edges"],
      "receivers": list(p[np.array(g["senders"])]),
      "senders": list(p[np.array(g["receivers"])])
    })
    
  return perms

In [11]:
graphs = [base_graph(n=4, d=4.0)]
print(graphs)
perms = permute_graphs(graphs)
print(perms)

[{'nodes': array([[-8.,  0.,  0.,  0.,  1.],
       [-4.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.],
       [ 4.,  0.,  0.,  0.,  1.]], dtype=float32), 'globals': [0.0, -10.0], 'senders': [0, 1, 2, 3], 'edges': [[50.0, 4.0], [50.0, 4.0], [50.0, 4.0], [50.0, 4.0]], 'receivers': [1, 2, 1, 2]}]
[{'nodes': array([[ 0.,  0.,  0.,  0.,  0.],
       [-4.,  0.,  0.,  0.,  0.],
       [ 4.,  0.,  0.,  0.,  1.],
       [-8.,  0.,  0.,  0.,  1.]], dtype=float32), 'globals': [0.0, -10.0], 'senders': [1, 3, 1, 3], 'edges': [[50.0, 4.0], [50.0, 4.0], [50.0, 4.0], [50.0, 4.0]], 'receivers': [2, 1, 3, 0]}]
