# **Tutorial 1. DBC (Diffusion Behavior Clone) for FrankaKitchen**
## 1. Introduction

In this tutorial, we will demonstrate how to implement a basic DBC (Diffusion Behavior Cloning) using CleanDiffuser. DBC is an imitation learning algorithm that replicates behaviors from an offline demonstration dataset. It leverages a diffusion model to generate actions based on the current observations. The underlying concept is similar to diffusion-based image generation models, with the key difference being that DBC generates actions conditioned on the state $\bm s$ instead of images.

Imitation learning relies on expert demonstrations. In this tutorial, we’ll tackle the RelayKitchen task, which involves a 9-DoF position-controlled Franka robot interacting with a kitchen environment. This environment includes an openable microwave, four turnable oven burners, an oven light switch, a movable kettle, two hinged cabinets, and a sliding cabinet door. The task includes 566 human demonstrations of various activities, such as opening the microwave, turning on the oven light, and moving the kettle. The goal is to train agents to imitate these demonstrations and complete as many tasks as possible within a limited time frame.

Let’s begin by downloading the expert demonstrations!

In [8]:
# ! mkdir ./dev
# ! cd ./dev
# ! wget https://diffusion-policy.cs.columbia.edu/data/training/kitchen.zip
# ! unzip kitchen.zip
# ! rm kitchen.zip
# ! cd ..

## 2. Setting up the Environment and the Dataset

CleanDiffuser provides a straightforward interface to set up the environment and dataset. The code below shows how to create a gym-like environment and a PyTorch Dataset class for the FrankaKitchen task.

In [1]:
import gym

from cleandiffuser.dataset.kitchen_dataset import KitchenDataset
from cleandiffuser.env import kitchen

env = gym.make("kitchen-all-v0")
dataset = KitchenDataset("dev/kitchen", horizon=1, pad_before=0, pad_after=0, abs_action=False)

data = dataset[0]
obs, act = data["state"], data["action"]
obs_dim, act_dim = dataset.obs_dim, dataset.act_dim
print(f"Finish loading data. Observation shape: {obs.shape}. Action shape {act.shape}.")

Reading configurations for Franka
[40m[97mInitializing Franka sim[0m


  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


Finish loading data. Observation shape: (1, 60). Action shape (1, 9).


## 3. Building the Diffusion Model
Following the DBC approach, we use a diffusion model to generate expert actions based on the current observations. We utilize a DDPM with `PearceMlp` as the neural network backbone and `PearceObsCondition` as the conditioning network. Once the networks are set up, building the diffusion model is simply a matter of integrating them!

In [2]:
import torch

from cleandiffuser.diffusion import ContinuousDiffusionSDE
from cleandiffuser.nn_condition import PearceObsCondition
from cleandiffuser.nn_diffusion import PearceMlp

nn_diffusion = PearceMlp(
    x_dim=act_dim, condition_horizon=1, emb_dim=128, hidden_dim=512, timestep_emb_type="untrainable_fourier"
)
""" nn.Module: xt (bs, act_dim) x t (bs, ) x condition (bs, condition_horizon * emb_dim) -> eps_theta (bs, act_dim) """
nn_condition = PearceObsCondition(obs_dim=obs_dim, emb_dim=128, flatten=True, dropout=0.0)
""" nn.Module: obs (bs, condition_horizon, obs_dim) x t (bs, ) -> condition (bs, condition_horizon * emb_dim) if `flatten` else (bs, condition_horizon, emb_dim) """

# Since the action space is [-1, 1], we can set `x_max` and `x_min` to constrain the generated actions.
actor = ContinuousDiffusionSDE(
    nn_diffusion,
    nn_condition,
    x_max=torch.full((act_dim,), fill_value=1.0),
    x_min=torch.full((act_dim,), fill_value=-1.0),
)

## 4. Training the Diffusion Model

### 4.1 PyTorch Lightning Approach
All diffusion models in CleanDiffuser are implemented as `LightningModules`, making it easy to train them using PyTorch Lightning Trainers. PyTorch Lightning simplifies the process of training deep learning models and supports features like distributed training, mixed precision training, and automatic checkpointing with just a few lines of code. To set up the Trainer, you'll need:

- A CleanDiffuser `DiffusionModel`.
- A PyTorch `DataLoader` that organizes the batch data as a dictionary. The keys are `x0`, `condition_cfg`, and `condition_cg` by default. `x0` contains samples from the target distribution and is required. `condition_cfg` and `condition_cg` represent the CFG/CG conditions for the diffusion model and are optional; they can be set to `None` or not included if not used.

Here’s an example of how to set up the Trainer and train the diffusion model.

**NOTE:** Setting up the PyTorch Lightning Trainer requires a specific configuration. You’ll either need to create a Dataset class that returns the required dictionary format, or use a Wrapper to adapt the data. The `BC_Wrapper` below demonstrates how to adapt the data. The `KitchenDataset` organizes the batch as `batch = {"state": torch.Tensor of shape (batch_size, horizon, state_dim), "action": torch.Tensor of shape (batch_size, horizon, action_dim)}`. Using `BC_Wrapper`, we adapt the batch to the required format: `{"x0": torch.Tensor of shape (batch_size, action_dim), "condition_cfg": torch.Tensor of shape (batch_size, state_dim)}`.


In [3]:
# Set the precision for float32 matrix multiplication. This is important for performance and numerical stability.
# "high" is recommended for most applications, but you can use "highest"(default) "high" "medium"
torch.set_float32_matmul_precision("high")


In [5]:
import pytorch_lightning as L
from pytorch_lightning.callbacks import ModelCheckpoint


class BC_Wrapper(torch.utils.data.Dataset):
    def __init__(self, dataset: torch.utils.data.Dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getattr__(self, name):
        return getattr(self.dataset, name)

    def __getitem__(self, idx):
        data = self.sampler.sample_sequence(idx)
        return {
            "x0": data["action"][0],
            "condition_cfg": data["state"][0],
        }


save_path = "results/tutorial1_dbc_for_kitchen/"

dataloader = torch.utils.data.DataLoader(
    BC_Wrapper(dataset), batch_size=512, shuffle=True, num_workers=4, persistent_workers=True
)

callback = ModelCheckpoint(dirpath=save_path, filename="dbc-{step}", every_n_train_steps=10_000)

trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    max_steps=200_000,
    deterministic=True,
    log_every_n_steps=200,
    default_root_dir=save_path,
    callbacks=[callback],
)

trainer.fit(actor, dataloader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/home/dynias/CleanDiffuser-lightning/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/home/dynias/CleanDiffuser-lightning/.venv/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:658: Checkpoint directory /home/dynias/CleanDiffuser-lightning/notebooks/result

Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_steps=200000` reached.


One of the key advantages of PyTorch Lightning is its native support for distributed training. By simply setting the `devices` argument to a list of GPU IDs, you can train the model on multiple GPUs. For more advanced features, such as mixed precision training or gradient accumulation, you can refer to the [PyTorch Lightning documentation](https://lightning.ai/docs/pytorch/stable/common/trainer.html).

### 4.2 Manual Updating Approach

For some offline RL algorithms, the training process may involve multiple components that need to be updated in a specific order. In such cases, you can manually update the diffusion model using the `update_diffusion` method. This method processes a batch of data, performs one update step, and returns the loss. Below is an example of a single training step using this approach.

In [6]:
batch = next(iter(dataloader))

update_log = actor.update_diffusion(x0=batch["x0"], condition_cfg=batch["condition_cfg"])

print(update_log)

{'diffusion_loss': 0.060467831790447235}


## 5. Evaluation

After training, the model’s performance can be evaluated by sampling actions from the model and running the agent in the environment. The `sample` method allows you to generate actions from the model. Below is an example of how to evaluate the model by running the agent in a parallel environment.

In [4]:
import numpy as np

n_seeds = 5
success_rate_for_n_tasks = np.zeros(5)

# device for evaluation
device = "cuda:0"

# loading from checkpoint
actor.load_state_dict(
    torch.load("results/tutorial1_dbc_for_kitchen/dbc-step=200000.ckpt", map_location=device)["state_dict"]
)
actor.to(device).eval()

# evaluating
env_eval = gym.vector.make("kitchen-all-v0", num_envs=50)
normalizer = dataset.get_normalizer()

for _ in range(n_seeds):
    obs, all_done, ep_rew, t = env_eval.reset(), False, 0, 0

    while not np.all(all_done):
        obs = torch.tensor(normalizer["state"].normalize(obs), device=device)

        prior = torch.zeros((50, act_dim))

        act, log = actor.sample(prior, solver="ddpm", sample_steps=5, condition_cfg=obs, w_cfg=1.0)
        act = act.cpu().numpy()
        act = normalizer["action"].unnormalize(act)

        obs, rew, done, info = env_eval.step(act)
        all_done = np.logical_or(all_done, done)
        ep_rew += rew
        t += 1

        print(f"[t={t}] Task completed: {ep_rew}")

    for i in range(5):
        success_rate_for_n_tasks[i] += ((ep_rew >= i + 1).sum() / 50)

env_eval.close()
success_rate_for_n_tasks /= n_seeds

print(f"Success rate (>= n tasks): {success_rate_for_n_tasks}")

  torch.load("results/tutorial1_dbc_for_kitchen/dbc-step=200000.ckpt", map_location=device)["state_dict"]


Reading configurations for Franka
[40m[97mInitializing Franka sim[0m


/home/dynias/CleanDiffuser-lightning/.venv/lib/python3.10/site-packages/gym/spaces/box.py:84: [33mWARN: Box bound precision lowered by casting to float32[0m


Reading configurations for FrankaReading configurations for Franka
[40m[97mInitializing Franka sim[0m
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka[40m[97mInitializing Franka sim[0m
Reading configurations for Franka


Reading configurations for Franka[40m[97mInitializing Franka sim[0m
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
Reading configurations for Franka

Reading configurations for Franka[40m[97mInitializing Franka sim[0m
[40m[97mInitializing Franka sim[0m
[40m[97mInitializing Franka sim[0mReading configurations for Franka


Reading configurations for Franka[40m[97mInitializing Franka sim[0m

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0m
Reading configurations for FrankaReading configurations for FrankaReading configurations for Franka


[40m[97mInitializing Franka sim

In [25]:
import torch  
  
# device for evaluation
device = "cuda:0"

# loading from checkpoint
actor.load_state_dict(
    torch.load("results/tutorial1_dbc_for_kitchen/dbc-step=200000.ckpt", map_location=device)["state_dict"]
)
actor.to(device).eval()

# Use single environment for visualization  
env_single = gym.make("kitchen-all-v0")  
frames = []  
  
# Run one episode for visualization  
obs = env_single.reset()  
done = False  
t = 0
ep_rew = 0  

with torch.no_grad():
    while not done and t < 280:  # max episode steps  
        obs = torch.tensor(normalizer["state"].normalize(obs), device=device, dtype=torch.float32)

        # Add batch dimension if missing  
        if obs.dim() == 1:  
            obs = obs.unsqueeze(0)  

        prior = torch.zeros((1, act_dim), dtype=torch.float32)

        act, log = actor.sample(prior, solver="ddpm", sample_steps=5, condition_cfg=obs, w_cfg=1.0)
        act = act.cpu().numpy()
        act = normalizer["action"].unnormalize(act)[0]

        obs, rew, done, info = env_single.step(act)
        ep_rew += rew
        t += 1

        # Capture frame  
        frame = env_single.render(mode='rgb_array')  
        frames.append(frame)  

env_single.close()

  torch.load("results/tutorial1_dbc_for_kitchen/dbc-step=200000.ckpt", map_location=device)["state_dict"]


Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Task kettle completed!
Task bottom burner completed!
Task light switch completed!
Task hinge cabinet completed!


In [26]:
import imageio  
import numpy as np  
  
# Enhanced frame validation with shape correction  
valid_frames = []  
for i, frame in enumerate(frames):  
    if frame is not None and isinstance(frame, np.ndarray) and frame.size > 0:            
        # Handle malformed 2D frames  
        if len(frame.shape) == 2:  
            if frame.shape == (1280, 3):  
                # This appears to be a flattened or incorrectly shaped frame  
                # Skip this frame or create a placeholder  
                continue  
            elif frame.shape[1] == 3:  
                # Try to interpret as (height*width, 3) and reshape  
                total_pixels = frame.shape[0]  
                # Assume square-ish aspect ratio for reshaping  
                height = int(np.sqrt(total_pixels))  
                width = total_pixels // height  
                try:  
                    frame = frame.reshape(height, width, 3)  
                except:  
                    print(f"Could not reshape frame {i}")  
                    continue  
          
        # Ensure proper 3D format  
        if len(frame.shape) == 3 and frame.shape[2] == 3:  
            # Ensure uint8 format  
            if frame.dtype != np.uint8:  
                if frame.max() <= 1.0:  
                    frame = (frame * 255).astype(np.uint8)  
                else:  
                    frame = frame.astype(np.uint8)  
              
            valid_frames.append(frame)  
        else:  
            print(f"Frame {i} still has invalid shape: {frame.shape}")  
  
# Save video with fallback options  
if valid_frames:  
    try:  
        writer = imageio.get_writer("results/kitchen_plan_visualization.mp4", fps=30, codec='libx264')  
        for frame in valid_frames:  
            writer.append_data(frame)  
        writer.close()  
        print(f"Video saved with {len(valid_frames)} valid frames")  
    except Exception as e:  
        print(f"MP4 failed: {e}")  
        # Fallback to GIF  
        try:  
            imageio.mimsave("results/kitchen_plan_visualization.gif", valid_frames, fps=10)  
            print(f"GIF saved with {len(valid_frames)} valid frames")  
        except Exception as e2:  
            print(f"Both MP4 and GIF failed: {e2}")  
else:  
    print("No valid frames to save - all frames were malformed")

MP4 failed: could not broadcast input array from shape (1280,3) into shape (1280,3,3)


/home/dynias/CleanDiffuser-lightning/.venv/lib/python3.10/site-packages/imageio/plugins/pillow.py:409: The keyword `fps` is no longer supported. Use `duration`(in ms) instead, e.g. `fps=50` == `duration=20` (1000 * 1/50).


GIF saved with 280 valid frames


The results are impressive! When comparing with the official report, the tutorial model achieves significantly better performance using only 5 sampling steps and no history observations, while the official model uses 50 sampling steps and 2 history observations. We believe that using more advanced solvers and sampling schedules could further enhance the model's performance.

||Sampling Steps|History Observations|Tasks>=1|Tasks>=2|Tasks>=3|Tasks>=4|Tasks>=5|
|---|---|---|---|---|---|--|--|
|Offical|50|2|99|94|82|68|2|
|Tutorial 1|5|0|100|99.2|94.8|77.6|4|