# **Tutorial 3. DQL (Diffusion Q-Learning) for D4RL-MuJoCo**

## 1. Introduction

In the previous two tutorials, we explored supervised learning tasks where we provided input data to the model, allowing it to start training immediately. This tutorial, however, introduces a more complex reinforcement learning task. In Diffusion Q-Learning (DQL), we alternate between updating a diffusion model and a Q-function, which makes it more challenging to handle compared to standard supervised methods. DQL is an effective diffusion-based actor-critic algorithm. Similar to TD3BC, DQL’s policy update involves two loss components: diffusion loss and Q-maximization loss. The diffusion loss serves as a behavior cloning loss:

$$
\mathcal L_{\text{policy}}(\theta) = \mathcal L_{\text{diffusion}}(\theta) - \eta \cdot Q_\phi(\bm s, \pi_\theta(\bm a | \bm s)). \tag{1}
$$

The critic is a Q-function trained via temporal difference (TD) learning:

$$
\mathcal L_{\text{critic}}(\phi) = \left[Q_\phi(\bm s, \bm a) - (r + \gamma \cdot Q_{\phi^-}(\bm s', \pi_\theta(\bm a' | \bm s')))\right]^2. \tag{2}
$$

As with any actor-critic algorithm, the diffusion model’s training influences the Q-function, and in turn, the Q-function impacts the training of the diffusion model.

## 2. Setting up the Environment and Dataset

For this tutorial, we will use the `halfcheetah-medium-v2` environment from D4RL-MuJoCo, a popular offline RL benchmark. The `halfcheetah-medium-v2` task involves controlling a halfcheetah robot to move forward as fast as possible. CleanDiffuser provides a simple interface to load the D4RL datasets.


In [1]:
import d4rl
import gym

from cleandiffuser.dataset.d4rl_mujoco_dataset import D4RLMuJoCoTDDataset

env = gym.make("halfcheetah-medium-v2")
dataset = D4RLMuJoCoTDDataset(d4rl.qlearning_dataset(env), normalize_reward=True)
obs_dim, act_dim = dataset.obs_dim, dataset.act_dim

No module named 'flow'
/home/dzb/miniforge3/envs/cleandiffuser/lib/python3.9/site-packages/glfw/__init__.py:914: GLFWError: (65544) b'X11: The DISPLAY environment variable is missing'
No module named 'carla'
pybullet build time: Nov 28 2023 23:52:03
  from .autonotebook import tqdm as notebook_tqdm
  logger.warn(
load datafile: 100%|██████████| 21/21 [00:03<00:00,  5.31it/s]


## 3. Building the Diffusion Model

In DQL, the diffusion model functions as the policy network. Like in Tutorial 1, this diffusion model must be designed to generate actions based on the input states.

In [2]:
import torch
import torch.nn as nn

from cleandiffuser.diffusion import DiscreteDiffusionSDE
from cleandiffuser.nn_diffusion import IDQLMlp
from cleandiffuser.nn_condition import MLPCondition

device = "cuda:1"

nn_diffusion = IDQLMlp(x_dim=act_dim, emb_dim=64, timestep_emb_type="positional")

# The label dropout rate is set to 0.0 as we do not use Classifier-free Guidance.
nn_condition = MLPCondition(in_dim=obs_dim, out_dim=64, hidden_dims=64, dropout=0.0)

actor = DiscreteDiffusionSDE(
    nn_diffusion,
    nn_condition,
    diffusion_steps=5,
    x_max=torch.full((act_dim,), fill_value=1.0),
    x_min=torch.full((act_dim,), fill_value=-1.0),
).to(device)

Double Q functions with LayerNorm and Mish activation.

In [3]:
from copy import deepcopy


class TwinQ(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_dim: int = 256):
        super().__init__()
        self.Q1 = nn.Sequential(
            nn.Linear(obs_dim + act_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, 1),
        )
        self.Q2 = nn.Sequential(
            nn.Linear(obs_dim + act_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, 1),
        )

    def both(self, obs, act):
        q1, q2 = self.Q1(torch.cat([obs, act], -1)), self.Q2(torch.cat([obs, act], -1))
        return q1, q2

    def forward(self, obs, act):
        return torch.min(*self.both(obs, act))


critic = TwinQ(obs_dim, act_dim, hidden_dim=256).to(device)
critic_target = deepcopy(critic).requires_grad_(False).eval().to(device)

## 4. Training the Diffusion Model in Manual-update Style

To implement the complex update steps required for DQL, we can use `update_diffusion` for manual updates. This allows for more flexible logic in the update process. By using this method, there is no need to add wrappers to the dataset to match the format required by PyTorch Lightning.

In [7]:
from cleandiffuser.utils import loop_dataloader, FreezeModules
import numpy as np
import os

save_path = "../results/tutorial3_dql_for_d4rl_mujoco/"
if not os.path.exists(save_path):
    os.makedirs(save_path)

batch_size = 2048
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True, drop_last=True
)
critic_optim = torch.optim.Adam(critic.parameters(), lr=3e-4)
actor.configure_manual_optimizers()

step = 0
log = dict.fromkeys(["bc_loss", "policy_q_loss", "critic_td_loss", "target_q"], 0.0)
prior = torch.zeros((batch_size, act_dim), device=device)

for batch in loop_dataloader(dataloader):
    obs, next_obs = batch["obs"]["state"].to(device), batch["next_obs"]["state"].to(device)
    act = batch["act"].to(device)
    rew = batch["rew"].to(device)
    tml = batch["tml"].to(device)

    # --- Critic Update ---
    actor.eval()
    critic.train()

    q1, q2 = critic.both(obs, act)

    next_act, _ = actor.sample(
        prior,
        solver="ddpm",
        n_samples=batch_size,
        sample_steps=5,
        condition_cfg=next_obs,
        w_cfg=1.0,
        requires_grad=False,
    )

    with torch.no_grad():
        target_q = torch.min(*critic_target.both(next_obs, next_act))

    target_q = (rew + (1 - tml) * 0.99 * target_q).detach()

    critic_td_loss = (q1 - target_q).pow(2).mean() + (q2 - target_q).pow(2).mean()

    critic_optim.zero_grad()
    critic_td_loss.backward()
    critic_optim.step()

    log["critic_td_loss"] += critic_td_loss.item()
    log["target_q"] += target_q.mean().item()

    # --- Actor Update ---
    actor.train()
    critic.eval()

    bc_loss = actor.loss(act, obs)

    new_act, _ = actor.sample(
        prior,
        solver="ddpm",
        n_samples=batch_size,
        sample_steps=5,
        condition_cfg=obs,
        w_cfg=1.0,
        use_ema=False,
        requires_grad=True,
    )

    with FreezeModules([critic]):
        q1_actor, q2_actor = critic.both(obs, new_act)

    if np.random.uniform() > 0.5:
        policy_q_loss = -q1_actor.mean() / q2_actor.abs().mean().detach()
    else:
        policy_q_loss = -q2_actor.mean() / q1_actor.abs().mean().detach()

    # eta=1.0 for halfcheetah-medium-v2
    actor_loss = bc_loss + policy_q_loss * 1.0

    actor.manual_optimizers["diffusion"].zero_grad()
    actor_loss.backward()
    actor.manual_optimizers["diffusion"].step()

    log["bc_loss"] += bc_loss.item()
    log["policy_q_loss"] += policy_q_loss.item()

    step += 1

    # --- EMA Update ---
    if step % 5 == 0:
        if step >= 1000:
            actor.ema_update()
        for param, target_param in zip(critic.parameters(), critic_target.parameters()):
            target_param.data.copy_(0.995 * param.data + (1 - 0.995) * target_param.data)

    if step % 1000 == 0:
        log = {k: v / 1000 for k, v in log.items()}
        print(f"[{step}] {log}")
        log = dict.fromkeys(["bc_loss", "policy_q_loss", "critic_td_loss", "target_q"], 0.0)

    if step % 50_000 == 0:
        actor.save(save_path + f"actor_step={step}.ckpt")
        torch.save(critic.state_dict(), save_path + f"critic_step={step}.ckpt")
        torch.save(critic_target.state_dict(), save_path + f"critic_target_step={step}.ckpt")

    if step >= 200_000:
        break

[1000] {'bc_loss': 0.1427274187579751, 'policy_q_loss': -0.9908457914963364, 'critic_td_loss': 5.12707413457334, 'target_q': 13.563100869894027}
[2000] {'bc_loss': 0.08299058483541012, 'policy_q_loss': -0.9975408108830452, 'critic_td_loss': 5.182138008594513, 'target_q': 43.36930799484253}
[3000] {'bc_loss': 0.06470946248620749, 'policy_q_loss': -0.9976939737200737, 'critic_td_loss': 3.4761104559898377, 'target_q': 83.71667295837402}
[4000] {'bc_loss': 0.060685421597212554, 'policy_q_loss': -0.9967594578266143, 'critic_td_loss': 2.8494967802762985, 'target_q': 89.67914269256592}
[5000] {'bc_loss': 0.058046852555125955, 'policy_q_loss': -0.996119406580925, 'critic_td_loss': 2.3522411046028138, 'target_q': 90.79142153167724}
[6000] {'bc_loss': 0.0559096395149827, 'policy_q_loss': -0.9963647210597992, 'critic_td_loss': 2.033193099975586, 'target_q': 92.29960679626465}
[7000] {'bc_loss': 0.05453023571893573, 'policy_q_loss': -0.996669750213623, 'critic_td_loss': 2.0170140857696532, 'target

## 5. Evaluation

Diffusion policies with critics often use a method like rejection sampling to determine which action to take. These algorithms generate several candidate actions based on the current state. The actions are then scored (using Q-values or another metric), and one action is randomly selected, with higher-scoring actions being more likely to be chosen.

In [6]:
import einops
import numpy as np

n_seeds = 3
save_path = "../results/tutorial3_dql_for_d4rl_mujoco/"

# Loading the trained model
actor.load(save_path + "actor_step=200000.ckpt")
actor.eval()
critic.load_state_dict(torch.load(save_path + "critic_step=200000.ckpt", map_location=device))
critic_target.load_state_dict(torch.load(save_path + "critic_target_step=200000.ckpt", map_location=device))
critic.eval()
critic_target.eval()

env_eval = gym.vector.make("halfcheetah-medium-v2", num_envs=50)
normalizer = dataset.get_normalizer()

num_candidates = 128
temperature = 300.0

prior = torch.zeros((num_candidates * 50, act_dim))
scores = []

for _ in range(n_seeds):
    obs, all_done, ep_rew, t = env_eval.reset(), False, 0, 0

    while not np.all(all_done):
        obs = torch.tensor(normalizer.normalize(obs), dtype=torch.float32, device=device)
        obs = einops.repeat(obs, "b d -> (b n) d", n=num_candidates)

        act, _ = actor.sample(prior, sample_steps=5, condition_cfg=obs, w_cfg=1.0)

        with torch.no_grad():
            q = critic(obs, act)
            q = einops.rearrange(q, "(b n) 1 -> b n 1", n=num_candidates)
            q = q - q.mean(1, keepdim=True)
            act = einops.rearrange(act, "(b n) d -> b n d", n=num_candidates)

            w = torch.softmax(q * temperature, 1)
            idx = torch.multinomial(w.squeeze(-1), 1).squeeze(1)
            act = act[torch.arange(act.size(0)), idx]
            act = act.cpu().numpy()

        obs, rew, done, info = env_eval.step(act)

        all_done = np.logical_or(all_done, done)
        ep_rew += rew
        t += 1

        print(f"Step: {t}, Reward: {np.round(rew, 2)}")

    scores.append(env.get_normalized_score(ep_rew.mean()) * 100.0)

print(f"D4RL score: {np.mean(scores)}+-{np.std(scores)}")
env_eval.close()

  logger.warn(


Step: 1, Reward: [-0.26 -0.09 -0.67 -0.65  0.08 -0.51 -0.35  0.01 -0.44 -0.14 -0.25 -0.05
 -0.42 -0.32 -0.21  0.05 -0.35 -0.    0.09 -0.84 -0.88 -0.6   0.28 -0.57
 -0.37  0.02 -0.17  0.05  0.22  0.2   0.03 -0.03 -0.45 -0.21 -0.29 -0.5
 -0.63 -0.23 -0.35  0.1   0.31 -0.71 -0.12 -0.19 -0.56 -0.31 -0.35 -0.56
 -0.3   0.28]
Step: 2, Reward: [-0.76 -0.89 -1.06 -1.12 -0.23 -1.05 -0.92 -0.33 -1.   -0.53 -0.73 -0.71
 -1.24 -0.95 -0.58 -0.12 -0.96 -0.36 -0.19 -1.5  -1.15 -1.21 -0.02 -1.03
 -1.01 -0.63 -0.79 -0.18  0.08 -0.14 -0.49 -0.   -1.08 -0.83 -0.93 -0.91
 -1.25 -0.85 -1.07 -0.42  0.09 -1.17 -0.6  -0.7  -0.95 -1.11 -0.85 -1.27
 -0.81  0.03]
Step: 3, Reward: [-1.15 -0.66 -0.27 -0.66 -0.75 -0.57 -0.45 -0.68 -0.45 -0.38 -0.94 -0.66
 -0.97 -0.89 -0.57 -0.12 -1.12 -0.79 -0.31 -0.41 -0.37 -0.8  -0.07 -0.9
 -0.56 -1.07 -0.59 -0.31 -0.45 -0.54 -0.76  0.51 -0.66 -0.41 -0.74 -0.13
 -0.72 -0.35 -0.86 -0.92 -0.33 -0.58 -0.71 -0.46 -0.52 -1.25 -0.51 -1.24
 -0.52 -0.82]
Step: 4, Reward: [-0.37 -0.18  0.

In our test, the results are slightly lower than those reported in the original DQL paper (*Tutorial 3*: 50.7 ± 0.1 vs *Official:* 51.1 ± 0.5). However, it’s important to note that we only trained for 200k steps as an example, whereas the original paper trained for 2 million steps. With a longer training time, the performance should improve.
