In [3]:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
%matplotlib notebook
import time

from graph_nets import blocks
from graph_nets import utils_tf
from graph_nets import utils_np
from graph_nets.demos import models
from matplotlib import pyplot as plt
from matplotlib import animation
import numpy as np
import pandas as pd
import networkx as nx
import sonnet as snt
import tensorflow as tf
import os
from Simulation_functions import SpringMassSimulator, generate_trajectory, roll_out_physics
from Graph_creator_functions import base_graph


try:
    import seaborn as sns
except ImportError:
    pass
else:
    sns.reset_orig()

SEED = 1
np.random.seed(SEED)
tf.set_random_seed(SEED)

examples.directory is deprecated; in the future, examples will be found relative to the 'datapath' directory.
  "found relative to the 'datapath' directory.".format(key))
The text.latex.unicode rcparam was deprecated in Matplotlib 2.2 and will be removed in 3.1.
  "2.2", name=key, obj_type="rcparam", addendum=addendum)


## Declare loss opration

In [4]:

def create_loss_ops_with_energy(target_op, output_ops,target_global_energy):
    """Create supervised loss operations from targets and outputs.

    Args:
        target_op: The target velocity tf.Tensor.
        output_ops: The list of output graphs from the model.

    Returns:
        A list of loss values (tf.Tensor), one per output op."""

    
    loss_ops = [
            tf.reduce_mean(
            tf.reduce_sum(
                (tf.cast(output_op.globals[...,2:3],tf.float64) - tf.cast(target_global_energy[..., 2:3],tf.float64))**2, axis=-1))
            + tf.reduce_mean(
            tf.reduce_sum(
                (tf.cast(output_op.nodes,tf.float64) - tf.cast(target_op[..., 2:4],tf.float64))**2, axis=-1))
             for output_op in output_ops
    ]
    return loss_ops


def create_loss_ops(target_op, output_ops):
    """Create supervised loss operations from targets and outputs.

    Args:
        target_op: The target velocity tf.Tensor.
        output_ops: The list of output graphs from the model.

    Returns:
        A list of loss values (tf.Tensor), one per output op.""" 
    loss_ops = [
            tf.reduce_mean(
            tf.reduce_sum((tf.cast(output_op.nodes,tf.float64) - tf.cast(target_op[..., 2:4],tf.float64))**2, axis=-1))
            for output_op in output_ops
    ]
    return loss_ops





def make_all_runnable_in_session(*args):
    """Apply make_runnable_in_session to an iterable of graphs."""
    return [utils_tf.make_runnable_in_session(a) for a in args]


# pylint: enable=redefined-outer-name

# Training the Network: Chain with Energy
## Generate:
- Model
- Training trajectory
- Generalization trajectories: 4 mass, 9 mass

In [5]:

tf.reset_default_graph()

rand = np.random.RandomState(SEED)

# Model parameters.
num_processing_steps_tr = 1
num_processing_steps_ge = 1

# Data / training parameters.~
num_training_iterations = 20000
batch_size_tr = 256
batch_size_ge = 100
num_time_steps = 50
step_size = 0.1

num_masses_min_max_tr = (5, 9)


# Create the model.
model = models.EncodeProcessDecode(node_output_size=2,edge_output_size=2,global_output_size=3)

num_masses_tr = rand.randint(*num_masses_min_max_tr, size=batch_size_tr)
static_graph_tr = [base_graph(n,True, 50., 0.7, 0.) for n in num_masses_tr]
#base_graph_tr_np = utils_np.data_dicts_to_graphs_tuple([base_graph(4,True, 50., 0.7, 0.)]*)

base_graph_tr =  utils_tf.data_dicts_to_graphs_tuple(static_graph_tr)
#base_graph_tr =  utils_tf.data_dicts_to_graphs_tuple([base_graph(4,True, 50., 0.7, 0.)] * batch_size_tr)

base_graph_4_ge = utils_tf.data_dicts_to_graphs_tuple(
    [base_graph(4,True, 50., 0.7, 0.)] * batch_size_ge)
# 9 masses 0.5m apart in a chain like structure.
base_graph_9_ge = utils_tf.data_dicts_to_graphs_tuple(
    [base_graph(9,True, 50., 0.7, 0.)] * batch_size_ge)

simulator = SpringMassSimulator(step_size=step_size)
# Training.
# Generate a training trajectory by adding noise to initial
# position, spring constants and gravity
initial_conditions_tr, true_trajectory_tr, true_globals_tr_per_step = generate_trajectory(
    simulator,
    base_graph_tr,
    num_time_steps,
    step_size,
    node_noise_level=0.04,
    edge_noise_level=5.0,
    global_noise_level=1.0,
    do_set_rest=True,do_apply_gravity=True)

t = tf.random_uniform([], minval=0, maxval=num_time_steps - 1, dtype=tf.int32)
input_graph_tr = initial_conditions_tr.replace(nodes=true_trajectory_tr[t])
target_nodes_tr = true_trajectory_tr[t + 1]
output_ops_tr = model(input_graph_tr, num_processing_steps_tr)


# Test data: 4-mass string.
initial_conditions_4_ge, _, _ = generate_trajectory(
    lambda x: model(x, num_processing_steps_ge),
    base_graph_4_ge,
    num_time_steps,
    step_size,
    node_noise_level=0.04,
    edge_noise_level=5.0,
    global_noise_level=1.0,
    do_set_rest=True,do_apply_gravity=True)

_, true_nodes_rollout_4_ge, true_globals_4_per_step = roll_out_physics(
    simulator, initial_conditions_4_ge, num_time_steps, step_size)
_, predicted_nodes_rollout_4_ge, predicted_globals_4_per_step = roll_out_physics(
    lambda x: model(x, num_processing_steps_ge), initial_conditions_4_ge,num_time_steps, step_size)



# Test data: 9-mass string.
initial_conditions_9_ge, _, _ = generate_trajectory(
    lambda x: model(x, num_processing_steps_ge),
    base_graph_9_ge,
    num_time_steps,
    step_size,
    node_noise_level=0.04,
    edge_noise_level=5.0,
    global_noise_level=1.0,
    do_set_rest=True,do_apply_gravity=True)

_, true_nodes_rollout_9_ge,true_globals_9_per_step = roll_out_physics(
    simulator, initial_conditions_9_ge, num_time_steps, step_size)
_, predicted_nodes_rollout_9_ge,predicted_globals_9_per_step = roll_out_physics(
    lambda x: model(x, num_processing_steps_ge), initial_conditions_9_ge,num_time_steps, step_size)

# Training loss.
#loss_ops_tr = create_loss_ops_with_energy(target_nodes_tr, output_ops_tr,true_globals_tr_per_step)
loss_ops_tr = create_loss_ops(target_nodes_tr, output_ops_tr)
# Training loss across processing steps.
loss_op_tr = sum(loss_ops_tr) / num_processing_steps_tr
# Test/generalization loss: 4-mass.
loss_op_4_ge = tf.reduce_mean(
                tf.reduce_sum(
                    (predicted_nodes_rollout_4_ge[..., 2:4]-true_nodes_rollout_4_ge[..., 2:4])**2,axis=-1))
     
# Test/generalization loss: 9-mass string.
loss_op_9_ge =tf.reduce_mean(
                tf.reduce_sum(
                    (predicted_nodes_rollout_9_ge[..., 2:4] - true_nodes_rollout_9_ge[..., 2:4])**2,axis=-1))

# Optimizer.
learning_rate = 1e-3
optimizer = tf.train.AdamOptimizer(learning_rate)
step_op = optimizer.minimize(loss_op_tr)

input_graph_tr = make_all_runnable_in_session(input_graph_tr)
initial_conditions_4_ge = make_all_runnable_in_session(initial_conditions_4_ge)
initial_conditions_9_ge = make_all_runnable_in_session(initial_conditions_9_ge)


In [6]:
#@title Reset session  { form-width: "30%" }

# This cell resets the Tensorflow session, but keeps the same computational
# graph.

try:
    sess.close()
except NameError:
    pass
sess = tf.Session()
sess.run(tf.global_variables_initializer())

last_iteration = 0
logged_iterations = []
losses_tr = []
losses_4_ge = []
losses_9_ge = []

## Training the Network

In [7]:
# How much time between logging and printing the current results.
log_every_seconds = 20

print("# (iteration number), T (elapsed seconds), "
      "Ltr (training 1-step loss), "
      "Lge4 (test/generalization rollout loss for 4-mass strings), "
      "Lge9 (test/generalization rollout loss for 9-mass strings)")

start_time = time.time()
last_log_time = start_time
for iteration in range(last_iteration, num_training_iterations):
    last_iteration = iteration
    train_values = sess.run({
        "step": step_op,
        "loss": loss_op_tr,
        "input_graph": input_graph_tr,
        "target_nodes": target_nodes_tr,
        "target_globals":true_globals_tr_per_step,
        "outputs": output_ops_tr})
    the_time = time.time()
    elapsed_since_last_log = the_time - last_log_time
    if elapsed_since_last_log > log_every_seconds:
        last_log_time = the_time
        test_values = sess.run({
            "loss_4": loss_op_4_ge,
            "true_rollout_4": true_nodes_rollout_4_ge,
            "true_globals_4": true_globals_4_per_step,
            "predicted_rollout_4": predicted_nodes_rollout_4_ge,
            "predicted_globals_4": predicted_globals_4_per_step,
            "loss_9": loss_op_9_ge,
            "true_rollout_9": true_nodes_rollout_9_ge,
            "true_globals_9": true_globals_9_per_step,
            "predicted_rollout_9": predicted_nodes_rollout_9_ge,
            "predicted_globals_9": predicted_globals_9_per_step})
        elapsed = time.time() - start_time
        losses_tr.append(train_values["loss"])
        losses_4_ge.append(test_values["loss_4"])
        losses_9_ge.append(test_values["loss_9"])
        logged_iterations.append(iteration)
        print("# {:05d}, T {:.1f}, Ltr {:.4f}, Lge4 {:.4f}, Lge9 {:.4f}".format(
            iteration, elapsed, train_values["loss"], test_values["loss_4"],test_values["loss_9"]))

# (iteration number), T (elapsed seconds), Ltr (training 1-step loss), Lge4 (test/generalization rollout loss for 4-mass strings), Lge9 (test/generalization rollout loss for 9-mass strings)
# 00033, T 22.0, Ltr 3.5378, Lge4 3.9037, Lge9 14.8346
# 00067, T 41.3, Ltr 1.7837, Lge4 5.9813, Lge9 12.6072
# 00103, T 61.7, Ltr 2.4057, Lge4 7.2173, Lge9 8.8781
# 00139, T 82.1, Ltr 1.6400, Lge4 5.8571, Lge9 15.2971
# 00175, T 102.6, Ltr 0.9033, Lge4 5.6127, Lge9 10.5668
# 00211, T 123.0, Ltr 2.3841, Lge4 5.8981, Lge9 5.8321
# 00247, T 143.4, Ltr 0.7259, Lge4 6.5879, Lge9 9.5665
# 00283, T 163.8, Ltr 1.4433, Lge4 4.4225, Lge9 7.4833
# 00319, T 184.1, Ltr 1.6930, Lge4 7.6906, Lge9 8.2032
# 00355, T 204.5, Ltr 0.9885, Lge4 5.3643, Lge9 9.3101
# 00391, T 224.9, Ltr 1.5715, Lge4 2.8951, Lge9 12.2117
# 00427, T 245.4, Ltr 0.1381, Lge4 6.4347, Lge9 11.5013
# 00463, T 265.7, Ltr 1.2998, Lge4 6.8088, Lge9 9.4773
# 00499, T 286.1, Ltr 0.1944, Lge4 8.1700, Lge9 12.3706
# 00535, T 306.7, Ltr 1.5370, Lge4 5.

# 05177, T 2922.1, Ltr 0.8664, Lge4 2.9340, Lge9 11.6566
# 05213, T 2942.3, Ltr 0.6618, Lge4 2.0563, Lge9 8.5253
# 05249, T 2962.6, Ltr 0.7265, Lge4 2.2771, Lge9 14.5177
# 05285, T 2982.9, Ltr 0.5447, Lge4 1.2615, Lge9 13.4713
# 05321, T 3003.1, Ltr 0.7648, Lge4 1.6380, Lge9 5.9344
# 05357, T 3023.4, Ltr 0.7832, Lge4 3.5192, Lge9 8.8268
# 05393, T 3043.7, Ltr 0.4952, Lge4 1.4690, Lge9 12.8250
# 05429, T 3063.9, Ltr 0.6624, Lge4 4.1234, Lge9 11.1902
# 05465, T 3084.2, Ltr 0.5780, Lge4 2.3662, Lge9 9.4318
# 05501, T 3104.4, Ltr 0.4303, Lge4 3.6623, Lge9 16.1990
# 05537, T 3124.7, Ltr 0.3743, Lge4 3.5826, Lge9 11.7426
# 05573, T 3145.0, Ltr 0.4592, Lge4 2.5085, Lge9 4.7099
# 05609, T 3165.2, Ltr 0.6307, Lge4 1.8167, Lge9 10.0379
# 05645, T 3185.5, Ltr 0.6804, Lge4 1.8177, Lge9 7.5906
# 05681, T 3205.8, Ltr 0.4630, Lge4 1.4686, Lge9 5.5741
# 05717, T 3226.0, Ltr 0.4652, Lge4 3.5063, Lge9 11.8213
# 05753, T 3246.3, Ltr 0.3004, Lge4 3.6554, Lge9 5.1668
# 05789, T 3266.6, Ltr 0.4864, Lge4 1.0

# 10421, T 5882.5, Ltr 0.1253, Lge4 0.6285, Lge9 11.0122
# 10457, T 5902.8, Ltr 0.1780, Lge4 1.5674, Lge9 8.9181
# 10493, T 5923.3, Ltr 0.1548, Lge4 2.3543, Lge9 19.6111
# 10529, T 5943.5, Ltr 0.3532, Lge4 1.2386, Lge9 13.9278
# 10565, T 5963.8, Ltr 0.2287, Lge4 1.5384, Lge9 10.9259
# 10601, T 5984.0, Ltr 0.4372, Lge4 1.8022, Lge9 8.8258
# 10637, T 6004.3, Ltr 0.3576, Lge4 1.2821, Lge9 14.0175
# 10673, T 6024.6, Ltr 0.3634, Lge4 5.0412, Lge9 23.6823
# 10709, T 6044.9, Ltr 0.4247, Lge4 0.6665, Lge9 14.5676
# 10745, T 6065.1, Ltr 0.3271, Lge4 0.5038, Lge9 9.8936
# 10781, T 6085.4, Ltr 0.3527, Lge4 1.0877, Lge9 10.4220
# 10817, T 6105.7, Ltr 0.3641, Lge4 0.4343, Lge9 12.0975
# 10853, T 6126.2, Ltr 0.4187, Lge4 1.3964, Lge9 14.0120
# 10889, T 6146.4, Ltr 0.1216, Lge4 2.3962, Lge9 11.1549
# 10925, T 6166.7, Ltr 0.1076, Lge4 0.9248, Lge9 8.6399
# 10961, T 6186.9, Ltr 0.2598, Lge4 0.4620, Lge9 10.6545
# 10997, T 6207.2, Ltr 0.0818, Lge4 1.2946, Lge9 23.4936
# 11033, T 6227.4, Ltr 0.3320, Lge4

# 15675, T 8841.3, Ltr 0.3333, Lge4 1.1529, Lge9 16.2767
# 15711, T 8861.6, Ltr 0.2418, Lge4 0.4827, Lge9 2.7863
# 15747, T 8882.0, Ltr 0.0717, Lge4 0.8007, Lge9 5.1163
# 15783, T 8902.3, Ltr 0.1651, Lge4 1.0840, Lge9 7.9401
# 15819, T 8922.5, Ltr 0.2856, Lge4 2.1255, Lge9 5.5088
# 15855, T 8942.8, Ltr 0.2522, Lge4 1.0254, Lge9 7.4164
# 15891, T 8963.0, Ltr 0.0974, Lge4 2.3544, Lge9 10.3806
# 15927, T 8983.3, Ltr 0.2674, Lge4 1.1117, Lge9 5.8766
# 15963, T 9003.6, Ltr 0.4556, Lge4 1.1855, Lge9 16.9986
# 15999, T 9023.9, Ltr 0.2913, Lge4 2.5360, Lge9 17.3111
# 16035, T 9044.1, Ltr 0.2715, Lge4 1.4236, Lge9 14.1685
# 16071, T 9064.4, Ltr 0.3894, Lge4 1.0482, Lge9 6.6785
# 16107, T 9084.7, Ltr 0.3624, Lge4 1.0297, Lge9 12.3459
# 16143, T 9104.9, Ltr 0.0389, Lge4 1.7926, Lge9 6.5270
# 16179, T 9125.2, Ltr 0.2244, Lge4 0.7438, Lge9 4.6889
# 16215, T 9145.5, Ltr 0.2975, Lge4 0.7662, Lge9 8.6073
# 16251, T 9165.8, Ltr 0.2229, Lge4 0.6424, Lge9 3.7023
# 16286, T 9185.9, Ltr 0.0461, Lge4 0.4254

In [10]:
# Get the weights for the model
"""
print(model.name_scopes)
print()
for i in model.trainable_variables[0:6]:
    print()
    print(i)
"""
    
tvars = tf.trainable_variables()
tvars_vals = sess.run(tvars)

for var, val in zip(tvars, tvars_vals):
    print(var.name, val.shape)  # Prints the name of the variable alongside its value.


MLPGraphIndependent/graph_independent/edge_model/mlp/linear_0/w:0 (3, 16)
MLPGraphIndependent/graph_independent/edge_model/mlp/linear_0/b:0 (16,)
MLPGraphIndependent/graph_independent/edge_model/mlp/linear_1/w:0 (16, 16)
MLPGraphIndependent/graph_independent/edge_model/mlp/linear_1/b:0 (16,)
MLPGraphIndependent/graph_independent/edge_model/layer_norm/gamma:0 (16,)
MLPGraphIndependent/graph_independent/edge_model/layer_norm/beta:0 (16,)
MLPGraphIndependent/graph_independent/node_model/mlp/linear_0/w:0 (5, 16)
MLPGraphIndependent/graph_independent/node_model/mlp/linear_0/b:0 (16,)
MLPGraphIndependent/graph_independent/node_model/mlp/linear_1/w:0 (16, 16)
MLPGraphIndependent/graph_independent/node_model/mlp/linear_1/b:0 (16,)
MLPGraphIndependent/graph_independent/node_model/layer_norm/gamma:0 (16,)
MLPGraphIndependent/graph_independent/node_model/layer_norm/beta:0 (16,)
MLPGraphIndependent/graph_independent/global_model/mlp/linear_0/w:0 (3, 16)
MLPGraphIndependent/graph_independent/global

In [12]:
BASE_PATH = "/home/vabence/git_workspace/Dyadic_Collaboration/Graph_Networks/Learning_Experiments/Results/Chain_without_Energy"

def get_node_trajectories(rollout_array, batch_size):  # pylint: disable=redefined-outer-name
    return np.split(rollout_array[..., :2], batch_size, axis=1)

def get_energy_trajectories(energy_array, batch_size):
    return np.split(energy_array[..., 2:3], batch_size, axis=1)

"""
#Store the data
true_rollouts_4 = get_node_trajectories(test_values["true_rollout_4"],batch_size_ge)
true_trajectory_4_np = np.array(true_rollouts_4)
true_energy_4 = get_energy_trajectories(test_values["true_globals_4"],batch_size_ge)
true_energy_4_np = np.array(true_energy_4)

predicted_rollouts_4 = get_node_trajectories(test_values["predicted_rollout_4"],batch_size_ge)
predicted_trajectory_4_np=np.array(predicted_rollouts_4)
predicted_energy_4 = get_energy_trajectories(test_values["predicted_globals_4"],batch_size_ge)
predicted_energy_4_np = np.array(predicted_energy_4)

true_rollouts_9 = get_node_trajectories(test_values["true_rollout_9"],batch_size_ge)
true_trajectory_9_np = np.array(true_rollouts_9)
true_energy_9 = get_energy_trajectories(test_values["true_globals_9"],batch_size_ge)
true_energy_9_np = np.array(true_energy_9)

predicted_rollouts_9 = get_node_trajectories(test_values["predicted_rollout_9"],batch_size_ge)
predicted_trajectory_9_np=np.array(predicted_rollouts_9)
predicted_energy_9 = get_energy_trajectories(test_values["predicted_globals_9"],batch_size_ge)
predicted_energy_9_np = np.array(predicted_energy_9)

# Saving the data
np.save(os.path.join(BASE_PATH, "true_trajectory_4.npy"), true_trajectory_4_np)
np.save(os.path.join(BASE_PATH, "predicted_trajectory_4.npy"), predicted_trajectory_4_np)
np.save(os.path.join(BASE_PATH, "true_energy_4.npy"), true_energy_4_np)
np.save(os.path.join(BASE_PATH, "predicted_energy_4.npy"), predicted_energy_4_np)

np.save(os.path.join(BASE_PATH, "true_trajectory_9.npy"), true_trajectory_9_np)
np.save(os.path.join(BASE_PATH, "predicted_trajectory_9.npy"), predicted_trajectory_9_np)
np.save(os.path.join(BASE_PATH, "true_energy_9.npy"), true_energy_9_np)
np.save(os.path.join(BASE_PATH, "predicted_energy_9.npy"), predicted_energy_9_np)
"""
#Getting the data
true_trajectory_4_np = np.load(os.path.join(BASE_PATH, "true_trajectory_4.npy"))[-1]
predicted_trajectory_4_np = np.load(os.path.join(BASE_PATH, "predicted_trajectory_4.npy"))[-1]
true_energy_4_np = np.load(os.path.join(BASE_PATH, "true_energy_4.npy"))[-1]
predicted_energy_4_np = np.load(os.path.join(BASE_PATH, "predicted_energy_4.npy"))[-1]

true_trajectory_9_np = np.load(os.path.join(BASE_PATH, "true_trajectory_9.npy"))[0]
predicted_trajectory_9_np = np.load(os.path.join(BASE_PATH, "predicted_trajectory_9.npy"))[0]
true_energy_9_np = np.load(os.path.join(BASE_PATH, "true_energy_9.npy"))[0]
predicted_energy_9_np = np.load(os.path.join(BASE_PATH, "predicted_energy_9.npy"))[0]


fig = plt.figure(1, figsize=(18, 3))
fig.clf()
x = np.array(logged_iterations)
# Next-step Loss.
y = losses_tr
ax = fig.add_subplot(1, 3, 1)
ax.plot(x, y, "k")
ax.set_title("Next step loss")


# Rollout 5 loss.
y = losses_4_ge
ax = fig.add_subplot(1, 3, 2)
ax.plot(x, y, "k")
ax.set_title("Rollout loss: 4-mass string")
# Rollout 9 loss.
y = losses_9_ge
ax = fig.add_subplot(1, 3, 3)
ax.plot(x, y, "k")
ax.set_title("Rollout loss: 9-mass string")
plt.show()

<IPython.core.display.Javascript object>

In [13]:
# Visualize trajectories
plt.close('all')
fig_animate1 = plt.figure(1, figsize=(5, 5))
ax1 = fig_animate1.add_subplot(1, 1, 1)
# Visualize trajectories for number_of_masses = 4


#energy_text = ax1.text(0.02, 0.90, '', transform=ax1.transAxes)
dots = []
dots1 = []

num_nodes = true_trajectory_4_np.shape[1]
    
def init():
    for i in range(num_nodes):
        dots.append(ax1.plot([], [], linestyle='none', marker='o', markersize=3, color="r"))
        dots1.append(ax1.plot([], [], linestyle='none', marker='o', markersize=3, color="k"))
    #energy_text.set_text('')
    return dots,dots1#, energy_text
    
ax1.set_xlim(-5, 5)
ax1.set_ylim(-5, 5)


def animate(z):
    for i in range(num_nodes):
        dots[i][0].set_data(true_trajectory_4_np[z,i,0],true_trajectory_4_np[z,i,1])
        dots1[i][0].set_data(predicted_trajectory_4_np[z,i,0],predicted_trajectory_4_np[z,i,1])
    legend = ax1.legend
    #energy_text.set_text("true_energy = %.3f J, pred_energy = %.3f J" % (true_energy_4_np[z],predicted_energy_4_np[z]))
    return dots,dots1#,energy_text



anim = animation.FuncAnimation(fig_animate1, animate,init_func = init, interval = 0.1 * 1000, frames=50, blit=False, repeat=True)
ax1.set_title("Trajectory for 4 nodes")
ax1.set_xlabel("x")
ax1.set_ylabel("y")
plt.show()

<IPython.core.display.Javascript object>

In [14]:
# Visualize trajectories
plt.close('all')
fig_animate3 = plt.figure(1, figsize=(5, 5))
ax3 = fig_animate3.add_subplot(1, 1, 1)

# Visualize trajectories for number_of_masses = 9
#energy_text = ax3.text(0.02, 0.90, '', transform=ax3.transAxes)
dots = []
dots1 = []

num_nodes = true_trajectory_9_np.shape[1]
    
def init():
    for i in range(num_nodes):
        if(i==0):
            dots.append(ax3.plot([], [], linestyle='none', marker='o', markersize=5, color="r",label="Truth"))
            dots1.append(ax3.plot([], [], linestyle='none', marker='o', markersize=5, color="k",label="Prediction"))
        else:
            dots.append(ax3.plot([], [], linestyle='none', marker='o', markersize=5, color="r"))
            dots1.append(ax3.plot([], [], linestyle='none', marker='o', markersize=5, color="k"))
    #energy_text.set_text('')
    return dots,dots1#, energy_text
    
ax3.set_xlim(-10, 10)
ax3.set_ylim(-10, 10)


def animate(z):
    for i in range(num_nodes):
        dots[i][0].set_data(true_trajectory_9_np[z,i,0],true_trajectory_9_np[z,i,1])
        dots1[i][0].set_data(predicted_trajectory_9_np[z,i,0],predicted_trajectory_9_np[z,i,1])
    #energy_text.set_text("true_energy = %.3f J, pred_energy = %.3f J" % (true_energy_9_np[z],predicted_energy_9_np[z]))
    legend = ax3.legend()
    return dots, dots1, legend#, energy_text



anim = animation.FuncAnimation(fig_animate3, animate,init_func = init, interval = step_size * 1000, frames=num_time_steps, blit=False, repeat=True)
ax3.set_title("Trajectory for 9 nodes")
ax3.set_xlabel("x")
ax3.set_ylabel("y")
plt.show()

<IPython.core.display.Javascript object>