## Breaking Down the Training Model: `model_fn`

The proposes of this notebook is to understand the model used during the training. The last objetivo to figure out how works `get_one_step_estimator_fn(data_path, noise_std)`

First, import the needed libraries

In [None]:
import sys
sys.path.append('../../')

from learning_to_simulate.train import *
import tensorflow as tf
import json
import os
import graph_nets as gn
import sonnet as snt

tf.compat.v1.enable_eager_execution()

Needed Functions

In [None]:
def _read_metadata(data_path):
  with open(os.path.join(data_path, 'metadata.json'), 'rt') as fp:
    return json.loads(fp.read())

## Module: `EncodeProcessDecode`

For defining the module, it is needed to use sonnet library. The sonnet abstract module, in this version, has two important function the `__init__` where all the known values are loaded, and the `_build()` which is executed when an instance of this class is called with some parameters. 

In [None]:
def build_mlp(
    hidden_size: int, num_hidden_layers: int, output_size: int) -> snt.Module:
  """Builds an MLP."""
  return snt.nets.MLP(
      output_sizes=[hidden_size] * num_hidden_layers + [output_size])


class EncodeProcessDecode(snt.AbstractModule):
  """Encode-Process-Decode function approximator for learnable simulator."""

  def __init__(
      self,
      latent_size: int,
      mlp_hidden_size: int,
      mlp_num_hidden_layers: int,
      num_message_passing_steps: int,
      output_size: int,
      name: str = "EncodeProcessDecode"):
    """Inits the model.

    Args:
      latent_size: Size of the node and edge latent representations.
      mlp_hidden_size: Hidden layer size for all MLPs.
      mlp_num_hidden_layers: Number of hidden layers in all MLPs.
      num_message_passing_steps: Number of message passing steps.
      output_size: Output size of the decode node representations as required
        by the downstream update function.
      name: Name of the model.
    """

    super().__init__(name=name)

    self._latent_size = latent_size
    self._mlp_hidden_size = mlp_hidden_size
    self._mlp_num_hidden_layers = mlp_num_hidden_layers
    self._num_message_passing_steps = num_message_passing_steps
    self._output_size = output_size

    with self._enter_variable_scope():
      self._networks_builder()

  def _build(self, input_graph: gn.graphs.GraphsTuple) -> tf.Tensor:
    """Forward pass of the learnable dynamics model."""

    # Encode the input_graph.
    latent_graph_0 = self._encode(input_graph)

    # Do `m` message passing steps in the latent graphs.
    latent_graph_m = self._process(latent_graph_0)

    # Decode from the last latent graph.
    return self._decode(latent_graph_m)

  def _networks_builder(self):
    """Builds the networks."""

    def build_mlp_with_layer_norm():
      # returns a mlp sonnet module with layers 
      # = [hidden_size] * num_hidden_layers + [output_size]
      mlp = build_mlp(
          hidden_size=self._mlp_hidden_size,
          num_hidden_layers=self._mlp_num_hidden_layers,
          output_size=self._latent_size)
      #adding a final normal layer
      return snt.Sequential([mlp, snt.LayerNorm()])

    # The encoder graph network independently encodes edge and node features.
    encoder_kwargs = dict(
        edge_model_fn=build_mlp_with_layer_norm,
        node_model_fn=build_mlp_with_layer_norm)
    self._encoder_network = gn.modules.GraphIndependent(**encoder_kwargs)

    # Create `num_message_passing_steps` graph networks with unshared parameters
    # that update the node and edge latent features.
    # Note that we can use `modules.InteractionNetwork` because
    # it also outputs the messages as updated edge latent features.
    self._processor_networks = []
    for _ in range(self._num_message_passing_steps):
      self._processor_networks.append(
          gn.modules.InteractionNetwork(
              edge_model_fn=build_mlp_with_layer_norm,
              node_model_fn=build_mlp_with_layer_norm))

    # The decoder MLP decodes node latent features into the output size.
    self._decoder_network = build_mlp(
        hidden_size=self._mlp_hidden_size,
        num_hidden_layers=self._mlp_num_hidden_layers,
        output_size=self._output_size)

  def _encode(
      self, input_graph: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple:
    """Encodes the input graph features into a latent graph."""

    # Copy the globals to all of the nodes, if applicable.
    if input_graph.globals is not None:
      broadcasted_globals = gn.blocks.broadcast_globals_to_nodes(input_graph)
      input_graph = input_graph.replace(
          nodes=tf.concat([input_graph.nodes, broadcasted_globals], axis=-1),
          globals=None)

    # Encode the node and edge features.
    latent_graph_0 = self._encoder_network(input_graph)
    return latent_graph_0

  def _process(
      self, latent_graph_0: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple:
    """Processes the latent graph with several steps of message passing."""

    # Do `m` message passing steps in the latent graphs.
    # (In the shared parameters case, just reuse the same `processor_network`)
    latent_graph_prev_k = latent_graph_0
    for processor_network_k in self._processor_networks:
      latent_graph_k = self._process_step(
          processor_network_k, latent_graph_prev_k)
      latent_graph_prev_k = latent_graph_k

    latent_graph_m = latent_graph_k
    return latent_graph_m

  def _process_step(
      self, processor_network_k: snt.Module,
      latent_graph_prev_k: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple:
    """Single step of message passing with node/edge residual connections."""

    # One step of message passing.
    latent_graph_k = processor_network_k(latent_graph_prev_k)

    # Add residuals. 
    # To the new graph add residual information of nodes and edges
    # of the prev graph
    latent_graph_k = latent_graph_k.replace(
        nodes=latent_graph_k.nodes+latent_graph_prev_k.nodes,
        edges=latent_graph_k.edges+latent_graph_prev_k.edges)
    return latent_graph_k

  def _decode(self, latent_graph: gn.graphs.GraphsTuple) -> tf.Tensor:
    """Decodes from the latent graph."""
    return self._decoder_network(latent_graph.nodes)

In [None]:



#Define some global variables
data_path = "../../information/datasets/WaterDropSample"
model_path = "../../information/models/WaterDropSample"

noise_std = 6.7e-4
batch_size = 2
num_steps = 2
    
# To train this model, the authors uses an estimator, which is another way 
# to define a model in tensorflow; it is like a wrapper of the model
# It needs two principal components: 
#     input_fn: input function that return a tf.data.Dataset
#     model_fn: model function that defines the architecture structure.     
model_fn = get_one_step_estimator_fn(data_path, noise_std)
estimator = tf.estimator.Estimator(model_fn = model_fn,model_dir=model_path)

input_fn = get_input_fn(data_path, batch_size,
                            mode='one_step_train', split='train')

estimator.train(
      input_fn= input_fn,
      max_steps=num_steps)


Now, we are going to break down the function to understand the working of this repository

In [15]:
metadata = _read_metadata(data_path)
metadata

{'bounds': [[0.1, 0.9], [0.1, 0.9]],
 'sequence_length': 1000,
 'default_connectivity_radius': 0.015,
 'dim': 2,
 'dt': 0.0025,
 'vel_mean': [-3.964619574176163e-05, -0.00026272129664401046],
 'vel_std': [0.0013722809722366911, 0.0013119977252142715],
 'acc_mean': [2.602686518497945e-08, 1.0721623948191945e-07],
 'acc_std': [6.742962470925277e-05, 8.700719180424815e-05]}

In [22]:
latent_size=128,
hidden_size=128,
hidden_layers=2,
message_passing_steps=10
model_kwargs = dict(
  latent_size=latent_size,
  mlp_hidden_size=hidden_size,
  mlp_num_hidden_layers=hidden_layers,
  num_message_passing_steps=message_passing_steps)
print(model_kwargs)
def func_keyword_arg(**kwargs):
    print(kwargs)

func_keyword_arg(**model_kwargs)

{'latent_size': (128,), 'mlp_hidden_size': (128,), 'mlp_num_hidden_layers': (2,), 'num_message_passing_steps': 10}
{'latent_size': (128,), 'mlp_hidden_size': (128,), 'mlp_num_hidden_layers': (2,), 'num_message_passing_steps': 10}


In [198]:
import tensorflow as tf

tf.reset_default_graph()
a = tf.placeholder(dtype=tf.float32,shape=(), name='a')
d = tf.placeholder(dtype=tf.float32,shape=(), name='d')
b = tf.get_variable(name='b', initializer=tf.zeros_like(d)) #, use_resource = True)
c=a+b
b_init = tf.assign(b, d)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())   
    print(sess.run([c,b_init,b], feed_dict={a:5.,d:10.})) 

[15.0, 10.0, 10.0]
