# **Tutorial 2: Decision Diffuser (DD) for D4RL-MuJoCo**

## 1. Introduction

In this tutorial, we’ll implement a minimal Decision Diffuser (DD) using CleanDiffuser. DD is a planning-based diffusion RL algorithm that uses classifier-free guidance (CFG) to generate high-performance decision trajectories. We’ll be using the D4RL-MuJoCo dataset for both training and evaluation. Along the way, we’ll dive into CFG and explore how to customize CFG models for specific tasks. Let’s start with an overview of how CFG works.

### 1.1 Classifier-free Guidance (CFG)

In a conditional generation task, we aim to sample from a conditional distribution $q_0(\bm x|\bm y)$. The score function can be written as:

$$
\nabla_{\bm x}\log q_t(\bm x_t|\bm y) = \nabla_{\bm x}\log q_t(\bm x_t) + \nabla_{\bm x}\log q_t(\bm y|\bm x_t),
$$

where the first term on the right is the score function for the unconditional distribution $q_t(\bm x_t)$, which we can estimate by training an unconditional diffusion model. The second term, which is the guidance term, is what needs to be estimated for CFG. CFG simplifies this by expressing the guidance term as:

$$
\nabla_{\bm x}\log q_t(\bm y|\bm x_t) = \nabla_{\bm x}\log q_t(\bm x_t|\bm y) - \nabla_{\bm x}\log q_t(\bm x_t).
$$

By training a conditional noise prediction model $\bm\epsilon_\theta(\bm x_t, t, \bm y)$, we can guide the sampling process without needing an additional classifier:

$$
\bar{\bm\epsilon_\theta}(\bm x_t, t, \bm y) = \bm\epsilon_\theta(\bm x_t, t) - w \cdot (\bm\epsilon_\theta(\bm x_t, t, \bm y) - \bm\epsilon_\theta(\bm x_t, t)),
$$

where $w$ represents the strength of the guidance. In practice, we use a dummy condition $\bm y = \bm\Phi$ for unconditional generation, meaning $\bm\epsilon_\theta(\bm x_t, t, \bm\Phi) = \bm\epsilon_\theta(\bm x_t, t)$.

In decision-making tasks, the condition $\bm y$ can represent highly complex, multimodal data like image-based observations, language instructions, or point clouds. Some implementations even use large transformers for multimodal fusion while utilizing smaller MLPs as the diffusion neural network backbone. In CleanDiffuser, we’ve decoupled the neural networks for diffusion $\bm\epsilon_\theta$ and conditions $\bm\zeta_\phi$ to facilitate development and debugging. Conditional diffusion models in CleanDiffuser are implemented as $\epsilon_\theta(\bm x_t, t, \bm\zeta_\phi(\bm y))$, with a dummy condition $\bm\zeta_\phi(\bm\Phi) = \bm 0$. This is why in Tutorial 1, we used both a `NNDiffusion` and a `NNCondition` to build the diffusion model. The `NNDiffusion` corresponds to $\bm\epsilon_\theta$, and the `NNCondition` corresponds to $\bm\zeta_\phi$. If the condition is simple, we can use an `IdentityCondition` to pass it directly to `NNDiffusion`.

### 1.2 Diffusion Planners

DD is a diffusion planner that uses CFG to generate high-performance decision trajectories. The core idea is to generate high-quality decision trajectories and extract the first action to execute—much like MPC (Model Predictive Control) and other planning-based model-based RL algorithms. While those approaches rely on search methods and dynamic models to find optimal trajectories, diffusion planners achieve this through conditional generation.

To guide this generation process, we use a “high-performance” variable as the condition. One straightforward approach is to use the discounted return-to-go of the trajectory, $\sum_{s=t}^T \gamma^{s-t} r_s$, as the condition. This is a Monte Carlo estimation of the trajectory’s value. During training, we normalize these values to the range [0, 1], where 1 represents the highest performance. During inference, we use relatively high normalized values (e.g., 0.8-1.0) as conditions to generate high-performance trajectories. For more details, you can refer to [Diffuser](https://arxiv.org/abs/2205.09991) and [Decision Diffuser](https://arxiv.org/abs/2211.15657).

## 2. Setting up the Environment and Dataset

For this tutorial, we’ll use the `halfcheetah-medium-v2` environment from D4RL-MuJoCo as an example. D4RL-MuJoCo is a widely used offline RL benchmark, and the `halfcheetah-medium-v2` task requires controlling a halfcheetah robot to move as fast as possible. CleanDiffuser provides a simple interface to load D4RL datasets.

In [None]:
import minari

dataset_minari = minari.load_dataset('mujoco/halfcheetah/medium-v0')
env  = dataset_minari.recover_environment()

FileNotFoundError: Dataset mujoco/halfcheetah/medium-v2 not found locally at /home/dynias/.minari/datasets/mujoco/halfcheetah/medium-v2. Use download=True to download the dataset.

In [12]:
import torch
from torch.utils.data import Dataset
import numpy as np
from typing import Dict, List, Tuple, Optional
from minari import MinariDataset


class GaussianNormalizer:
    def __init__(self, data: np.ndarray):
        self.mean = data.mean(axis=0)
        self.std = data.std(axis=0) + 1e-8

    def normalize(self, x: np.ndarray) -> np.ndarray:
        return (x - self.mean) / self.std

    def unnormalize(self, x: np.ndarray) -> np.ndarray:
        return x * self.std + self.mean


class MinariSequenceDataset(Dataset):
    """
    Converts Minari episodic dataset into fixed-length sequential samples for training.

    Args:
        minari_dataset (MinariDataset): Minari dataset instance.
        terminal_penalty (float): Penalty applied for terminal states (if not timeout).
        horizon (int): Length of each sequence.
        discount (float): Discount factor for computing Monte Carlo returns.
    """
    def __init__(
        self,
        minari_dataset: MinariDataset,
        terminal_penalty: float = -100.,
        horizon: int = 32,
        discount: float = 0.99,
    ):
        self.horizon = horizon
        self.discount = discount
        self.terminal_penalty = terminal_penalty

        # Step 1: Normalize all observations
        all_obs = np.concatenate([
            ep.observations for ep in minari_dataset.iterate_episodes()
        ], axis=0)
        self.normalizer = GaussianNormalizer(all_obs)

        self.sequences: List[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]] = []

        # Step 2: Iterate episodes and generate fixed-length sequence samples
        for ep in minari_dataset.iterate_episodes():
            obs = np.asarray(ep.observations, dtype=np.float32)
            act = np.asarray(ep.actions, dtype=np.float32)
            rew = np.asarray(ep.rewards, dtype=np.float32)

            # Ensure they are the same length
            min_len = min(len(obs), len(act), len(rew))
            if min_len < horizon:
                continue

            obs = obs[:min_len]
            act = act[:min_len]
            rew = rew[:min_len]
            terms = np.asarray(ep.terminations[:min_len], dtype=bool)
            truncs = np.asarray(ep.truncations[:min_len], dtype=bool)

            obs = self.normalizer.normalize(obs)

            # apply terminal penalty
            if terms[-1] and not truncs[-1]:
                rew[-1] = terminal_penalty

            # Compute returns
            returns = np.zeros_like(rew)
            returns[-1] = rew[-1]
            for t in reversed(range(len(rew) - 1)):
                returns[t] = rew[t] + discount * returns[t + 1]

            for start in range(0, min_len - horizon + 1):
                end = start + horizon
                self.sequences.append((
                    obs[start:end],
                    act[start:end],
                    rew[start:end][:, None],
                    returns[start][None]
                ))


    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx: int):
        obs_seq, act_seq, rew_seq, val = self.sequences[idx]
        return {
            'obs': {'state': torch.tensor(obs_seq, dtype=torch.float32)},
            'act': torch.tensor(act_seq, dtype=torch.float32),
            'rew': torch.tensor(rew_seq, dtype=torch.float32),
            'val': torch.tensor(val, dtype=torch.float32),
        }

    def get_normalizer(self):
        return self.normalizer


In [13]:
# from cleandiffuser.dataset.d4rl_mujoco_dataset import D4RLMuJoCoDataset

# horizon=4 is enough for halfcheetah tasks as suggested in Diffuser paper.
horizon = 4
dataset = MinariSequenceDataset(dataset_minari, terminal_penalty=-100, horizon=horizon)
obs_dim, act_dim = env.observation_space.shape[0], env.action_space.shape[0]
print(f"Observation Dimension: {obs_dim}, Action Dimension: {act_dim}")

Observation Dimension: 17, Action Dimension: 6


## 3. Building the Diffusion Model

Unlike in Tutorial 1, here we need the diffusion model to generate decision trajectories, which look like this:

$$
\bm\tau = \left[
\begin{aligned}
&\bm s_0, \bm s_1, \dots, \bm s_{H-1} \\
&\bm a_0, \bm a_1, \dots, \bm a_{H-1}
\end{aligned}
\right],
$$

where $\bm s_t$ is the state at time $t$, $\bm a_t$ is the action at time $t$, and $H$ is the horizon. The trajectory $\bm\tau$ has the shape (H, obs_dim + act_dim). For this, we need a neural network backbone designed to generate sequences. We’ll use `DiT1d`, a modified version of DiT for 1D sequences. `DiT1d` expects the condition as a tensor of shape (batch_size, embed_dim), so we use an MLP `NNCondition` to map the scalar condition (the trajectory’s value) to a tensor of the required shape.

> **Note:** The official DD generates only state trajectories and uses an inverse dynamics model $\bm a_t = \mathcal{I}(\bm s_t, \bm s_{t+1})$ to extract the action. In this tutorial, we skip the inverse dynamics model and directly use state-action trajectories.


In [14]:
import torch

from cleandiffuser.diffusion import ContinuousDiffusionSDE
from cleandiffuser.nn_condition import MLPCondition
from cleandiffuser.nn_diffusion import DiT1d

# Neural network backbones
nn_diffusion = DiT1d(
    x_dim=obs_dim + act_dim, emb_dim=128, d_model=320, n_heads=10, depth=2, timestep_emb_type="untrainable_fourier",
    x_seq_len=horizon
)
nn_condition = MLPCondition(in_dim=1, out_dim=128, hidden_dims=128, dropout=0.25)

# Mask
fix_mask = torch.zeros((horizon, obs_dim + act_dim))
fix_mask[0, :obs_dim] = 1.0
loss_weight = torch.ones((horizon, obs_dim + act_dim))
loss_weight[0, obs_dim:] = 10.0

planner = ContinuousDiffusionSDE(
    nn_diffusion,
    nn_condition,
    fix_mask=fix_mask,
    loss_weight=loss_weight,
)

We also define a `fix_mask` and `loss_weight`. The `fix_mask` is a binary tensor of the same shape as the generated data, where 1 indicates fixed data (known, such as the current observation), and 0 indicates data to be generated. The `loss_weight` assigns different weights to different parts of the data, with more weight on the first action (as this is the action we care most about).

## 4. Training the Diffusion Model

In [None]:
import numpy as np
import pytorch_lightning as L
from pytorch_lightning.callbacks import ModelCheckpoint


class StateActionSequenceWrapper(torch.utils.data.Dataset):
    def __init__(self, dataset: torch.utils.data.Dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getattr__(self, name):
        return getattr(self.dataset, name)

    def __getitem__(self, idx):
        batch = self.dataset[idx]

        obs = batch["obs"]["state"]  # shape: (horizon, obs_dim)
        act = batch["act"]           # shape: (horizon, act_dim)
        val = batch["val"] / 580.0   # normalize scalar return

        return {
            "x0": torch.cat([obs, act], dim=-1),
            "condition_cfg": val,
        }


save_path = "results/tutorial2_dd_for_d4rl_mujoco/"

dataloader = torch.utils.data.DataLoader(
    StateActionSequenceWrapper(dataset), batch_size=512, shuffle=True, num_workers=4, persistent_workers=True
)

callback = ModelCheckpoint(dirpath=save_path, filename="dd-{step}", every_n_train_steps=10_000)

trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    max_steps=500_000,
    deterministic=True,
    log_every_n_steps=200,
    default_root_dir=save_path,
    callbacks=[callback],
)

trainer.fit(planner, dataloader)

## 5. Evaluation

Inference with diffusion planners differs slightly from diffusion policies. First, we replace the first state in the prior with the current observation. Then, we generate a trajectory conditioned on a high-performance score. Finally, we extract and execute the first action. Here, we evaluate the trained model in the `halfcheetah-medium-v2` environment. We set the target score condition to `0.95` and the guidance strength to `15.0`.

In [16]:
from gymnasium.vector import SyncVectorEnv  
from gymnasium.wrappers.vector import NormalizeReward  

n_seeds = 3

# device for evaluation
device = "cuda:0"

# loading from checkpoint
planner.load_state_dict(
    torch.load("results/tutorial2_dd_for_d4rl_mujoco/dd-step=500000.ckpt", map_location=device)["state_dict"]
)
planner.to(device).eval()

# evaluating
# env_eval = gym.vector.make("halfcheetah-medium-v2", num_envs=50)
env_eval = SyncVectorEnv([  
    lambda: dataset_minari.recover_environment(eval_env=True)   
    for _ in range(50)  
])
# env_eval = NormalizeReward(env_eval)   

dataset = MinariSequenceDataset(dataset_minari, terminal_penalty=-100, horizon=horizon)
normalizer = dataset.get_normalizer()
condition = torch.full((50, 1), 0.95, device=device)
prior = torch.zeros((50, horizon, obs_dim + act_dim), device=device)
scores = []
for _ in range(n_seeds):
    
    reset_output = env_eval.reset()
    obs = reset_output[0] if isinstance(reset_output, tuple) else reset_output
    all_done, ep_rew, t = False, 0.0, 0

    while not np.all(all_done):
        obs_normalized = normalizer.normalize(obs.astype(np.float32))
        obs_tensor = torch.tensor(obs_normalized, device=device, dtype=torch.float32)
        prior[:, 0, :obs_dim] = obs_tensor

        traj, log = planner.sample(
            prior,
            solver="ddpm",
            sample_steps=5,
            sampling_schedule="uniform_logsnr",
            condition_cfg=condition,
            w_cfg=15,
            use_ema=True,
            temperature=0.5,
        )
        act = traj[:, 0, obs_dim:].clip(-1.0, 1.0).cpu().numpy()

        step_result = env_eval.step(act)
        if len(step_result) == 5:  # newer gymnasium format: obs, rew, terminated, truncated, info
            obs, rew, terminated, truncated, info = step_result
            done = np.logical_or(terminated, truncated)
        else:  # older format: obs, rew, done, info
            obs, rew, done, info = step_result

        ep_rew += rew  
        t += 1  
        all_done = np.logical_or(all_done, done)  
  
        print(f"[t={t}] rew: {rew}")  
    
    def normalize_d4rl_score(raw_score, env_name="halfcheetah"):  
        # These are approximate D4RL normalization constants  
        if "halfcheetah" in env_name.lower():  
            min_score = -280.178953  
            max_score = 12135.0  
        else:  
            # Add other environment constants as needed  
            min_score = 0  
            max_score = 1000  
        
        return 100 * (raw_score - min_score) / (max_score - min_score)  
  
    # In your evaluation loop:  
    raw_scores = ep_rew.mean()  
    normalized_score = normalize_d4rl_score(raw_scores, "halfcheetah")  
    scores.append(normalized_score)
print(f"D4RL score: {np.mean(scores)}+-{np.std(scores)}")

  torch.load("results/tutorial2_dd_for_d4rl_mujoco/dd-step=500000.ckpt", map_location=device)["state_dict"]


[t=1] rew: [-0.21253028 -0.39852241 -0.36648301 -0.5456739  -0.24516123 -0.33492638
 -0.44609571 -0.53283467  0.67221767 -0.29955345  0.51035066 -0.50173723
 -0.38128009  0.36581636  0.17489653 -0.40667822  0.04608589 -0.10764742
  0.41612897 -0.02651359 -0.19887308 -0.00944009  0.13406474  0.03207815
  0.40028082  0.22870271 -0.37958524  0.55479671 -0.58409227 -0.32595426
 -0.26836398  0.56897551 -0.50448955  0.18582013  0.3651735  -0.25770912
 -0.08680328  0.04176419 -0.1584922   0.48905725 -0.34877883  0.03569675
 -0.13712698 -0.40692688  0.12922699 -0.11859857  0.31739761 -0.55702289
 -0.09784548 -0.06052509]
[t=2] rew: [-0.04729451  0.33256014  0.1443133   0.15069743  0.02377837  0.09728272
  0.11591884 -0.18349802  0.83766016  0.1534757   0.49214921  0.12471432
 -0.01653687  0.94401701  0.51809849 -0.49977672 -0.1861461  -0.37761195
 -0.15294804  0.30868511  0.14922189 -0.43154054  0.52332042  0.5049627
  0.83428606  0.42446247  0.29823877  0.59937904 -0.88695318  0.0386863
  0.3

The results are promising! Despite not using the inverse dynamics model like the official DD, the performance remains competitive compared to other popular offline RL algorithms (see the table below).

||BC|CQL|IQL|DT|TT|Diffuser|DD (Official)|DD (Tutorial 2)|
|---|--|--|---|--|--|--------|-------------|--------------|
|HalfCheetah-Medium-v2|42.6|44.0|47.4|42.6|46.9|44.2|49.1+-1.0|48.0+-0.3|

In [18]:
# evaluating - single environment instance
env_eval = dataset_minari.recover_environment(eval_env=True, render_mode="rgb_array")

dataset = MinariSequenceDataset(dataset_minari, terminal_penalty=-100, horizon=horizon)
normalizer = dataset.get_normalizer()
condition = torch.full((1, 1), 0.95, device=device)  # Changed from (50, 1) to (1, 1)
prior = torch.zeros(
    (1, horizon, obs_dim + act_dim), device=device
)  # Changed from (50, ...) to (1, ...)
scores = []
episode_frames = []
video_length = 24*10  # Assuming you want a 10-second video at 24 FPS

for _ in range(n_seeds):
    reset_output = env_eval.reset()
    obs = reset_output[0] if isinstance(reset_output, tuple) else reset_output
    done, ep_rew, t = False, 0.0, 0  # Changed from all_done to done

    while not done:  # Changed from np.all(all_done) to not done
        obs_normalized = normalizer.normalize(obs.astype(np.float32))
        obs_tensor = torch.tensor(obs_normalized, device=device, dtype=torch.float32).unsqueeze(
            0
        )  # Add batch dimension
        prior[:, 0, :obs_dim] = obs_tensor

        traj, log = planner.sample(
            prior,
            solver="ddpm",
            sample_steps=5,
            sampling_schedule="uniform_logsnr",
            condition_cfg=condition,
            w_cfg=15,
            use_ema=True,
            temperature=0.5,
        )
        act = (
            traj[0, 0, obs_dim:].clip(-1.0, 1.0).cpu().numpy()
        )  # Remove batch dimension from action

        step_result = env_eval.step(act)
        if len(step_result) == 5:  # newer gymnasium format: obs, rew, terminated, truncated, info
            obs, rew, terminated, truncated, info = step_result
            done = terminated or truncated  # Changed from np.logical_or to simple or
        else:  # older format: obs, rew, done, info
            obs, rew, done, info = step_result

        ep_rew += rew
        t += 1

        print(f"[t={t}] rew: {rew}")
        frame = env_eval.render()  # Direct render call instead of env_eval.call("render")
        print(f"Frame shape: {frame.shape if frame is not None else 'No frame'}")
        if frame is not None:
            episode_frames.append(frame)  # Append single frame instead of extending
        if len(episode_frames) >= video_length:
            break

    def normalize_d4rl_score(raw_score, env_name="halfcheetah"):
        # These are approximate D4RL normalization constants
        if "halfcheetah" in env_name.lower():
            min_score = -280.178953
            max_score = 12135.0
        else:
            # Add other environment constants as needed
            min_score = 0
            max_score = 1000

        return 100 * (raw_score - min_score) / (max_score - min_score)

    # In your evaluation loop:
    normalized_score = normalize_d4rl_score(ep_rew, "halfcheetah")  # No .mean() needed
    scores.append(normalized_score)

env_eval.close()  # Close the single environment
print(f"D4RL score: {np.mean(scores)}+-{np.std(scores)}")


/home/dynias/CleanDiffuser-lightning/.venv/lib/python3.10/site-packages/glfw/__init__.py:917: (65537) b'The GLFW library is not initialized'


[t=1] rew: 0.07732871503440616
Frame shape: (480, 480, 3)
[t=2] rew: -0.5504078824765595
Frame shape: (480, 480, 3)
[t=3] rew: 0.08109215332702241
Frame shape: (480, 480, 3)
[t=4] rew: 1.1102365511078
Frame shape: (480, 480, 3)
[t=5] rew: 1.3096589142577484
Frame shape: (480, 480, 3)
[t=6] rew: 0.7485695732617019
Frame shape: (480, 480, 3)
[t=7] rew: 0.1451863631058078
Frame shape: (480, 480, 3)
[t=8] rew: 0.7794615877737493
Frame shape: (480, 480, 3)
[t=9] rew: 1.9119315813990287
Frame shape: (480, 480, 3)
[t=10] rew: 2.484796453831925
Frame shape: (480, 480, 3)
[t=11] rew: 2.3343213748498925
Frame shape: (480, 480, 3)
[t=12] rew: 1.0452810017739842
Frame shape: (480, 480, 3)
[t=13] rew: 0.8320379384173088
Frame shape: (480, 480, 3)
[t=14] rew: 2.352762653658005
Frame shape: (480, 480, 3)
[t=15] rew: 3.229851586749012
Frame shape: (480, 480, 3)
[t=16] rew: 3.4058665621481117
Frame shape: (480, 480, 3)
[t=17] rew: 3.6142248653848057
Frame shape: (480, 480, 3)
[t=18] rew: 2.558146921335

In [20]:
# Enhanced frame validation with shape correction
valid_frames = []
for i, frame in enumerate(episode_frames):
    if frame is not None and isinstance(frame, np.ndarray) and frame.size > 0:
        # Handle malformed 2D frames
        if len(frame.shape) == 2:
            if frame.shape == (1280, 3):
                # This appears to be a flattened or incorrectly shaped frame
                # Skip this frame or create a placeholder
                continue
            elif frame.shape[1] == 3:
                # Try to interpret as (height*width, 3) and reshape
                total_pixels = frame.shape[0]
                # Assume square-ish aspect ratio for reshaping
                height = int(np.sqrt(total_pixels))
                width = total_pixels // height
                try:
                    frame = frame.reshape(height, width, 3)
                except:
                    print(f"Could not reshape frame {i}")
                    continue

        # Ensure proper 3D format
        if len(frame.shape) == 3 and frame.shape[2] == 3:
            # Ensure uint8 format
            if frame.dtype != np.uint8:
                if frame.max() <= 1.0:
                    frame = (frame * 255).astype(np.uint8)
                else:
                    frame = frame.astype(np.uint8)

            valid_frames.append(frame)
        else:
            print(f"Frame {i} still has invalid shape: {frame.shape}")

# Save video with fallback options
if valid_frames:
    try:
        writer = imageio.get_writer("results/dd_plan_visualization.mp4", fps=30, codec="libx264")
        for frame in valid_frames:
            writer.append_data(frame)
        writer.close()
        print(f"Video saved with {len(valid_frames)} valid frames")
    except Exception as e:
        print(f"MP4 failed: {e}")
        # Fallback to GIF
        try:
            imageio.mimsave("results/dd_plan_visualization.gif", valid_frames, fps=10)
            print(f"GIF saved with {len(valid_frames)} valid frames")
        except Exception as e2:
            print(f"Both MP4 and GIF failed: {e2}")
else:
    print("No valid frames to save - all frames were malformed")


MP4 failed: could not broadcast input array from shape (480,3) into shape (480,3,3)
GIF saved with 242 valid frames
