# How to: Create and Use an Experiment

An `Experiment` is the top-level orchestrator in ModularML. It coordinates:

- **Phases** — units of work such as training (`TrainPhase`), evaluation (`EvalPhase`), or batch fitting (`FitPhase`)
- **Phase Groups** — named collections of phases that execute in order
- **Callbacks** — hooks at phase, group, and experiment boundaries
- **Checkpointing** — automatic saving and restoring of experiment state
- **Execution History** — records of every run for reproducibility

> **Note:** This notebook covers the `Experiment` API and how phases are registered,
> organized, and executed. Phase-specific details (configuration, advanced usage) are
> covered in dedicated notebooks:
> [`train_phases.ipynb`](train_phases.ipynb),
> [`eval_phases.ipynb`](eval_phases.ipynb), and
> [`fit_phases.ipynb`](fit_phases.ipynb).

This notebook covers:

1. [Creating an Experiment](#1-creating-an-experiment)
2. [Setting Up a Model Graph](#2-setting-up-a-model-graph)
3. [Defining Phases](#3-defining-phases)
4. [The Execution Plan](#4-the-execution-plan)
5. [Running Phases](#5-running-phases)
6. [Running the Full Execution Plan](#6-running-the-full-execution-plan)
7. [Preview Mode](#7-preview-mode)
8. [Execution History](#8-execution-history)
9. [Phase Groups](#9-phase-groups)
10. [Experiment Callbacks](#10-experiment-callbacks)
11. [Checkpointing](#11-checkpointing)
12. [Serialization](#12-serialization)
13. [Summary](#13-summary)

In [1]:
import numpy as np
import torch

import modularml as mml
from modularml import (
    AppliedLoss,
    EvalPhase,
    Experiment,
    FeatureSet,
    FitPhase,
    InputBinding,
    Loss,
    ModelGraph,
    ModelNode,
    Optimizer,
    TrainPhase,
)
from modularml.core.experiment.phases.phase_group import PhaseGroup
from modularml.samplers import SimpleSampler


---

## 1. Creating an Experiment

An `Experiment` is created with a label and an optional `registration_policy` that
controls how duplicate node names are handled.

```python
    Experiment(
        label: str,
        registration_policy: str | None = None,
        ctx: ExperimentContext | None = None,
        checkpointing: Checkpointing | None = None,
        callbacks: list[ExperimentCallback] | None = None,
    )
```

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `label` | `str` | (required) | Name for this experiment. |
| `registration_policy` | `str \| None` | `None` | How to handle duplicate node labels: `"raise"`, `"overwrite"`, or `"rename"`. |
| `ctx` | `ExperimentContext \| None` | `None` | Context to associate with. If `None`, a new context is created. |
| `checkpointing` | `Checkpointing \| None` | `None` | Experiment-level checkpointing configuration. |
| `callbacks` | `list[ExperimentCallback] \| None` | `None` | Experiment-level callbacks for phase/group boundaries. |

In [2]:
exp = Experiment(label="my_experiment", registration_policy="overwrite")
print(f"Experiment: {exp.label}")
print(f"Context:    {exp.ctx}")

Experiment: my_experiment
Context:    <modularml.core.experiment.experiment_context.ExperimentContext object at 0x317904d30>


### 1.1 Registration Policy

The `registration_policy` determines what happens when two nodes share the same label.
This is primarily useful in notebook environments where cells may be re-executed.

| Policy | Behavior |
|--------|----------|
| `"raise"` | Raises an error on duplicate labels (default). |
| `"overwrite"` | Silently replaces the existing node. |
| `"rename"` | Assigns a unique suffix to the new node's label. |

### 1.2 Creating from an Active Context

If nodes have already been registered in the current `ExperimentContext`,
you can bind a new `Experiment` to that existing context with `from_active_context()`.
This retains all previously registered nodes.

```python
    exp = Experiment.from_active_context(
        label="my_experiment",
        registration_policy="overwrite",
    )
```

---

## 2. Setting Up a Model Graph

Before defining phases, we need a `ModelGraph` with at least one `ModelNode` and a
`FeatureSet` to supply data. The `Experiment` automatically tracks the `ModelGraph`
registered in its context.

For details on creating model graphs, see the
[`create_modelgraph.ipynb`](create_modelgraph.ipynb) notebook.

In [3]:
# Create synthetic data
rng = np.random.default_rng(42)

fs = FeatureSet.from_dict(
    label="SensorData",
    data={
        "voltage": list(rng.standard_normal((500, 10))),
        "soh": list(rng.standard_normal((500, 1))),
    },
    feature_keys="voltage",
    target_keys="soh",
)

# Create a train/test split
fs.split_random(
    ratios={
        "train": 0.8,
        "test": 0.2,
    },
    seed=13,
)
print(fs)
print(f"Splits: {fs.available_splits}")
fs.visualize()

FeatureSet(label='SensorData', n_samples=500)
Splits: ['train', 'test']


```mermaid
flowchart LR
	n0 e0@-->|"n=400"| n1
	n0 e1@-->|"n=100"| n2

	n0@{ label: "<b>FeatureSet</b><br>'SensorData'<br>n=500<br>──────────<br><b>features</b><br>  voltage.raw: (10,)<br><b>targets</b><br>  soh.raw: ()", shape: rect }
	n1@{ label: "<b>Split</b><br>'train'<br>n=400<br>──────────<br><b>features</b><br>  voltage.raw: (10,)<br><b>targets</b><br>  soh.raw: ()<br>──────────<br><b>overlap</b><br>  test: 0", shape: rect }
	n2@{ label: "<b>Split</b><br>'test'<br>n=100<br>──────────<br><b>features</b><br>  voltage.raw: (10,)<br><b>targets</b><br>  soh.raw: ()<br>──────────<br><b>overlap</b><br>  train: 0", shape: rect }
	n0:::FeatureSet
	n1:::FeatureSetView
	n2:::FeatureSetView
	classDef FeatureSet stroke-width: 2px, stroke-dasharray: 0, stroke: #AA00FF, fill: #E1BEE7, color:#000000;
	classDef FeatureSetView stroke-width: 2px, stroke-dasharray: 0, stroke: #AA00FF, fill: #F3E5F5, color:#000000;

	classDef NoAnimation stroke-dasharray: 0;
	class e0 NoAnimation
	class e1 NoAnimation
```

In [4]:
from modularml.models.torch import SequentialMLP

# Reference defining which columns feed into the model
fs_ref = fs.reference(features="voltage", targets="soh")

# Create model node
node = ModelNode(
    label="MLP",
    model=SequentialMLP(output_shape=(1, 1), n_layers=2, hidden_dim=32),
    upstream_ref=fs_ref,
)

# Create model graph with a global optimizer
graph = ModelGraph(
    label="SimpleGraph",
    nodes=[node],
    optimizer=Optimizer("adam", opt_kwargs={"lr": 1e-3}, backend="torch"),
)

# Build the graph (infers shapes)
graph.build()
graph.visualize()

print(f"Experiment model_graph: {exp.model_graph}")

```mermaid
flowchart LR
	n1 e0@-->|"(1, 10)"| n0
	n0 e1@-->|"(1, 1)"| n2

	n0@{ label: "<b>ModelNode</b><br>'MLP'  &lt;torch&gt;", shape: rect }
	n1@{ label: "<b>FeatureSet</b><br>'SensorData'<br>n=500<br>splits: train, test", shape: rect }
	n2@{ label: " ", shape: circle }
	n0:::ModelNode
	n1:::FeatureSet
	n2:::OutputTerminal
	classDef FeatureSet stroke-width: 2px, stroke-dasharray: 0, stroke: #AA00FF, fill: #E1BEE7, color:#000000;
	classDef OutputTerminal stroke-width: 2px, stroke-dasharray: 0, stroke: #424242, fill: #616161, color:#FFFFFF;
	classDef ModelNode stroke-width: 2px, stroke-dasharray: 0, stroke: #2962FF, fill: #BBDEFB, color:#000000;

	classDef NoAnimation stroke-dasharray: 0;
	class e0 NoAnimation
	class e1 NoAnimation
```

Experiment model_graph: <modularml.core.topology.model_graph.ModelGraph object at 0x31796d630>


---

## 3. Defining Phases

Phases are the executable units of an `Experiment`. Each phase type handles a
different style of model execution:

| Phase | Purpose | Key Concept |
|-------|---------|-------------|
| `TrainPhase` | Mini-batch gradient training | Requires a `Sampler` and `Loss` |
| `EvalPhase` | Forward-only evaluation | No sampler; runs on full split |
| `FitPhase` | Batch fitting (e.g., scikit-learn) | Entire dataset passed at once |

All phases require **input bindings** that connect `FeatureSet` data to head
`GraphNode`s in the model graph.

### 3.1 Input Bindings

An `InputBinding` defines how data flows from a `FeatureSet` into a head `GraphNode`
during a specific phase. There are two constructors:

- **`InputBinding.for_training(...)`** — requires a `Sampler` to generate batches
- **`InputBinding.for_evaluation(...)`** — passes data directly (no sampler)

| Parameter | `for_training` | `for_evaluation` |
|-----------|:-:|:-:|
| `node` | required | required |
| `sampler` | required | — |
| `upstream` | required\* | required\* |
| `split` | optional | optional |

\* Can be `None` if the node has exactly one upstream `FeatureSet`.

In [5]:
# Training binding: requires a sampler
train_binding = InputBinding.for_training(
    node=node,
    sampler=SimpleSampler(batch_size=32, shuffle=True, seed=42),
    upstream=None,  # auto-resolved (node has one upstream FeatureSet)
    split="train",
)
print(f"Train binding node: {train_binding.node_id[:8]}...")
print(f"Train binding split: {train_binding.split}")

Train binding node: 818054a3...
Train binding split: train


In [6]:
# Evaluation binding: no sampler needed
eval_binding = InputBinding.for_evaluation(
    node=node,
    upstream=None,
    split="test",
)
print(f"Eval binding split: {eval_binding.split}")

Eval binding split: test


### 3.2 Defining a Loss

Training phases require at least one `AppliedLoss`, which binds a `Loss` function to
a specific `ModelNode` and specifies what inputs the loss receives.

```python
    AppliedLoss(
        loss: Loss,
        on: str | ModelNode,
        inputs: list[str] | dict[str, str],
        weight: float = 1.0,
        label: str | None = None,
    )
```

The `inputs` argument uses string references to resolve data at runtime:
- `"outputs"` — the model node's predictions
- `"targets"` — the target data passed through the model node

In [7]:
mse_loss = AppliedLoss(
    loss=Loss("mse", backend="torch"),
    on=node,
    inputs=["outputs", "targets"],
)
print(f"Loss: {mse_loss.label}")
print(f"Applied on: {mse_loss.node_id[:8]}...")

Loss: mse
Applied on: 818054a3...


### 3.3 Creating a TrainPhase

A `TrainPhase` performs mini-batch gradient training over one or more epochs.

There are two ways to create a `TrainPhase`:

1. **Default constructor** — provide `InputBinding`s explicitly
2. **`from_split()` convenience** — auto-generates bindings from a split name

In [8]:
# Option A: Using explicit InputBindings
train_phase = TrainPhase(
    label="train",
    input_sources=[train_binding],
    losses=[mse_loss],
    n_epochs=3,
)
print(f"TrainPhase: {train_phase.label}")
print(f"  n_epochs: {train_phase.n_epochs}")
print(f"  losses:   {[ls.label for ls in train_phase.losses]}")

train_phase.visualize()

TrainPhase: train
  n_epochs: 3
  losses:   ['mse']


```mermaid
flowchart LR
	n0 e1@-->|"(1, 1)"| n2
	n1 e2@-->|"split: train"| n3
	n3 e3@-->|"(1, 10)"| n0
	n4 e4@-.-> n0

	n0@{ label: "<b>ModelNode</b><br>'MLP'  &lt;torch&gt;<br>──────────<br><b>features</b><br>  voltage<br><b>targets</b><br>  soh", shape: rect }
	n1@{ label: "<b>FeatureSet</b><br>'SensorData'<br>n=500<br>splits: train, test", shape: rect }
	n2@{ label: " ", shape: circle }
	n3@{ label: "<b>Sampler</b><br>SimpleSampler<br>batch_size: 32<br>split: train", shape: rect }
	n4@{ label: "<b>AppliedLoss</b><br>'mse'<br>loss: mse", shape: stadium }
	n0:::ModelNode
	n1:::FeatureSet
	n2:::OutputTerminal
	n3:::FeatureSampler
	n4:::AppliedLoss
	classDef OutputTerminal stroke-width: 2px, stroke-dasharray: 0, stroke: #424242, fill: #616161, color:#FFFFFF;
	classDef ModelNode stroke-width: 2px, stroke-dasharray: 0, stroke: #2962FF, fill: #BBDEFB, color:#000000;
	classDef AppliedLoss stroke-width: 2px, stroke-dasharray: 0, stroke: #D50000, fill: #FFCDD2, color:#000000;
	classDef FeatureSampler stroke-width: 2px, stroke-dasharray: 0, stroke: #FF6D00, fill: #FFE0B2, color:#000000;
	classDef FeatureSet stroke-width: 2px, stroke-dasharray: 0, stroke: #AA00FF, fill: #E1BEE7, color:#000000;

	classDef DashMediumAnimation stroke-dasharray: 9,5, stroke-dashoffset: 100, animation: dash 3s linear infinite;
	classDef NoAnimation stroke-dasharray: 0;
	class e1 NoAnimation
	class e2 DashMediumAnimation
	class e3 DashMediumAnimation
	class e4 NoAnimation
```

In [9]:
# Option B: Using the from_split() convenience constructor
# This auto-generates InputBindings for all active head nodes
train_phase_b = TrainPhase.from_split(
    label="train_from_split",
    split="train",
    sampler=SimpleSampler(batch_size=32, shuffle=True, seed=42),
    losses=[mse_loss],
    n_epochs=3,
)
print(f"TrainPhase (from_split): {train_phase_b.label}")

train_phase.visualize()

TrainPhase (from_split): train_from_split


```mermaid
flowchart LR
	n0 e1@-->|"(1, 1)"| n2
	n1 e2@-->|"split: train"| n3
	n3 e3@-->|"(1, 10)"| n0
	n4 e4@-.-> n0

	n0@{ label: "<b>ModelNode</b><br>'MLP'  &lt;torch&gt;<br>──────────<br><b>features</b><br>  voltage<br><b>targets</b><br>  soh", shape: rect }
	n1@{ label: "<b>FeatureSet</b><br>'SensorData'<br>n=500<br>splits: train, test", shape: rect }
	n2@{ label: " ", shape: circle }
	n3@{ label: "<b>Sampler</b><br>SimpleSampler<br>batch_size: 32<br>split: train", shape: rect }
	n4@{ label: "<b>AppliedLoss</b><br>'mse'<br>loss: mse", shape: stadium }
	n0:::ModelNode
	n1:::FeatureSet
	n2:::OutputTerminal
	n3:::FeatureSampler
	n4:::AppliedLoss
	classDef OutputTerminal stroke-width: 2px, stroke-dasharray: 0, stroke: #424242, fill: #616161, color:#FFFFFF;
	classDef ModelNode stroke-width: 2px, stroke-dasharray: 0, stroke: #2962FF, fill: #BBDEFB, color:#000000;
	classDef AppliedLoss stroke-width: 2px, stroke-dasharray: 0, stroke: #D50000, fill: #FFCDD2, color:#000000;
	classDef FeatureSampler stroke-width: 2px, stroke-dasharray: 0, stroke: #FF6D00, fill: #FFE0B2, color:#000000;
	classDef FeatureSet stroke-width: 2px, stroke-dasharray: 0, stroke: #AA00FF, fill: #E1BEE7, color:#000000;

	classDef DashMediumAnimation stroke-dasharray: 9,5, stroke-dashoffset: 100, animation: dash 3s linear infinite;
	classDef NoAnimation stroke-dasharray: 0;
	class e1 NoAnimation
	class e2 DashMediumAnimation
	class e3 DashMediumAnimation
	class e4 NoAnimation
```

### 3.4 Creating an EvalPhase

An `EvalPhase` runs a forward pass over a FeatureSet split without any gradient
computation. All graph nodes are automatically frozen during evaluation.

In [10]:
# Using the from_split() convenience constructor
eval_phase = EvalPhase.from_split(
    label="eval",
    split="test",
    losses=[mse_loss],
)
print(f"EvalPhase: {eval_phase.label}")

eval_phase.visualize()

EvalPhase: eval


```mermaid
flowchart LR
	n1 e0@-->|"(1, 10) split: test"| n0
	n0 e1@-->|"(1, 1)"| n2
	n3 e2@-.-> n0

	n0@{ label: "<b>ModelNode</b><br>'MLP'  &lt;torch&gt;<br>──────────<br><b>features</b><br>  voltage<br><b>targets</b><br>  soh", shape: rect }
	n1@{ label: "<b>FeatureSet</b><br>'SensorData'<br>n=500<br>splits: train, test", shape: rect }
	n2@{ label: " ", shape: circle }
	n3@{ label: "<b>AppliedLoss</b><br>'mse'<br>loss: mse", shape: stadium }
	n0:::ModelNodeFrozen
	n1:::FeatureSet
	n2:::OutputTerminal
	n3:::AppliedLoss
	classDef FeatureSet stroke-width: 2px, stroke-dasharray: 0, stroke: #AA00FF, fill: #E1BEE7, color:#000000;
	classDef OutputTerminal stroke-width: 2px, stroke-dasharray: 0, stroke: #424242, fill: #616161, color:#FFFFFF;
	classDef ModelNodeFrozen stroke-width: 2px, stroke-dasharray: 0, stroke: #90CAF9, fill: #E3F2FD, color:#666666;
	classDef AppliedLoss stroke-width: 2px, stroke-dasharray: 0, stroke: #D50000, fill: #FFCDD2, color:#000000;

	classDef NoAnimation stroke-dasharray: 0;
	class e0 NoAnimation
	class e1 NoAnimation
	class e2 NoAnimation
```

### 3.5 Creating a FitPhase

A `FitPhase` fits batch-fit models (like scikit-learn estimators) on the entire
dataset at once. It has no epochs or sampling. By default, fitted nodes are frozen
after fitting.

```python
    fit_phase = FitPhase.from_split(
        label="fit_rf",
        split="train",
        freeze_after_fit=True,  # default
    )
```

> **Note:** FitPhase is only relevant when your `ModelGraph` contains scikit-learn
> (batch-fit) model nodes. We will not use it in the running examples below since
> our graph uses PyTorch models.

---

## 4. The Execution Plan

Every `Experiment` has an `execution_plan` property — a `PhaseGroup` that defines the
order in which phases execute when you call `experiment.run()`.

Phases are added with `add_phase()` and execute in the order they are registered.

In [11]:
# Access the execution plan
plan = exp.execution_plan
print(f"Execution plan: {plan}")
print(f"Currently empty: {len(plan.all) == 0}")

Execution plan: PhaseGroup(label=my_experiment, entries=0)
Currently empty: True


In [12]:
# Register phases in execution order
plan.add_phase(train_phase)
plan.add_phase(eval_phase)

print(f"Plan entries: {len(plan.all)}")
for i, entry in enumerate(plan.all):
    print(f"  [{i}] {entry.label} ({type(entry).__name__})")

Plan entries: 2
  [0] train (TrainPhase)
  [1] eval (EvalPhase)


### 4.1 Accessing Phases

Phases can be accessed by position (index) or by label.

In [13]:
# By index
first_phase = plan[0]
print(f"By index:  {first_phase.label}")

# By label
train_ref = plan["train"]
print(f"By label:  {train_ref.label}")

# Type-safe accessors
tp = plan.get_train_phase("train")
ep = plan.get_eval_phase("eval")
print(f"TrainPhase: {tp.label}, EvalPhase: {ep.label}")

By index:  train
By label:  train
TrainPhase: train, EvalPhase: eval


### 4.2 Removing Phases

Phases can be removed by index, label, or instance.

In [14]:
# Remove by label
plan.remove_phase("eval")
print(f"After remove: {[e.label for e in plan.all]}")

# Re-add for later examples
plan.add_phase(eval_phase)
print(f"After re-add: {[e.label for e in plan.all]}")

After remove: ['train']
After re-add: ['train', 'eval']


### 4.3 Convenience Methods

The execution plan also provides convenience methods to construct and register
phases in a single call:

```python
    plan.add_train_phase(
        label="train",
        input_sources=[...],
        losses=[...],
        n_epochs=5,
    )

    plan.add_eval_phase(
        label="eval",
        input_sources=[...],
        losses=[...],
    )
```

Aliases `add_train()`, `add_training()`, `add_eval()`, and `add_evaluation()` are also available.

---

## 5. Running Phases

Phases can be run individually with `run_phase()`, regardless of whether they
are registered on the execution plan. Each run mutates experiment state and
records an entry in `history`.

In [15]:
# Run the training phase
train_results = exp.run_phase(train_phase)
print("Training completed.")
print(f"  History entries: {len(exp.history)}")

Output()

In [16]:
# Run the evaluation phase
eval_results = exp.run_phase(eval_phase)
print("Evaluation completed.")
print(f"  History entries: {len(exp.history)}")

Evaluation completed.
  History entries: 2


### 5.1 Display Options

Each phase type accepts display-related keyword arguments to control progress bars:

**TrainPhase:**

| Parameter | Default | Description |
|-----------|---------|-------------|
| `show_sampler_progress` | `True` | Show progress for batch creation |
| `show_training_progress` | `True` | Show epoch-level progress bar |
| `persist_progress` | `IN_NOTEBOOK` | Keep progress bars visible after completion |
| `persist_epoch_progress` | `IN_NOTEBOOK` | Keep per-epoch bars visible |

**EvalPhase:**

| Parameter | Default | Description |
|-----------|---------|-------------|
| `show_eval_progress` | `False` | Show evaluation progress bar |
| `persist_progress` | `IN_NOTEBOOK` | Keep progress bars visible after completion |

---

## 6. Running the Full Execution Plan

Calling `experiment.run()` executes all phases registered on the execution plan,
in the order they were added. This is the primary entry point for running a
complete experiment.

In [17]:
# Run the full execution plan (train -> eval)
results = exp.run()
print("Full run completed.")
print(f"  History entries: {len(exp.history)}")

Output()

`run()` returns a `PhaseGroupResults` object that contains results from all
executed phases. Individual phase results can be accessed by label.

In [18]:
# Inspect results
print(f"Result type: {type(results).__name__}")
print(f"Contained results: {results.flatten()}")

Result type: PhaseGroupResults
Contained results: {'train': TrainResults(label='train', epochs=3), 'eval': EvalResults(label='eval', batches=1)}


---

## 7. Preview Mode

Sometimes you want to evaluate a phase without permanently changing experiment
state. The `preview_phase()` and `preview_group()` methods do exactly this:

1. Capture the current experiment state
2. Execute the phase/group
3. Restore the original state

Preview runs are **not** recorded in `history`, and checkpointing is disabled.

In [19]:
history_before = len(exp.history)

# Preview does not mutate state
preview_res = exp.preview_phase(eval_phase)

history_after = len(exp.history)
print(f"History before: {history_before}")
print(f"History after:  {history_after}")
print(f"State was restored: {history_before == history_after}")

History before: 3
History after:  3
State was restored: True


---

## 8. Execution History

Every call to `run_phase()`, `run_group()`, or `run()` records an `ExperimentRun`
in `experiment.history`. Each run captures:

- Label, start/end timestamps, and status
- Phase results (losses, outputs, etc.)
- Execution metadata (timing per phase)

In [20]:
for i, run in enumerate(exp.history):
    print(
        f"  Run {i}: label={run.label!r}, "
        f"status={run.status}, "
        f"duration={run.ended_at - run.started_at}"
    )

  Run 0: label='train', status=completed, duration=0:00:00.481113
  Run 1: label='eval', status=completed, duration=0:00:00.006110
  Run 2: label='my_experiment', status=completed, duration=0:00:00.822204


In [21]:
# Access the most recent run
last = exp.last_run
print(f"Last run: {last.label}")
print(f"  Status:  {last.status}")
print(f"  Results: {type(last.results).__name__}")

Last run: my_experiment
  Status:  completed
  Results: PhaseGroupResults


---

## 9. Phase Groups

A `PhaseGroup` is a named collection that organizes phases into logical blocks.
Phase groups can be nested (a group can contain other groups), enabling
hierarchical experiment structures.

The experiment's `execution_plan` is itself a `PhaseGroup`.

In [22]:
# Create a sub-group for a train-eval cycle
cycle = PhaseGroup(label="train_eval_cycle")

cycle.add_phase(
    TrainPhase.from_split(
        label="cycle_train",
        split="train",
        sampler=SimpleSampler(batch_size=32, shuffle=True, seed=42),
        losses=[mse_loss],
        n_epochs=2,
    ),
)
cycle.add_phase(
    EvalPhase.from_split(
        label="cycle_eval",
        split="test",
        losses=[mse_loss],
    ),
)

print(f"Group: {cycle}")
print(f"Entries: {[e.label for e in cycle.all]}")

Group: PhaseGroup(label=train_eval_cycle, entries=2)
Entries: ['cycle_train', 'cycle_eval']


In [23]:
# Run the group directly
group_results = exp.run_group(cycle)
print(f"Group results: {group_results.flatten()}")

Output()

### 9.1 Nesting Groups

Groups can be nested within the execution plan or within other groups.
Use `add_group()` to nest a `PhaseGroup` inside another.

In [24]:
# Build a nested plan
outer = PhaseGroup(label="outer")

inner = PhaseGroup(label="inner")
inner.add_phase(
    TrainPhase.from_split(
        label="inner_train",
        split="train",
        sampler=SimpleSampler(batch_size=64, shuffle=True, seed=0),
        losses=[mse_loss],
        n_epochs=1,
    ),
)

outer.add_group(inner)
outer.add_phase(
    EvalPhase.from_split(
        label="outer_eval",
        split="test",
        losses=[mse_loss],
    ),
)

# flatten() unrolls all nested groups into execution order
print(f"Flattened: {[p.label for p in outer.flatten()]}")

Flattened: ['inner_train', 'outer_eval']


### 9.2 PhaseGroup API

| Method | Description |
|--------|-------------|
| `add_phase(phase)` | Register a phase. |
| `add_group(group)` | Register a nested group. |
| `add_train_phase(...)` | Construct and register a `TrainPhase`. |
| `add_eval_phase(...)` | Construct and register an `EvalPhase`. |
| `remove_phase(key)` | Remove a phase by index, label, or instance. |
| `remove_group(key)` | Remove a group by index, label, or instance. |
| `clear()` | Remove all entries. |
| `flatten()` | Unroll all nested groups into a flat list of phases. |
| `get_phase(key)` | Get a phase by index or label. |
| `get_train_phase(key)` | Get a `TrainPhase` by index or label. |
| `get_eval_phase(key)` | Get an `EvalPhase` by index or label. |
| `get_group(key)` | Get a nested `PhaseGroup` by index or label. |
| `items()` | Iterate over `(label, entry)` pairs. |

---

## 10. Experiment Callbacks

Experiment-level callbacks (`ExperimentCallback`) fire at phase and group
boundaries during `run()`. They are distinct from phase-level `Callback`s that
fire at batch/epoch boundaries within a single phase.

| Hook | Trigger |
|------|---------|
| `on_experiment_start(experiment)` | Before the execution plan begins |
| `on_experiment_end(experiment)` | After the execution plan completes |
| `on_phase_start(experiment, phase)` | Before each phase executes |
| `on_phase_end(experiment, phase)` | After each phase completes |
| `on_group_start(experiment, group)` | Before each group executes |
| `on_group_end(experiment, group)` | After each group completes |
| `on_exception(experiment, phase, exception)` | On unhandled exception |

Callbacks are registered via the constructor or `add_callback()`:

```python
    exp = Experiment(
        label="my_exp",
        callbacks=[my_callback],
    )

    # Or add later
    exp.add_callback(another_callback)
```

---

## 11. Checkpointing

Experiment-level checkpointing automatically saves the full experiment state to
disk at configurable lifecycle hooks. This is useful for fault tolerance and
resumption.

Experiment checkpointing only supports `mode="disk"` (in-memory snapshots of the
full experiment state would be too large).

### 11.1 Configuring Checkpointing

Checkpointing is configured via the `Checkpointing` class and passed at
construction time or via `set_checkpointing()`.

Valid `save_on` hooks for experiment-level checkpointing:

| Hook | When |
|------|------|
| `"phase_start"` | Before each phase |
| `"phase_end"` | After each phase |
| `"group_start"` | Before each group |
| `"group_end"` | After each group |
| `"experiment_start"` | Before `run()` begins |
| `"experiment_end"` | After `run()` completes |

```python
    from modularml import Checkpointing

    exp = Experiment(
        label="checkpointed_exp",
        checkpointing=Checkpointing(
            mode="disk",
            save_on=["phase_end"],
            directory="./checkpoints",
        ),
    )
```

### 11.2 Manual Checkpointing

You can also save and restore checkpoints manually.

In [25]:
from pathlib import Path
from tempfile import TemporaryDirectory

CKPT_DIR = TemporaryDirectory()

# Set the checkpoint directory
exp.set_checkpoint_dir(Path(CKPT_DIR.name))

# Save a checkpoint
ckpt_path = exp.save_checkpoint("after_training", overwrite=True)
print(f"Checkpoint saved to: {ckpt_path}")
print(f"Available checkpoints: {list(exp.available_checkpoints.keys())}")

Checkpoint saved to: /var/folders/21/fsx4ddjs3fg2wgpl7_ksh0k00000gn/T/tmpcbx2h5de/after_training.ckpt.mml
Available checkpoints: ['after_training']


In [26]:
# Restore from a checkpoint (by name or path)
exp.restore_checkpoint("after_training")
print("Checkpoint restored.")

Checkpoint restored.


### 11.3 Disabling Checkpointing

Use the `disable_checkpointing()` context manager to temporarily suppress all
checkpointing (both experiment-level and TrainPhase-level).

```python
    with exp.disable_checkpointing():
        exp.run_phase(train_phase)  # No checkpoints saved
```

---

## 12. Serialization

An `Experiment` can be fully serialized to disk via `save()` and reloaded with `load()`.
This includes the model graph state, execution plan, and execution history.

In [27]:
SAVE_DIR = TemporaryDirectory()

# Save the experiment
save_path = exp.save(Path(SAVE_DIR.name) / "my_experiment", overwrite=True)
print(f"Experiment saved to: {save_path}")

Experiment saved to: /var/folders/21/fsx4ddjs3fg2wgpl7_ksh0k00000gn/T/tmpvpfqkphf/my_experiment.exp.mml


In [28]:
# Load the experiment
loaded_exp = Experiment.load(save_path, overwrite=True)
print(f"Loaded experiment: {loaded_exp.label}")
print(f"  Model graph: {loaded_exp.model_graph}")

─────────────────────────────── INFO - Node ID Collision ───────────────────────────────
 Loaded FeatureSet is identical to 'SensorData' in the existing ExperimentContext.
 Returning 'FeatureSet(label='SensorData', n_samples=500)'.
────────────────────────────────────────────────────────────────────────────────────────
─────────────────────────────── INFO - Node ID Collision ───────────────────────────────
 The loaded ModelNode has an overlapping ID with existing ModelNode 'MLP'. 'MLP' will
 be overwritten in the active ExperimentContext.
────────────────────────────────────────────────────────────────────────────────────────
───────────────────────────── INFO - ModelGraph Collision ──────────────────────────────
 Loaded ModelGraph is identical to 'SimpleGraph' in the existing ExperimentContext.
 Returning 'SimpleGraph'.
────────────────────────────────────────────────────────────────────────────────────────
 The serialized experiment contains 1 checkpoint(s), but no `checkpoint_dir` w

Loaded experiment: my_experiment
  Model graph: <modularml.core.topology.model_graph.ModelGraph object at 0x31796d630>


The `get_config()` and `get_state()` methods provide lower-level access to the
experiment's structure and mutable state for custom serialization workflows.

```python
    config = exp.get_config()   # Structure (label, plan, policy)
    state = exp.get_state()     # Mutable state (context, history, checkpoints)

    # Restore
    exp.set_state(state)
```

---

## 13. Summary

### Experiment Constructor

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `label` | `str` | (required) | Name for this experiment. |
| `registration_policy` | `str \| None` | `None` | `"raise"`, `"overwrite"`, or `"rename"`. |
| `ctx` | `ExperimentContext \| None` | `None` | Context to bind to. |
| `checkpointing` | `Checkpointing \| None` | `None` | Auto-checkpoint configuration. |
| `callbacks` | `list[ExperimentCallback] \| None` | `None` | Experiment-level callbacks. |

### Experiment Properties

| Property | Type | Description |
|----------|------|-------------|
| `ctx` | `ExperimentContext` | The associated context. |
| `model_graph` | `ModelGraph \| None` | The registered model graph. |
| `execution_plan` | `PhaseGroup` | Phases to run on `run()`. |
| `history` | `list[ExperimentRun]` | All completed runs. |
| `last_run` | `ExperimentRun \| None` | Most recent run. |
| `checkpointing` | `Checkpointing \| None` | Checkpoint configuration. |
| `available_checkpoints` | `dict[str, Path]` | Saved checkpoint registry. |
| `exp_callbacks` | `list[ExperimentCallback]` | Registered callbacks. |

### Experiment Methods

| Method | Description |
|--------|-------------|
| `run()` | Execute the full execution plan. |
| `run_phase(phase)` | Execute a single phase (records history). |
| `run_group(group)` | Execute a phase group (records history). |
| `preview_phase(phase)` | Execute a phase without mutating state. |
| `preview_group(group)` | Execute a group without mutating state. |
| `add_callback(cb)` | Register an experiment-level callback. |
| `set_checkpointing(ckpt)` | Attach/replace checkpointing configuration. |
| `set_checkpoint_dir(path)` | Set the checkpoint save directory. |
| `save_checkpoint(name)` | Manually save a checkpoint. |
| `restore_checkpoint(name)` | Restore from a saved checkpoint. |
| `disable_checkpointing()` | Context manager to suppress checkpointing. |
| `save(filepath)` | Serialize experiment to disk. |
| `load(filepath)` | Load experiment from disk. |
| `get_config()` / `from_config()` | Config serialization. |
| `get_state()` / `set_state()` | State serialization. |

### Phase Types

| Phase | Module | Use Case |
|-------|--------|----------|
| `TrainPhase` | `modularml` | Mini-batch gradient training with epochs and sampling. |
| `EvalPhase` | `modularml` | Forward-only evaluation on a data split. |
| `FitPhase` | `modularml` | Batch fitting for scikit-learn models. |

### Next Steps

- **TrainPhase:** Detailed training configuration, batch scheduling, and
  TrainPhase-level checkpointing — see [`train_phases.ipynb`](train_phases.ipynb).

- **EvalPhase:** Evaluation strategies, batched evaluation, and metrics —
  see [`eval_phases.ipynb`](eval_phases.ipynb).

- **FitPhase:** Batch-fit workflows for scikit-learn models —
  see [`fit_phases.ipynb`](fit_phases.ipynb).
