# Run/Tumble TD with runner (same simulation)

Companion to the manual run_tumble_td demo. Uses runner.stream to step the exact same environment and policy, and plots the resulting trajectory and action statistics.

In [None]:
# Path and plotting setup
import sys, pathlib

repo_root = pathlib.Path.cwd()
if not (repo_root / "src").exists():
    repo_root = repo_root.parent
sys.path.insert(0, str(repo_root / "src"))

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib_inline.backend_inline import set_matplotlib_formats

mpl.use("module://matplotlib_inline.backend_inline")
set_matplotlib_formats("png")
from IPython.display import display

import plume_nav_sim as pns
from plume_nav_sim.policies.run_tumble_td import RunTumbleTemporalDerivativePolicy
from plume_nav_sim.runner import runner as r
import seaborn as sns

In [None]:
# Same configuration as the manual demo
grid_size = (64, 64)
source_location = (48, 48)
start_location = (16, 16)
max_steps = 500
seed = 123

env = pns.make_env(
    grid_size=grid_size,
    source_location=source_location,
    start_location=start_location,
    max_steps=max_steps,
    action_type="run_tumble",
    observation_type="concentration",
    reward_type="step_penalty",
    render_mode="rgb_array",
)

policy = RunTumbleTemporalDerivativePolicy(threshold=1e-6, eps_seed=seed)

# Tabular event collection
import pandas as pd


def stream_to_df(env, policy, *, seed: int, render: bool = True):
    rows = []
    final_frame = None
    for ev in r.stream(env, policy, seed=seed, render=render):
        c = float(ev.obs[0])
        x = y = None
        if isinstance(ev.info, dict) and "agent_xy" in ev.info:
            x, y = ev.info["agent_xy"]
        rows.append(
            {
                "t": int(ev.t),
                "action": int(ev.action),
                "c": c,
                "reward": float(ev.reward),
                "terminated": bool(ev.terminated),
                "truncated": bool(ev.truncated),
                "x": x,
                "y": y,
            }
        )
        if ev.frame is not None:
            final_frame = ev.frame
        if ev.terminated or ev.truncated:
            break
    df = pd.DataFrame(rows).sort_values("t").reset_index(drop=True)
    df["dc"] = df["c"].diff().fillna(0.0)
    df["action_label"] = (
        df["action"].map({0: "RUN", 1: "TUMBLE"}).fillna(df["action"].astype(str))
    )
    return df, final_frame


df, final_frame = stream_to_df(env, policy, seed=seed, render=True)
print(df.head())
print("Steps:", len(df), "Total reward:", df["reward"].sum())

In [None]:
# Trajectory overlay with seaborn
frame = final_frame if isinstance(final_frame, np.ndarray) else env.render("rgb_array")
grid_w, grid_h = grid_size
sx, sy = source_location
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(frame)
ax.set_xlim(0, grid_w)
ax.set_ylim(grid_h, 0)
path = df.dropna(subset=["x", "y"])
if not path.empty:
    sns.lineplot(
        data=path.sort_values("t"),
        x="x",
        y="y",
        ax=ax,
        color="yellow",
        marker="o",
        markersize=3,
        linewidth=1,
        sort=False,  # we want to keep the temporal order of the trajectory
        estimator=None,
    )
    ax.scatter(
        [path["x"].iloc[0]],
        [path["y"].iloc[0]],
        c="lime",
        s=36,
        marker="^",
        label="start",
    )
    ax.scatter(
        [path["x"].iloc[-1]], [path["y"].iloc[-1]], c="magenta", s=30, label="end"
    )
ax.scatter(
    [sx],
    [sy],
    marker="s",
    s=60,
    facecolors="none",
    edgecolors="red",
    linewidths=1.5,
    label="source",
)
ax.legend(loc="upper right")
ax.set_title("Run/Tumble TD with runner: final frame with trajectory")
display(fig)
plt.close(fig)

In [None]:
# Action counts and time series (seaborn)
run_count = int((df["action"] == 0).sum())
tumble_count = int((df["action"] == 1).sum())
print("RUN:", run_count, "TUMBLE:", tumble_count)

fig, ax = plt.subplots(1, 2, figsize=(10, 3))
sns.countplot(x="action_label", data=df, ax=ax[0], palette=["#4caf50", "#ff9800"])
ax[0].set_title("Action counts")
sns.lineplot(x="t", y="c", data=df, ax=ax[1], label="c")
sns.lineplot(x="t", y="dc", data=df, ax=ax[1], label="dc")
ax[1].set_title("Concentration and dC")
ax[1].legend()
plt.tight_layout()
display(fig)
plt.close(fig)
env.close()