# Learning to Simulate Complex Physics with Graph Neural Networks
### [*Sanchez-Gonzalez et al.*](https://github.com/google-deepmind/deepmind-research/tree/master/learning_to_simulate) *(2020)*
### Ported from TensorFlow/Sonnet + Graph Nets to TensorFlow/Keras + TensorFlow GNN.

## Connect to Google Drive

In [None]:
from google.colab import drive
drive.mount("/content/drive")
DRIVE_DIR = "/content/drive/MyDrive/learning_to_simulate"

## Install dependencies

In [None]:
%pip install tensorflow_gnn

## Environment

In [1]:
import datetime, logging, os, pickle, sys

# if os.path.exists("learning_to_simulate"):
#     %cd learning_to_simulate
#     !git fetch
#     !git checkout no-batching
#     !git pull origin no-batching
# else:
#     !git clone -b no-batching https://github.com/BitTrain/learning_to_simulate.git
#     %cd learning_to_simulate

BASE_DIR = os.getcwd()
PARENT_DIR = os.path.dirname(BASE_DIR)
if PARENT_DIR not in sys.path:
    sys.path.append(PARENT_DIR)

os.environ["TF_USE_LEGACY_KERAS"] = '1'  # tensorflow_gnn requires Keras v2

import tensorflow as tf
import tensorflow_gnn as tfgnn
from learning_to_simulate import utils, settings
from learning_to_simulate.models.learned_simulator import LearnedSimulator
settings.TF_DEBUG_MODE = False  # Eager data, input checks

print("TensorFlow", tf.__version__)
print("TensorFlow GNN", tfgnn.__version__)

2025-06-20 21:04:48.992483: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


TensorFlow 2.16.2
TensorFlow GNN 1.0.3


## Parameters

In [2]:
#  Available datasets from Google DeepMind
"""
    "WaterDrop"
    "Water"
    "Sand"
    "Goop"
    "MultiMaterial"
    "RandomFloor"
    "WaterRamps"
    "SandRamps"
    "FluidShake"
    "FluidShakeBox"
    "Continuous"
    "WaterDrop-XL"
    "Water-3D"
    "Sand-3D"
    "Goop-3D"
"""
DATASET = "WaterDrop"
DRIVE_DIR = os.path.join(BASE_DIR, "datasets", "local")

params = {
    "DATASET": DATASET,
    "DATA_PATH": os.path.join(BASE_DIR, "datasets", "deepmind", DATASET),
    "MODEL_PATH": os.path.join(DRIVE_DIR, DATASET, "weights"),
    "OUTPUT_PATH": os.path.join(DRIVE_DIR, DATASET, "rollouts"),
    "LOG_PATH": os.path.join(DRIVE_DIR, DATASET, "logs"),
    "MODE": "train",
    "BATCH_SIZE": None,  # not supported in this version
    "EVAL_SPLIT": "test",
    "NUM_STEPS": 20_000_000,  # tunable
    "NOISE_STD": 3e-4,  # @S-G, p. 6
    "VELOCITY_CONTEXT_SIZE": 5,  # @S-G, p. 4
    "NUM_PARTICLE_TYPES": 9,  # hardcoded
    "STATIC_PARTICLE_ID": 3,  # hardcoded
}

if not os.path.exists(params["DATA_PATH"]):
    print(f"Dataset '{DATASET}' not found at {params['DATA_PATH']}. Downloading...")
    path_to_script = os.path.join(BASE_DIR, "download_dataset.sh")
    os.system(f"bash {path_to_script} {DATASET} {os.path.dirname(params['DATA_PATH'])}")

if not os.path.exists(params["MODEL_PATH"]):
    os.makedirs(params["MODEL_PATH"], exist_ok=True)
    print(f"Created model weights path {params['MODEL_PATH']}")

if not os.path.exists(params["OUTPUT_PATH"]):
    os.makedirs(params["OUTPUT_PATH"], exist_ok=True)
    print(f"Created rollouts output path {params['OUTPUT_PATH']}")

if not os.path.exists(params["LOG_PATH"]):
    os.makedirs(params["LOG_PATH"], exist_ok=True)
    print(f"Created TensorBoard logging path {params['LOG_PATH']}")

print("\nParameters configured:")
for key, value in params.items():
    print(f"{key}:".ljust(22), f"{value}")
print('')


Parameters configured:
DATASET:               WaterDrop
DATA_PATH:             /Users/dinbergare/Desktop/experimental/learning_to_simulate/datasets/deepmind/WaterDrop
MODEL_PATH:            /Users/dinbergare/Desktop/experimental/learning_to_simulate/datasets/local/WaterDrop/weights
OUTPUT_PATH:           /Users/dinbergare/Desktop/experimental/learning_to_simulate/datasets/local/WaterDrop/rollouts
LOG_PATH:              /Users/dinbergare/Desktop/experimental/learning_to_simulate/datasets/local/WaterDrop/logs
MODE:                  train
BATCH_SIZE:            None
EVAL_SPLIT:            test
NUM_STEPS:             20000000
NOISE_STD:             0.0003
VELOCITY_CONTEXT_SIZE: 5
NUM_PARTICLE_TYPES:    9
STATIC_PARTICLE_ID:    3



## Modes

In [3]:
def run_train(model, metadata, timestamp, params):
    os.makedirs(params["MODEL_PATH"], exist_ok=True)
    window_length = params["VELOCITY_CONTEXT_SIZE"] + 2
    train_ds, train_size = utils.io.load_dataset(
        params["DATA_PATH"],
        split="train",
        mode="one_step_train",
        window_length=window_length,
        materialize_cache=False  # Caches all pre-processed examples in memory
    )
    valid_ds, valid_size = utils.io.load_dataset(
        params["DATA_PATH"],
        split="valid",
        mode="one_step_train",
        window_length=window_length,
        materialize_cache=False
    )
    test_ds, test_size = utils.io.load_dataset(
        params["DATA_PATH"],
        split="test",
        mode="one_step",
        window_length=window_length,
        materialize_cache=False
    )
    if all((train_size, valid_size, test_size)):
        total_size = train_size + valid_size + test_size
        print(f"\nDataset summary:")
        print(f"{'Split':<12} {'Examples':>10} {'Percent':>10}")
        print("-" * 34)
        print(f"{'Train':<12} {train_size:>10,} {train_size / total_size:>9.1%}")
        print(f"{'Valid':<12} {valid_size:>10,} {valid_size / total_size:>9.1%}")
        print(f"{'Test':<12} {test_size:>10,} {test_size / total_size:>9.1%}")
        print("-" * 34)
        print(f"{'Total':<12} {total_size:>10,} {100:>9.1f}%\n")
    try:
        for dummy in train_ds.take(1):
            inputs = dict(dummy)  # Copy to mutable dict
            inputs["positions"] = inputs["positions"][:, :-1, :]  # Remove target
            model(inputs)  # Build
        checkpoint = utils.io.get_latest_checkpoint(params["MODEL_PATH"])
        model.load_weights(checkpoint)
    except FileNotFoundError:
        print("No saved model weights. Training from scratch.")
    try:
        steps_per_epoch = 100  # tunable
        model.fit(
            train_ds,
            epochs=params["NUM_STEPS"] // steps_per_epoch,
            steps_per_epoch=steps_per_epoch,
            validation_data=valid_ds,
            validation_steps=steps_per_epoch,
            validation_freq=10,  # tuned to training : validation ratio
            callbacks=[
                tf.keras.callbacks.ModelCheckpoint(
                    filepath=os.path.join(params["MODEL_PATH"], f"{timestamp}.weights.h5"),
                    save_weights_only=True,
                    save_best_only=True,
                    save_freq="epoch"
                ),
                tf.keras.callbacks.EarlyStopping(
                    monitor="val_loss",
                    patience=10,  # tunable
                    restore_best_weights=True,
                    verbose=1
                ),
                tf.keras.callbacks.LambdaCallback(
                    on_epoch_end=lambda epoch, logs: logs.update(
                        { "learning_rate": tf.keras.backend.get_value(model.optimizer.learning_rate) }
                    )
                ),
                tf.keras.callbacks.TensorBoard(
                    log_dir=os.path.join(params["LOG_PATH"], timestamp),
                    write_graph=False,
                    update_freq="epoch"
                )
            ]
        )
        metrics = model.evaluate(test_ds, return_dict=True, verbose=1)
        logging.info("Evaluation metrics:")
        for k, v in metrics.items():
            logging.info(f"{k}: {v:.6f}")


    except Exception as e:
        print(f"Exception occurred: {e}")
        model.save_weights(os.path.join(params["MODEL_PATH"], f"{timestamp}.crash.weights.h5"))
        print(f"Weights saved to {params['MODEL_PATH']}.")
        return

def run_eval(model, metadata, timestamp, params):
    eval_ds, eval_size = utils.io.load_dataset(
        params["DATA_PATH"],
        split=params["EVAL_SPLIT"],
        mode="one_step",
        window_length=params["VELOCITY_CONTEXT_SIZE"] + 2
    )
    for dummy in eval_ds.take(1):
        inputs = dict(dummy)  # Copy to mutable dict
        inputs["positions"] = inputs["positions"][:, :-1, :]  # Remove target
        model(inputs)  # Build
    checkpoint = utils.io.get_latest_checkpoint(params["MODEL_PATH"])
    model.load_weights(checkpoint)
    metrics = model.evaluate(eval_ds, return_dict=True)
    logging.info("Evaluation metrics:")
    for k, v in metrics.items():
        logging.info(f"{k}: {v:.6f}")

def run_rollout(model, metadata, timestamp, params):
    os.makedirs(params["OUTPUT_PATH"], exist_ok=True)
    rollout_ds, rollout_size = utils.io.load_dataset(
        params["DATA_PATH"],
        split=params["EVAL_SPLIT"],
        mode="rollout"
    )
    for dummy in rollout_ds.take(1):
        inputs = dict(dummy)  # Copy to mutable dict
        num_seed = params["VELOCITY_CONTEXT_SIZE"] + 1
        inputs["positions"] = inputs["positions"][:, :num_seed, :]  # Window
        model(inputs)  # Build
    checkpoint = utils.io.get_latest_checkpoint(params["MODEL_PATH"])
    model.load_weights(checkpoint)
    num_steps =  metadata["sequence_length"] - params["VELOCITY_CONTEXT_SIZE"]
    for i, example in enumerate(rollout_ds, start=1):
        result = model.rollout(example, num_steps=num_steps)
        result["metadata"] = metadata
        filename = os.path.join(params["OUTPUT_PATH"], f"rollout_{params['EVAL_SPLIT']}_{i}.pkl")
        logging.info(f"Rollout {i} computed for {num_steps} steps. Saving to {filename}")
        with open(filename, "wb") as f:
            pickle.dump(result, f)

## Run simulation

In [4]:
def main(argv):
    metadata = utils.io.load_metadata(argv["DATA_PATH"])

    model = LearnedSimulator(
        dim=metadata["dim"],
        cutoff_radius=metadata["default_connectivity_radius"],
        boundaries=metadata["bounds"],
        noise_std=argv["NOISE_STD"],
        normalization_stats=utils.io.get_normalization_stats(metadata, argv["NOISE_STD"], argv["NOISE_STD"]),
        num_particle_types=argv["NUM_PARTICLE_TYPES"],
        static_particle_type_id=argv["STATIC_PARTICLE_ID"],
        velocity_context_size=argv["VELOCITY_CONTEXT_SIZE"],
        bitwave_sizes=(9, 5, 4, 3, 2, 2, 2, 2, 1, 1, 1),
        bitqueue_size=9,
    )
    model.compile(
        optimizer=tf.keras.optimizers.Adam(
            learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(  # S-G, p. 12
                initial_learning_rate=1e-4,
                decay_steps=argv["NUM_STEPS"],
                decay_rate=1e-2
            )  # 1e4 -> 1e6 exponentially over all training steps, can be more aggressive
        )
    )

    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    if argv["MODE"] == "train":
        run_train(model, metadata, timestamp, argv)
    elif argv["MODE"] == "eval":
        run_eval(model, metadata, timestamp, argv)
    elif argv["MODE"] == "rollout":
        run_rollout(model, metadata, timestamp, argv)

In [5]:
if __name__ == "__main__":
    tf.get_logger().setLevel(logging.ERROR)  # Suppress TF warnings
    tf.config.run_functions_eagerly(False)
    if settings.TF_DEBUG_MODE:
        tf.data.experimental.enable_debug_mode()

    main(params)

2025-06-20 21:05:18.283914: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] ShuffleDatasetV3:7: Filling up shuffle buffer (this may take a while): 9951 of 10000
2025-06-20 21:05:18.286497: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:480] Shuffle buffer filled.
2025-06-20 21:05:18.287264: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


ValueError: Exception encountered when calling layer 'residual_next_state' (type ResidualNextState).

A ResidualNextState() requires an update_fn whose output has the same shape as the input state, but got output shape [307, 128] vs input shape [307, 256] from single input.

Call arguments received by layer 'residual_next_state' (type ResidualNextState):
  • inputs=('tf.Tensor(shape=(307, 256), dtype=float32)', {'neighbors': 'tf.Tensor(shape=(307, 128), dtype=float32)'}, {})