# TorchRL PPO Action Probing

This notebook demonstrates how to set up a TorchRL PPO agent and use tdhook to probe action representations.

## Setup

In [1]:
import importlib.util

DEV = True

if importlib.util.find_spec("google.colab") is not None:
    MODE = "colab-dev" if DEV else "colab"
else:
    MODE = "local"

In [2]:
if MODE == "colab":
    %pip install -q tdhook torchrl
elif MODE == "colab-dev":
    !rm -rf tdhook
    !git clone https://github.com/Xmaster6y/tdhook -b main
    %pip install -q ./tdhook torchrl

## Imports

In [3]:
import torch
from torch import nn
from torchrl.envs import TransformedEnv, Compose, DoubleToFloat, StepCounter
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import MLP, ProbabilisticActor, NormalParamExtractor, TanhNormal
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from tensordict.nn import TensorDictModule
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score

from tdhook.latent.probing import Probing, ProbeManager

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Create Environment

In [4]:
env_name = "InvertedDoublePendulum-v4"
base_env = GymEnv(env_name)
env = TransformedEnv(
    base_env,
    Compose(
        DoubleToFloat(),
        StepCounter(),
    ),
)
print(f"Observation space: {env.observation_spec}")
print(f"Action space: {env.action_spec}")

Observation space: Composite(
    observation: UnboundedContinuous(
        shape=torch.Size([11]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, contiguous=True),
            high=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, contiguous=True)),
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    step_count: BoundedDiscrete(
        shape=torch.Size([1]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True),
            high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),
        device=cpu,
        dtype=torch.int64,
        domain=discrete),
    device=None,
    shape=torch.Size([]),
    data_cls=None)
Action space: BoundedContinuous(
    shape=torch.Size([1]),
    space=ContinuousBox(
        low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)

## Create Actor and Critic Networks

In [5]:
hidden_size = 32
num_cells = 6

# Actor network
actor_module = TensorDictModule(
    nn.Sequential(
        MLP(
            in_features=env.observation_spec["observation"].shape[-1],
            out_features=2 * env.action_spec.shape[-1],
            num_cells=[hidden_size] * num_cells,
        ),
        NormalParamExtractor(),
    ),
    in_keys=["observation"],
    out_keys=["loc", "scale"],
)

actor = ProbabilisticActor(
    module=actor_module,
    spec=env.action_spec,
    in_keys=["loc", "scale"],
    distribution_class=TanhNormal,
    distribution_kwargs={
        "low": env.action_spec.space.low,
        "high": env.action_spec.space.high,
    },
    return_log_prob=True,
)

# Critic network
critic = TensorDictModule(
    MLP(
        in_features=env.observation_spec["observation"].shape[-1],
        out_features=1,
        num_cells=[hidden_size] * num_cells,
    ),
    in_keys=["observation"],
    out_keys=["state_value"],
)

print("Actor network created")
print("Critic network created")

Actor network created
Critic network created


## Create PPO Loss Module

In [6]:
advantage_module = GAE(
    gamma=0.99,
    lmbda=0.95,
    value_network=critic,
    average_gae=True,
    device=device,
    deactivate_vmap=True,
)

loss_module = ClipPPOLoss(
    actor_network=actor,
    critic_network=critic,
    clip_epsilon=0.2,
    entropy_bonus=True,
    entropy_coeff=1e-4,
    critic_coeff=1.0,
    loss_critic_type="smooth_l1",
    functional=False,
)

print("PPO loss module created")

PPO loss module created


## Collect Sample Data

In [7]:
from torchrl.collectors import SyncDataCollector

collector = SyncDataCollector(
    env,
    actor,
    frames_per_batch=100,
    total_frames=100,
)

# Collect a batch of data
batch = next(iter(collector))
advantage_module(batch)

print(f"Batch shape: {batch.shape}")
print(f"Batch keys: {batch.keys(True, True)}")

Batch shape: torch.Size([100])
Batch keys: _TensorDictKeysView(['step_count', 'action', 'observation', 'done', 'terminated', 'truncated', ('next', 'observation'), ('next', 'step_count'), ('next', 'reward'), ('next', 'done'), ('next', 'terminated'), ('next', 'truncated'), ('next', 'state_value'), ('collector', 'traj_ids'), 'loc', 'scale', 'action_log_prob', 'state_value', 'advantage', 'value_target'],
    include_nested=True,
    leaves_only=True)


## Set Up Action Probing

In [8]:
# Split data into train and test
indices = torch.randperm(batch.numel())
split_idx = int(0.8 * batch.numel())
train_indices, test_indices = indices[:split_idx], indices[split_idx:]
train_batch = batch[train_indices]
test_batch = batch[test_indices]

# Create probe manager
probe_manager = ProbeManager(
    LinearRegression,
    {},
    lambda preds, labels: {"r2": r2_score(labels, preds)},
)

## Run Probing on Actor and Critic Layers

In [9]:
# Hook into actor and critic layers to probe action representations
with Probing(
    "td_module.(critic_network.module.\d+|actor_network.module.0.module.0.\d+)",
    probe_manager.probe_factory,
    additional_keys=["labels", "step_type"],
    relative=False,
).prepare(loss_module) as hooked_module:
    # Fit probes on training data
    train_batch["labels"] = train_batch["action"]
    train_batch["step_type"] = "fit"
    hooked_module(train_batch)

    # Evaluate probes on test data
    test_batch["labels"] = test_batch["action"]
    test_batch["step_type"] = "predict"
    hooked_module(test_batch)

## Display Results

In [10]:
print("Training R² scores:")
for key, value in probe_manager.fit_metrics.items():
    print(f"  {key}: {value['r2']:.3f}")

print("\nTest R² scores:")
for key, value in probe_manager.predict_metrics.items():
    print(f"  {key}: {value['r2']:.3f}")

Training R² scores:
  td_module.actor_network.module.0.module.0.0_fwd: 0.057
  td_module.actor_network.module.0.module.0.1_fwd: 0.457
  td_module.actor_network.module.0.module.0.2_fwd: 0.457
  td_module.actor_network.module.0.module.0.3_fwd: 0.424
  td_module.actor_network.module.0.module.0.4_fwd: 0.424
  td_module.actor_network.module.0.module.0.5_fwd: 0.392
  td_module.actor_network.module.0.module.0.6_fwd: 0.391
  td_module.actor_network.module.0.module.0.7_fwd: 0.475
  td_module.actor_network.module.0.module.0.8_fwd: 0.429
  td_module.actor_network.module.0.module.0.9_fwd: 0.399
  td_module.actor_network.module.0.module.0.10_fwd: 0.376
  td_module.actor_network.module.0.module.0.11_fwd: 0.377
  td_module.actor_network.module.0.module.0.12_fwd: 0.003
  td_module.critic_network.module.0_fwd: 0.057
  td_module.critic_network.module.1_fwd: 0.380
  td_module.critic_network.module.2_fwd: 0.380
  td_module.critic_network.module.3_fwd: 0.315
  td_module.critic_network.module.4_fwd: 0.314
 

**Note:** The R² scores shown above are expected to be poor (often negative) because the model is not trained. The actor and critic networks are initialized with random weights, so their internal representations do not yet encode meaningful information about actions. After training the PPO agent, you would expect to see higher R² scores, indicating that the network layers learn to represent action-relevant information.