# torch.compile with `DiscreteStates`

This short experiment shows that a `DiscreteStates` wrapper can safely flow through `torch.compile`. We instantiate a simple environment, grab its states/actions, and compare the eager and compiled results of a single `_step` call.


In [1]:
import torch
from gfn.gym.hypergrid import HyperGrid

# Resolve device (CUDA if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Instantiate a small environment and grab states/actions.
env = HyperGrid(ndim=2, height=4, device=device)
states = env.reset(batch_shape=4)
actions = env.actions_from_batch_shape((4,))
actions.tensor = torch.ones((4, 1), dtype=torch.long, device=device)

# Define a helper that takes raw tensors, rebuilds the wrappers, and returns the step result.
def step_once(states_tensor: torch.Tensor, actions_tensor: torch.Tensor) -> torch.Tensor:
    s = env.States(states_tensor)
    a = env.Actions(actions_tensor)
    return env._step(s, a).tensor

compiled_step = torch.compile(step_once, dynamic=True)

eager_out = step_once(states.tensor, actions.tensor)
compiled_out = compiled_step(states.tensor, actions.tensor)

print("Outputs match:", torch.equal(eager_out, compiled_out))
print("Output device:", compiled_out.device)
print("Example compiled output:\n", compiled_out)


Using device: cpu


W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] Graph break from `Tensor.item()`, consider setting:
W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0]     torch._dynamo.config.capture_scalar_outputs = True
W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] or:
W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] to include these operations in the captured graph.
W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] 
W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0] Graph break: from user code at:
W1126 23:25:12.035000 80441 site-packages/torch/_dynamo/variables/tensor.py:1047] [9/0]   File "/Users/jdv/code/torchgfn/src/gfn/env.py", line 3

Outputs match: True
Output device: cpu
Example compiled output:
 tensor([[0, 1],
        [0, 1],
        [0, 1],
        [0, 1]])


## Microbenchmark harness

The cells below build a small timing helper so we can compare `step_once` in eager mode vs the `torch.compile(..., dynamic=True)` variant under identical inputs. We run everything on CPU for consistency.


In [2]:
import math
import statistics
import warnings
from typing import Callable, Dict

import torch.utils.benchmark as benchmark


def _sync_if_needed() -> None:
    if torch.cuda.is_available():
        torch.cuda.synchronize()


def benchmark_step_fn(
    step_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    label: str,
    states_tensor: torch.Tensor,
    actions_tensor: torch.Tensor,
    *,
    iters: int = 200,
) -> Dict[str, float]:
    """Time repeated calls to `step_fn` under identical inputs."""

    torch.manual_seed(0)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    warmup_iters = max(5, iters // 10)
    for _ in range(warmup_iters):
        step_fn(states_tensor, actions_tensor)
    _sync_if_needed()

    timer = benchmark.Timer(
        stmt="fn(states_tensor, actions_tensor)",
        globals={
            "fn": step_fn,
            "states_tensor": states_tensor,
            "actions_tensor": actions_tensor,
        },
        label=label,
        sub_label=f"device={states_tensor.device}",
        description="step_once microbenchmark",
    )
    result = timer.timeit(iters)
    std_ms = statistics.pstdev(result.raw_times) * 1000 if result.raw_times else float("nan")
    run_count = len(result.raw_times) if result.raw_times else iters
    return {
        "label": label,
        "mean_ms": result.mean * 1000,
        "std_ms": std_ms,
        "iters": run_count,
    }



In [3]:
benchmark_iters = 20000
results = []

results.append(
    benchmark_step_fn(
        step_once,
        label="Eager step_once",
        states_tensor=states.tensor,
        actions_tensor=actions.tensor,
        iters=benchmark_iters,
    )
)

with warnings.catch_warnings(record=True) as caught:
    warnings.simplefilter("always")
    results.append(
        benchmark_step_fn(
            compiled_step,
            label="torch.compile(step_once)",
            states_tensor=states.tensor,
            actions_tensor=actions.tensor,
            iters=benchmark_iters,
        )
    )
    compile_warning_messages = sorted({str(w.message) for w in caught})

results



[{'label': 'Eager step_once',
  'mean_ms': 0.10937216250458733,
  'std_ms': 0.0,
  'iters': 1},
 {'label': 'torch.compile(step_once)',
  'mean_ms': 0.3456614166498184,
  'std_ms': 0.0,
  'iters': 1}]

In [4]:
import torch._dynamo as dynamo


def _format_results(rows):
    header = f"{'Mode':<30} {'mean (ms)':>12} {'std (ms)':>12} {'iters':>8}"
    lines = [header, "-" * len(header)]
    for row in rows:
        lines.append(
            f"{row['label']:<30} {row['mean_ms']:>12.4f} {row['std_ms']:>12.4f} {row['iters']:>8d}"
        )
    return "\n".join(lines)


def _extract_count(report: str, prefix: str) -> int:
    for line in report.splitlines():
        if line.startswith(prefix):
            return int(line.split(":", 1)[1].strip())
    return -1


print(_format_results(results))

eager_mean = next(r for r in results if r["label"] == "Eager step_once")["mean_ms"]
compiled_mean = next(r for r in results if "torch.compile" in r["label"])["mean_ms"]
speedup = eager_mean / compiled_mean if compiled_mean else float("nan")
print(f"\nSpeedup (eager / compiled): {speedup:.3f}x")

compiled_report = dynamo.explain(step_once)(states.tensor, actions.tensor)
compiled_report_text = str(compiled_report)

graph_count = _extract_count(compiled_report_text, "Graph Count")
graph_breaks = _extract_count(compiled_report_text, "Graph Break Count")
break_reasons = sorted(
    {
        line.strip().split(":", 1)[1].strip()
        for line in compiled_report_text.splitlines()
        if line.strip().startswith("Reason:")
    }
)

print(
    f"\nDynamo summary -> Graphs: {graph_count}, Graph breaks: {graph_breaks}, "
    f"Break reasons: {break_reasons or ['None']}"
)

if compile_warning_messages:
    print("\nWarnings during compiled execution:")
    for msg in compile_warning_messages:
        print(f"  - {msg}")
else:
    print("\nWarnings during compiled execution: none captured")



Mode                              mean (ms)     std (ms)    iters
-----------------------------------------------------------------
Eager step_once                      0.1094       0.0000        1
torch.compile(step_once)             0.3457       0.0000        1

Speedup (eager / compiled): 0.316x

Dynamo summary -> Graphs: 12, Graph breaks: 11, Break reasons: ['Dynamic shape operator', 'Unsupported Tensor.item() call with capture_scalar_outputs=False']



## Full GFlowNet benchmark

The cell below reuses `train_hypergrid_optimized.py`'s benchmarking entry-point so we can time a larger training loop (Baseline vs compiled) directly from this notebook.


In [5]:
import importlib
import json
import sys
from pathlib import Path

# Add project root to path (notebook is in tutorials/notebooks/)
project_root = Path.cwd().parent.parent
sys.path.append(str(project_root))
from tutorials.examples import train_hypergrid_optimized as hypergrid_train


# Reload to pick up local edits without restarting the kernel.
importlib.reload(hypergrid_train)


def notebook_benchmark_run(
    *,
    compile_mode: str = "none",
    use_compile: bool = False,
    chunk_size: int = 0,
    n_iterations: int = 200,
    warmup_iters: int = 50,
    seed: int = 0,
    device: str = "cpu",
    label: str,
) -> dict:
    argv_backup = sys.argv
    try:
        sys.argv = [sys.argv[0]]
        args = hypergrid_train.parse_args()
    finally:
        sys.argv = argv_backup
    args.compile = use_compile
    args.compile_mode = compile_mode
    args.chunk_size = chunk_size
    args.n_iterations = n_iterations
    args.warmup_iters = warmup_iters
    args.seed = seed
    args.device = device
    args.benchmark = True
    args.use_vmap = False
    args.loss = "TB"
    args.batch_size = 16
    args.height = 32
    args.ndim = 2

    result = hypergrid_train.train_with_options(
        args,
        device=hypergrid_train.resolve_device(device),
        enable_compile=use_compile,
        use_vmap=False,
        warmup_iters=warmup_iters,
        quiet=True,
        timing=True,
        record_history=True,
        use_chunk=(chunk_size > 0),
    )
    result["label"] = label
    result["compile_mode"] = compile_mode if use_compile else "none"
    return result


scenarios = [
    dict(label="Eager", use_compile=False),
    dict(label="Compiled", use_compile=True, compile_mode="reduce-overhead"),
]

benchmark_runs = []
for scenario in scenarios:
    run_result = notebook_benchmark_run(**scenario)
    benchmark_runs.append(run_result)

benchmark_runs


calculated tensor of all states in 0.0009723345438639323 minutes
+ Environment has 1024 states
+ Environment log partition is 5.711750507354736




calculated tensor of all states in 0.0007189313570658366 minutes
+ Environment has 1024 states
+ Environment log partition is 5.711750507354736


[{'elapsed': 6.105575958034024,
  'losses': [7.572955131530762,
   2.491760015487671,
   3.2979516983032227,
   3.0194754600524902,
   0.9618801474571228,
   0.8673862218856812,
   1.6195870637893677,
   0.5269477367401123,
   1.0297844409942627,
   1.332466959953308,
   0.6973802447319031,
   1.6610175371170044,
   0.47196799516677856,
   1.980200171470642,
   2.484879970550537,
   1.153307557106018,
   0.5622124671936035,
   0.7249877452850342,
   1.2468236684799194,
   1.9157636165618896,
   1.5802578926086426,
   0.9950146675109863,
   0.9827088713645935,
   0.9594094753265381,
   2.0273141860961914,
   1.0678741931915283,
   1.7654989957809448,
   1.8363938331604004,
   0.5704580545425415,
   2.0948450565338135,
   0.8548241853713989,
   4.518639087677002,
   1.0827535390853882,
   1.2317500114440918,
   0.6395683288574219,
   1.3933279514312744,
   1.7131190299987793,
   1.1856663227081299,
   1.428055763244629,
   0.8084158897399902,
   0.37907153367996216,
   1.583935260772705,

In [6]:
import pandas as pd

benchmark_df = pd.DataFrame(benchmark_runs)
display(
    benchmark_df[
        [
            "label",
            "elapsed",
            "compile_mode",
            "effective_vmap",
            "chunk_size_effective",
        ]
    ]
)

baseline = benchmark_df.iloc[0]
print("Baseline label:", baseline["label"], "elapsed:", f"{baseline['elapsed']:.2f}s")
for idx in range(1, len(benchmark_df)):
    row = benchmark_df.iloc[idx]
    speedup = baseline["elapsed"] / row["elapsed"] if row["elapsed"] else float("inf")
    print(
        f"{row['label']} elapsed={row['elapsed']:.2f}s "
        f"({speedup:.2f}x vs baseline), compile_mode={row['compile_mode']}"
    )


Unnamed: 0,label,elapsed,compile_mode,effective_vmap,chunk_size_effective
0,Eager,6.105576,none,False,0
1,Compiled,6.345466,reduce-overhead,False,0


Baseline label: Eager elapsed: 6.11s
Compiled elapsed=6.35s (0.96x vs baseline), compile_mode=reduce-overhead


## Dynamo trace analysis

`torch._dynamo.explain` gives a per-graph summary: captured ops, guards, and where graph breaks (if any) occur. The cell below reuses the state/action tensors above and prints the explanation so you can confirm there is only one graph and zero breaks.


In [2]:
import torch._dynamo as dynamo

explanation = dynamo.explain(step_once)(states.tensor, actions.tensor)
print(explanation)


Graph Count: 16
Graph Break Count: 15
Op Count: 24
Break Reasons:
  Break Reason 1:
    Reason: Dynamic shape operator
  Explanation: Operator `aten.repeat_interleave.Tensor`'s output shape depends on input Tensor data.
  Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`

  Developer debug context: aten.repeat_interleave.Tensor

    User Stack:
      <FrameSummary file /var/folders/hd/jqxc7ns56l35q_zyk7xmptmw0000gn/T/ipykernel_75536/351529152.py, line 18 in step_once>
      <FrameSummary file /Users/jdv/code/torchgfn/src/gfn/env.py, line 680 in _step>
      <FrameSummary file /Users/jdv/code/torchgfn/src/gfn/env.py, line 306 in _step>
      <FrameSummary file /Users/jdv/code/torchgfn/src/gfn/actions.py, line 149 in __getitem__>
      <FrameSummary file /Users/jdv/code/torchgfn/src/gfn/actions.py, line 188 in _mask_select>
      <FrameSummary file /Users/jdv/code/torchgfn/src/gfn/utils/mask_select.py, line 66 in boolean_m