# How to: Create and Use a ModelGraph

A `ModelGraph` is the computational backbone of a ModularML `Experiment`. It organizes
one or more `ModelNode`s (and optionally `MergeNode`s) into a directed acyclic graph (DAG)
that handles:

- **Shape inference:** Automatically determines input/output shapes for every node during `build()`.
- **Topological execution:** Ensures nodes execute in dependency order during forward, training, and evaluation passes.
- **Global optimizer management:** Optionally shares a single optimizer across all trainable nodes for end-to-end gradient flow.
- **Freeze / unfreeze control:** Selectively disable training for subsets of the graph.
- **Graph mutation:** Add, remove, replace, or insert nodes dynamically.
- **Serialization & checkpointing:** Save and restore the full graph structure and learned weights.

```
FeatureSet ──> ModelNode("Encoder") ──> ModelNode("Regressor")

FeatureSet ─┬─> ModelNode("A") ──┐
            │                    ├─> ConcatNode ──> ModelNode("Head")
            └─> ModelNode("B") ──┘
```

This notebook covers:

1. [Creating a ModelGraph](#1-creating-a-modelgraph)
2. [Building the Graph](#2-building-the-graph)
3. [Graph Properties](#3-graph-properties)
4. [Forward Pass](#4-forward-pass)
5. [Graph Mutation](#5-graph-mutation)
6. [Freezing and Unfreezing](#6-freezing-and-unfreezing)
7. [Optimizer Management](#7-optimizer-management)
8. [Serialization](#8-serialization)
9. [Checkpointing](#9-checkpointing)
10. [Summary](#10-summary)

In [1]:
import numpy as np
import torch

from modularml import (
    ConcatNode,
    Experiment,
    FeatureSet,
    ModelGraph,
    ModelNode,
    Optimizer,
)
from modularml.models.torch import SequentialMLP

# Create an Experiment with overwrite policy so we can freely recreate nodes
# with the same names (prevent getting a warning each time we overwrite a node)
exp = Experiment(label="create_modelgraph", registration_policy="overwrite")

We'll use a simple synthetic dataset throughout this notebook: 500 samples of a 10-point feature with a scalar target.

In [2]:
rng = np.random.default_rng(42)

fs = FeatureSet.from_dict(
    label="SensorData",
    data={
        "voltage": list(rng.standard_normal((500, 10))),
        "soh": list(rng.standard_normal((500, 1))),
    },
    feature_keys="voltage",
    target_keys="soh",
)
fs_ref = fs.reference(features="voltage", targets="soh")
print(fs)

FeatureSet(label='SensorData', n_samples=500)


---

## 1. Creating a ModelGraph

A `ModelGraph` is constructed from a list of `GraphNode` instances and an optional shared `Optimizer`.

```python
    ModelGraph(
        nodes: list[str | GraphNode] | None,
        optimizer: Optimizer | None = None,
        label: str = "model-graph",
    )
```

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `nodes` | `list[str \| GraphNode] \| None` | (required) | Nodes comprising the graph. Pass node instances or their string labels. If `None`, all registered `GraphNode`s in the active `ExperimentContext` are used. |
| `optimizer` | `Optimizer \| None` | `None` | A shared optimizer for end-to-end training. If provided, all trainable nodes must share the same backend. |
| `label` | `str` | `"model-graph"` | A human-readable label for this graph. |

### 1.1 Simple Linear Graph

The simplest graph is a linear chain: `FeatureSet -> ModelNode`.

In [3]:
node = ModelNode(
    label="SimpleMLP",
    model=SequentialMLP(output_shape=(1, 1), n_layers=2, hidden_dim=32),
    upstream_ref=fs_ref,
)

mg = ModelGraph(
    nodes=[node],
    optimizer=Optimizer(opt="adam", opt_kwargs={"lr": 1e-3}, backend="torch"),
    label="simple-graph",
)
print(f"Label:  {mg.label}")
print(f"Nodes:  {mg.node_labels}")
print(f"Built:  {mg.is_built}")

Label:  simple-graph
Nodes:  {'SimpleMLP'}
Built:  False


### 1.2 Multi-Node Chain

Chain multiple `ModelNode`s by passing one as the `upstream_ref` of the next.

ModelGraph supports the `.visualize()` method, which we'll use to show our topology updates.

In [4]:
encoder = ModelNode(
    label="Encoder",
    model=SequentialMLP(output_shape=(1, 8), n_layers=2, hidden_dim=32),
    upstream_ref=fs_ref,
)

regressor = ModelNode(
    label="Regressor",
    model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=16),
    upstream_ref=encoder,
)

mg_chain = ModelGraph(
    nodes=[encoder, regressor],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Node labels: {mg_chain.node_labels}")

mg_chain.visualize()

Node labels: {'Encoder', 'Regressor'}


```mermaid
flowchart LR
	n2 e0@--> n0
	n0 e1@-->|"(1, 8)"| n1
	n1 e2@--> n3

	n0@{ label: "<b>ModelNode</b><br>'Encoder'", shape: rect }
	n1@{ label: "<b>ModelNode</b><br>'Regressor'", shape: rect }
	n2@{ label: "<b>FeatureSet</b><br>'SensorData'<br>n=500", shape: rect }
	n3@{ label: " ", shape: circle }
	n0:::ModelNode
	n1:::ModelNode
	n2:::FeatureSet
	n3:::OutputTerminal
	classDef ModelNode stroke-width: 2px, stroke-dasharray: 0, stroke: #2962FF, fill: #BBDEFB, color:#000000;
	classDef OutputTerminal stroke-width: 2px, stroke-dasharray: 0, stroke: #424242, fill: #616161, color:#FFFFFF;
	classDef FeatureSet stroke-width: 2px, stroke-dasharray: 0, stroke: #AA00FF, fill: #E1BEE7, color:#000000;

	classDef NoAnimation stroke-dasharray: 0;
	class e0 NoAnimation
	class e1 NoAnimation
	class e2 NoAnimation
```

### 1.3 Branching Graph with MergeNode

Use `ConcatNode` (a `MergeNode`) to combine outputs from parallel branches.

```
FeatureSet ─┬─> EncoderA ──┐
            │              ├─> ConcatNode ──> Head
            └─> EncoderB ──┘
```

In [5]:
enc_a = ModelNode(
    label="EncoderA",
    model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
enc_b = ModelNode(
    label="EncoderB",
    model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)

merge = ConcatNode(
    label="Merge",
    upstream_refs=[enc_a, enc_b],
    concat_axis=-1,
    concat_axis_targets="first",
)

head = ModelNode(
    label="Head",
    model=SequentialMLP(n_layers=1, hidden_dim=8),
    upstream_ref=merge,
)

mg_branch = ModelGraph(
    nodes=[enc_a, enc_b, merge, head],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Node labels: {mg_branch.node_labels}")

mg_branch.visualize()

Node labels: {'Merge', 'EncoderA', 'Head', 'EncoderB'}


```mermaid
flowchart LR
	n4 e0@--> n0
	n4 e1@--> n1
	n1 e2@-->|"(1, 8)"| n2
	n0 e3@-->|"(1, 4)"| n2
	n2 e4@--> n3
	n3 e5@--> n5

	n0@{ label: "<b>ModelNode</b><br>'EncoderB'", shape: rect }
	n1@{ label: "<b>ModelNode</b><br>'EncoderA'", shape: rect }
	n2@{ label: "<b>ConcatNode</b><br>'Merge'", shape: rect }
	n3@{ label: "<b>ModelNode</b><br>'Head'", shape: rect }
	n4@{ label: "<b>FeatureSet</b><br>'SensorData'<br>n=500", shape: rect }
	n5@{ label: " ", shape: circle }
	n0:::ModelNode
	n1:::ModelNode
	n2:::MergeNode
	n3:::ModelNode
	n4:::FeatureSet
	n5:::OutputTerminal
	classDef ModelNode stroke-width: 2px, stroke-dasharray: 0, stroke: #2962FF, fill: #BBDEFB, color:#000000;
	classDef MergeNode stroke-width: 2px, stroke-dasharray: 0, stroke: #565656, fill: #B1B1B1, color:#000000;
	classDef FeatureSet stroke-width: 2px, stroke-dasharray: 0, stroke: #AA00FF, fill: #E1BEE7, color:#000000;
	classDef OutputTerminal stroke-width: 2px, stroke-dasharray: 0, stroke: #424242, fill: #616161, color:#FFFFFF;

	classDef NoAnimation stroke-dasharray: 0;
	class e0 NoAnimation
	class e1 NoAnimation
	class e2 NoAnimation
	class e3 NoAnimation
	class e4 NoAnimation
	class e5 NoAnimation
```

### 1.4 Referencing Nodes by Label

Instead of passing node instances, you can pass their string labels. The graph will look them up in the active `ExperimentContext`.

In [6]:
mg_by_label = ModelGraph(
    nodes=["EncoderA", "EncoderB", "Merge", "Head"],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Node labels: {mg_by_label.node_labels}")

mg_by_label.visualize()

Node labels: {'Merge', 'EncoderA', 'Head', 'EncoderB'}


```mermaid
flowchart LR
	n4 e0@--> n0
	n4 e1@--> n1
	n1 e2@-->|"(1, 8)"| n2
	n0 e3@-->|"(1, 4)"| n2
	n2 e4@--> n3
	n3 e5@--> n5

	n0@{ label: "<b>ModelNode</b><br>'EncoderB'", shape: rect }
	n1@{ label: "<b>ModelNode</b><br>'EncoderA'", shape: rect }
	n2@{ label: "<b>ConcatNode</b><br>'Merge'", shape: rect }
	n3@{ label: "<b>ModelNode</b><br>'Head'", shape: rect }
	n4@{ label: "<b>FeatureSet</b><br>'SensorData'<br>n=500", shape: rect }
	n5@{ label: " ", shape: circle }
	n0:::ModelNode
	n1:::ModelNode
	n2:::MergeNode
	n3:::ModelNode
	n4:::FeatureSet
	n5:::OutputTerminal
	classDef ModelNode stroke-width: 2px, stroke-dasharray: 0, stroke: #2962FF, fill: #BBDEFB, color:#000000;
	classDef MergeNode stroke-width: 2px, stroke-dasharray: 0, stroke: #565656, fill: #B1B1B1, color:#000000;
	classDef FeatureSet stroke-width: 2px, stroke-dasharray: 0, stroke: #AA00FF, fill: #E1BEE7, color:#000000;
	classDef OutputTerminal stroke-width: 2px, stroke-dasharray: 0, stroke: #424242, fill: #616161, color:#FFFFFF;

	classDef NoAnimation stroke-dasharray: 0;
	class e0 NoAnimation
	class e1 NoAnimation
	class e2 NoAnimation
	class e3 NoAnimation
	class e4 NoAnimation
	class e5 NoAnimation
```

### 1.5 Without a Global Optimizer

If no global optimizer is provided, each `ModelNode` must define its own local optimizer. This is useful when different nodes need different optimizers or learning rates (stage-wise training).

In [7]:
node_with_opt = ModelNode(
    label="StageWiseMLP",
    model=SequentialMLP(output_shape=(1, 1), n_layers=2, hidden_dim=32),
    upstream_ref=fs_ref,
    optimizer=Optimizer("adam", opt_kwargs={"lr": 1e-3}, backend="torch"),
)

mg_no_global = ModelGraph(
    nodes=[node_with_opt],
    optimizer=None,
)
print(f"Global optimizer: {mg_no_global.backend}")

Global optimizer: None


---

## 2. Building the Graph

`ModelGraph.build()` performs the following steps in topological order:

1. **Validates** the DAG structure (no cycles, all upstream references resolved).
2. **Infers** input and output shapes for each node from upstream outputs and FeatureSet shapes.
3. **Builds** each node's underlying model (lazy initialization).
4. **Builds** the global optimizer (if provided) with parameters from all trainable nodes.

```python
    ModelGraph.build(*, force: bool = False)
```

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `force` | `bool` | `False` | If `True`, rebuilds even if the graph is already built. |

In [8]:
mg_branch.build()
print(f"Built: {mg_branch.is_built}")

for node in mg_branch.nodes.values():
    in_shape = node.input_shape if hasattr(node, "input_shape") else list(node.input_shapes.values())
    out_shape = getattr(node, "output_shape", None)
    print(f"  {node.label}: {in_shape} -> {out_shape}")

mg_branch.visualize() # Note how all edges now show the input/output shapes

Built: True
  EncoderA: (1, 10) -> (1, 8)
  EncoderB: (1, 10) -> (1, 4)
  Merge: [(1, 8), (1, 4)] -> (1, 12)
  Head: (1, 12) -> (1, 1)


```mermaid
flowchart LR
	n4 e0@-->|"(1, 10)"| n0
	n4 e1@-->|"(1, 10)"| n1
	n1 e2@-->|"(1, 8)"| n2
	n0 e3@-->|"(1, 4)"| n2
	n2 e4@-->|"(1, 12)"| n3
	n3 e5@-->|"(1, 1)"| n5

	n0@{ label: "<b>ModelNode</b><br>'EncoderB'  &lt;torch&gt;", shape: rect }
	n1@{ label: "<b>ModelNode</b><br>'EncoderA'  &lt;torch&gt;", shape: rect }
	n2@{ label: "<b>ConcatNode</b><br>'Merge'", shape: rect }
	n3@{ label: "<b>ModelNode</b><br>'Head'  &lt;torch&gt;", shape: rect }
	n4@{ label: "<b>FeatureSet</b><br>'SensorData'<br>n=500", shape: rect }
	n5@{ label: " ", shape: circle }
	n0:::ModelNode
	n1:::ModelNode
	n2:::MergeNode
	n3:::ModelNode
	n4:::FeatureSet
	n5:::OutputTerminal
	classDef ModelNode stroke-width: 2px, stroke-dasharray: 0, stroke: #2962FF, fill: #BBDEFB, color:#000000;
	classDef MergeNode stroke-width: 2px, stroke-dasharray: 0, stroke: #565656, fill: #B1B1B1, color:#000000;
	classDef FeatureSet stroke-width: 2px, stroke-dasharray: 0, stroke: #AA00FF, fill: #E1BEE7, color:#000000;
	classDef OutputTerminal stroke-width: 2px, stroke-dasharray: 0, stroke: #424242, fill: #616161, color:#FFFFFF;

	classDef NoAnimation stroke-dasharray: 0;
	class e0 NoAnimation
	class e1 NoAnimation
	class e2 NoAnimation
	class e3 NoAnimation
	class e4 NoAnimation
	class e5 NoAnimation
```

### Shape Inference Details

During `build()`, shapes propagate through the graph as follows:

- **Head nodes** (inputs from a `FeatureSet`): Input shape is pulled directly from the referenced `FeatureSet` data.
- **Intermediate nodes**: Input shape equals the output shape of their upstream node.
- **Tail nodes** (no downstream consumers): If no `output_shape` is specified on the model, it defaults to the target shape propagated from the upstream `FeatureSet`.
- **MergeNodes**: Both feature and target output shapes are determined by a dummy forward pass through the merge logic.

You generally do not need to specify `input_shape` on your models — `build()` infers it. Specifying `output_shape` is recommended for all non-tail nodes.

### Rebuilding

Calling `build()` on an already-built graph is a no-op unless `force=True`.

In [9]:
# No-op (already built)
mg_branch.build()

# Force rebuild (e.g., after modifying graph structure)
mg_branch.build(force=True)
print(f"Rebuilt: {mg_branch.is_built}")

Rebuilt: True


---

## 3. Graph Properties

After building, the graph exposes several useful properties for inspecting its structure.

In [10]:
print(f"Label:       {mg_branch.label}")
print(f"Built:       {mg_branch.is_built}")
print(f"Backend:     {mg_branch.backend}")
print(f"Node labels: {mg_branch.node_labels}")

Label:       model-graph
Built:       True
Backend:     torch
Node labels: {'Merge', 'EncoderA', 'Head', 'EncoderB'}


### Head and Tail Nodes

- **Head nodes**: Nodes whose inputs come directly from a `FeatureSet` (no upstream `GraphNode` dependencies).
- **Tail nodes**: Nodes whose outputs are not consumed by any other node in the graph.

In [11]:
print("Head nodes (receive FeatureSet data):")
for n in mg_branch.head_nodes.values():
    print(f"  - {n.label}")

print("\nTail nodes (produce final outputs):")
for n in mg_branch.tail_nodes.values():
    print(f"  - {n.label}")

Head nodes (receive FeatureSet data):
  - EncoderA
  - EncoderB

Tail nodes (produce final outputs):
  - Head


### Accessing Individual Nodes

Nodes are stored in a dict keyed by `node_id`. These IDs are globally unique and are the reason nodes can be reference by their label, ID, or instance at any point in an Experiment.

You can iterate over nodes or access by label.

In [12]:
# All nodes (keyed by node_id)
for n_id, node in mg_branch.nodes.items():
    print(f"  {node.label}  (id={n_id[:8]}...)")

  EncoderA  (id=ece9d534...)
  EncoderB  (id=26c998cf...)
  Merge  (id=04f48fd3...)
  Head  (id=fc21af39...)


---

## 4. Forward Pass

Once built, you can execute a forward pass through the graph. The graph handles data routing between nodes in topological order.

```python
    ModelGraph.forward(
        inputs: dict[tuple[str, FeatureSetReference], TForward],
        *,
        active_nodes: list[str | GraphNode] | None = None,
    ) -> dict[str, TForward]
```


| Parameter | Type | Description |
|-----------|------|-------------|
| `inputs` | `dict` | Mapping of `(head_node_id, FeatureSetReference)` to input data. Each head node needs its upstream `FeatureSet` data. |
| `active_nodes` | `list \| None` | Optional subset of nodes to execute. Upstream dependencies are included automatically. If `None`, all nodes run. |

**Returns:** A dict mapping `node_id` to that node's output data for every executed node.

In [13]:
from modularml.core.data.sample_data import SampleData
from modularml.utils.data.data_format import DataFormat

# Prepare input data
fsv = fs_ref.resolve()
sample_data = SampleData(
    features=fsv.get_features(fmt=DataFormat.TORCH),
    targets=fsv.get_targets(fmt=DataFormat.TORCH),
)

# Build the inputs dict: (head_node_id, featureset_ref) -> data
inputs = {}
for n_id, node in mg_branch.head_nodes.items():
    for ref in node.get_upstream_refs():
        inputs[(n_id, ref)] = sample_data

print(f"Number of input entries: {len(inputs)}")

Number of input entries: 2


In [14]:
# Execute forward pass
with torch.no_grad():
    outputs = mg_branch.forward(inputs)

print("Outputs per node:")
for n_id, out in outputs.items():
    node_label = mg_branch.nodes[n_id].label
    print(f"  {node_label}: features={out.features.shape}")

Outputs per node:
  EncoderB: features=torch.Size([500, 1, 4])
  EncoderA: features=torch.Size([500, 1, 8])
  Merge: features=torch.Size([500, 1, 12])
  Head: features=torch.Size([500, 1, 1])


### Active Nodes

You can restrict the forward pass to a subset of the graph using `active_nodes`. All required upstream dependencies are automatically included.

We can set just "merge" to be active, but all upstream nodes (Encoders A and B) will need to be executed as well.
The head node, however, does not need to be executed.

In [15]:
# Only execute EncoderA and the Merge (plus its dependencies)
with torch.no_grad():
    partial_outputs = mg_branch.forward(inputs, active_nodes=[merge])

print("Executed nodes:")
for n_id in partial_outputs:
    print(f"  - {mg_branch.nodes[n_id].label}")

Executed nodes:
  - EncoderB
  - EncoderA
  - Merge


---

## 5. Graph Mutation

`ModelGraph` provides several methods to modify the graph structure after creation. All mutation methods return `self` for method chaining.

After any structural change, the graph automatically revalidates connections and recomputes the topological order. You will need to call `build()` again to reinitialize shapes and optimizers.

### 5.1 `add_node()`

Add a new node to the graph. The node must already be connected to existing nodes via its `upstream_ref`.

In [16]:
# Start with a simple single-node graph
base_node = ModelNode(
    label="Base",
    model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
mg_mut = ModelGraph(
    nodes=[base_node],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Before: {mg_mut.node_labels}")

# Add a downstream node
added_node = ModelNode(
    label="Added",
    model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
    upstream_ref=base_node,
)
mg_mut.add_node(added_node)
print(f"After:  {mg_mut.node_labels}")

mg_mut.visualize()

Before: {'Base'}
After:  {'Base', 'Added'}


```mermaid
flowchart LR
	n2 e0@--> n0
	n0 e1@-->|"(1, 4)"| n1
	n1 e2@--> n3

	n0@{ label: "<b>ModelNode</b><br>'Base'", shape: rect }
	n1@{ label: "<b>ModelNode</b><br>'Added'", shape: rect }
	n2@{ label: "<b>FeatureSet</b><br>'SensorData'<br>n=500", shape: rect }
	n3@{ label: " ", shape: circle }
	n0:::ModelNode
	n1:::ModelNode
	n2:::FeatureSet
	n3:::OutputTerminal
	classDef ModelNode stroke-width: 2px, stroke-dasharray: 0, stroke: #2962FF, fill: #BBDEFB, color:#000000;
	classDef OutputTerminal stroke-width: 2px, stroke-dasharray: 0, stroke: #424242, fill: #616161, color:#FFFFFF;
	classDef FeatureSet stroke-width: 2px, stroke-dasharray: 0, stroke: #AA00FF, fill: #E1BEE7, color:#000000;

	classDef NoAnimation stroke-dasharray: 0;
	class e0 NoAnimation
	class e1 NoAnimation
	class e2 NoAnimation
```

### 5.2 `remove_node()`

Remove a node from the graph. Downstream nodes are reconnected to the removed node's upstream sources.

```
Given: A -> B -> C
Remove B:
Result: A -> C
```

In [17]:
# Create a 3-node chain
n1 = ModelNode(
    label="N1",
    model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
n2 = ModelNode(
    label="N2",
    model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=8),
    upstream_ref=n1,
)
n3 = ModelNode(
    label="N3",
    model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
    upstream_ref=n2,
)
mg_rem = ModelGraph(
    nodes=[n1, n2, n3],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Before: {mg_rem.node_labels}")

# Remove the middle node
mg_rem.remove_node("N2")
print(f"After:  {mg_rem.node_labels}")

mg_rem.visualize()

Before: {'N2', 'N1', 'N3'}
After:  {'N3', 'N1'}


```mermaid
flowchart LR
	n2 e0@--> n0
	n0 e1@-->|"(1, 8)"| n1
	n1 e2@--> n3

	n0@{ label: "<b>ModelNode</b><br>'N1'", shape: rect }
	n1@{ label: "<b>ModelNode</b><br>'N3'", shape: rect }
	n2@{ label: "<b>FeatureSet</b><br>'SensorData'<br>n=500", shape: rect }
	n3@{ label: " ", shape: circle }
	n0:::ModelNode
	n1:::ModelNode
	n2:::FeatureSet
	n3:::OutputTerminal
	classDef ModelNode stroke-width: 2px, stroke-dasharray: 0, stroke: #2962FF, fill: #BBDEFB, color:#000000;
	classDef OutputTerminal stroke-width: 2px, stroke-dasharray: 0, stroke: #424242, fill: #616161, color:#FFFFFF;
	classDef FeatureSet stroke-width: 2px, stroke-dasharray: 0, stroke: #AA00FF, fill: #E1BEE7, color:#000000;

	classDef NoAnimation stroke-dasharray: 0;
	class e0 NoAnimation
	class e1 NoAnimation
	class e2 NoAnimation
```

### 5.3 `replace_node()`

Replace an existing node with a new one, preserving all upstream and downstream connections.

In [18]:
# Create a simple chain
old_enc = ModelNode(
    label="OldEncoder",
    model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
reg = ModelNode(
    label="Reg",
    model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
    upstream_ref=old_enc,
)
mg_rep = ModelGraph(
    nodes=[old_enc, reg],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Before: {mg_rep.node_labels}")

# Replace with a deeper encoder
new_enc = ModelNode(
    label="NewEncoder",
    model=SequentialMLP(output_shape=(1, 8), n_layers=3, hidden_dim=64),
    upstream_ref=fs_ref,
)
mg_rep.replace_node(old_node="OldEncoder", new_node=new_enc)
print(f"After:  {mg_rep.node_labels}")

mg_rep.visualize()

Before: {'Reg', 'OldEncoder'}
After:  {'Reg', 'NewEncoder'}


```mermaid
flowchart LR
	n2 e0@--> n0
	n0 e1@-->|"(1, 8)"| n1
	n1 e2@--> n3

	n0@{ label: "<b>ModelNode</b><br>'NewEncoder'", shape: rect }
	n1@{ label: "<b>ModelNode</b><br>'Reg'", shape: rect }
	n2@{ label: "<b>FeatureSet</b><br>'SensorData'<br>n=500", shape: rect }
	n3@{ label: " ", shape: circle }
	n0:::ModelNode
	n1:::ModelNode
	n2:::FeatureSet
	n3:::OutputTerminal
	classDef ModelNode stroke-width: 2px, stroke-dasharray: 0, stroke: #2962FF, fill: #BBDEFB, color:#000000;
	classDef OutputTerminal stroke-width: 2px, stroke-dasharray: 0, stroke: #424242, fill: #616161, color:#FFFFFF;
	classDef FeatureSet stroke-width: 2px, stroke-dasharray: 0, stroke: #AA00FF, fill: #E1BEE7, color:#000000;

	classDef NoAnimation stroke-dasharray: 0;
	class e0 NoAnimation
	class e1 NoAnimation
	class e2 NoAnimation
```

### 5.4 `insert_node_between()`

Insert a new node between two already-connected nodes.

```
Given: A -> B
Insert C between A and B:
Result: A -> C -> B
```

In [19]:
a = ModelNode(
    label="A",
    model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
b = ModelNode(
    label="B",
    model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
    upstream_ref=a,
)
mg_ins = ModelGraph(
    nodes=[a, b],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
print(f"Before: {mg_ins.node_labels}")

c = ModelNode(
    label="C",
    model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,  # will be overwritten by insert
)
mg_ins.insert_node_between(new_node=c, upstream=a, downstream=b)
print(f"After:  {mg_ins.node_labels}")

# Verify connectivity
for node in mg_ins.nodes.values():
    ups = [r.node_label for r in node.get_upstream_refs()]
    print(f"  {node.label} <- {ups}")
    
mg_ins.visualize()

Before: {'A', 'B'}
After:  {'A', 'C', 'B'}
  A <- ['SensorData']
  B <- ['C']
  C <- ['A']


```mermaid
flowchart LR
	n3 e0@--> n0
	n0 e1@-->|"(1, 8)"| n1
	n1 e2@-->|"(1, 4)"| n2
	n2 e3@--> n4

	n0@{ label: "<b>ModelNode</b><br>'A'", shape: rect }
	n1@{ label: "<b>ModelNode</b><br>'C'", shape: rect }
	n2@{ label: "<b>ModelNode</b><br>'B'", shape: rect }
	n3@{ label: "<b>FeatureSet</b><br>'SensorData'<br>n=500", shape: rect }
	n4@{ label: " ", shape: circle }
	n0:::ModelNode
	n1:::ModelNode
	n2:::ModelNode
	n3:::FeatureSet
	n4:::OutputTerminal
	classDef ModelNode stroke-width: 2px, stroke-dasharray: 0, stroke: #2962FF, fill: #BBDEFB, color:#000000;
	classDef OutputTerminal stroke-width: 2px, stroke-dasharray: 0, stroke: #424242, fill: #616161, color:#FFFFFF;
	classDef FeatureSet stroke-width: 2px, stroke-dasharray: 0, stroke: #AA00FF, fill: #E1BEE7, color:#000000;

	classDef NoAnimation stroke-dasharray: 0;
	class e0 NoAnimation
	class e1 NoAnimation
	class e2 NoAnimation
	class e3 NoAnimation
```

### 5.5 `insert_node_before()` and `insert_node_after()`

- `insert_node_before(new_node, downstream=...)`: Insert before an existing node, taking over all its upstream connections.
- `insert_node_after(new_node, upstream=...)`: Insert after an existing node as an additional downstream consumer.

In [20]:
p = ModelNode(
    label="P",
    model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
q = ModelNode(
    label="Q",
    model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
    upstream_ref=p,
)
mg_ib = ModelGraph(
    nodes=[p, q],
    optimizer=Optimizer(opt="adam", backend="torch"),
)

# Insert a node before Q (takes over Q's upstream connections)
pre_q = ModelNode(
    label="PreQ",
    model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
mg_ib.insert_node_before(new_node=pre_q, downstream=q)
print("After insert_node_before:")
for node in mg_ib.nodes.values():
    ups = [r.node_label for r in node.get_upstream_refs()]
    print(f"  {node.label} <- {ups}")
    
mg_ib.visualize()

After insert_node_before:
  P <- ['SensorData']
  Q <- ['PreQ']
  PreQ <- ['P']


```mermaid
flowchart LR
	n3 e0@--> n0
	n0 e1@-->|"(1, 8)"| n1
	n1 e2@-->|"(1, 4)"| n2
	n2 e3@--> n4

	n0@{ label: "<b>ModelNode</b><br>'P'", shape: rect }
	n1@{ label: "<b>ModelNode</b><br>'PreQ'", shape: rect }
	n2@{ label: "<b>ModelNode</b><br>'Q'", shape: rect }
	n3@{ label: "<b>FeatureSet</b><br>'SensorData'<br>n=500", shape: rect }
	n4@{ label: " ", shape: circle }
	n0:::ModelNode
	n1:::ModelNode
	n2:::ModelNode
	n3:::FeatureSet
	n4:::OutputTerminal
	classDef ModelNode stroke-width: 2px, stroke-dasharray: 0, stroke: #2962FF, fill: #BBDEFB, color:#000000;
	classDef OutputTerminal stroke-width: 2px, stroke-dasharray: 0, stroke: #424242, fill: #616161, color:#FFFFFF;
	classDef FeatureSet stroke-width: 2px, stroke-dasharray: 0, stroke: #AA00FF, fill: #E1BEE7, color:#000000;

	classDef NoAnimation stroke-dasharray: 0;
	class e0 NoAnimation
	class e1 NoAnimation
	class e2 NoAnimation
	class e3 NoAnimation
```

In [21]:
# Insert a node after P (adds a new branch)
post_p = ModelNode(
    label="PostP",
    model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
    upstream_ref=fs_ref,
)
mg_ib.insert_node_after(new_node=post_p, upstream=p)
print("After insert_node_after:")
for node in mg_ib.nodes.values():
    ups = [r.node_label for r in node.get_upstream_refs()]
    print(f"  {node.label} <- {ups}")

print(f"\nTail nodes: {[n.label for n in mg_ib.tail_nodes.values()]}")

mg_ib.visualize()

After insert_node_after:
  P <- ['SensorData']
  Q <- ['PreQ']
  PreQ <- ['P']
  PostP <- ['P']

Tail nodes: ['Q', 'PostP']


```mermaid
flowchart LR
	n4 e0@--> n0
	n0 e1@-->|"(1, 8)"| n1
	n0 e2@-->|"(1, 8)"| n2
	n1 e3@-->|"(1, 4)"| n3
	n2 e4@--> n5
	n3 e5@--> n6

	n0@{ label: "<b>ModelNode</b><br>'P'", shape: rect }
	n1@{ label: "<b>ModelNode</b><br>'PreQ'", shape: rect }
	n2@{ label: "<b>ModelNode</b><br>'PostP'", shape: rect }
	n3@{ label: "<b>ModelNode</b><br>'Q'", shape: rect }
	n4@{ label: "<b>FeatureSet</b><br>'SensorData'<br>n=500", shape: rect }
	n5@{ label: " ", shape: circle }
	n6@{ label: " ", shape: circle }
	n0:::ModelNode
	n1:::ModelNode
	n2:::ModelNode
	n3:::ModelNode
	n4:::FeatureSet
	n5:::OutputTerminal
	n6:::OutputTerminal
	classDef ModelNode stroke-width: 2px, stroke-dasharray: 0, stroke: #2962FF, fill: #BBDEFB, color:#000000;
	classDef OutputTerminal stroke-width: 2px, stroke-dasharray: 0, stroke: #424242, fill: #616161, color:#FFFFFF;
	classDef FeatureSet stroke-width: 2px, stroke-dasharray: 0, stroke: #AA00FF, fill: #E1BEE7, color:#000000;

	classDef NoAnimation stroke-dasharray: 0;
	class e0 NoAnimation
	class e1 NoAnimation
	class e2 NoAnimation
	class e3 NoAnimation
	class e4 NoAnimation
	class e5 NoAnimation
```

---

## 6. Freezing and Unfreezing

Freezing prevents a node's parameters from being updated during training. This is useful for transfer learning, multi-stage training, or keeping pretrained components fixed.

```python
    ModelGraph.freeze(nodes: list[str | GraphNode] | None = None)
    ModelGraph.unfreeze(nodes: list[str | GraphNode] | None = None)
```

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `nodes` | `list \| None` | `None` | Nodes to freeze/unfreeze (by label, ID, or instance). If `None`, applies to all trainable nodes. |

In [22]:
# Using the branching graph from Section 1.3
mg_branch.build(force=True)

# Freeze specific nodes
mg_branch.freeze(nodes=[enc_a])
print(f"Frozen nodes: {[n.label for n in mg_branch.frozen_nodes.values()]}")
mg_branch.visualize(show_frozen=True)

# Unfreeze
mg_branch.unfreeze(nodes=[enc_a])
print(f"Frozen nodes: {[n.label for n in mg_branch.frozen_nodes.values()]}")


Frozen nodes: ['EncoderA']


```mermaid
flowchart LR
	n4 e0@-->|"(1, 10)"| n0
	n4 e1@-->|"(1, 10)"| n1
	n1 e2@-->|"(1, 8)"| n2
	n0 e3@-->|"(1, 4)"| n2
	n2 e4@-->|"(1, 12)"| n3
	n3 e5@-->|"(1, 1)"| n5

	n0@{ label: "<b>ModelNode</b><br>'EncoderB'  &lt;torch&gt;", shape: rect }
	n1@{ label: "<b>ModelNode</b><br>'EncoderA'  &lt;torch&gt; · frozen", shape: rect }
	n2@{ label: "<b>ConcatNode</b><br>'Merge'", shape: rect }
	n3@{ label: "<b>ModelNode</b><br>'Head'  &lt;torch&gt;", shape: rect }
	n4@{ label: "<b>FeatureSet</b><br>'SensorData'<br>n=500", shape: rect }
	n5@{ label: " ", shape: circle }
	n0:::ModelNode
	n1:::ModelNodeFrozen
	n2:::MergeNode
	n3:::ModelNode
	n4:::FeatureSet
	n5:::OutputTerminal
	classDef MergeNode stroke-width: 2px, stroke-dasharray: 0, stroke: #565656, fill: #B1B1B1, color:#000000;
	classDef FeatureSet stroke-width: 2px, stroke-dasharray: 0, stroke: #AA00FF, fill: #E1BEE7, color:#000000;
	classDef ModelNode stroke-width: 2px, stroke-dasharray: 0, stroke: #2962FF, fill: #BBDEFB, color:#000000;
	classDef OutputTerminal stroke-width: 2px, stroke-dasharray: 0, stroke: #424242, fill: #616161, color:#FFFFFF;
	classDef ModelNodeFrozen stroke-width: 2px, stroke-dasharray: 0, stroke: #90CAF9, fill: #E3F2FD, color:#666666;

	classDef NoAnimation stroke-dasharray: 0;
	class e0 NoAnimation
	class e1 NoAnimation
	class e2 NoAnimation
	class e3 NoAnimation
	class e4 NoAnimation
	class e5 NoAnimation
```

Frozen nodes: []


In [23]:
# Freeze all nodes at once
mg_branch.freeze()
print(f"All frozen: {[n.label for n in mg_branch.frozen_nodes.values()]}")

# Unfreeze all
mg_branch.unfreeze()
print(f"All unfrozen: {[n.label for n in mg_branch.frozen_nodes.values()]}")

All frozen: ['EncoderA', 'EncoderB', 'Head']
All unfrozen: []


### Frozen Nodes and the Optimizer

When using a global optimizer, the optimizer is automatically rebuilt to exclude frozen nodes' parameters before each training step. This means frozen nodes will not accumulate gradients and their weights remain unchanged.

---

## 7. Optimizer Management

The `ModelGraph` supports two training modes based on whether a global optimizer is provided:

### Global Optimizer (Graph-Wise Training)

When a global `Optimizer` is set on the `ModelGraph`:
- A single forward pass runs through the entire graph.
- All losses are accumulated.
- A single backward pass computes gradients across all unfrozen nodes.
- The global optimizer steps once.

This enables **end-to-end gradient flow** through the full graph, which is the most common training paradigm.

### No Global Optimizer (Stage-Wise Training)

When `optimizer=None` on the `ModelGraph`:
- Each `ModelNode` must have its own local `Optimizer`.
- Nodes are trained independently in topological order.
- Each node performs its own forward pass, loss computation, backward pass, and optimizer step.

This is useful when you need different optimizers per node, or when certain nodes should not share gradient flow.

### Inspecting Optimizer Parameters

After at least one training step (or after calling `build()`), you can inspect which nodes contribute parameters to the global optimizer.

In [24]:
mg_branch.build(force=True)

opt_info = mg_branch.get_optimizer_parameters()
print(f"Backend: {opt_info['backend']}")
print(f"Contributing nodes: {len(opt_info['contributing_nodes'])}")
print(f"Total parameters: {len(opt_info['parameters'])}")

Backend: torch
Contributing nodes: 3
Total parameters: 6


### Backend Constraints

When using a global optimizer, all trainable nodes must share the same backend (e.g., all PyTorch). A `RuntimeError` is raised if backends conflict.

Mixed-backend graphs (e.g., PyTorch encoder + scikit-learn head) must use stage-wise training (no global optimizer).

---

## 8. Serialization

`ModelGraph` supports full serialization: saving and loading both the graph structure (config) and learned weights (state).

### Config Serialization

`get_config()` captures the graph structure (node configs, optimizer config) without learned weights. `from_config()` reconstructs the graph from a config dict.

In [25]:
config = mg_branch.get_config()
print(f"Config keys: {list(config.keys())}")
print(f"Number of node configs: {len(config['nodes'])}")
print(f"Optimizer config: {config['optimizer'] is not None}")

Config keys: ['label', 'nodes', 'optimizer']
Number of node configs: 4
Optimizer config: True


### State Serialization

`get_state()` captures the learned weights and optimizer state. `set_state()` restores them.

In [26]:
state = mg_branch.get_state()
print(f"State keys: {list(state.keys())}")
print(f"Number of node states: {len(state['nodes'])}")
print(f"Is built: {state['is_built']}")

State keys: ['nodes', 'optimizer', 'opt_built_from_node_ids', 'is_built']
Number of node states: 4
Is built: True


### Save and Load to Disk

Use `save()` and `load()` for persistent serialization. The file includes both config and state.

In [27]:
from pathlib import Path
from tempfile import TemporaryDirectory

SAVE_DIR = TemporaryDirectory()

# Save
save_path = mg_branch.save(Path(SAVE_DIR.name) / "my_graph", overwrite=True)
print(f"Saved to: {save_path}")

# Load
# Note that we need allow overwriting because all reloaded node labels/IDs
# with those defined in this notebook
mg_loaded = ModelGraph.load(save_path, overwrite=True)
print(f"Loaded graph labels: {mg_loaded.node_labels}")

mg_loaded.visualize()

─────────────────────────────── INFO - Node ID Collision ───────────────────────────────
 The loaded ModelNode has an overlapping ID with existing ModelNode 'EncoderB'.
 'EncoderB' will be overwritten in the active ExperimentContext.
────────────────────────────────────────────────────────────────────────────────────────
─────────────────────────────── INFO - Node ID Collision ───────────────────────────────
 The loaded ModelNode has an overlapping ID with existing ModelNode 'EncoderA'.
 'EncoderA' will be overwritten in the active ExperimentContext.
────────────────────────────────────────────────────────────────────────────────────────
─────────────────────────────── INFO - Node ID Collision ───────────────────────────────
 The loaded ConcatNode has an overlapping ID with existing ConcatNode 'Merge'. 'Merge'
 will be overwritten in the active ExperimentContext.
────────────────────────────────────────────────────────────────────────────────────────
─────────────────────────────── INF

───────────────────────────── INFO - ModelGraph Collision ──────────────────────────────
 The existing ModelGraph 'model-graph' will be overwritten with the loaded ModelGraph.
────────────────────────────────────────────────────────────────────────────────────────


Saved to: /var/folders/21/fsx4ddjs3fg2wgpl7_ksh0k00000gn/T/tmp15ud7b0t/my_graph.mg.mml
Loaded graph labels: {'Merge', 'EncoderA', 'Head', 'EncoderB'}


```mermaid
flowchart LR
	n4 e0@-->|"(1, 10)"| n0
	n4 e1@-->|"(1, 10)"| n1
	n1 e2@-->|"(1, 8)"| n2
	n0 e3@-->|"(1, 4)"| n2
	n2 e4@--> n3
	n3 e5@-->|"(1, 1)"| n5

	n0@{ label: "<b>ModelNode</b><br>'EncoderB'  &lt;torch&gt;", shape: rect }
	n1@{ label: "<b>ModelNode</b><br>'EncoderA'  &lt;torch&gt;", shape: rect }
	n2@{ label: "<b>ConcatNode</b><br>'Merge'", shape: rect }
	n3@{ label: "<b>ModelNode</b><br>'Head'  &lt;torch&gt;", shape: rect }
	n4@{ label: "<b>FeatureSet</b><br>'SensorData'<br>n=500", shape: rect }
	n5@{ label: " ", shape: circle }
	n0:::ModelNode
	n1:::ModelNode
	n2:::MergeNode
	n3:::ModelNode
	n4:::FeatureSet
	n5:::OutputTerminal
	classDef ModelNode stroke-width: 2px, stroke-dasharray: 0, stroke: #2962FF, fill: #BBDEFB, color:#000000;
	classDef MergeNode stroke-width: 2px, stroke-dasharray: 0, stroke: #565656, fill: #B1B1B1, color:#000000;
	classDef FeatureSet stroke-width: 2px, stroke-dasharray: 0, stroke: #AA00FF, fill: #E1BEE7, color:#000000;
	classDef OutputTerminal stroke-width: 2px, stroke-dasharray: 0, stroke: #424242, fill: #616161, color:#FFFFFF;

	classDef NoAnimation stroke-dasharray: 0;
	class e0 NoAnimation
	class e1 NoAnimation
	class e2 NoAnimation
	class e3 NoAnimation
	class e4 NoAnimation
	class e5 NoAnimation
```

---

## 9. Checkpointing

Checkpointing allows you to save and restore the full state of a `ModelGraph` at a specific point during training. Unlike `save()` / `load()` (which creates a new `ModelGraph` instance), checkpointing restores state into an existing graph.

```python
    ModelGraph.save_checkpoint(
        filepath: Path,
        *,
        overwrite: bool = False,
        meta: dict[str, Any] | None = None,
    ) -> Path

    ModelGraph.restore_checkpoint(filepath: Path) -> ModelGraph
```

| Parameter | Type | Description |
|-----------|------|-------------|
| `filepath` | `Path` | Location to save/load the checkpoint. |
| `overwrite` | `bool` | Whether to overwrite an existing file. |
| `meta` | `dict` | Optional metadata to attach to the checkpoint (must be pickle-able). |

In [28]:
# Save a checkpoint (includes model weights and optimizer state)
ckpt_path = mg_branch.save_checkpoint(
    Path(SAVE_DIR.name) / "checkpoint_epoch5",
    overwrite=True,
    meta={"epoch": 5, "val_loss": 0.032},
)
print(f"Checkpoint saved to: {ckpt_path}")

Checkpoint saved to: /var/folders/21/fsx4ddjs3fg2wgpl7_ksh0k00000gn/T/tmp15ud7b0t/checkpoint_epoch5.ckpt.mml


In [29]:
# Restore the checkpoint into the existing graph
mg_branch.restore_checkpoint(ckpt_path)
print(f"Restored. Built: {mg_branch.is_built}")

Restored. Built: True


---

## 10. Summary

### Constructor

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `nodes` | `list[str \| GraphNode] \| None` | (required) | Nodes comprising the graph. |
| `optimizer` | `Optimizer \| None` | `None` | Shared optimizer for graph-wise training. |
| `label` | `str` | `"model-graph"` | Human-readable label. |

### Properties

| Property | Type | Description |
|----------|------|-------------|
| `.nodes` | `dict[str, GraphNode]` | All nodes keyed by `node_id`. |
| `.node_labels` | `set[str]` | Unique node labels. |
| `.head_nodes` | `dict[str, GraphNode]` | Nodes receiving FeatureSet input. |
| `.tail_nodes` | `dict[str, GraphNode]` | Nodes with no downstream consumers. |
| `.is_built` | `bool` | Whether `build()` has been called. |
| `.backend` | `Backend \| None` | Backend of the global optimizer, or `None`. |
| `.frozen_nodes` | `dict[str, GraphNode]` | Currently frozen trainable nodes. |

### Methods

| Method | Description |
|--------|-------------|
| `build(force=False)` | Build all nodes and the global optimizer. |
| `forward(inputs, active_nodes=None)` | Execute a forward pass through the graph. |
| `train_step(ctx, losses, active_nodes=None)` | Execute a single training step (graph-wise or stage-wise). |
| `eval_step(ctx, losses, active_nodes=None)` | Execute a forward-only evaluation step (no gradients). |
| `fit_step(ctx, losses=None, active_nodes=None)` | Fit batch-fit nodes (e.g., scikit-learn) in topological order. |
| `freeze(nodes=None)` | Freeze nodes to prevent training. |
| `unfreeze(nodes=None)` | Unfreeze nodes to allow training. |
| `add_node(node)` | Add a node to the graph. |
| `remove_node(node)` | Remove a node, reconnecting neighbors. |
| `replace_node(old_node, new_node)` | Replace a node, preserving connections. |
| `insert_node_between(new_node, upstream, downstream)` | Insert between two connected nodes. |
| `insert_node_before(new_node, downstream)` | Insert before an existing node. |
| `insert_node_after(new_node, upstream)` | Insert after an existing node. |
| `get_config()` / `from_config()` | Config serialization (structure only). |
| `get_state()` / `set_state()` | State serialization (includes weights). |
| `save(filepath)` / `load(filepath)` | Full serialization to/from disk. |
| `save_checkpoint(filepath, meta=None)` | Save a training checkpoint. |
| `restore_checkpoint(filepath)` | Restore state from a checkpoint. |

### Training Modes

| Mode | When | Behavior |
|------|------|----------|
| **Graph-wise** | Global `Optimizer` provided | Single forward + backward pass across all nodes. End-to-end gradient flow. |
| **Stage-wise** | No global optimizer (`None`) | Each node trains independently with its own optimizer. |

### Next Steps

- **Experiment:** Use `Experiment` to combine a `ModelGraph` with training phases,
  loss functions, and evaluation — the primary user-facing entry point.

- **ModelNode:** See how individual nodes wrap models and handle forward passes.

- **MergeNode:** Learn how to combine parallel branches with `ConcatNode`.
