# **Tutorial 4: Experimenting with More Diffusion Backbones**

## 1. Introduction

Curious to try out more backbone models like EDM, Rectified Flow, or Consistency Models? In this tutorial, we’ll explore how to implement them! To keep things simple, we’ll revisit the RelayKitchen task from **Tutorial 1**. We’ll follow the familiar steps to set up the environment, dataset, model, training, and evaluation. However, this time, we’ll increase the complexity of the diffusion model by requiring it to generate **a sequence of actions** of length `Ta`, instead of just a single action. 

In addition, we’ll experiment with different backbone models to observe their differences and performance.

Let’s start by setting up the dataset and neural network, just like we did before!

In [1]:
import pytorch_lightning as L
import torch
from pytorch_lightning.callbacks import ModelCheckpoint

from cleandiffuser.dataset.kitchen_dataset import KitchenDataset
from cleandiffuser.nn_condition import PearceObsCondition
from cleandiffuser.nn_diffusion import DiT1d


class BC_Wrapper(torch.utils.data.Dataset):
    def __init__(self, dataset: KitchenDataset, To: int, Ta: int):
        self.dataset = dataset
        self.To, self.Ta = To, Ta

    def __len__(self):
        return len(self.dataset)

    def __getattr__(self, name):
        return getattr(self.dataset, name)

    def __getitem__(self, idx):
        batch = self.dataset[idx]
        return {"x0": batch["action"][To - 1 :], "condition_cfg": batch["state"][:To]}


L.seed_everything(0, workers=True)
save_path = "./results/tutorial4_try_more_diffusion_backbones/"

To = 2
Ta = 12

dataset = KitchenDataset("./dev/kitchen", horizon=To + Ta - 1, pad_before=To - 1, pad_after=Ta - 1, abs_action=True)
obs_dim, act_dim = dataset.obs_dim, dataset.act_dim

nn_diffusion = DiT1d(
    x_dim=act_dim, emb_dim=128 * To, d_model=384, n_heads=12, depth=4, timestep_emb_type="untrainable_fourier",
    x_seq_len=Ta
)
nn_condition = PearceObsCondition(obs_dim=obs_dim, emb_dim=128, flatten=True, dropout=0)

dataloader = torch.utils.data.DataLoader(
    BC_Wrapper(dataset, To, Ta), batch_size=256, shuffle=True, num_workers=4, persistent_workers=True
)

Seed set to 0


Abs action dataset found. Loading...


In [2]:
from cleandiffuser.diffusion import ContinuousRectifiedFlow

actor1 = ContinuousRectifiedFlow(
    nn_diffusion,
    nn_condition,
    ema_rate=0.999,
    x_max=torch.full((Ta, act_dim), 1.0),
    x_min=torch.full((Ta, act_dim), -1.0),
    optimizer_params={"lr": 5e-4},
)

You may have noticed that this time the dataset uses special horizon values and padding. This is because we need to generate a sequence of **12 actions** based on **2 frames of observations** (1 historical and 1 current). As a result, each sample in the dataset is a sequence of length **13**.

$$
a_{0},~\bm a_{1},~a_{2},~a_{3}, \ldots, a_{12},~a_{13} \\
o_{0},~\bm o_{1},~o_{2},~o_{3}, \ldots, o_{12},~o_{13}
$$

As shown in the diagram, the **bolded parts** represent the current variables. To generate sequence data, we utilize **DiT1d** as the neural network backbone for the diffusion model (previously used in **Tutorial 2** for generating state-action sequences). We also employ **PearceObsCondition** to encode the multi-frame observations.

In [3]:
import gym
import numpy as np

from cleandiffuser.diffusion import DiffusionModel
from cleandiffuser.env import kitchen


def evaluate(actor: DiffusionModel, sample_steps: int, **kwargs):
    env = gym.vector.make("kitchen-all-v0", 50, use_abs_action=True)
    normalizers = dataset.get_normalizer()
    prior = torch.zeros((50, Ta, act_dim))
    avg_sr = []

    for _ in range(3):
        obs, condition, all_done, ep_rew, t = env.reset(), None, False, 0, 0

        while t < 280:
            obs = torch.tensor(normalizers["state"].normalize(obs), device=actor.device, dtype=torch.float32)
            print(f"obs shape: {obs.shape}")
            if condition is None:
                condition = obs.unsqueeze(1).repeat(1, 2, 1)
            else:
                condition[:, 1] = obs
            print(f"condition shape: {condition.shape}")

            act, log = actor.sample(prior, sample_steps=sample_steps, condition_cfg=condition, w_cfg=1.0, **kwargs)
            act = normalizers["action"].unnormalize(act.cpu().numpy())

            for i in range(4):
                print(f"act: {act.shape}, act[{i}]: {act[:, i].shape}")
                obs, rew, done, _ = env.step(act[:, i])
                return

                all_done = np.logical_or(all_done, done)
                ep_rew += rew
                t += 1

                if all_done.all():
                    break

                if i == 2:
                    condition[:, 0] = torch.tensor(
                        normalizers["state"].normalize(obs), device=actor.device, dtype=torch.float32
                    )

            print(f"[t={t}] ep_rew={ep_rew}")

        success_rate = np.zeros(5)
        for i in range(5):
            success_rate[i] = (ep_rew > i).sum() / 50

        avg_sr.append(success_rate)

    print(np.mean(avg_sr, axis=0))
    print(log)

    env.close()

In [94]:
# Evaluation with 5-step sampling
evaluate(actor1, sample_steps=5)

Reading configurations for Franka
[40m[97mInitializing Franka sim[0m


/home/dynias/CleanDiffuser-lightning/.venv/lib/python3.10/site-packages/gym/spaces/box.py:84: [33mWARN: Box bound precision lowered by casting to float32[0m


Reading configurations for Franka
[40m[97mInitializing Franka sim[0mReading configurations for Franka
[40m[97mInitializing Franka sim[0mReading configurations for Franka
Reading configurations for Franka


[40m[97mInitializing Franka sim[0m
[40m[97mInitializing Franka sim[0mReading configurations for Franka

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka sim[0mReading configurations for Franka

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka sim[0mReading configurations for Franka


Reading configurations for Franka[40m[97mInitializing Franka sim[0m
[40m[97mInitializing Franka sim[0mReading configurations for Franka

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

Reading configurations for Fra

William made for saving video :)

In [7]:
import gym
import numpy as np

from cleandiffuser.diffusion import DiffusionModel
from cleandiffuser.env import kitchen

import imageio


def evaluate_and_record_video(actor: DiffusionModel, sample_steps: int, save_path: str, **kwargs):
    env = gym.make("kitchen-all-v0", use_abs_action=True)
    normalizers = dataset.get_normalizer()
    prior = torch.zeros((Ta, act_dim))
    avg_sr = []
    frames = []

    for _ in range(3):
        obs, condition, all_done, ep_rew, t = env.reset(), None, False, 0, 0

        while t < 280:
            obs = torch.tensor(normalizers["state"].normalize(obs), device=actor.device, dtype=torch.float32)
            # print(f"obs shape: {obs.shape}")
            if condition is None:
                condition = obs.repeat(1, 2, 1)
            else:
                condition[:, 1] = obs
            # print(f"condition shape: {condition.shape}")

            act, log = actor.sample(prior, sample_steps=sample_steps, condition_cfg=condition, w_cfg=1.0, **kwargs)
            act = normalizers["action"].unnormalize(act.cpu().numpy())[0]

            for i in range(4):
                # print(f"act: {act.shape}, act[{i}]: {act[i].shape}")
                obs, rew, done, _ = env.step(act[i])

                all_done = np.logical_or(all_done, done)
                ep_rew += rew
                t += 1
                                # Render and store frame
                frame = env.render(mode="rgb_array")
                if frame is not None:
                    frames.append(np.uint8(frame))

                if all_done.all():
                    break

                if i == 2:
                    condition[:, 0] = torch.tensor(
                        normalizers["state"].normalize(obs), device=actor.device, dtype=torch.float32
                    )

            print(f"[t={t}] ep_rew={ep_rew}")

        success_rate = np.zeros(5)
        for i in range(5):
            success_rate[i] = (ep_rew > i)

        avg_sr.append(success_rate)

    print(np.mean(avg_sr, axis=0))
    print(log)

    env.close()


    # Save frames to video
    if frames:
        try:
            imageio.mimwrite(save_path, frames, fps=10, codec='libx264', format='FFMPEG')
            print(f"Video saved to {save_path}")
        except Exception as e:
            print(f"MP4 failed: {e}, trying GIF fallback...")
            try:
                gif_path = save_path.replace(".mp4", ".gif")
                imageio.mimsave(gif_path, frames, fps=10)
                print(f"GIF saved to {gif_path}")
            except Exception as e2:
                print(f"Both MP4 and GIF failed: {e2}")
    else:
        print("No valid frames to save - all frames were malformed")


In [8]:
# Evaluation with 5-step sampling
evaluate_and_record_video(actor1, sample_steps=5, save_path="./results/tutorial4_try_more_diffusion_backbones/actor1_steps5.mp4")

Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
[t=4] ep_rew=0.0
[t=8] ep_rew=0.0
[t=12] ep_rew=0.0
[t=16] ep_rew=0.0
[t=20] ep_rew=0.0
[t=24] ep_rew=0.0
[t=28] ep_rew=0.0
[t=32] ep_rew=0.0
[t=36] ep_rew=0.0
[t=40] ep_rew=0.0
Task kettle completed!
[t=44] ep_rew=1.0
[t=48] ep_rew=1.0
[t=52] ep_rew=1.0
[t=56] ep_rew=1.0
[t=60] ep_rew=1.0
[t=64] ep_rew=1.0
[t=68] ep_rew=1.0
[t=72] ep_rew=1.0
[t=76] ep_rew=1.0
[t=80] ep_rew=1.0
[t=84] ep_rew=1.0
[t=88] ep_rew=1.0
[t=92] ep_rew=1.0
[t=96] ep_rew=1.0
[t=100] ep_rew=1.0
[t=104] ep_rew=1.0
[t=108] ep_rew=1.0
[t=112] ep_rew=1.0
[t=116] ep_rew=1.0
[t=120] ep_rew=1.0
[t=124] ep_rew=1.0
[t=128] ep_rew=1.0
[t=132] ep_rew=1.0
[t=136] ep_rew=1.0
Task bottom burner completed!
[t=140] ep_rew=2.0
[t=144] ep_rew=2.0
[t=148] ep_rew=2.0
[t=152] ep_rew=2.0
[t=156] ep_rew=2.0
[t=160] ep_rew=2.0
[t=164] ep_rew=2.0
[t=168] ep_rew=2.0
[t=172] ep_rew=2.0
[t=176] ep_rew=2.0
[t=180] ep_rew=2.0
Task light switch completed!
[t=184] ep_rew=3.

## 2. Rectified Flow

[Rectified Flow](https://arxiv.org/abs/2209.03003) features a straight ODE flow, which enables it to generate high-quality samples even with very few sampling steps. Moreover, it demonstrates excellent distillation efficiency. With each iteration of the *Reflow* procedure, the flow becomes increasingly straight, continually improving the model’s performance.

Let’s begin training!

In [35]:
callback = ModelCheckpoint(dirpath=save_path, filename="bc1-{step}", every_n_train_steps=10_000)

trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    max_steps=100_000,
    deterministic=True,
    log_every_n_steps=50,
    default_root_dir=save_path,
    callbacks=[callback],
)

# trainer.fit(actor1, dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [5]:
# device for evaluation
device = "cuda:0"

# loading from checkpoint
actor1.load_state_dict(
    torch.load("results/tutorial4_try_more_diffusion_backbones/bc1-step=100000.ckpt", map_location=device)["state_dict"]
)
actor1.to(device).eval()

  torch.load("results/tutorial4_try_more_diffusion_backbones/bc1-step=100000.ckpt", map_location=device)["state_dict"]


ContinuousRectifiedFlow(
  (model): ModuleDict(
    (diffusion): DiT1d(
      (map_noise): UntrainableFourierEmbedding()
      (x_proj): Linear(in_features=9, out_features=384, bias=True)
      (t_proj): Sequential(
        (0): Linear(in_features=256, out_features=384, bias=True)
        (1): SiLU()
        (2): Linear(in_features=384, out_features=384, bias=True)
      )
      (cond_proj): Sequential(
        (0): Linear(in_features=256, out_features=384, bias=True)
        (1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      )
      (blocks): ModuleList(
        (0-3): 4 x DiTBlock(
          (sa_norm): LayerNorm((384,), eps=1e-06, elementwise_affine=False)
          (sa_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=True)
          )
          (ffn_norm): LayerNorm((384,), eps=1e-06, elementwise_affine=False)
          (mlp): Sequential(
            (0): Linear(in_features=384, out_features=1536

In [61]:
# Evaluation with 1-step sampling
evaluate(actor1, sample_steps=1)

Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka sim[0mReading configurations for Franka

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka sim[0mReading configurations for Franka

Reading configurations for Fra

In [62]:
# Evaluation with 3-step sampling
evaluate(actor1, sample_steps=3)

Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
Reading configurations for Franka[40m[97mInitializing Franka sim[0m

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka
Reading configurations for Franka

[40m[97mInitializing Franka sim[0m[40m[97mInitializing Franka sim[0m

Reading configurations for Franka
[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka 

In [63]:
# Evaluation with 5-step sampling
evaluate(actor1, sample_steps=5)

Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0mReading configurations for Franka

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka sim[0mReading configurations for Franka

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka sim[0mReading configurations for Franka

Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka[40m[97mInitializing Franka sim[0m

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0mReading configurations for Franka

Reading configurations for Franka

[40m[97mInitializing Franka sim[0m[40m[97mInitializing Franka 

Due to Rectified Flow’s straightness property, the model maintains decent performance even with 1-step sampling. As shown in the table below, performance improves as the number of sampling steps increases.

|Completed Tasks|>=1|>=2|>=3|>=4|>=5|
|---|---|---|--|--|--|
|1-step|99.3|88.7|70.7|58.0|0.0|
|3-step|100.0|99.3|96.7|90.0|0.0|
|5-step|100.0|100.0|99.3|96.0|0.7|

## 3. EDM

[EDM](https://arxiv.org/abs/2206.00364) demonstrates the equivalence of noise schedule and time schedule, and introduces several optimal designs for diffusion models, including its unique preconditioning techniques. Let’s see how EDM performs on this task!

In [9]:
from cleandiffuser.diffusion import ContinuousEDM

actor2 = ContinuousEDM(
    nn_diffusion,
    nn_condition,
    ema_rate=0.999,
    x_max=torch.full((Ta, act_dim), 1.0),
    x_min=torch.full((Ta, act_dim), -1.0),
    optimizer_params={"lr": 5e-4},
)

In [10]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

import torch
torch.use_deterministic_algorithms(True)  # if you're setting this manually


In [11]:

callback = ModelCheckpoint(dirpath=save_path, filename="bc2-{step}", every_n_train_steps=10_000)

trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    max_steps=100_000,
    deterministic=True,
    log_every_n_steps=50,
    default_root_dir=save_path,
    callbacks=[callback],
)

trainer.fit(actor2, dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX 5000 Ada Generation Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/dynias/CleanDiffuser-lightning/.venv/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:658: Checkpoint directory /home/dynias/CleanDiffuser-lightning/notebooks/results/tutorial4_try_more_diffusion_backbones exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type       | Params | Mode 
----------------------------------------------------
0 | model        | ModuleDict | 11.3 M | train
1 | model_ema    | ModuleDict | 11.3 M | eval 


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=100000` reached.


In [12]:
# device for evaluation
device = "cuda:0"

# loading from checkpoint
actor2.load_state_dict(
    torch.load("results/tutorial4_try_more_diffusion_backbones/bc2-step=100000.ckpt", map_location=device)["state_dict"]
)
actor2.to(device).eval()

  torch.load("results/tutorial4_try_more_diffusion_backbones/bc2-step=100000.ckpt", map_location=device)["state_dict"]


ContinuousEDM(
  (model): ModuleDict(
    (diffusion): DiT1d(
      (map_noise): UntrainableFourierEmbedding()
      (x_proj): Linear(in_features=9, out_features=384, bias=True)
      (t_proj): Sequential(
        (0): Linear(in_features=256, out_features=384, bias=True)
        (1): SiLU()
        (2): Linear(in_features=384, out_features=384, bias=True)
      )
      (cond_proj): Sequential(
        (0): Linear(in_features=256, out_features=384, bias=True)
        (1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      )
      (blocks): ModuleList(
        (0-3): 4 x DiTBlock(
          (sa_norm): LayerNorm((384,), eps=1e-06, elementwise_affine=False)
          (sa_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=True)
          )
          (ffn_norm): LayerNorm((384,), eps=1e-06, elementwise_affine=False)
          (mlp): Sequential(
            (0): Linear(in_features=384, out_features=1536, bias=Tru

In [13]:
# Evaluation with 1-step sampling
evaluate(actor2, sample_steps=1)

Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka 

In [14]:
# Evaluation with 3-step sampling
evaluate(actor2, sample_steps=3)

Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
Reading configurations for Franka
[40m[97mInitializing Franka sim[0mReading configurations for Franka[40m[97mInitializing Franka sim[0m

Reading configurations for Franka[40m[97mInitializing Franka sim[0m


[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka 

In [15]:
# Evaluation with 5-step sampling
evaluate(actor2, sample_steps=5)

Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
[40m[97mInitializing Franka sim[0mReading configurations for Franka


Reading configurations for Franka[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka sim[0mReading configurations for Franka

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka sim[0m
Reading configurations for FrankaReading configurations for Fran

The table below shows EDM’s performance. While the 1-step performance is significantly lower, increasing the sampling steps rapidly boosts the model’s performance.

|Completed Tasks|>=1|>=2|>=3|>=4|>=5|
|---|---|---|--|--|--|
|1-step|86.0|1.3|0.0|0.0|0.0|
|3-step|100.0|98.7|97.3|90.7|1.3|
|5-step|100.0|100.0|100.0|99.3|0.0|

## 4. Consistency Models

[Consistency Models](https://arxiv.org/abs/2303.01469) are a new class of generative models designed to predict $t=0$ samples from any point on the same ODE flow. As a result, Consistency Models naturally support one-step generation, and they can further improve sample quality through iterative refinement. Consistency Models offer two training methods: Consistency Training, which learns from scratch, and Consistency Distillation, which distills knowledge from a pre-trained EDM model.

Let’s evaluate Consistency Distillation on this task! Since we’ve already trained an EDM model, we can compare the performance.

In [16]:
from cleandiffuser.diffusion import ContinuousConsistencyModel, ContinuousEDM

actor2 = ContinuousEDM(
    nn_diffusion,
    nn_condition,
    ema_rate=0.999,
    x_max=torch.full((Ta, act_dim), 1.0),
    x_min=torch.full((Ta, act_dim), -1.0),
    optimizer_params={"lr": 5e-4},
)
actor2.load_state_dict(torch.load(save_path + "bc2-step=100000.ckpt", map_location="cpu")["state_dict"])
actor2.eval().cpu()

actor3 = ContinuousConsistencyModel(
    nn_diffusion,
    nn_condition,
    ema_rate=0.999,
    x_max=torch.full((Ta, act_dim), 1.0),
    x_min=torch.full((Ta, act_dim), -1.0),
    optimizer_params={"lr": 5e-4},
    edm=actor2,
    distillation_N=18,
)

  actor2.load_state_dict(torch.load(save_path + "bc2-step=100000.ckpt", map_location="cpu")["state_dict"])


In [17]:

callback = ModelCheckpoint(dirpath=save_path, filename="bc4-{step}", every_n_train_steps=20_000)

trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    max_steps=100_000,
    deterministic=True,
    log_every_n_steps=50,
    default_root_dir=save_path,
    callbacks=[callback],
)

trainer.fit(actor3, dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type          | Params | Mode 
----------------------------------------------------
0 | model     | ModuleDict    | 11.3 M | train
1 | model_ema | ModuleDict    | 11.3 M | eval 
2 | edm       | ContinuousEDM | 22.6 M | eval 
----------------------------------------------------
11.3 M    Trainable params
22.6 M    Non-trainable params
34.0 M    Total params
135.817   Total estimated model params size (MB)
79        Modules in train mode
160       Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=100000` reached.


In [18]:
# device for evaluation
device = "cuda:0"

# loading from checkpoint
actor3.load_state_dict(
    torch.load("results/tutorial4_try_more_diffusion_backbones/bc4-step=100000.ckpt", map_location=device)["state_dict"]
)
actor3.to(device).eval()

  torch.load("results/tutorial4_try_more_diffusion_backbones/bc4-step=100000.ckpt", map_location=device)["state_dict"]


ContinuousConsistencyModel(
  (model): ModuleDict(
    (diffusion): DiT1d(
      (map_noise): UntrainableFourierEmbedding()
      (x_proj): Linear(in_features=9, out_features=384, bias=True)
      (t_proj): Sequential(
        (0): Linear(in_features=256, out_features=384, bias=True)
        (1): SiLU()
        (2): Linear(in_features=384, out_features=384, bias=True)
      )
      (cond_proj): Sequential(
        (0): Linear(in_features=256, out_features=384, bias=True)
        (1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      )
      (blocks): ModuleList(
        (0-3): 4 x DiTBlock(
          (sa_norm): LayerNorm((384,), eps=1e-06, elementwise_affine=False)
          (sa_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=True)
          )
          (ffn_norm): LayerNorm((384,), eps=1e-06, elementwise_affine=False)
          (mlp): Sequential(
            (0): Linear(in_features=384, out_features=1

In [19]:
# Evaluation with 1-step sampling
evaluate(actor3, sample_steps=1)

Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0mReading configurations for FrankaReading configurations for Franka

[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for FrankaReading configurations for Franka


Reading configurations for Franka[40m[97mInitializing Franka sim[0m[40m[97mInitializing Franka sim[0m


[40m[97mInitializing Franka sim[0mReading configurations for Franka

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

Reading configurations for Franka[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka 

Consistency Models show significantly better performance in single-step generation compared to their EDM teacher model. When inference speed is critical for the task, Consistency Models offer an excellent solution.

|Completed Tasks|>=1|>=2|>=3|>=4|>=5|
|---|---|---|--|--|--|
|EDM-1-step|86.0|1.3|0.0|0.0|0.0|
|CM-1-step|100.0|92.0|78.0|59.3|0.7|