# **Tutorial 4. Try More Diffusion Backbones**

## 1. Introduction
想尝试更多的Backbone models 例如EDM，Recitified Flow，Consistency Models？让我们在这个tutorial中完成它！为了简单起见，让我们回到Tutorial 1中的RelayKitchen任务。我们会用熟悉的方式set up环境，数据集，模型，训练和评估。不同地是，这次我们会提高Diffusion Model的生成难度，要求模型生成a sequence of actions of length `Ta` 而不仅仅是一个action。 同时我们会尝试不同的backbone models看看他们的差异。

首先，让我们用熟悉的方式set up数据集和神经网络吧！

In [1]:
import pytorch_lightning as L
import torch
from pytorch_lightning.callbacks import ModelCheckpoint

from cleandiffuser.dataset.kitchen_dataset import KitchenDataset
from cleandiffuser.nn_condition import PearceObsCondition
from cleandiffuser.nn_diffusion import DiT1d


class BC_Wrapper(torch.utils.data.Dataset):
    def __init__(self, dataset: KitchenDataset, To: int, Ta: int):
        self.dataset = dataset
        self.To, self.Ta = To, Ta

    def __len__(self):
        return len(self.dataset)

    def __getattr__(self, name):
        return getattr(self.dataset, name)

    def __getitem__(self, idx):
        batch = self.dataset[idx]
        return {"x0": batch["action"][To - 1 :], "condition_cfg": batch["state"][:To]}


L.seed_everything(0, workers=True)
save_path = "../results/tutorial4_try_more_diffusion_backbones/"

To = 2
Ta = 12

dataset = KitchenDataset("../dev/kitchen", horizon=To + Ta - 1, pad_before=To - 1, pad_after=Ta - 1, abs_action=True)
obs_dim, act_dim = dataset.obs_dim, dataset.act_dim

nn_diffusion = DiT1d(
    x_dim=act_dim, emb_dim=128 * To, d_model=384, n_heads=12, depth=4, timestep_emb_type="untrainable_fourier"
)
nn_condition = PearceObsCondition(obs_dim=obs_dim, emb_dim=128, flatten=True, dropout=0)

dataloader = torch.utils.data.DataLoader(
    BC_Wrapper(dataset, To, Ta), batch_size=256, shuffle=True, num_workers=4, persistent_workers=True
)

Seed set to 0


Abs action dataset found. Loading...


你可能注意到了这次的Dataset使用了特别的horizon和paddings。这是因为我们需要根据2帧observations(1个history，1个current)生成一个长度为12的action序列。 因此数据集中的每个样本都是一个长度为13的序列。

$$
a_{0},~\bm a_{1},~a_{2},~a_{3}, \ldots, a_{12},~a_{13} \\
o_{0},~\bm o_{1},~o_{2},~o_{3}, \ldots, o_{12},~o_{13}
$$
如上图所示，加粗的部分为current variables。为了生成序列数据，我们使用了DiT1d作为diffusion model的神经网络backbone（我们曾在Tutorial 2中使用它生成state-action sequences），并且我们使用PearceObs Condition来编码多帧的observations。

In [2]:
import gym
import numpy as np

from cleandiffuser.diffusion import DiffusionModel
from cleandiffuser.env import kitchen


def evaluate(actor: DiffusionModel, sample_steps: int, **kwargs):
    env = gym.vector.make("kitchen-all-v0", 50, use_abs_action=True)
    normalizers = dataset.get_normalizer()
    prior = torch.zeros((50, Ta, act_dim))
    avg_sr = []

    for _ in range(3):
        obs, condition, all_done, ep_rew, t = env.reset(), None, False, 0, 0

        while t < 280:
            obs = torch.tensor(normalizers["state"].normalize(obs), device=actor.device, dtype=torch.float32)
            if condition is None:
                condition = obs.unsqueeze(1).repeat(1, 2, 1)
            else:
                condition[:, 1] = obs

            act, log = actor.sample(prior, sample_steps=sample_steps, condition_cfg=condition, w_cfg=1.0, **kwargs)
            act = normalizers["action"].unnormalize(act.cpu().numpy())

            for i in range(4):
                obs, rew, done, _ = env.step(act[:, i])

                all_done = np.logical_or(all_done, done)
                ep_rew += rew
                t += 1

                if all_done.all():
                    break

                if i == 2:
                    condition[:, 0] = torch.tensor(
                        normalizers["state"].normalize(obs), device=actor.device, dtype=torch.float32
                    )

            print(f"[t={t}] ep_rew={ep_rew}")

        success_rate = np.zeros(5)
        for i in range(5):
            success_rate[i] = (ep_rew > i).sum() / 50

        avg_sr.append(success_rate)

    print(np.mean(avg_sr, axis=0))

    env.close()

## 2. Rectified Flow

[Rectified Flow](https://arxiv.org/abs/2209.03003) 的特点在于其straight的ODE flow，这个特性让它在very few sampling steps下也能生成高质量的samples，并且蒸馏效果也更好。Moreover，通过不断进行Reflow procedure，Rectified flow会越来越straight，不断提高模型性能。让我们开始训练吧！

In [4]:
from cleandiffuser.diffusion import ContinuousRectifiedFlow

actor1 = ContinuousRectifiedFlow(
    nn_diffusion,
    nn_condition,
    ema_rate=0.999,
    x_max=torch.full((Ta, act_dim), 1.0),
    x_min=torch.full((Ta, act_dim), -1.0),
    optimizer_params={"lr": 5e-4},
)

callback = ModelCheckpoint(dirpath=save_path, filename="bc1-{step}", every_n_train_steps=10_000)

trainer = L.Trainer(
    accelerator="gpu",
    devices=[0, 1, 2, 3],
    max_steps=100_000,
    deterministic=True,
    log_every_n_steps=50,
    default_root_dir=save_path,
    callbacks=[callback],
)

trainer.fit(actor1, dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize the

Epoch 740:  74%|███████▍  | 100/135 [00:07<00:02, 12.60it/s, v_num=10, diffusion_loss=0.00685]

`Trainer.fit` stopped: `max_steps=100000` reached.


Epoch 740:  74%|███████▍  | 100/135 [00:07<00:02, 12.60it/s, v_num=10, diffusion_loss=0.00685]


In [12]:
# Evaluation with 1-step sampling
evaluate(actor1, sample_steps=1)

Reading configurations for Franka
[40m[97mInitializing Franka sim[0m


  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka

Reading configurations for Franka[40m[97mInitializing Franka sim[0m
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for FrankaReading configurations for Franka
[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka 

In [13]:
# Evaluation with 3-step sampling
evaluate(actor1, sample_steps=3)

Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka 

In [14]:
# Evaluation with 5-step sampling
evaluate(actor1, sample_steps=5)

Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m

Reading configurations for Franka[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for FrankaReading configurations for Franka

[40m[97mInitializing Franka sim[0m[40m[97mInitializing Franka sim[0m

Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka 

|Completed Tasks|>=1|>=2|>=3|>=4|>=5|
|---|---|---|--|--|--|
|1-step|99.3|88.7|70.7|58.0|0.0|
|3-step|100.0|99.3|96.7|90.0|0.0|
|5-step|100.0|100.0|99.3|96.0|0.7|

## 3. EDM

[EDM](https://arxiv.org/abs/2206.00364) 证明了noise schedule和time schedule的等价，并且设计了一系列Diffusion model的最佳设计例如其独特的preconditioning等。让我们看看它在这个任务上的表现吧！

In [3]:
from cleandiffuser.diffusion import ContinuousEDM

actor2 = ContinuousEDM(
    nn_diffusion,
    nn_condition,
    ema_rate=0.999,
    x_max=torch.full((Ta, act_dim), 1.0),
    x_min=torch.full((Ta, act_dim), -1.0),
    optimizer_params={"lr": 5e-4},
)

callback = ModelCheckpoint(dirpath=save_path, filename="bc2-{step}", every_n_train_steps=10_000)

trainer = L.Trainer(
    accelerator="gpu",
    devices=[0, 1, 2, 3],
    max_steps=100_000,
    deterministic=True,
    log_every_n_steps=50,
    default_root_dir=save_path,
    callbacks=[callback],
)

trainer.fit(actor2, dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for perform

Epoch 740:  74%|███████▍  | 100/135 [00:08<00:02, 11.89it/s, v_num=12, diffusion_loss=0.00803]

`Trainer.fit` stopped: `max_steps=100000` reached.


Epoch 740:  74%|███████▍  | 100/135 [00:08<00:02, 11.89it/s, v_num=12, diffusion_loss=0.00803]


In [6]:
# Evaluation with 1-step sampling
evaluate(actor2, sample_steps=1)

/home/dzb/miniforge3/envs/cleandiffuser/lib/python3.9/site-packages/glfw/__init__.py:914: GLFWError: (65544) b'X11: The DISPLAY environment variable is missing'


Reading configurations for Franka
[40m[97mInitializing Franka sim[0m


  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
  if isinstance(observation, collections.Mapping):


Reading configurations for Franka
[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka sim[0mReading configurations for Franka
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka[40m[97mInitializing Franka sim[0m

Reading configurations for Franka
Reading configurations for Franka[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka 

In [7]:
# Evaluation with 3-step sampling
evaluate(actor2, sample_steps=3)

Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for FrankaReading configurations for Franka

[40m[97mInitializing Franka sim[0m
[40m[97mInitializing Franka 

In [8]:
# Evaluation with 5-step sampling
evaluate(actor2, sample_steps=5)

Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m

Reading configurations for Franka[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka

[40m[97mInitializing Franka sim[0mReading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka 

## 4. Consistency Models

[Consistency Models](https://arxiv.org/abs/2303.01469) 是一类新的生成模型，它通过设计一个model来根据同一条ODE flow上任意时间的样本预测$t=0$时刻的样本，因此Consistency Models天生就能够one-step generation，同时它也能勾通过不断迭代的方式提高样本质量。Consistency Models有两种训练方式，一种是learn from scatch的Consistency Training，另一种从预训练的EDM上蒸馏知识的Consistency Distillation。让我们看看它在这个任务上的表现吧！首先从Consistency Training开始。

In [14]:
from cleandiffuser.diffusion import ContinuousConsistencyModel

actor3 = ContinuousConsistencyModel(
    nn_diffusion,
    nn_condition,
    ema_rate=0.999,
    x_max=torch.full((Ta, act_dim), 1.0),
    x_min=torch.full((Ta, act_dim), -1.0),
    optimizer_params={"lr": 5e-4},
    curriculum_cycle=100_000,
)

callback = ModelCheckpoint(dirpath=save_path, filename="bc3-{step}", every_n_train_steps=10_000)

trainer = L.Trainer(
    accelerator="gpu",
    devices=[0, 1, 2, 3],
    max_steps=100_000,
    deterministic=True,
    log_every_n_steps=50,
    default_root_dir=save_path,
    callbacks=[callback],
)

trainer.fit(actor3, dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize the

Epoch 15:  87%|████████▋ | 55/63 [00:04<00:00, 11.04it/s, v_num=4, unweighted_loss=0.0494, diffusion_loss=0.193]

`Trainer.fit` stopped: `max_steps=1000` reached.


Epoch 15:  87%|████████▋ | 55/63 [00:04<00:00, 11.03it/s, v_num=4, unweighted_loss=0.0494, diffusion_loss=0.193]


In [15]:
evaluate(actor3)

Reading configurations for Franka
[40m[97mInitializing Franka sim[0m


  logger.warn(f"Box bound precision lowered by casting to {self.dtype}")


Reading configurations for Franka
[40m[97mInitializing Franka sim[0m

Reading configurations for Franka[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0mReading configurations for Franka

[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
Reading configurations for Franka[40m[97mInitializing Franka sim[0m

[40m[97mInitializing Franka sim[0m
Reading configurations for Franka
[40m[97mInitializing Franka sim[0mReading configurations for Franka
Reading configurations for Fran

再试试Consistency Distillation吧！正好我们之前已经训练了一个EDM。

In [9]:
from cleandiffuser.diffusion import ContinuousConsistencyModel, ContinuousEDM

actor2 = ContinuousEDM(
    nn_diffusion,
    nn_condition,
    ema_rate=0.999,
    x_max=torch.full((Ta, act_dim), 1.0),
    x_min=torch.full((Ta, act_dim), -1.0),
    optimizer_params={"lr": 5e-4},
)
actor2.load_state_dict(torch.load(save_path + "bc2-step=100000.ckpt", map_location="cpu")["state_dict"])
actor2.eval().cpu()

actor4 = ContinuousConsistencyModel(
    nn_diffusion,
    nn_condition,
    ema_rate=0.999,
    x_max=torch.full((Ta, act_dim), 1.0),
    x_min=torch.full((Ta, act_dim), -1.0),
    optimizer_params={"lr": 5e-4},
    distillation_N=18,
)

callback = ModelCheckpoint(dirpath=save_path, filename="bc4-{step}", every_n_train_steps=10_000)

trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    max_steps=100_000,
    deterministic=True,
    log_every_n_steps=50,
    default_root_dir=save_path,
    callbacks=[callback],
)

trainer.fit(actor4, dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/dzb/miniforge3/envs/cleandiffuser/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /home/dzb/github/CleanDiffuser/results/tutorial4_try_more_diffusion_backbones exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5]

  | Name         | Type       | Params | Mode 
----------------------------------------------------
0 | model        | ModuleDict | 11.2 M | train
1 | model_ema    | ModuleDict | 11.2 M | eval 
  | other params | n/a      

Epoch 3:  97%|█████████▋| 244/252 [00:16<00:00, 15.07it/s, v_num=5, unweighted_loss=0.0494, diffusion_loss=0.200]

`Trainer.fit` stopped: `max_steps=1000` reached.


Epoch 3:  97%|█████████▋| 244/252 [00:16<00:00, 15.07it/s, v_num=5, unweighted_loss=0.0494, diffusion_loss=0.200]
