# How Graph Neural Networks Actually Work

This notebook builds GNNs **from scratch** so you can see exactly what happens at every step. No black boxes.

---

## Table of Contents
1. [Why Graphs? The Problem with Regular Neural Networks](#1)
2. [Graphs as Data Structures](#2)
3. [The Core Idea: Message Passing (by hand)](#3)
4. [Building a GNN Layer from Scratch in PyTorch](#4)
5. [Full GNN Model: Node Classification Example](#5)
6. [Edge Features — What MeshGraphNets Adds](#6)
7. [Summary: The GNN Mental Model](#7)

---

<a id='1'></a>
# 1. Why Graphs? The Problem with Regular Neural Networks

Regular neural networks assume structured input:

| Network | Assumes input is... | Example |
|---------|--------------------|---------|
| MLP | Fixed-size vector | `[age, height, weight]` |
| CNN | Regular grid | Images (pixels on a grid) |
| RNN | Sequence | Text, time series |

But many real-world data are **irregular**:
- Social networks (users + friendships)
- Molecules (atoms + bonds)
- **Simulation meshes** (nodes + elements) ← this is why MeshGraphNets uses GNNs
- Citation networks, road networks, power grids...

You **can't** flatten a mesh into a grid without losing structure. GNNs operate **directly on the graph**.

---

<a id='2'></a>
# 2. Graphs as Data Structures

A graph is just: **nodes** (with features) + **edges** (connections between nodes).

```
    (0)───(1)
     │   / │
     │  /  │
     │ /   │
    (2)───(3)
```

Let's build this in code.

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

np.set_printoptions(precision=3, suppress=True)
torch.manual_seed(42)

# =============================================
# STEP 1: Define the graph structure
# =============================================

# Node features: each node has a 3-dimensional feature vector
# Think of this as: [temperature, pressure, velocity] at each mesh node
node_features = np.array([
    [1.0, 0.0, 0.5],   # Node 0
    [0.0, 1.0, 0.3],   # Node 1
    [0.5, 0.5, 0.8],   # Node 2
    [0.2, 0.8, 0.1],   # Node 3
])

# Edge list: pairs of (source, target)
# Undirected edges → we list both directions
edges = [
    (0, 1), (1, 0),  # 0 ↔ 1
    (0, 2), (2, 0),  # 0 ↔ 2
    (1, 2), (2, 1),  # 1 ↔ 2
    (1, 3), (3, 1),  # 1 ↔ 3
    (2, 3), (3, 2),  # 2 ↔ 3
]

# Separate into source and target arrays (common format)
src = [e[0] for e in edges]
dst = [e[1] for e in edges]

print(f"Nodes: {len(node_features)}, each with {node_features.shape[1]} features")
print(f"Edges: {len(edges)} (directed, so 5 undirected edges × 2)")
print(f"\nNode features:\n{node_features}")
print(f"\nEdge list (src → dst):")
for s, d in edges:
    print(f"  {s} → {d}")

Nodes: 4, each with 3 features
Edges: 10 (directed, so 5 undirected edges × 2)

Node features:
[[1.  0.  0.5]
 [0.  1.  0.3]
 [0.5 0.5 0.8]
 [0.2 0.8 0.1]]

Edge list (src → dst):
  0 → 1
  1 → 0
  0 → 2
  2 → 0
  1 → 2
  2 → 1
  1 → 3
  3 → 1
  2 → 3
  3 → 2


In [3]:
# =============================================
# ADJACENCY MATRIX: another way to represent edges
# =============================================

# A[i][j] = 1 if there's an edge from i to j
num_nodes = len(node_features)
A = np.zeros((num_nodes, num_nodes))
for s, d in edges:
    A[s][d] = 1

print("Adjacency matrix A:")
print(A)
print()

# Each row tells you who a node's neighbors are:
for i in range(num_nodes):
    neighbors = np.where(A[i] == 1)[0]
    print(f"Node {i}'s neighbors: {neighbors.tolist()}")

Adjacency matrix A:
[[0. 1. 1. 0.]
 [1. 0. 1. 1.]
 [1. 1. 0. 1.]
 [0. 1. 1. 0.]]

Node 0's neighbors: [1, 2]
Node 1's neighbors: [0, 2, 3]
Node 2's neighbors: [0, 1, 3]
Node 3's neighbors: [1, 2]


---

<a id='3'></a>
# 3. The Core Idea: Message Passing (by hand)

**This is the entire idea behind GNNs.** Everything else is details.

Each GNN layer does three things:

```
For each node i:
  1. MESSAGE:    Collect features from all neighbors j
  2. AGGREGATE:  Combine those messages (sum, mean, max)
  3. UPDATE:     Compute new feature for node i using its own feature + aggregated messages
```

That's it. Let's do it **manually with numpy** first — no PyTorch, no libraries.

In [5]:
# =============================================
# MESSAGE PASSING BY HAND (pure numpy)
# =============================================

print("=" * 60)
print("STEP-BY-STEP MESSAGE PASSING (1 layer, no learning yet)")
print("=" * 60)

X = node_features.copy()  # Shape: (4 nodes, 3 features)

print(f"\nOriginal node features X:")
for i in range(num_nodes):
    print(f"  Node {i}: {X[i]}")

# ----- STEP 1: MESSAGE -----
# For each node, gather the features of its neighbors
print(f"\n--- Step 1: MESSAGE (gather neighbor features) ---")
for i in range(num_nodes):
    neighbors = np.where(A[i] == 1)[0]
    print(f"  Node {i} receives messages from neighbors {neighbors.tolist()}:")
    for j in neighbors:
        print(f"    message from {j}: {X[j]}")

# ----- STEP 2: AGGREGATE (sum) -----
# Sum up all neighbor features for each node
# This is literally just matrix multiplication: A @ X
aggregated = A @ X

print(f"\n--- Step 2: AGGREGATE (sum of neighbor features) ---")
print(f"  This is just A @ X (matrix multiply!)")
for i in range(num_nodes):
    neighbors = np.where(A[i] == 1)[0]
    manual_sum = sum(X[j] for j in neighbors)
    print(f"  Node {i}: sum of neighbors {neighbors.tolist()} = {aggregated[i]}")

# ----- STEP 3: UPDATE -----
# Combine own features with aggregated neighbor features
# Simplest version: just add them
X_new = X + aggregated

print(f"\n--- Step 3: UPDATE (own features + aggregated) ---")
for i in range(num_nodes):
    print(f"  Node {i}: {X[i]} + {aggregated[i]} = {X_new[i]}")

print(f"\n{'=' * 60}")
print(f"After 1 round of message passing, each node's features")
print(f"now contain information from its IMMEDIATE neighbors.")
print(f"After 2 rounds → 2-hop neighborhood.")
print(f"After k rounds → k-hop neighborhood.")
print(f"{'=' * 60}")

STEP-BY-STEP MESSAGE PASSING (1 layer, no learning yet)

Original node features X:
  Node 0: [1.  0.  0.5]
  Node 1: [0.  1.  0.3]
  Node 2: [0.5 0.5 0.8]
  Node 3: [0.2 0.8 0.1]

--- Step 1: MESSAGE (gather neighbor features) ---
  Node 0 receives messages from neighbors [1, 2]:
    message from 1: [0.  1.  0.3]
    message from 2: [0.5 0.5 0.8]
  Node 1 receives messages from neighbors [0, 2, 3]:
    message from 0: [1.  0.  0.5]
    message from 2: [0.5 0.5 0.8]
    message from 3: [0.2 0.8 0.1]
  Node 2 receives messages from neighbors [0, 1, 3]:
    message from 0: [1.  0.  0.5]
    message from 1: [0.  1.  0.3]
    message from 3: [0.2 0.8 0.1]
  Node 3 receives messages from neighbors [1, 2]:
    message from 1: [0.  1.  0.3]
    message from 2: [0.5 0.5 0.8]

--- Step 2: AGGREGATE (sum of neighbor features) ---
  This is just A @ X (matrix multiply!)
  Node 0: sum of neighbors [1, 2] = [0.5 1.5 1.1]
  Node 1: sum of neighbors [0, 2, 3] = [1.7 1.3 1.4]
  Node 2: sum of neighbors

In [None]:
# =============================================
# WHY 15 ROUNDS IN MESHGRAPHNETS?
# =============================================

# Let's see how information propagates through the graph.
# After k rounds, node i has information from all nodes within k hops.

# A^k tells you how many k-hop paths exist between nodes
print("Reachability after k message-passing rounds:")
print("(non-zero means information can flow between those nodes)\n")

A_power = np.eye(num_nodes)  # A^0 = identity (each node knows itself)
for k in range(1, 4):
    A_power = A_power @ A
    reachable = (A_power > 0).astype(int)
    print(f"After {k} round(s):")
    print(reachable)
    
    # Check if fully connected
    if reachable.all():
        print(f"→ ALL nodes can reach ALL other nodes after {k} rounds!\n")
        break
    print()

print("MeshGraphNets uses 15 rounds because simulation meshes are")
print("much larger — information needs to travel across the entire mesh.")
print("15 hops ≈ the diameter of typical simulation meshes.")

---

<a id='4'></a>
# 4. Building a GNN Layer from Scratch in PyTorch

Now let's add **learnable parameters**. The simplest GNN layer (like GCN) does:

$$h_i^{(l+1)} = \sigma\left( W \cdot \text{AGGREGATE}\left(\{h_j^{(l)} : j \in \mathcal{N}(i) \cup \{i\}\}\right) \right)$$

In plain English:
1. Gather neighbor features (including self)
2. Aggregate (sum or mean)
3. Multiply by a learnable weight matrix W
4. Apply activation function (ReLU)

That's a full GNN layer. Let's build it.

In [None]:
# =============================================
# GNN LAYER FROM SCRATCH
# =============================================

class GNNLayer(nn.Module):
    """One layer of a basic Graph Neural Network.
    
    Does: h_i' = ReLU( W · mean({h_j : j ∈ neighbors(i) ∪ {i}}) )
    """
    def __init__(self, in_features, out_features):
        super().__init__()
        # This is the ONLY learnable parameter — a linear transformation
        self.linear = nn.Linear(in_features, out_features)
    
    def forward(self, X, edge_index):
        """
        X:          (num_nodes, in_features)  — node feature matrix
        edge_index: (2, num_edges)            — [src_nodes; dst_nodes]
        """
        src, dst = edge_index[0], edge_index[1]
        num_nodes = X.shape[0]
        
        # ----- STEP 1: MESSAGE -----
        # Gather source node features for each edge
        messages = X[src]  # Shape: (num_edges, in_features)
        # messages[k] = feature vector of the SOURCE node of edge k
        
        # ----- STEP 2: AGGREGATE (mean) -----
        # For each destination node, average all incoming messages
        # Using scatter_mean: group messages by destination node, then average
        agg = torch.zeros(num_nodes, X.shape[1])
        count = torch.zeros(num_nodes, 1)
        
        # Accumulate messages at destination nodes
        for k in range(len(src)):
            agg[dst[k]] += messages[k]
            count[dst[k]] += 1
        
        # Add self-loop (each node includes its OWN features)
        agg += X
        count += 1
        
        # Mean aggregation
        agg = agg / count
        
        # ----- STEP 3: UPDATE -----
        # Linear transformation + activation
        out = self.linear(agg)  # W @ agg + b
        out = F.relu(out)
        
        return out


# ---- Test it ----
X = torch.tensor(node_features, dtype=torch.float32)
edge_index = torch.tensor([src, dst], dtype=torch.long)

layer = GNNLayer(in_features=3, out_features=4)  # 3 input features → 4 output features

print(f"Input shape:  {X.shape}  (4 nodes, 3 features each)")
output = layer(X, edge_index)
print(f"Output shape: {output.shape}  (4 nodes, 4 features each)")
print(f"\nOutput (each node now has a 4-dim learned representation):")
for i in range(num_nodes):
    print(f"  Node {i}: {output[i].detach().numpy()}")

In [None]:
# =============================================
# VECTORIZED VERSION (how it's actually done)
# =============================================
# The loop above is slow. In practice, we use scatter operations.
# Here's the clean version:

class GNNLayerFast(nn.Module):
    """Same as above but vectorized using scatter_add."""
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
    
    def forward(self, X, edge_index):
        src, dst = edge_index[0], edge_index[1]
        num_nodes = X.shape[0]
        
        # Message: gather source features
        messages = X[src]  # (num_edges, features)
        
        # Aggregate: scatter_add groups messages by destination
        agg = torch.zeros_like(X)
        agg.scatter_add_(0, dst.unsqueeze(1).expand_as(messages), messages)
        
        # Add self-loop + compute mean
        degree = torch.zeros(num_nodes)
        degree.scatter_add_(0, dst, torch.ones(len(dst)))
        degree = degree + 1  # +1 for self-loop
        agg = (agg + X) / degree.unsqueeze(1)
        
        # Update
        return F.relu(self.linear(agg))


# Verify both produce same results (with same weights)
layer_fast = GNNLayerFast(3, 4)
layer_fast.linear.weight = layer.linear.weight
layer_fast.linear.bias = layer.linear.bias

output_fast = layer_fast(X, edge_index)
print(f"Outputs match: {torch.allclose(output, output_fast, atol=1e-6)}")

### Let's visualize what just happened

```
BEFORE (raw features):              AFTER 1 GNN layer:

Node 0: [temp=1.0, pres=0.0, vel=0.5]    Node 0: [?, ?, ?, ?]  ← now 4-dim,
Node 1: [temp=0.0, pres=1.0, vel=0.3]    Node 1: [?, ?, ?, ?]     encodes info
Node 2: [temp=0.5, pres=0.5, vel=0.8]    Node 2: [?, ?, ?, ?]     from neighbors
Node 3: [temp=0.2, pres=0.8, vel=0.1]    Node 3: [?, ?, ?, ?]

What happened to Node 0:
  1. Gathered features from neighbors {1, 2} and self {0}
  2. Averaged them: mean([1,0,.5], [0,1,.3], [.5,.5,.8]) = [0.5, 0.5, 0.53]
  3. Applied W @ [0.5, 0.5, 0.53] + b, then ReLU
  4. Node 0 now "knows about" the temperature/pressure/velocity of its neighbors
```

---

<a id='5'></a>
# 5. Full GNN Model: Node Classification Example

Let's build a complete GNN that actually **trains** on a task. We'll create a small graph where node color (label) depends on neighborhood structure, and train the GNN to predict it.

In [None]:
# =============================================
# CREATE A SMALL DATASET
# =============================================
# 
# Graph: two clusters connected by a bridge
#
#   Cluster A (label=0)       Cluster B (label=1)
#     0 ── 1                    4 ── 5
#     │  ╲ │                    │  ╲ │
#     2 ── 3 ──── bridge ──── 6 ── 7
#
# Task: predict which cluster each node belongs to

torch.manual_seed(42)

# Node features: random (the GNN must learn from STRUCTURE, not features)
X_train = torch.randn(8, 3)

# Labels: 0 for cluster A (nodes 0-3), 1 for cluster B (nodes 4-7)
y_train = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])

# Edges (undirected → list both directions)
edge_pairs = [
    # Cluster A (dense connections)
    (0,1),(1,0), (0,2),(2,0), (0,3),(3,0), (1,3),(3,1), (2,3),(3,2),
    # Bridge
    (3,6),(6,3),
    # Cluster B (dense connections)
    (4,5),(5,4), (4,6),(6,4), (4,7),(7,4), (5,7),(7,5), (6,7),(7,6),
]
edge_index_train = torch.tensor([[e[0] for e in edge_pairs],
                                  [e[1] for e in edge_pairs]])

print(f"Graph: {X_train.shape[0]} nodes, {len(edge_pairs)} directed edges")
print(f"Labels: {y_train.tolist()}")

In [None]:
# =============================================
# FULL GNN MODEL (2 layers + classifier)
# =============================================

class GNN(nn.Module):
    def __init__(self, in_features, hidden_dim, num_classes):
        super().__init__()
        self.layer1 = GNNLayerFast(in_features, hidden_dim)  # Our custom layer!
        self.layer2 = GNNLayerFast(hidden_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, X, edge_index):
        # Layer 1: each node learns about 1-hop neighbors
        h = self.layer1(X, edge_index)  # (N, hidden_dim)
        
        # Layer 2: each node now learns about 2-hop neighbors
        h = self.layer2(h, edge_index)  # (N, hidden_dim)
        
        # Classify each node
        out = self.classifier(h)  # (N, num_classes)
        return out


model = GNN(in_features=3, hidden_dim=16, num_classes=2)
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters())}")

In [None]:
# =============================================
# TRAIN THE GNN
# =============================================

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

print("Training...")
for epoch in range(200):
    model.train()
    optimizer.zero_grad()
    
    # Forward pass: GNN predicts label for EVERY node at once
    logits = model(X_train, edge_index_train)  # (8, 2)
    
    # Loss: compare predictions to ground truth
    loss = loss_fn(logits, y_train)
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    if epoch % 50 == 0 or epoch == 199:
        preds = logits.argmax(dim=1)
        acc = (preds == y_train).float().mean()
        print(f"  Epoch {epoch:3d}: loss={loss.item():.4f}, acc={acc:.2f}")

# Final predictions
model.eval()
with torch.no_grad():
    preds = model(X_train, edge_index_train).argmax(dim=1)
    
print(f"\nFinal predictions: {preds.tolist()}")
print(f"Ground truth:      {y_train.tolist()}")
print(f"Correct:           {['✓' if p==y else '✗' for p,y in zip(preds, y_train)]}")

### Why this works

The GNN learned to classify nodes by their **neighborhood structure**:
- Nodes in cluster A are densely connected to each other
- Nodes in cluster B are densely connected to each other
- Only 1 edge bridges the two clusters

After 2 message-passing layers, each node has aggregated features from its 2-hop neighborhood — which captures enough cluster structure to classify correctly.

**This is exactly what MeshGraphNets does**, but instead of predicting a class label, it predicts the **next-step physics** (acceleration, velocity change) at each mesh node.

---

<a id='6'></a>
# 6. Edge Features — What MeshGraphNets Adds

Our basic GNN only uses **node features**. But in MeshGraphNets, **edges also have features** (relative position, distance, etc.). This is crucial:

| Basic GNN | MeshGraphNets |
|-----------|---------------|
| Message = neighbor's node features | Message = f(sender features, receiver features, **edge features**) |
| Edge just means "connected" | Edge carries **geometric information** (relative displacement, distance) |

Let's build a GNN layer with edge features.

In [None]:
# =============================================
# GNN LAYER WITH EDGE FEATURES
# (This is what MeshGraphNets actually uses)
# =============================================

class EdgeGNNLayer(nn.Module):
    """GNN layer that uses edge features in message computation.
    
    Message:   m_ij = MLP_edge([h_i; h_j; e_ij])
    Aggregate: agg_i = sum(m_ij for j in neighbors)
    Update:    h_i' = MLP_node([h_i; agg_i])
    
    This is the architecture used in MeshGraphNets!
    """
    def __init__(self, node_features, edge_features, hidden_dim):
        super().__init__()
        # Edge MLP: takes [src_node; dst_node; edge_feat] → message
        self.edge_mlp = nn.Sequential(
            nn.Linear(node_features * 2 + edge_features, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        
        # Node MLP: takes [node_feat; aggregated_messages] → updated node
        self.node_mlp = nn.Sequential(
            nn.Linear(node_features + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        
        # LayerNorm (used in MeshGraphNets)
        self.norm = nn.LayerNorm(hidden_dim)
    
    def forward(self, X, edge_index, edge_attr):
        """
        X:         (num_nodes, node_features)
        edge_index: (2, num_edges)
        edge_attr:  (num_edges, edge_features) — e.g., relative position, distance
        """
        src, dst = edge_index[0], edge_index[1]
        num_nodes = X.shape[0]
        
        # ----- MESSAGE -----
        # Concatenate: [sender_features, receiver_features, edge_features]
        edge_input = torch.cat([X[src], X[dst], edge_attr], dim=1)
        messages = self.edge_mlp(edge_input)  # (num_edges, hidden_dim)
        
        # ----- AGGREGATE (sum) -----
        agg = torch.zeros(num_nodes, messages.shape[1])
        agg.scatter_add_(0, dst.unsqueeze(1).expand_as(messages), messages)
        
        # ----- UPDATE -----
        node_input = torch.cat([X, agg], dim=1)  # [own features; aggregated]
        out = self.node_mlp(node_input)
        
        # Residual connection + LayerNorm (MeshGraphNets style)
        # out = self.norm(X + out)  # Would need matching dims in practice
        
        return out


print("EdgeGNNLayer — the building block of MeshGraphNets")
print("=" * 55)

In [None]:
# =============================================
# DEMO: Mesh with spatial edge features
# =============================================

# Imagine a 2D mesh with 4 nodes at these positions:
#   Node 0: (0, 0)     Node 1: (1, 0)
#   Node 2: (0, 1)     Node 3: (1, 1)

positions = torch.tensor([
    [0.0, 0.0],  # Node 0
    [1.0, 0.0],  # Node 1
    [0.0, 1.0],  # Node 2
    [1.0, 1.0],  # Node 3
])

# Node features: e.g., temperature at each node
X_mesh = torch.tensor([
    [100.0],  # Node 0: hot
    [50.0],   # Node 1: warm
    [50.0],   # Node 2: warm
    [0.0],    # Node 3: cold
])

# Edges (same as before)
edge_index_mesh = torch.tensor([src, dst])

# ===== THE KEY PART: Edge features =====
# For each edge (i→j), compute RELATIVE displacement and distance
# This is exactly what MeshGraphNets does!

src_pos = positions[edge_index_mesh[0]]  # (num_edges, 2)
dst_pos = positions[edge_index_mesh[1]]  # (num_edges, 2)

relative_disp = dst_pos - src_pos           # x_j - x_i  (relative position)
distance = torch.norm(relative_disp, dim=1, keepdim=True)  # |x_j - x_i|

# Edge features = [relative_displacement, distance]
edge_attr_mesh = torch.cat([relative_disp, distance], dim=1)  # (num_edges, 3)

print("Edge features (what MeshGraphNets uses):")
print(f"{'Edge':<10} {'Rel. Disp (x,y)':<25} {'Distance':<10}")
print("-" * 45)
for k, (s, d) in enumerate(edge_pairs[:len(src)]):
    print(f"{s} → {d:<6} {str(edge_attr_mesh[k][:2].tolist()):<25} {edge_attr_mesh[k][2].item():.3f}")

In [None]:
# =============================================
# RUN THE EDGE-AWARE GNN LAYER
# =============================================

edge_layer = EdgeGNNLayer(
    node_features=1,   # Just temperature
    edge_features=3,   # [rel_x, rel_y, distance]
    hidden_dim=8
)

out = edge_layer(X_mesh, edge_index_mesh, edge_attr_mesh)

print(f"Input:  {X_mesh.shape}  — 4 nodes, 1 feature (temperature)")
print(f"Output: {out.shape}  — 4 nodes, 8 features (learned representation)")
print(f"\nEach node's representation now encodes:")
print(f"  - Its own temperature")
print(f"  - Neighbors' temperatures")
print(f"  - WHERE those neighbors are (via edge features)")
print(f"  - HOW FAR they are (via distance)")
print(f"\nThis is why MeshGraphNets can learn physics — the edge features")
print(f"encode the spatial relationships that govern physical interactions.")

---

<a id='7'></a>
# 7. Summary: The GNN Mental Model

## The 3-Step Recipe (memorize this)

```
Every GNN layer does:
┌─────────────────────────────────────────────────────────────────┐
│                                                                 │
│  1. MESSAGE:    For each edge (i,j), compute a message          │
│                 m_ij = f(h_i, h_j, e_ij)                       │
│                         ↑     ↑     ↑                           │
│                       sender receiver edge                      │
│                       node   node    features                   │
│                                                                 │
│  2. AGGREGATE:  For each node i, combine incoming messages      │
│                 agg_i = Σ m_ij   (or mean, or max)              │
│                         j∈N(i)                                  │
│                                                                 │
│  3. UPDATE:     Compute new node feature                        │
│                 h_i' = g(h_i, agg_i)                            │
│                         ↑      ↑                                │
│                       old    aggregated                         │
│                       self   neighbor info                      │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘
```

## How GNN Variants Differ

| Variant | Message function | Aggregation | Key difference |
|---------|-----------------|-------------|----------------|
| **GCN** | W · h_j (just transform neighbor) | Normalized sum | Simplest |
| **GAT** | α_ij · W · h_j (attention-weighted) | Weighted sum | Learns which neighbors matter more |
| **GraphSAGE** | W · h_j | Mean, Max, or LSTM | Supports sampling for large graphs |
| **MPNN** | MLP([h_i; h_j; e_ij]) | Sum | Uses edge features (general) |
| **MeshGraphNets** | MLP([h_i; h_j; e_ij]) + residual + LayerNorm | Sum | MPNN + dual-space edges + noise training |

## From Basic GNN → MeshGraphNets

```
Basic GNN (what we built):
  ✓ Node features
  ✓ Message passing
  ✓ Learnable weights

MeshGraphNets adds:
  + Edge features (relative displacement, distance)
  + Two edge types (mesh-space + world-space)
  + 15 message-passing rounds (not 2)
  + Residual connections + LayerNorm
  + Predicts derivatives (not states)
  + Noise injection during training
  + Separate encoder/processor/decoder
```

## Key Intuitions

1. **GNN layers ≈ diffusion**: Each layer spreads information 1 hop. After L layers, each node knows about its L-hop neighborhood.

2. **Why it works for physics**: Physical interactions are LOCAL — a node's next-state depends on its nearby neighbors. This is exactly what message passing computes.

3. **Why RELATIVE edge features**: If you use `x_j - x_i` instead of absolute positions, the model learns "how neighbors interact" — which is the same everywhere in space. A spring behaves the same whether it's at coordinates (0,0) or (100,100).

4. **Why SUM not MEAN for physics**: Sum preserves the signal that "this node has 10 neighbors pushing it" vs "1 neighbor pushing it." Mean would lose that distinction. Physical forces ADD up.