# Training on 2D Cylinder Flow Prediction

- Remember to select a GPU runtime (L4 or T4 will be enough)

# Les Imports

In [2]:
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

In [4]:
TORCH = "2.4.0"

In [5]:
!pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric

Looking in links: https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.4.0%2Bcu121/torch_scatter-2.1.2%2Bpt24cu121-cp310-cp310-linux_x86_64.whl (10.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m39.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.1.2+pt24cu121
Looking in links: https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.4.0%2Bcu121/torch_sparse-0.6.18%2Bpt24cu121-cp310-cp310-linux_x86_64.whl (5.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m38.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.18+pt24cu121
Looking in links: https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html
Collecting torch-cluster
 

In [6]:
!pip install loguru==0.7.2
!pip install autoflake==2.3.0
!pip install pytest==8.0.1
!pip install meshio==5.3.5
!pip install h5py==3.10.0

Collecting loguru==0.7.2
  Downloading loguru-0.7.2-py3-none-any.whl.metadata (23 kB)
Downloading loguru-0.7.2-py3-none-any.whl (62 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/62.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.5/62.5 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: loguru
Successfully installed loguru-0.7.2
Collecting autoflake==2.3.0
  Downloading autoflake-2.3.0-py3-none-any.whl.metadata (7.6 kB)
Collecting pyflakes>=3.0.0 (from autoflake==2.3.0)
  Downloading pyflakes-3.2.0-py2.py3-none-any.whl.metadata (3.5 kB)
Downloading autoflake-2.3.0-py3-none-any.whl (32 kB)
Downloading pyflakes-3.2.0-py2.py3-none-any.whl (62 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyflakes, autoflake
Successfully installed autoflake-2.3.0 pyflakes-3.2.0
Coll

# Fetching the dataset

In [7]:
!wget -O "cylinder.zip" "https://tinyurl.com/cylinder-idsc"
!unzip "cylinder.zip"

--2024-09-25 12:08:29--  https://tinyurl.com/cylinder-idsc
Resolving tinyurl.com (tinyurl.com)... 104.18.111.161, 104.17.112.233, 2606:4700::6811:70e9, ...
Connecting to tinyurl.com (tinyurl.com)|104.18.111.161|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://storage.googleapis.com/large-physics-model/datasets/cylinder/cylinder.zip [following]
--2024-09-25 12:08:29--  https://storage.googleapis.com/large-physics-model/datasets/cylinder/cylinder.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 172.253.118.207, 74.125.200.207, 74.125.130.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|172.253.118.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 91952009 (88M) [application/zip]
Saving to: ‘cylinder.zip’


2024-09-25 12:08:40 (8.91 MB/s) - ‘cylinder.zip’ saved [91952009/91952009]

Archive:  cylinder.zip
   creating: large-physics-model/data/
  inflating: large-physics-model

# Dataset

## First look

In [8]:
import pickle
with open("large-physics-model/data/cylinder_10.pkl", "rb") as openfile:
  graph = pickle.load(openfile)

In [9]:
graph

Data(x=[1923, 4], edge_index=[2, 11070], edge_attr=[11070, 3], y=[1923, 2], pos=[1923, 2])

In [10]:
print(graph.x.shape)
graph.x

torch.Size([1923, 4])


tensor([[ 3.3412e-01,  0.0000e+00,  4.0000e+00,  1.0000e-01],
        [ 3.0082e-01, -6.4449e-04,  0.0000e+00,  1.0000e-01],
        [ 1.6676e-01,  0.0000e+00,  4.0000e+00,  1.0000e-01],
        ...,
        [ 0.0000e+00,  0.0000e+00,  6.0000e+00,  1.0000e-01],
        [ 3.3250e-01,  2.9416e-02,  5.0000e+00,  1.0000e-01],
        [ 0.0000e+00,  0.0000e+00,  6.0000e+00,  1.0000e-01]], device='cuda:0')

In [11]:
print(graph.edge_index.shape)
graph.edge_index

torch.Size([2, 11070])


tensor([[   0,    0,    0,  ..., 1922, 1922, 1922],
        [   1,    2,    6,  ..., 1919, 1920, 1921]], device='cuda:0')

In [12]:
print(graph.edge_attr.shape)
graph.edge_attr

torch.Size([11070, 3])


tensor([[-0.0123, -0.0016,  0.0124],
        [ 0.0000, -0.0082,  0.0082],
        [ 0.0000,  0.0148,  0.0148],
        ...,
        [ 0.0184,  0.0035,  0.0187],
        [ 0.0184,  0.0000,  0.0184],
        [ 0.0000,  0.0035,  0.0035]], device='cuda:0')

In [13]:
print(graph.pos.shape)
graph.pos

torch.Size([1923, 2])


tensor([[0.0000, 0.3940],
        [0.0123, 0.3955],
        [0.0000, 0.4022],
        ...,
        [1.5816, 0.4100],
        [1.6000, 0.4065],
        [1.6000, 0.4100]], device='cuda:0')

In [14]:
print(graph.y.shape)
graph.y

torch.Size([1923, 2])


tensor([[ 3.3412e-01,  0.0000e+00],
        [ 3.0120e-01, -2.2341e-05],
        [ 1.6676e-01,  0.0000e+00],
        ...,
        [ 0.0000e+00,  0.0000e+00],
        [ 3.1932e-01,  2.8717e-02],
        [ 0.0000e+00,  0.0000e+00]], device='cuda:0')

In [15]:
import enum

class NodeType(enum.IntEnum):
    NORMAL = 0
    OBSTACLE = 1
    AIRFOIL = 2
    HANDLE = 3
    INFLOW = 4
    OUTFLOW = 5
    WALL_BOUNDARY = 6
    SIZE = 9

# Create a dataset class

In [25]:
from torch.utils.data import Dataset as BaseDataset

class Dataset(BaseDataset):
    def __init__(
        self,
        folder_path: str,
    ):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.folder_path = folder_path
        self.files = None

        self.trajectory_length = 599

    def __len__(self):
      return self.trajectory_length

    def __getitem__(self, i: int):
        with open(f"large-physics-model/data/cylinder_{i}.pkl", "rb") as openfile:
          graph = pickle.load(openfile)
        graph = graph.to(self.device)

        graph_data = {
            "x": graph.x,
            "pos": graph.pos,
            "edge_index": graph.edge_index,
            "edge_attr": graph.edge_attr,
            "y": graph.y,
        }

        return graph_data


In [27]:
dataset = Dataset("large-physics-model/data/")

## Test the dataset class

- Create a dataset object
- Fetch the first item

# The Dataloader

In [29]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    dataset=dataset,
    batch_size=1,
    shuffle=True,
    num_workers=0,
)

# The Model


First, we need a function to build an MLP since we will use that a lot.

In [30]:
import torch
import torch.nn as nn

def build_mlp(
    in_size: int,
    hidden_size: int,
    out_size: int,
    nb_of_layers: int = 4,
    lay_norm: bool = True,
) -> nn.Module:
    """
    Builds a Multilayer Perceptron (MLP) using PyTorch.

    Parameters:
        - in_size (int): The size of the input layer.
        - hidden_size (int): The size of the hidden layers.
        - out_size (int): The size of the output layer.
        - nb_of_layers (int, optional): The number of layers in the MLP, including the input and output layers. Defaults to 4.

    Returns:
        - nn.Module: The constructed MLP model.
    """
    # Initialize the model with the first layer.
    layers = []
    layers.append(nn.linear(in_size,hidden_size))
    layers.append(nn.ReLU())

    if lay_norm:
      layers.append(nn.LayerNorm(hidden_size))

    for _ in range(nb_of_layers - 2):
      layers.append(nn.Linear(hidden_size,hidden_size))
      layers.append(nn.ReLU())

      if lay_norm:
        layers.append(nn.LayerNorm(hidden_size))

    # Add the output layer
      layers.append(nn.Linear(hidden_size,out_size))

    # Construct the model using the specified layers.
    module = nn.Sequential(*layers)

    return module

## The Encoder

Remember? We just want an MLP to encode the edges and one to encode the nodes.

In [31]:
from torch_geometric.data import Data

class Encoder(nn.Module):
    """Encoder class for encoding graph structures into latent representations.

    This encoder takes a graph as input and produces latent representations for both nodes and edges.
    It utilizes MLPs (Multi-Layer Perceptrons) to encode the node and edge attributes into a latent space.

    Attributes:
        - edge_encoder (nn.Module): MLP for encoding edge attributes.
        - nodes_encoder (nn.Module): MLP for encoding node attributes.

    Args:
        - edge_input_size (int): Size of the input edge features. Defaults to 128.
        - node_input_size (int): Size of the input node features. Defaults to 128.
        - hidden_size (int): Size of the hidden layers in the MLPs. Defaults to 128.
    """

    def __init__(
        self, edge_input_size=128, node_input_size=128, hidden_size=128, nb_of_layers=4
    ):

        super(Encoder, self).__init__()

        self.edge_encoder = build_mlp(
            in_size=edge_input_size,
            hidden_size=hidden_size,
            out_size=hidden_size,
            nb_of_layers=nb_of_layers
            )
        self.node_encoder = build_mlp(
            in_size=node_input_size,
            hidden_size=hidden_size,
            out_size=hidden_size,
            nb_of_layers=nb_of_layers
        )

    def forward(self, graph: Data) -> Data:
        """
        Forward pass of the encoder.

        Args:
            - graph (Data): A graph object from torch_geometric containing node and edge attributes.

        Returns:
            - Data: A graph object with encoded node and edge attributes.
        """

        node_attr, edge_attr = graph.x, graph.edge_attr
        node_latents = self.node_encoder(node_attr)
        edge_latents = self.edge_encoder(edge_attr)

        return Data(
            x=node_latents,
            edge_attr=edge_latents,
            edge_index=graph.edge_index,
            pos=graph.pos,
        )

## And the decoder since we're already here

In [32]:
class Decoder(nn.Module):
    """Decoder class for decoding latent representations back into graph structures.

    This decoder takes the latent representations of nodes (and potentially edges) and decodes them back into
    graph space, aiming to reconstruct the original graph or predict certain properties of the graph.

    Attributes:
        decode_module (nn.Module): An MLP module used for decoding the latent representations.

    Args:
        hidden_size (int): The size of the hidden layers in the MLP. This is also the size of the latent representation.
        output_size (int): The size of the output layer, which should match the dimensionality of the target graph space.
    """

    def __init__(
        self, hidden_size: int = 128, output_size: int = 2, nb_of_layers: int = 4
    ):

        super(Decoder, self).__init__()

        self.decode_module = build_mlp(
            in_size=hidden_size,
            hidden_size=hidden_size,
            out_size=output_size,
            nb_of_layers=nb_of_layers
        )

    def forward(self, graph: Data) -> Data:
        """Forward pass of the decoder.

        Args:
            graph (Data): A graph object from torch_geometric containing the latent representations of nodes.

        Returns:
            Data: A graph object where `x` has been decoded from the latent space back into the original graph space.
                  The structure of the graph (edges) remains unchanged.
        """
        return self.decode_module(graph.x)

# Message Passing

We need to build:
- the edge block
- the node block
- the full message passing block

In [34]:
from torch_scatter import scatter_add

class EdgeBlock(nn.Module):
    """A block that updates the attributes of the edges in a graph based on the features of the
    sending and receiving nodes, as well as the original edge attributes.

    Attributes:
        model_fn (callable): A function to update edge attributes.
    """

    def __init__(self, model_fn=None):

        super(EdgeBlock, self).__init__()
        self._model_fn = model_fn

    def forward(self, graph):
        """Forward pass of the EdgeBlock.

        Args:
            graph (Data): A graph containing node attributes, edge indices, and edge attributes.

        Returns:
            Data: An updated graph with new edge attributes.
        """
        edge_inputs = torch.concat(
            [
            graph.edge_attribute,
            graph.x[graph.edge_index[0]],
            graph.x[graph.edge_index[1]]
            ], dim=1
        )

        edge_attr_ = self._model_fn(edge_inputs)

        return Data(
                x=graph.x, edge_attr=edge_attr_, edge_index=graph.edge_index, pos=graph.pos
            )

In [35]:
class NodeBlock(nn.Module):
    """A block that updates the attributes of the nodes in a graph based on the aggregated features
    of the incoming edges and the original node attributes.

    Attributes:
        model_fn (callable): A function to update node attributes.
    """

    def __init__(self, model_fn=None):

        super(NodeBlock, self).__init__()

        self._model_fn = model_fn

    def forward(self, graph):
        """Forward pass of the NodeBlock.

        Args:
            graph (Data): A graph containing node attributes, edge indices, and edge attributes.

        Returns:
            Data: An updated graph with new node attributes.
        """
        edge_attr = graph.edge_attr
        receivers_index = graph.edge_index[1]
        agrr_edge_features = scatter_add(
            edge_attr, receivers_index, dim=0, dim_size=graph.num_nodes
        )

        node_inputs = torch.cat(
            [graph.x,agrr_edge_features],dim=1
        )

        x_ = self._model_fn(node_inputs)

        Data(
                x=x_, edge_attr=graph.edge_attr, edge_index=graph.edge_index, pos=graph.pos
            )

In [36]:
class GraphNetBlock(nn.Module):
    """A block that sequentially applies an EdgeBlock and a NodeBlock to update the attributes of
    both edges and nodes in a graph.

    Attributes:
        edge_block (EdgeBlock): The block to update edge attributes.
        node_block (NodeBlock): The block to update node attributes.
    """

    def __init__(
        self,
        hidden_size=128,
        use_batch=False,
        use_gated_mlp=False,
        use_gated_lstm=False,
        use_gated_mha=False,
    ):

        super(GraphNetBlock, self).__init__()

        edge_input_dim = 3*hidden_size
        node_input_dim = 2*hidden_size

        self.edge_block = EdgeBlock(model_fn = build_mlp(
            in_size=edge_input_dim,
            hidden_size=hidden_size,
            out_size=hidden_size
        ))
        self.node_block = NodeBlock(model_fn = build_mlp(
            in_size=node_input_dim,
            hidden_size=hidden_size,
            out_size=hidden_size
        ))

    def _apply_sub_block(self, graph):
        graph = self.edge_block(graph)
        return self.node_block(graph)

    def forward(self, graph):

        graph_last = None #
        graph = self._apply_sub_block(graph)

        edge_attr = graph_last.edge_attr + graph.edge_attr
        x = graph_last.x + graph.x

        return Data(
                x=x, edge_attr=edge_attr, edge_index=graph.edge_index, pos=graph.pos
            )

# The full architecture

We want to build a model that given a graph:
- applies the encoder
- applies L message passing
- returns the output of the decoder

In [None]:
class EncodeProcessDecode(nn.Module):
    """An Encode-Process-Decode model for graph neural networks.

    This model architecture is designed for processing graph-structured data. It consists of three main components:
    an encoder, a processor, and a decoder. The encoder maps input graph features to a latent space, the processor
    performs message passing and updates node representations, and the decoder generates the final output from the
    processed graph.

    Attributes:
        encoder (Encoder): The encoder component that transforms input graph features to a latent representation.
        decoder (Decoder): The decoder component that generates output from the processed graph.
        processer_list (nn.ModuleList): A list of GraphNetBlock modules for message passing and node updates.

    Parameters:
        message_passing_num (int): The number of message passing (GraphNetBlock) layers.
        node_input_size (int): The size of the input node features.
        edge_input_size (int): The size of the input edge features.
        output_size (int): The size of the output features.
        hidden_size (int, optional): The size of the hidden layers. Defaults to 128.
    """

    def __init__(
        self,
        message_passing_num,
        node_input_size,
        edge_input_size,
        output_size,
        hidden_size=128,
    ):

        super(EncodeProcessDecode, self).__init__()
        self.encoder = None #

        self.decoder = None #

        self.processer_list = nn.ModuleList(
                [
                    GraphNetBlock(hidden_size=hidden_size)
                    for _ in range(message_passing_num)
                ]
        )

    def forward(self, graph):
        """Forward pass of the Encode-Process-Decode model.

        Args:
            graph: The input graph data. The exact type and format depend on the implementation of the Encoder and
                   GraphNetBlock modules.

        Returns:
            The output of the model after encoding, processing, and decoding the input graph.
        """
        return None #

# We can now define a way to normalize our data and put everything together. We'll talk a bit about how it works then believe it works. It's not super interesting.

In [None]:
class Normalizer(nn.Module):
    def __init__(
        self,
        size,
        max_accumulations=10**5,
        std_epsilon=1e-8,
        name="Normalizer",
        device="cuda",
    ):
        """Initializes the Normalizer module.

        Args:
            size (int): Size of the input data.
            max_accumulations (int): Maximum number of accumulations allowed.
            std_epsilon (float): Epsilon value for standard deviation calculation.
            name (str): Name of the Normalizer.
            device (str): Device to run the Normalizer on.
        """
        super(Normalizer, self).__init__()
        self.name = name
        self._max_accumulations = max_accumulations
        self._std_epsilon = torch.tensor(
            std_epsilon, dtype=torch.float32, requires_grad=False, device=device
        )
        self._acc_count = torch.tensor(
            0, dtype=torch.float32, requires_grad=False, device=device
        )
        self._num_accumulations = torch.tensor(
            0, dtype=torch.float32, requires_grad=False, device=device
        )
        self._acc_sum = torch.zeros(
            (1, size), dtype=torch.float32, requires_grad=False, device=device
        )
        self._acc_sum_squared = torch.zeros(
            (1, size), dtype=torch.float32, requires_grad=False, device=device
        )
        self._std_zeros = torch.zeros(
            (1, size), dtype=torch.float32, requires_grad=False, device=device
        )

    def forward(self, batched_data, accumulate=True):
        """Normalizes input data and accumulates statistics."""
        if accumulate:
            # stop accumulating after a million updates, to prevent accuracy issues
            if self._num_accumulations < self._max_accumulations:
                self._accumulate(batched_data.detach())
        return (batched_data - self._mean()) / self._std_with_epsilon()

    def inverse(self, normalized_batch_data):
        """Inverse transformation of the normalizer."""
        return normalized_batch_data * self._std_with_epsilon() + self._mean()

    def _accumulate(self, batched_data):
        """Function to perform the accumulation of the batch_data statistics."""
        count = batched_data.shape[0]
        data_sum = torch.sum(batched_data, axis=0, keepdims=True)
        squared_data_sum = torch.sum(batched_data**2, axis=0, keepdims=True)

        self._acc_sum += data_sum
        self._acc_sum_squared += squared_data_sum
        self._acc_count += count
        self._num_accumulations += 1

    def _mean(self):
        safe_count = torch.maximum(
            self._acc_count,
            torch.tensor(1.0, dtype=torch.float32, device=self._acc_count.device),
        )
        return self._acc_sum / safe_count

    def _std_with_epsilon(self):
        safe_count = torch.maximum(
            self._acc_count,
            torch.tensor(1.0, dtype=torch.float32, device=self._acc_count.device),
        )
        std = torch.sqrt(
            torch.maximum(
                self._std_zeros, self._acc_sum_squared / safe_count - self._mean() ** 2
            )
        )
        return torch.maximum(std, self._std_epsilon)

    def get_variable(self):

        dict = {
            "_max_accumulations": self._max_accumulations,
            "_std_epsilon": self._std_epsilon,
            "_acc_count": self._acc_count,
            "_num_accumulations": self._num_accumulations,
            "_acc_sum": self._acc_sum,
            "_acc_sum_squared": self._acc_sum_squared,
            "name": self.name,
        }

        return dict

In [None]:
from loguru import logger

class Simulator(nn.Module):

    def __init__(
        self,
        node_input_size: int,
        edge_input_size: int,
        output_size: int,
        feature_index_start: int,
        feature_index_end: int,
        output_index_start: int,
        output_index_end: int,
        node_type_index: int,
        batch_size: int,
        model,
        device,
        model_dir="checkpoint/simulator.pth",
        time_index: int = None,
    ):
        """Initialize the Simulator module.

        Args:
            node_input_size (int): Size of node input.
            edge_input_size (int): Size of edge input.
            output_size (int): Size of the output/prediction from the network.
            feature_index_start (int): Start index of features.
            feature_index_end (int): End index of features.
            output_index_start (int): Start index of output.
            output_index_end (int): End index of output.
            node_type_index (int): Index of node type.
            model: The model to be used.
            device: The device to run the model on.
            model_dir (str): Directory to save the model checkpoint.
            time_index (int): Index of time feature.
        """
        super(Simulator, self).__init__()

        self.node_input_size = node_input_size
        self.edge_input_size = edge_input_size
        self.output_size = output_size

        self.feature_index_start = feature_index_start
        self.feature_index_end = feature_index_end
        self.node_type_index = node_type_index

        self.time_index = time_index

        self.output_index_start = output_index_start
        self.output_index_end = output_index_end

        self.model_dir = model_dir
        self.model = model.to(device)
        self._output_normalizer = Normalizer(
            size=output_size, name="output_normalizer", device=device
        )
        self._node_normalizer = Normalizer(
            size=node_input_size, name="node_normalizer", device=device
        )
        self._edge_normalizer = Normalizer(
            size=edge_input_size, name="edge_normalizer", device=device
        )

        self.device = device
        self.batch_size = batch_size

    def _get_pre_target(self, inputs: Data) -> torch.Tensor:
        return inputs.x[:, self.output_index_start : self.output_index_end]

    def _build_input_graph(self, inputs: Data, is_training: bool) -> Data:
        node_type = inputs.x[:, self.node_type_index]
        features = inputs.x[:, self.feature_index_start : self.feature_index_end]

        target = inputs.y
        pre_target = self._get_pre_target(inputs)

        target_delta = target - pre_target
        target_delta_normalized = self._output_normalizer(target_delta, is_training)

        one_hot_type = torch.nn.functional.one_hot(
            torch.squeeze(node_type.long()), NodeType.SIZE
        )

        node_features_list = [features, one_hot_type]
        node_features_list.append(inputs.x[:, self.time_index].reshape(-1, 1))

        node_features = torch.cat(node_features_list, dim=1)

        node_features_normalized = self._node_normalizer(node_features, is_training)
        edge_features_normalized = self._edge_normalizer(
                    inputs.edge_attr, is_training
        )

        graph = Data(
                x=node_features_normalized,
                pos=inputs.pos,
                edge_attr=edge_features_normalized,
                edge_index=inputs.edge_index,
            )

        return graph, target_delta_normalized

    def _build_outputs(
        self, inputs: Data, network_output: torch.Tensor
    ) -> torch.Tensor:
        pre_target = self._get_pre_target(inputs)
        update = self._output_normalizer.inverse(network_output)
        return pre_target + update

    def forward(self, inputs: Data):
        if self.training:
            graph, target_delta_normalized = self._build_input_graph(
                inputs=inputs, is_training=True
            )
            network_output = self.model(graph)
            return network_output, target_delta_normalized
        else:
            graph, target_delta_normalized = self._build_input_graph(
                inputs=inputs, is_training=False
            )
            network_output = self.model(graph)
            return (
                network_output,
                target_delta_normalized,
                self._build_outputs(inputs=inputs, network_output=network_output),
            )

    def freeze_all(self):
        for param in self.model.parameters():
            param.requires_grad = False

    def load_checkpoint(self, ckpdir=None):

        if ckpdir is None:
            ckpdir = self.model_dir
        dicts = torch.load(ckpdir, map_location=torch.device(self.device))
        self.load_state_dict(dicts["model"])

        keys = list(dicts.keys())
        keys.remove("model")

        for k in keys:
            v = dicts[k]
            for para, value in v.items():
                object = eval("self." + k)
                setattr(object, para, value)

        logger.success("Simulator model loaded checkpoint %s" % ckpdir)

    def save_checkpoint(self, savedir=None):
        if savedir is None:
            savedir = self.model_dir

        os.makedirs(os.path.dirname(self.model_dir), exist_ok=True)

        model = self.state_dict()
        _output_normalizer = self._output_normalizer.get_variable()
        _node_normalizer = self._node_normalizer.get_variable()
        _edge_normalizer = self._edge_normalizer.get_variable()

        to_save = {
            "model": model,
            "_output_normalizer": _output_normalizer,
            "_node_normalizer": _node_normalizer,
            "_edge_normalizer": _edge_normalizer,
        }

        torch.save(to_save, savedir)

# Training

In [None]:
from torch.nn.modules.loss import _Loss

class L2Loss(_Loss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @property
    def __name__(self):
        return "MSE"

    def forward(
        self, target_acceleration, network_output, node_type
    ):
        "Computes L2 loss on velocity, with respect to the noise"
        mask = torch.logical_or(
            node_type == NodeType.NORMAL, node_type == NodeType.OUTFLOW
        )

        errors = None #
        return torch.mean(errors)

In [None]:
import numpy as np


class Meter(object):
    """Meters provide a way to keep track of important statistics in an online manner.
    This class is abstract, but provides a standard interface for all meters to follow.
    """

    def reset(self):
        """Reset the meter to default settings."""

    def add(self, value):
        """Log a new value to the meter
        Args:
            value: Next result to include.
        """

    def value(self):
        """Get the value of the meter in the current state."""


class AverageValueMeter(Meter):
    def __init__(self):
        super(AverageValueMeter, self).__init__()
        self.reset()
        self.val = 0

    def add(self, value, n=1):
        self.val = value
        self.sum += value
        self.var += value * value
        self.n += n

        if self.n == 0:
            self.mean, self.std = np.nan, np.nan
        elif self.n == 1:
            self.mean = 0.0 + self.sum  # This is to force a copy in torch/numpy
            self.std = np.inf
            self.mean_old = self.mean
            self.m_s = 0.0
        else:
            self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n)
            self.m_s += (value - self.mean_old) * (value - self.mean)
            self.mean_old = self.mean
            self.std = np.sqrt(self.m_s / (self.n - 1.0))

    def value(self):
        return self.mean, self.std

    def reset(self):
        self.n = 0
        self.sum = 0.0
        self.var = 0.0
        self.val = 0.0
        self.mean = np.nan
        self.mean_old = 0.0
        self.m_s = 0.0
        self.std = np.nan

In [None]:
from tqdm import tqdm as tqdm
import sys

class Epoch:
    def __init__(
        self,
        model,
        loss,
        stage_name,
        parameters,
        device="cpu",
        verbose=True,
        starting_step=0,
    ):
        self.model = model
        self.loss = loss
        self.verbose = verbose
        self.device = device
        self.parameters = parameters
        self.step = 0
        self._to_device()

        self.full_batch_graph = []
        self.starting_step = starting_step

    def _to_device(self):
        self.model.to(self.device)
        self.loss.to(self.device)

    def _format_logs(self, logs):
        str_logs = ["{} - {:.4}".format(k, v) for k, v in logs.items()]
        s = ", ".join(str_logs)
        return s

    def batch_update(self, x, y):
        raise NotImplementedError

    def on_epoch_start(self):
        pass

    def run(self, dataloader, writer=None, model_save_dir="checkpoint/simulator.pth"):

        self.on_epoch_start()

        logs = {}
        loss_meter = AverageValueMeter()

        with tqdm(
            dataloader,
            desc=self.stage_name,
            file=sys.stdout,
            disable=not (self.verbose),
        ) as iterator:
            for graph_data in iterator:
                for indx in range(1):

                    input_graph = Data(
                        x=graph_data["x"][indx],
                        pos=graph_data["pos"][indx],
                        edge_index=graph_data["edge_index"][indx],
                        edge_attr=graph_data.get("edge_attr", [None])[indx],
                        y=graph_data["y"][indx],
                    ).to(self.device)

                    self.full_batch_graph.append(input_graph)

                if len(self.full_batch_graph) % self.model.batch_size == 0:

                    loss = self.batch_update(self.full_batch_graph, writer)

                    # update loss logs
                    loss_value = loss.cpu().detach().numpy()
                    loss_meter.add(loss_value)
                    loss_logs = {self.loss.__name__: loss_meter.mean}
                    logs.update(loss_logs)

                    if self.model.training:
                        writer.add_scalar(
                            "Loss/train/value_per_step",
                            loss_value,
                            self.step + self.starting_step,
                        )

                    else:
                        writer.add_scalar(
                            "Loss/test/value_per_step",
                            loss_value,
                            self.step + self.starting_step,
                        )

                    if self.step % 200 == 0:
                        self.model.save_checkpoint(model_save_dir)
                        writer.flush()

                    self.step += 1
                    self.full_batch_graph = []

                    if self.verbose:
                        s = self._format_logs(logs)
                        iterator.set_postfix_str(s)

        return loss_meter.mean


class TrainEpoch(Epoch):
    def __init__(
        self,
        model,
        loss,
        parameters,
        optimizer,
        device="cpu",
        verbose=True,
        starting_step=0,
        use_sub_graph=False,
    ):
        super().__init__(
            model=model,
            loss=loss,
            stage_name="train",
            parameters=parameters,
            device=device,
            verbose=verbose,
            starting_step=starting_step,
        )
        self.optimizer = optimizer
        self.use_sub_graph = use_sub_graph

    def on_epoch_start(self):
        self.model.train()

    def batch_update(self, batch_graph, writer):
        self.optimizer.zero_grad()
        loss = 0
        for graph in batch_graph:
            g_x = graph.x
            node_type = g_x[:, self.model.node_type_index]
            network_output, target_delta_normalized = self.model(graph)
            loss += self.loss(
                target_delta_normalized,
                network_output,
                node_type,
            )

        loss /= len(batch_graph)
        loss.backward()
        max_norm = 10.0
        nn.utils.clip_grad_norm_(self.model.parameters(), max_norm)
        self.optimizer.step()

        return loss

# We can now (i hope so) train a model

In [None]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("tensorboard")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_load = None #
model = None #
loss = None #
simulator = None #
optimizer = torch.optim.Adam(simulator.parameters(), lr=0.0001)

train_epoch = None #

for i in range(0, 10):
    print("\nEpoch: {}".format(i))
    print("=== Training ===")
    train_loss = train_epoch.run(train_loader, writer, "model.pth")

    writer.add_scalar("Loss/train/mean_value_per_epoch", train_loss, i)
    writer.flush()
    writer.close()