# 1. Setup

We recommend using GPUs (Runtime > Change runtime type > Hardware accelerator > GPU). Run this cell to clone the DiT GitHub repo and setup PyTorch. You only have to run this once.

In [17]:

import torch
from torchvision.utils import save_image
from diffusion import create_diffusion
from download import find_model
from models import DiT_B_4
from PIL import Image
from IPython.display import display
from models import DiT_models
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("GPU not found. Using CPU instead.")

Load models

In [22]:
image_size = 32 #@param [256, 512]
# vae_model = "stabilityai/sd-vae-ft-ema" #@param ["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"]
latent_size = image_size
# Load model:
model = DiT_models["DiT-B/4"](
        input_size=latent_size,
        in_channels=12
    ).to(device)
checkpoint = torch.load("./results/004-DiT-B-4/checkpoints/000060000.pt", map_location=torch.device(device))
model.load_state_dict(checkpoint["model"])
model.eval() # important!

DiT(
  (x_embedder): PatchEmbed_3d(
    (proj): Conv3d(12, 768, kernel_size=(4, 4, 4), stride=(4, 4, 4))
    (norm): Identity()
  )
  (t_embedder): TimestepEmbedder(
    (mlp): Sequential(
      (0): Linear(in_features=256, out_features=768, bias=True)
      (1): SiLU()
      (2): Linear(in_features=768, out_features=768, bias=True)
    )
  )
  (blocks): ModuleList(
    (0-11): 12 x DiTBlock(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=False)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=False)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='tanh')
        (drop1): Dropout(p=0, inplace=False)
        (norm): Identity()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop2): Dropout(

# 2. Sample from Pre-trained DiT Models


In [24]:
# Set user inputs:
import numpy as np
seed = 0 #@param {type:"number"}
torch.manual_seed(seed)
num_sampling_steps =1000 #@param {type:"slider", min:0, max:1000, step:1}
cfg_scale = 4 #@param {type:"slider", min:1, max:10, step:0.1}
samples_per_row = 4 #@param {type:"number"}

# Create diffusion object:
diffusion = create_diffusion(str(num_sampling_steps))

# Create sampling noise:

z = torch.randn(5, 12, latent_size,latent_size, latent_size, device=device)

# Sample images:
samples = diffusion.p_sample_loop(
    model.forward, z.shape, z, clip_denoised=False, 
    progress=True, device=device
)
print(samples.shape)
np.save(f'./sample_npy.npy',np.array(samples.cpu().detach()))


  0%|          | 0/1000 [00:00<?, ?it/s]

torch.Size([5, 12, 32, 32, 32])
