# **Tutorial 3. DQL (Diffusion Q-learning) for D4RL-MuJoCo**

## 1 Introduction

The first two tutorials were pretty standard supervised learning tasks where we just provided the model with input and it started training. In this tutorial, we're going to tackle a more complex reinforcement learning task that involves alternating updates of a diffusion model and a Q function, which makes it tricky to handle using simple supervised learning methods. DQL is a simple and effective diffusion-based reinforcement learning actor-critic algorithm. Similar to TD3BC, its policy updates involve two types of loss: diffusion loss and Q maximizing loss, in which the former is just a behavior cloning loss:
$$
\begin{equation}
\mathcal L_{\text{policy}}(\theta)=\mathcal L_{\text{diffusion}}(\theta)-\eta\cdot Q_\phi(\bm s, \pi_\theta(\bm a|\bm s)). \tag{1}
\end{equation}
$$
Its critic is a policy Q function that's trained using TD learning. 
$$
\begin{equation}
\mathcal L_{\text{critic}}(\phi)=\left(Q_\phi(\bm s, \bm a)-(r + \gamma\cdot Q_{\phi^-}(\bm s', \pi_\theta(\bm a'|\bm s')))\right)^2. \tag{2}
\end{equation}
$$
 
As a typical actor-critic algorithm, training of the diffusion model affects the Q function, and training of the Q function also impacts the diffusion model.

## 2 Setting up the Environment and the Dataset

Here we use `halfcheetah-medium-v2` in D4RL-MuJoCo as an example. D4RL-MuJoCo is a widely used offline RL benchmark, and `halfcheetah-medium-v2` requires to control a halfcheetah robot to move forward as fast as possible. CleanDiffuser has already provided a simple interface to load D4RL datasets.

In [5]:
import d4rl
import gym

from cleandiffuser.dataset.d4rl_mujoco_dataset import D4RLMuJoCoTDDataset

env = gym.make("halfcheetah-medium-v2")
dataset = D4RLMuJoCoTDDataset(d4rl.qlearning_dataset(env), normalize_reward=True)
obs_dim, act_dim = dataset.obs_dim, dataset.act_dim

  logger.warn(
load datafile:   0%|          | 0/21 [00:00<?, ?it/s]

load datafile: 100%|██████████| 21/21 [00:02<00:00,  8.63it/s]


## 3 Building the Diffusion Model

Diffusion model in DQL plays the role of a policy network. Just like in the tutorial 1, the diffusion model should be designed to generate actions conditioned on the input states. 

In [6]:
import torch
import torch.nn as nn

from cleandiffuser.diffusion import DiscreteDiffusionSDE
from cleandiffuser.nn_diffusion import IDQLMlp
from cleandiffuser.nn_condition import MLPCondition

device = "cuda:0"

nn_diffusion = IDQLMlp(x_dim=act_dim, emb_dim=64, timestep_emb_type="positional")

# The label dropout rate is set to 0.0 as we do not use Classifier-free Guidance.
nn_condition = MLPCondition(in_dim=obs_dim, out_dim=64, hidden_dims=64, dropout=0.0)

actor = DiscreteDiffusionSDE(
    nn_diffusion,
    nn_condition,
    diffusion_steps=5,
    x_max=torch.full((act_dim,), fill_value=1.0),
    x_min=torch.full((act_dim,), fill_value=-1.0),
).to(device)

Double Q functions with LayerNorm and Mish activation.

In [7]:
from copy import deepcopy


class TwinQ(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_dim: int = 256):
        super().__init__()
        self.Q1 = nn.Sequential(
            nn.Linear(obs_dim + act_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, 1),
        )
        self.Q2 = nn.Sequential(
            nn.Linear(obs_dim + act_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, 1),
        )

    def both(self, obs, act):
        q1, q2 = self.Q1(torch.cat([obs, act], -1)), self.Q2(torch.cat([obs, act], -1))
        return q1, q2

    def forward(self, obs, act):
        return torch.min(*self.both(obs, act))


critic = TwinQ(obs_dim, act_dim, hidden_dim=256).to(device)
critic_target = deepcopy(critic).requires_grad_(False).eval().to(device)

## 4 Training the Diffusion Model in Manual-update Style

To implement complex model update steps, we can use `update_diffusion` for manual updates, allowing for more flexible update logic. This way, we don’t need to add a wrapper to the dataset to match the format required by PyTorch Lightning.

In [8]:
from cleandiffuser.utils import loop_dataloader, FreezeModules
import numpy as np
import os

save_path = "../results/tutorial3_dql_for_d4rl_mujoco/"
if not os.path.exists(save_path):
    os.makedirs(save_path)

batch_size = 2048
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True, drop_last=True
)
critic_optim = torch.optim.Adam(critic.parameters(), lr=3e-4)
actor.configure_manual_optimizers()

step = 0
log = dict.fromkeys(["bc_loss", "policy_q_loss", "critic_td_loss", "target_q"], 0.0)
prior = torch.zeros((batch_size, act_dim), device=device)

for batch in loop_dataloader(dataloader):
    obs, next_obs = batch["obs"]["state"].to(device), batch["next_obs"]["state"].to(device)
    act = batch["act"].to(device)
    rew = batch["rew"].to(device)
    tml = batch["tml"].to(device)

    # --- Critic Update ---
    actor.eval()
    critic.train()

    q1, q2 = critic.both(obs, act)

    next_act, _ = actor.sample(
        prior,
        solver="ddpm",
        n_samples=batch_size,
        sample_steps=5,
        condition_cfg=next_obs,
        w_cfg=1.0,
        requires_grad=False,
    )

    with torch.no_grad():
        target_q = torch.min(*critic_target.both(next_obs, next_act))

    target_q = (rew + (1 - tml) * 0.99 * target_q).detach()

    critic_td_loss = (q1 - target_q).pow(2).mean() + (q2 - target_q).pow(2).mean()

    critic_optim.zero_grad()
    critic_td_loss.backward()
    critic_optim.step()

    log["critic_td_loss"] += critic_td_loss.item()
    log["target_q"] += target_q.mean().item()

    # --- Actor Update ---
    actor.train()
    critic.eval()

    bc_loss = actor.loss(act, obs)

    new_act, _ = actor.sample(
        prior,
        solver="ddpm",
        n_samples=batch_size,
        sample_steps=5,
        condition_cfg=obs,
        w_cfg=1.0,
        use_ema=False,
        requires_grad=True,
    )

    with FreezeModules([critic]):
        q1_actor, q2_actor = critic.both(obs, new_act)

    if np.random.uniform() > 0.5:
        policy_q_loss = -q1_actor.mean() / q2_actor.abs().mean().detach()
    else:
        policy_q_loss = -q2_actor.mean() / q1_actor.abs().mean().detach()

    # eta=1.0 for halfcheetah-medium-v2
    actor_loss = bc_loss + policy_q_loss * 1.0

    actor.manual_optimizers["diffusion"].zero_grad()
    actor_loss.backward()
    actor.manual_optimizers["diffusion"].step()

    log["bc_loss"] += bc_loss.item()
    log["policy_q_loss"] += policy_q_loss.item()

    step += 1

    # --- EMA Update ---
    if step % 5 == 0:
        if step >= 1000:
            actor.ema_update()
        for param, target_param in zip(critic.parameters(), critic_target.parameters()):
            target_param.data.copy_(0.995 * param.data + (1 - 0.995) * target_param.data)

    if step % 1000 == 0:
        log = {k: v / 1000 for k, v in log.items()}
        print(f"[{step}] {log}")
        log = dict.fromkeys(["bc_loss", "policy_q_loss", "critic_td_loss", "target_q"], 0.0)

    if step % 50_000 == 0:
        actor.save(save_path + f"actor_step={step}.ckpt")
        torch.save(critic.state_dict(), save_path + f"critic_step={step}.ckpt")
        torch.save(critic_target.state_dict(), save_path + f"critic_target_step={step}.ckpt")

    if step >= 200_000:
        break

[1000] {'bc_loss': 0.12377934498339892, 'policy_q_loss': -0.9846072461605072, 'critic_td_loss': 7.595048444822431, 'target_q': 18.960346080839635}
[2000] {'bc_loss': 0.08238512057811022, 'policy_q_loss': -0.98832589417696, 'critic_td_loss': 8.06298545050621, 'target_q': 40.818730871200565}
[3000] {'bc_loss': 0.06406415082886815, 'policy_q_loss': -0.9965834428668022, 'critic_td_loss': 3.769754119157791, 'target_q': 81.05738888549804}
[4000] {'bc_loss': 0.060024288821965456, 'policy_q_loss': -0.9959487068653107, 'critic_td_loss': 2.715253163576126, 'target_q': 88.40225748443603}
[5000] {'bc_loss': 0.057553786966949703, 'policy_q_loss': -0.9958168971538544, 'critic_td_loss': 2.2044705004692076, 'target_q': 89.54628697967529}
[6000] {'bc_loss': 0.055614226009696725, 'policy_q_loss': -0.9962961921691894, 'critic_td_loss': 2.152037105202675, 'target_q': 90.65434854888916}
[7000] {'bc_loss': 0.05411395792290569, 'policy_q_loss': -0.9968310383558273, 'critic_td_loss': 2.0946031621694563, 'targ