# Simulating Complex Physical Phenomena

This notebook provides you with a complete code example to simulate complex systems of interacting particles using graph neural networks.

## Working with the SAND Dataset

Download the SAND data ...

In [None]:
from huggingface_hub import snapshot_download

snapshot_download(repo_id="DeepTrackAI/Sand", local_dir="./sand_dataset", 
                  allow_patterns=["*.npz", "*.json"], repo_type="dataset")

... load the SAND data ...

In [2]:
import numpy as np

def load_npz_data(path):
    """Load NPZ data."""
    with np.load(path, allow_pickle=True) as data_file:
        data = [item for _, item in data_file.items()]
    return data

train_data = load_npz_data("sand_dataset/train.npz")
val_data = load_npz_data("sand_dataset/valid.npz")
test_data = load_npz_data("sand_dataset/test.npz")

... load the SAND metadata ...

In [3]:
import json

with open("sand_dataset/metadata.json", "r") as data_file:
    metadata = json.load(data_file)

... prepare a video of a SAND simulation ...

In [None]:
import matplotlib.pyplot as plt
from IPython.display import HTML
from matplotlib.animation import FuncAnimation

sample_id = np.random.randint(0, len(train_data))
positions = train_data[sample_id][0]

fig, ax = plt.subplots(figsize=(6, 6)) 
scatter = ax.scatter([], [], s=50, c="y", edgecolors="k", linewidth=0.5)
ax.set_xlim(0, 1); ax.set_xticks([]); ax.set_ylim(0, 1); ax.set_yticks([])

def update(frame):
    """Update frame."""
    scatter.set_offsets(positions[frame])
    return [scatter]

ani = FuncAnimation(fig, update, frames=len(positions), interval=10, blit=True)
video = HTML(ani.to_jshtml()); plt.close()
plt.close()

... and visualize it.

In [None]:
video

## Building a Graph Network-Based Simulator

Implement the message-passing model ...

In [9]:
import deeplay as dl
import torch.nn as nn

model = dl.GraphToNodeMPM(hidden_features=[64,] * 9, out_features=2)

... incorporate skip connections in the message-passing layer ...

In [10]:
rmp_backbone = dl.ResidualMessagePassingNeuralNetwork(
    hidden_features=model.backbone.hidden_features,
    out_features=model.backbone.out_features,
    out_activation=nn.ReLU,
)
model.replace("backbone", rmp_backbone)
model = model.create()

## Building the Dataset

Implement a class to manage the dataset with particle simulations ...

In [11]:
def animate(sample_id, regressor, test_data, metadata, time_window, noise_std):
    """Animate simulation"""
    pos = test_data[sample_id][0]
    sim_pos = simulate(regressor, pos, metadata, time_window, noise_std)
    sim_pos = np.transpose(sim_pos, (1, 0, 2))

    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    scatters = [
        ax[0].scatter([], [], s=50, c="y", edgecolors="k", linewidth=0.5),
        ax[1].scatter([], [], s=50, c="y", edgecolors="k", linewidth=0.5),
    ]
    ax[0].set_xlim(0, 1); ax[0].set_xticks([])
    ax[0].set_ylim(0, 1); ax[0].set_yticks([])
    ax[1].set_xlim(0, 1); ax[1].set_xticks([])
    ax[1].set_ylim(0, 1);  ax[1].set_yticks([])
    ax[0].set_title("Ground Truth"); ax[1].set_title("Simulated")

    def update(frame):
        """Update frame."""
        scatters[0].set_offsets(pos[frame])
        scatters[1].set_offsets(sim_pos[frame])
        return scatters

    ani = FuncAnimation(fig, update, frames=len(pos), interval=10, blit=True)
    video = HTML(ani.to_jshtml())
    plt.close()
    return video

... initialize the training, validation, and testing datasets ...

In [12]:
time_window, noise_std = 6, 3e-4

train_dataset = ParticleDataset(train_data, metadata,
                                time_window=time_window, noise_std=noise_std)
val_dataset = ParticleDataset(val_data, metadata, time_window=time_window,
                              noise_std=noise_std)
test_dataset = ParticleDataset(test_data, metadata, time_window=time_window, 
                               noise_std=noise_std)

... define the data loaders ...

In [None]:
from torch_geometric.data import DataLoader

train_loader = \
    DataLoader(train_dataset, batch_size=4, shuffle=True, pin_memory=True)
val_loader = \
    DataLoader(val_dataset, batch_size=4, shuffle=False, pin_memory=True)
test_loader = \
    DataLoader(test_dataset, batch_size=4, shuffle=False, pin_memory=True)

... and train the model.

In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint

regressor = dl.Regressor(model, loss=torch.nn.MSELoss(), 
                         optimizer=dl.Adam(lr=1e-4))
regressor = regressor.create()

checkpoint_callback = ModelCheckpoint(
    monitor="val_loss", dirpath="models",
    filename="SAND-GNS-model{epoch:02d}-val_loss{val_loss:.2f}",
    auto_insert_metric_name=False,
)
trainer = dl.Trainer(max_epochs=5, callbacks=[checkpoint_callback])
trainer.fit(regressor, train_loader, val_loader)

## Loading a pretrained model

In [17]:
def get_device():
    """Select device where to perform computations."""
    if torch.cuda.is_available():
        return torch.device("cuda:0")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")

In [19]:
device = get_device()

In [22]:
import os

best_model_path = os.path.join("models", "SAND-GNS-model.ckpt")
best_model = torch.load(best_model_path, map_location=torch.device(device))
regressor.load_state_dict(best_model["state_dict"]);

## Testing the Model

In [None]:
trainer.test(regressor, test_loader)

## Simulating the System

Implement a function to simulate the system ...

In [24]:
def simulate(model, positions, metadata, time_window, noise_std):
    """Simulate the system."""
    model.eval()

    total_time = positions.shape[0] 
    windowed_positions = positions[:time_window].copy()
    windowed_positions = np.transpose(windowed_positions, (1, 0, 2))
    for _ in range(total_time - time_window):
        with torch.no_grad():
            x, edge_index, edge_attr = \
                val_dataset.compute_graph(windowed_positions[:, -time_window:])

            graph = \
                Data(x=torch.tensor(x, dtype=torch.float32),
                edge_index=torch.tensor(edge_index, dtype=torch.long),
                edge_attr=torch.tensor(edge_attr, dtype=torch.float32))

            graph = graph.to(model.device)
            acceleration = model(graph)
            acceleration = acceleration.cpu().numpy()
            acceleration = acceleration * (np.array(metadata["acc_std"]) ** 2 \
                + noise_std ** 2) ** 0.5 + np.array(metadata["acc_mean"])

            current_position = windowed_positions[:, -1]
            current_velocity = current_position - windowed_positions[:, -2]
            next_velocity = current_velocity + acceleration
            next_position = current_position + next_velocity

            windowed_positions = np.concatenate(
                [windowed_positions, next_position[:, None]], axis=1
            )
            
    return windowed_positions

... implement a function to animate a simulation ...

In [27]:
def animate(sample_id, regressor, test_data, metadata, time_window, noise_std):
    """Animate simulation"""
    pos = test_data[sample_id][0]
    sim_pos = simulate(regressor, pos, metadata, time_window, noise_std)
    sim_pos = np.transpose(sim_pos, (1, 0, 2))

    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    scatters = [
        axs[0].scatter([], [], s=50, c="y", edgecolors="k", linewidth=0.5),
        axs[1].scatter([], [], s=50, c="y", edgecolors="k", linewidth=0.5),
    ]
    axs[0].set_xlim(0, 1); axs[0].set_xticks([])
    axs[0].set_ylim(0, 1); axs[0].set_yticks([])
    axs[1].set_xlim(0, 1); axs[1].set_xticks([])
    axs[1].set_ylim(0, 1); axs[1].set_yticks([])
    axs[0].set_title("Ground Truth"); axs[1].set_title("Simulated")

    def update(frame):
        """Update frame."""
        scatters[0].set_offsets(pos[frame])
        scatters[1].set_offsets(sim_pos[frame])
        return scatters

    ani = FuncAnimation(fig, update, frames=len(pos), interval=10, blit=True)
    video = HTML(ani.to_jshtml()); plt.close()
    return video

... and try this simulation.

In [None]:
animate(23, regressor, test_data, metadata, time_window, noise_std)