# How to: Create and Use a MergeNode

A `MergeNode` is a computational node that combines outputs from multiple upstream nodes into a single output. It is the counterpart to `ModelNode` in a `ModelGraph`: while a `ModelNode` accepts exactly one input, a `MergeNode` accepts two or more.

Currently, ModularML provides one concrete implementation:

- **`ConcatNode`** — Concatenates inputs along a specified axis, with optional padding for mismatched dimensions.

```
ComputeNode (abstract)
├── ModelNode       # Single-input, wraps a model
└── MergeNode       # Multi-input, merges upstream outputs (abstract)
    └── ConcatNode  # Concatenates along an axis
```

This notebook covers:

- {ref}`04-create-mergenode-when-to-use-a-mergenode`
- {ref}`04-create-mergenode-creating-a-concatnode`
- {ref}`04-create-mergenode-feature-axis-behavior`
- {ref}`04-create-mergenode-per-domain-axes`
- {ref}`04-create-mergenode-target-and-tag-aggregation-strategies`
- {ref}`04-create-mergenode-padding-mismatched-dimensions`
- {ref}`04-create-mergenode-building-a-graph-with-mergenodes`
- {ref}`04-create-mergenode-forward-pass`
- {ref}`04-create-mergenode-key-properties-and-methods`
    

In [None]:
import numpy as np
import torch

from modularml import (
    ConcatNode,
    Experiment,
    FeatureSet,
    ModelGraph,
    ModelNode,
    Optimizer,
)
from modularml.core.topology.merge_nodes.merge_strategy import MergeStrategy
from modularml.models.torch import SequentialMLP

# Note that we don't need to explicitly create an Experiment right away
# We do it here so we can disable the warning raise when creating multiple
# nodes with the same name (`registration_policy` is what controls this).
exp = Experiment(label="create_mergenode", registration_policy="overwrite")

We'll use a simple synthetic dataset: 200 samples of a 10-point feature with a scalar target.

In [None]:
rng = np.random.default_rng(42)

fs = FeatureSet.from_dict(
    label="Data A",
    data={
        "X": list(rng.standard_normal((200, 10))),
        "Y": list(rng.standard_normal((200, 1))),
    },
    feature_keys="X",
    target_keys="Y",
)

fs_ref = fs.reference(features="X", targets="Y")
print(fs)

---

(04-create-mergenode-when-to-use-a-mergenode)=
## When to Use a MergeNode


A `MergeNode` is needed when your model graph has **multiple parallel branches** that must be combined before continuing to a downstream node. Common patterns include:

- **Multi-encoder fusion:** Several encoders process the same (or different) inputs, and their representations are concatenated before a final regressor.
- **Feature augmentation:** A raw feature path is concatenated with a learned embedding path.
- **Ensemble merging:** Outputs from several models are merged (by concatenation, averaging, etc.) for downstream processing.

```
FeatureSet ─┬─> EncoderA ──┐
            │              ├─> ConcatNode ──> Regressor
            └─> EncoderB ──┘
```

---

(04-create-mergenode-creating-a-concatnode)=
## Creating a ConcatNode


`ConcatNode` concatenates multiple inputs along a specified axis.

```python
    ConcatNode(
        label: str,
        upstream_refs: list[ExperimentNode | ExperimentNodeReference],
        concat_axis: int = 0,
        *,
        concat_axis_targets: int | str | MergeStrategy | ExperimentNodeReference = -1,
        concat_axis_tags: int | str | MergeStrategy | ExperimentNodeReference = -1,
        pad_inputs: bool = False,
        pad_mode: str = "constant",
        pad_value: float = 0.0,
    )
```

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `label` | `str` | (required) | Unique name for this node. |
| `upstream_refs` | `list` | (required) | List of upstream nodes or references to merge. |
| `concat_axis` | `int` | `0` | Axis along which to concatenate **features** (see {ref}`04-create-mergenode-feature-axis-behavior`). |
| `concat_axis_targets` | `int \| str \| MergeStrategy \| ExperimentNodeReference` | `-1` | Strategy for merging **targets** (see {ref}`04-create-mergenode-per-domain-axes` and {ref}`04-create-mergenode-target-and-tag-aggregation-strategies`). |
| `concat_axis_tags` | `int \| str \| MergeStrategy \| ExperimentNodeReference` | `-1` | Strategy for merging **tags** (same semantics as `concat_axis_targets`). |
| `pad_inputs` | `bool` | `False` | Whether to pad inputs to align non-concat dimensions. |
| `pad_mode` | `str` | `"constant"` | Padding mode: `"constant"`, `"reflect"`, `"replicate"`, or `"circular"`. |
| `pad_value` | `float` | `0.0` | Fill value when `pad_mode="constant"`. |

The `concat_axis` parameter controls how **features** are merged and is the primary axis used for shape inference during `ModelGraph.build()`. Targets, tags, and sample UUIDs each have their own merge behavior (see {ref}`04-create-mergenode-per-domain-axes` and {ref}`04-create-mergenode-target-and-tag-aggregation-strategies`)).

We will utilize the `ModelGraph` class in this tutorial to showcase building of connected `ModelNode`s and `ConcatNode`s.

Details on the `ModelGraph` class are provided in {doc}`03_create_modelgraph`

In [None]:
def create_model_graph(
    output_shape_a: tuple[int, ...],
    output_shape_b: tuple[int, ...],
    concat_axis: int,
):
    """
    Build a two-encoder graph to demonstrate different feature concatenation axes.

    Args:
        output_shape_a (tuple[int, ...]):
            Output shape of encoder A (excluding batch dimension).
        output_shape_b (tuple[int, ...]):
            Output shape of encoder B (excluding batch dimension).
        concat_axis (int):
            The feature concatenation axis.

    """
    enc_a = ModelNode(
        label="EncoderA",
        model=SequentialMLP(output_shape=output_shape_a, n_layers=1, hidden_dim=16),
        upstream_ref=fs_ref,
    )
    enc_b = ModelNode(
        label="EncoderB",
        model=SequentialMLP(output_shape=output_shape_b, n_layers=1, hidden_dim=16),
        upstream_ref=fs_ref,
    )
    merge = ConcatNode(
        label="Merge",
        upstream_refs=[enc_a, enc_b],
        concat_axis=concat_axis,
        pad_inputs=True,
    )

    reg = ModelNode(
        label="Regressor",
        model=SequentialMLP(n_layers=1, hidden_dim=8),
        upstream_ref=merge,
    )

    mg = ModelGraph(
        nodes=[enc_a, enc_b, merge, reg],
        optimizer=Optimizer(opt="adam", backend="torch"),
    )
    mg.build()

    print(merge)
    for k, inp_shape in merge.input_shapes.items():
        print(f" - Data from {k.resolve()}: {inp_shape}")
    print(f" - Merged output shape: {merge.output_shape}")

    return mg


mg = create_model_graph(output_shape_a=(1, 10), output_shape_b=(1, 5), concat_axis=0)
mg.visualize()

---

(04-create-mergenode-feature-axis-behavior)=
## Feature Axis Behavior


The `concat_axis` parameter controls which dimension the **feature** inputs are concatenated along.
All axis values are relative to the **data shape excluding the batch dimension**.

For example, with upstream output shapes of `(1, 8)` (excluding batch), a training batch of size 32 produces tensors of shape `(32, 1, 8)`. Here, `concat_axis=0` refers to the `1` dimension and `concat_axis=1` refers to the `8` dimension.

| `concat_axis` | Behavior | Example: `(1, 8)` + `(1, 8)` |
|---------------|----------|-------------------------------|
| `0` | Concat along first data dim | `(2, 8)` |
| `1` | Concat along second data dim | `(1, 16)` |
| `-1` | Concat along last data dim | `(1, 16)` — same as `axis=1` here |

When non-concat dimensions don't match, the node will raise a `ValueError` unless `pad_inputs=True` (see {ref}`04-create-mergenode-padding-mismatched-dimensions`).

In [None]:
# concat_axis=0: stack along first data dim
# (1, 8) + (1, 8) -> (2, 8)
mg = create_model_graph((1, 8), (1, 8), concat_axis=0)
mg.visualize()

In [None]:
# concat_axis=1: concat along second data dim
# (1, 8) + (1, 8) -> (1, 16)
mg = create_model_graph((1, 8), (1, 8), concat_axis=1)
mg.visualize()

In [None]:
# concat_axis=-1: concat along last dim (useful when ndim may vary)
# (1, 8) + (1, 16) -> (1, 24)
mg = create_model_graph((1, 8), (1, 16), concat_axis=-1)
mg.visualize()

---

(04-create-mergenode-per-domain-axes)=
## Per-Domain Axes


When a `ConcatNode` merges data from upstream nodes, it processes each domain of the `SampleData` independently:

| Domain | Parameter | Default | Description |
|--------|-----------|---------|-------------|
| **Features** | `concat_axis` | `0` | Primary axis, also used for shape inference. Always int-based. |
| **Targets** | `concat_axis_targets` | `-1` | Concatenation axis or aggregation strategy (see {ref}`04-create-mergenode-target-and-tag-aggregation-strategies`). |
| **Tags** | `concat_axis_tags` | `-1` | Concatenation axis or aggregation strategy. |
| **Sample UUIDs** | *(fixed)* | `-1` | Always concatenated along the last axis. Not configurable. |

By default, all domains use int-based concatenation. When an `int` is provided, it specifies the axis along which to concatenate - identical semantics to the feature `concat_axis`. For 1-D arrays (the most common case for targets, tags, and sample UUIDs), `-1` is equivalent to `axis=0`.

To use a non-concatenation strategy for targets or tags, see {ref}`04-create-mergenode-target-and-tag-aggregation-strategies`.

In [None]:
# Example: concat features along axis 0, targets along last axis (default)
enc_a = ModelNode(
    label="EncoderA",
    model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
enc_b = ModelNode(
    label="EncoderB",
    model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
merge = ConcatNode(
    label="Merge",
    upstream_refs=[enc_a, enc_b],
    concat_axis=0,  # features: (1,8) + (1,8) -> (2,8)
    concat_axis_targets=-1,  # targets: concat along last axis (default)
    concat_axis_tags=-1,  # tags: concat along last axis (default)
)
print(f"Feature axis:       {merge.concat_axis}")
print(f"Target strategy:    {merge.target_strategy}")
print(f"Tags strategy:      {merge.tags_strategy}")

---

(04-create-mergenode-target-and-tag-aggregation-strategies)=
## Target and Tag Aggregation Strategies


When concatenating features from multiple upstream nodes, the default behavior is to also concatenate the associated targets and tags. This is often undesirable — for example, if both encoders receive the same FeatureSet targets, concatenation doubles the target values.

The `concat_axis_targets` and `concat_axis_tags` parameters accept several types to control how these domains are merged:

| Value | Type | Behavior |
|-------|------|----------|
| `-1` (default) | `int` | Concatenate along last axis (original behavior). |
| Any `int` | `int` | Concatenate along the specified axis. |
| `"first"` | `str` or `MergeStrategy.FIRST` | Use targets/tags from the **first** upstream input only. |
| `"last"` | `str` or `MergeStrategy.LAST` | Use targets/tags from the **last** upstream input only. |
| `"mean"` | `str` or `MergeStrategy.MEAN` | Element-wise mean across all inputs (shapes must match). |
| `enc_a` | `ExperimentNode` or `ExperimentNodeReference` | Use targets/tags from a **specific** upstream input. |

When a non-concat strategy is used, any upstream inputs with `None` data for that domain are silently filtered out.

Strings are automatically converted to `MergeStrategy` enum values, so `"first"` and `MergeStrategy.FIRST` are equivalent.

In [None]:
# Strategy: "first" - use targets from the first upstream input only (enc_a)
merge_first = ConcatNode(
    label="MergeFirst",
    upstream_refs=[enc_a, enc_b],
    concat_axis=-1,
    concat_axis_targets="first",
)
print(f"target_strategy: {merge_first.target_strategy}")

# Strategy: MergeStrategy enum (equivalent to string)
merge_mean = ConcatNode(
    label="MergeMean",
    upstream_refs=[enc_a, enc_b],
    concat_axis=-1,
    concat_axis_targets=MergeStrategy.MEAN,
)
print(f"target_strategy: {merge_mean.target_strategy}")

# Strategy: select by reference — use targets from a specific upstream node
merge_ref = ConcatNode(
    label="MergeRef",
    upstream_refs=[enc_a, enc_b],
    concat_axis=-1,
    concat_axis_targets=enc_a,  # use EncoderA's targets
)
print(f"target_strategy: {merge_ref.target_strategy.node_label}")

### Comparing Strategies on a Forward Pass

Let's run the same data through merge nodes with different target strategies to see how the output targets differ.

In [None]:
from modularml.core.data.sample_data import SampleData
from modularml.utils.data.data_format import DataFormat

# First, build enc_a and enc_b by constructing a graph with one merge node (above)
reg_demo = ModelNode(
    label="Reg_demo",
    model=SequentialMLP(n_layers=1, hidden_dim=8),
    upstream_ref=merge_first,
)
mg = ModelGraph(
    nodes=[enc_a, enc_b, merge_first, reg_demo],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
mg.build()


# Now build the remaining merge nodes manually (enc_a and enc_b are already built)
input_shapes = {
    enc_a.reference(): enc_a.output_shape,
    enc_b.reference(): enc_b.output_shape,
}
for m in [merge_mean, merge_ref]:
    m.build(input_shapes=input_shapes, includes_batch_dim=False, backend="torch")

# Also build a default-concat merge for comparison
merge_concat = ConcatNode(
    label="MergeConcat",
    upstream_refs=[enc_a, enc_b],
    concat_axis=-1,
)
merge_concat.build(input_shapes=input_shapes, includes_batch_dim=False, backend="torch")

# Create sample data
fsv = fs_ref.resolve()
sample_data = SampleData(
    features=fsv.get_features(fmt=DataFormat.TORCH),
    targets=fsv.get_targets(fmt=DataFormat.TORCH),
)

with torch.no_grad():
    out_a = enc_a(sample_data)
    out_b = enc_b(sample_data)
    merge_inputs = {enc_a.reference(): out_a, enc_b.reference(): out_b}

    out_concat = merge_concat.forward(merge_inputs)
    out_first = merge_first.forward(merge_inputs)
    out_mean = merge_mean.forward(merge_inputs)
    out_ref = merge_ref.forward(merge_inputs)

print(f"Input targets shape:              {sample_data.targets.shape}")
print(f"concat (default) targets shape:   {out_concat.targets.shape}")
print(f"'first' strategy targets shape:   {out_first.targets.shape}")
print(f"'mean' strategy targets shape:    {out_mean.targets.shape}")
print(f"select-by-ref targets shape:      {out_ref.targets.shape}")
print()
print(
    f"Targets match (first == ref):     {torch.equal(out_first.targets, out_ref.targets)}",
)
print(
    f"Targets match (first == mean):    {torch.equal(out_first.targets, out_mean.targets)}",
)

In this example both encoders receive the same FeatureSet targets, so:

- **concat (default):** Targets are doubled — `(200, 1)` + `(200, 1)` → `(200, 2)`.
- **"first":** Only the first input's targets are kept — shape stays `(200, 1)`.
- **"mean":** Element-wise average of identical targets — shape stays `(200, 1)`, values unchanged.
- **select-by-ref (`enc_a`):** Identical to "first" here since `enc_a` is the first input.

The "first"/"last" and select-by-reference strategies are most useful when upstream nodes have different targets, or when you want to pass through a specific node's targets unchanged.

---

(04-create-mergenode-padding-mismatched-dimensions)=
## Padding Mismatched Dimensions


When inputs have different shapes in non-concat dimensions, `ConcatNode` can automatically pad the shorter tensors to match the longest one.

Consider two encoders with outputs `(2, 8)` and `(3, 6)`, concatenated along axis 0 (first data dim):
- **Concat dim:** `2 + 3 = 5`
- **Non-concat dim:** `max(8, 6) = 8` (shorter tensor is padded)
- **Result:** `(5, 8)`

In [None]:
# Two encoders with different output shapes in BOTH dimensions
enc_wide = ModelNode(
    label="WideEncoder",
    model=SequentialMLP(output_shape=(2, 8), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)

enc_tall = ModelNode(
    label="TallEncoder",
    model=SequentialMLP(output_shape=(3, 6), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)

# Concat on axis 0 with padding enabled
# dim 0: concatenated (2+3=5), dim 1: padded to max(8,6)=8
merge_padded = ConcatNode(
    label="PaddedMerge",
    upstream_refs=[enc_wide, enc_tall],
    concat_axis=0,
    concat_axis_targets="first",  # avoid target concatenation doubling
    pad_inputs=True,
    pad_mode="constant",
    pad_value=0.0,
)

reg = ModelNode(
    label="Regressor",
    model=SequentialMLP(output_shape=(1, 1), n_layers=1, hidden_dim=8),
    upstream_ref=merge_padded,
)

mg = ModelGraph(
    nodes=[enc_wide, enc_tall, merge_padded, reg],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
mg.build()
mg.visualize()


print(merge_padded)
for k, inp_shape in merge_padded.input_shapes.items():
    print(f" - Data from {k.resolve()}: {inp_shape}")
print(f" - Merged output shape: {merge_padded.output_shape}")

### Without Padding

If `pad_inputs=False` (the default) and non-concat dimensions don't match, a `ValueError` is raised at build time with a helpful message.

In [None]:
merge_no_pad = ConcatNode(
    label="NoPadMerge",
    upstream_refs=[enc_wide, enc_tall],
    concat_axis=0,
    pad_inputs=False,
)

try:
    merge_no_pad.build(
        input_shapes={
            enc_wide.reference(): enc_wide.output_shape,
            enc_tall.reference(): enc_tall.output_shape,
        },
        includes_batch_dim=False,
    )
except ValueError as e:
    print(f"ValueError: {e}")

---

(04-create-mergenode-building-a-graph-with-mergenodes)=
## Building a Graph with MergeNodes


In practice, you don't need to build `MergeNode`s manually. `ModelGraph.build()` handles
shape inference and build order for all nodes, including merge nodes.

We already saw this in the `create_model_graph` helper above. Here's the full pattern with a non-default target strategy:

In [None]:
enc_a = ModelNode(
    label="EncoderA",
    model=SequentialMLP(output_shape=(1, 8), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)
enc_b = ModelNode(
    label="EncoderB",
    model=SequentialMLP(output_shape=(1, 4), n_layers=1, hidden_dim=16),
    upstream_ref=fs_ref,
)

merge = ConcatNode(
    label="Merge",
    upstream_refs=[enc_a, enc_b],
    concat_axis=-1,  # concat features along last axis
    concat_axis_targets=enc_a,  # use only enc_a's targets
)

regressor = ModelNode(
    label="Regressor",
    model=SequentialMLP(n_layers=1, hidden_dim=8),
    upstream_ref=merge,
)

mg = ModelGraph(
    nodes=[enc_a, enc_b, merge, regressor],
    optimizer=Optimizer(opt="adam", backend="torch"),
)
mg.build()
mg.visualize()

print("Graph built successfully!")
for node in mg.nodes.values():
    in_shapes = None
    out_shape = None
    if hasattr(node, "input_shape"):
        in_shapes = node.input_shape
    elif hasattr(node, "intput_shapes"):
        in_shapes = node.input_shapes
    if hasattr(node, "output_shape"):
        out_shape = node.output_shape

    print(f"  {node.label}: {in_shapes} -> {out_shape}")

Here `concat_axis_targets=enc_a` tells the merge node to use **EncoderA's targets** as the output targets instead of concatenating targets from both inputs. This is passed as an `ExperimentNode` instance (which is automatically converted to an `ExperimentNodeReference`).

The graph correctly infers:
- **EncoderA:** input `(1, 10)` → output `(1, 8)`
- **EncoderB:** input `(1, 10)` → output `(1, 4)`
- **Merge:** `(1, 8)` + `(1, 4)` along last axis → `(1, 12)` features, targets selected from EncoderA
- **Regressor:** input `(1, 12)` → output `(1, 1)`

---

(04-create-mergenode-forward-pass)=
## Forward Pass


Forward passes through a `MergeNode` work the same as through a `ModelNode`. The merge
accepts `SampleData`, `RoleData`, or `Batch` and returns the same type.

When running through a `ModelGraph`, this is all handled automatically. Below we trace
a manual forward pass to show how data flows through each node, using the `concat_axis_targets=enc_a` merge node from {ref}`04-create-mergenode-building-a-graph-with-mergenodes`.

In [None]:
# Create SampleData from the FeatureSet reference (already imported above)
fsv = fs_ref.resolve()
sample_data = SampleData(
    features=fsv.get_features(fmt=DataFormat.TORCH),
    targets=fsv.get_targets(fmt=DataFormat.TORCH),
)
print(f"Input features shape: {sample_data.features.shape}")
print(f"Input targets shape:  {sample_data.targets.shape}")

In [None]:
# Trace through each node manually
with torch.no_grad():
    out_a = enc_a(sample_data)
    out_b = enc_b(sample_data)
    print(f"EncoderA features: {out_a.features.shape}")
    print(f"EncoderA targets:  {out_a.targets.shape}")
    print(f"EncoderB features: {out_b.features.shape}")
    print(f"EncoderB targets:  {out_b.targets.shape}")

    # Merge expects a dict of {reference: data}
    merge_inputs = {
        enc_a.reference(): out_a,
        enc_b.reference(): out_b,
    }
    out_merge = merge.forward(merge_inputs)
    print(f"\nMerge features:    {out_merge.features.shape}")
    print(f"Merge targets:     {out_merge.targets.shape}  (selected from EncoderA)")

    out_final = regressor(out_merge)
    print(f"Regressor output:  {out_final.features.shape}")

Notice that features are concatenated along the last axis (`concat_axis=-1`): `(1,8) + (1,4) -> (1,12)`. Because `concat_axis_targets=enc_a`, the merged targets have the same shape as the original FeatureSet targets `(200, 1)` — they are not concatenated.

Compare this with the default behavior (shown in {ref}`04-create-mergenode-target-and-tag-aggregation-strategies`), where targets would be `(200, 2)` due to concatenation.

### Verifying Padded Forward Pass

Let's verify that the padded merge node (from {ref}`04-create-mergenode-padding-mismatched-dimensions`) produces the expected shapes and that padded regions are filled with zeros.

In [None]:
print(f"PaddedMerge output_shape: {merge_padded.output_shape}")

# Forward pass
with torch.no_grad():
    out_wide = enc_wide(sample_data)
    out_tall = enc_tall(sample_data)
    print(f"WideEncoder output: {out_wide.features.shape}")
    print(f"TallEncoder output: {out_tall.features.shape}")

    padded_inputs = {
        enc_wide.reference(): out_wide,
        enc_tall.reference(): out_tall,
    }
    out_padded = merge_padded.forward(padded_inputs)
    print(f"Padded merge output: {out_padded.features.shape}")

    # Verify padding: TallEncoder (3, 6) is padded to (3, 8)
    # After concat on axis 0: rows 0:2 from WideEncoder, rows 2:5 from TallEncoder
    # Columns 6:8 of TallEncoder's contribution should be zero
    padded_region = out_padded.features[:, 2:5, 6:8].numpy()
    print(f"Padded region values (should be all zeros): {np.unique(padded_region)}")

---

(04-create-mergenode-key-properties-and-methods)=
## Key Properties and Methods


### MergeNode (base class)

| Property / Method | Description |
|-------------------|-------------|
| `.is_built` | Whether shape inference has been completed. |
| `.output_shape` | Output shape (no batch dim) after merging. |
| `.input_shapes` | Dict mapping each upstream reference to its input shape. |
| `.backend` | Backend enum, or `None` if not set. |
| `merge(x)` | Forward pass on a list of `SampleData`, `RoleData`, or `Batch`. |
| `forward(inputs)` | Forward pass from a dict of `{reference: data}`. |
| `apply_merge(values, domain=...)` | Abstract method that subclasses implement. Receives a `domain` string to allow per-domain merge logic. |

### ConcatNode

| Property / Method | Description |
|-------------------|-------------|
| `.concat_axis` | The axis along which **features** are concatenated (`int`). |
| `.target_strategy` | Strategy for merging targets: `int` (concat axis), `MergeStrategy`, or `ExperimentNodeReference`. |
| `.tags_strategy` | Strategy for merging tags (same types as `target_strategy`). |
| `.target_axis` | Convenience property — returns the int axis when `target_strategy` is `int`. Raises `TypeError` otherwise. |
| `.tags_axis` | Convenience property — returns the int axis when `tags_strategy` is `int`. Raises `TypeError` otherwise. |
| `.pad_inputs` | Whether padding is enabled. |
| `.pad_mode` | Padding mode (`"constant"`, `"reflect"`, etc.). |
| `.pad_value` | Fill value for constant padding. |

### MergeStrategy Enum

| Value | Description |
|-------|-------------|
| `MergeStrategy.CONCAT` | Concatenate along an axis (requires an int axis). |
| `MergeStrategy.FIRST` | Use data from the first upstream input. |
| `MergeStrategy.LAST` | Use data from the last upstream input. |
| `MergeStrategy.MEAN` | Element-wise mean across inputs (shapes must match). |

### Next Steps

- **ModelGraph:** See how `ModelNode`s and `MergeNode`s are composed into a full computational graph with automatic shape inference.
- **Experiment:** Use `Experiment` to combine a `ModelGraph` with training phases, loss functions, and evaluation.
- **Custom MergeNode:** Subclass `MergeNode` and implement `apply_merge()` for custom merging strategies (e.g., averaging, attention-based fusion).