# Scalable Diffusion Models with Transformer (DiT)

This notebook samples from pre-trained DiT models. DiTs are class-conditional latent diffusion models trained on ImageNet that use transformers in place of U-Nets as the DDPM backbone. DiT outperforms all prior diffusion models on the ImageNet benchmarks.

[Project Page](https://www.wpeebles.com/DiT) | [HuggingFace Space](https://huggingface.co/spaces/wpeebles/DiT) | [Paper](http://arxiv.org/abs/2212.09748) | [GitHub](github.com/facebookresearch/DiT)

# 1. Setup

We recommend using GPUs (Runtime > Change runtime type > Hardware accelerator > GPU). Run this cell to clone the DiT GitHub repo and setup PyTorch. You only have to run this once.

In [2]:
!git clone https://github.com/facebookresearch/DiT.git
import DiT, os
os.chdir('DiT')
os.environ['PYTHONPATH'] = '/env/python:/content/DiT'
!pip install diffusers timm --upgrade
# DiT imports:
import torch
from torchvision.utils import save_image
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from download import find_model
from models import DiT_XL_2
from PIL import Image
from IPython.display import display
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("GPU not found. Using CPU instead.")

Cloning into 'DiT'...
remote: Enumerating objects: 102, done.[K
remote: Counting objects: 100% (82/82), done.[K
remote: Compressing objects: 100% (49/49), done.[K
remote: Total 102 (delta 57), reused 33 (delta 33), pack-reused 20 (from 1)[K
Receiving objects: 100% (102/102), 6.36 MiB | 24.22 MiB/s, done.
Resolving deltas: 100% (57/57), done.
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->timm)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvi

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

GPU not found. Using CPU instead.


# Download DiT-XL/2 Models

You can choose between a 512x512 model and a 256x256 model. You can swap-out the LDM VAE, too.

In [5]:
image_size = "512" #@param [256, 512]
vae_model = "stabilityai/sd-vae-ft-mse" #@param ["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"]
latent_size = int(image_size) // 8
# Load model:
model = DiT_XL_2(input_size=latent_size).to(device)
state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt")
model.load_state_dict(state_dict)
model.eval() # important!
#vae = AutoencoderKL.from_pretrained(vae_model).to(device)

import torch

def calculate_parameters(model):
    """输出模型参数量及详细计算过程"""
    total_params = sum(p.numel() for p in model.parameters())

    # 初始化各组件参数统计
    components = {
        "Embedding Layers": 0,
        "Transformer Blocks": 0,
        "Final Layers": 0,
        "Conditioning Modules": 0
    }

    # 详细层参数记录
    layer_details = []

    # 递归遍历所有模块
    for name, module in model.named_modules():
        params = sum(p.numel() for p in module.parameters(recurse=False))

        # 分类统计主要组件
        if 'x_embedder' in name or 'pos_embed' in name:
            components["Embedding Layers"] += params
        elif 'blocks' in name:
            components["Transformer Blocks"] += params
        elif 'final_layer' in name or 'output' in name:
            components["Final Layers"] += params
        elif 'adaLN' in name or 'mlp_t' in name:
            components["Conditioning Modules"] += params

        # 记录每个参数层的详细信息
        if params > 0 and list(module.children()) == []:  # 只处理叶子节点
            layer_info = {
                "name": name,
                "type": type(module).__name__,
                "params": params,
                "shape": []
            }

            # 记录具体参数形状
            for p_name, param in module.named_parameters(recurse=False):
                layer_info["shape"].append(f"{p_name}: {tuple(param.shape)}")

            layer_details.append(layer_info)

    # 打印统计结果
    print(f"\n{' 模型参数量分析 ':=^80}")
    print(f"总参数量: {total_params:,} (约 {total_params/1e6:.1f}M)")
    print("\n主要组件参数分布:")
    for k, v in components.items():
        print(f"• {k:<20} {v/1e6:>6.2f}M ({v/total_params:.1%})")

    # 打印详细参数计算过程
    print("\n详细层参数计算:")
    for detail in layer_details:
        print(f"\n▌ 层名称: {detail['name']}")
        print(f"  类型: {detail['type']}")
        print(f"  参数量: {detail['params']:,}")
        print("  参数形状:")
        for shape_info in detail['shape']:
            print(f"    └ {shape_info}")

# 执行计算
calculate_parameters(model)



Collecting torchstat
  Downloading torchstat-0.0.7-py3-none-any.whl.metadata (4.1 kB)
Downloading torchstat-0.0.7-py3-none-any.whl (11 kB)
Installing collected packages: torchstat
Successfully installed torchstat-0.0.7


TypeError: DiT.forward() missing 2 required positional arguments: 't' and 'y'

# 2. Sample from Pre-trained DiT Models

You can customize several sampling options. For the full list of ImageNet classes, [check out this](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a).

In [None]:
from torchstat import stat
# Set user inputs:
seed = 0 #@param {type:"number"}
torch.manual_seed(seed)
num_sampling_steps = 250 #@param {type:"slider", min:0, max:1000, step:1}
cfg_scale = 4 #@param {type:"slider", min:1, max:10, step:0.1}
class_labels = 207, 360, 387, 974, 88, 979, 417, 279 #@param {type:"raw"}
samples_per_row = 4 #@param {type:"number"}

# Create diffusion object:
diffusion = create_diffusion(str(num_sampling_steps))

# Create sampling noise:
n = len(class_labels)
z = torch.randn(n, 4, latent_size, latent_size, device=device)
y = torch.tensor(class_labels, device=device)

# Setup classifier-free guidance:
z = torch.cat([z, z], 0)
y_null = torch.tensor([1000] * n, device=device)
y = torch.cat([y, y_null], 0)
model_kwargs = dict(y=y, cfg_scale=cfg_scale)

# Sample images:
samples = diffusion.p_sample_loop(
    model.forward_with_cfg, z.shape, z, clip_denoised=False,
    model_kwargs=model_kwargs, progress=True, device=device
)
samples, _ = samples.chunk(2, dim=0)  # Remove null class samples
samples = vae.decode(samples / 0.18215).sample

# Save and display images:
save_image(samples, "sample.png", nrow=int(samples_per_row),
           normalize=True, value_range=(-1, 1))
samples = Image.open("sample.png")
display(samples)