# Scalable Diffusion Models with Transformer (DiT)

This notebook samples from pre-trained DiT models. DiTs are class-conditional latent diffusion models trained on ImageNet that use transformers in place of U-Nets as the DDPM backbone. DiT outperforms all prior diffusion models on the ImageNet benchmarks.

[Project Page](https://www.wpeebles.com/DiT) | [HuggingFace Space](https://huggingface.co/spaces/wpeebles/DiT) | [Paper](http://arxiv.org/abs/2212.09748) | [GitHub](github.com/facebookresearch/DiT)

# 1. Setup

We recommend using GPUs (Runtime > Change runtime type > Hardware accelerator > GPU). Run this cell to clone the DiT GitHub repo and setup PyTorch. You only have to run this once.

In [4]:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu


Looking in indexes: https://download.pytorch.org/whl/cpu
Collecting torchaudio
  Downloading https://download.pytorch.org/whl/cpu/torchaudio-2.4.1-cp38-cp38-macosx_11_0_arm64.whl (1.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Installing collected packages: torchaudio
Successfully installed torchaudio-2.4.1

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [1]:
import torch

print("Is MPS (Metal) available?", torch.backends.mps.is_available())
print("Is CUDA available?", torch.cuda.is_available())


Is MPS (Metal) available? True
Is CUDA available? False


In [None]:
!git clone https://github.com/facebookresearch/DiT.git
import DiT, os
os.chdir('DiT')
os.environ['PYTHONPATH'] = '/env/python:/content/DiT'
!pip install diffusers timm --upgrade
# DiT imports:
import torch
from torchvision.utils import save_image
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from download import find_model
from models import DiT_XL_2
from PIL import Image
from IPython.display import display
torch.set_grad_enabled(False)
device = "mps" if torch.backends.mps.is_available() else "cpu"  ##Since I am working on macbook I needed to change the GPU
if device == "cpu":
    print("GPU not found. Using CPU instead.")

fatal: destination path 'DiT' already exists and is not an empty directory.

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


  from .autonotebook import tqdm as notebook_tqdm


# Download DiT-XL/2 Models

You can choose between a 512x512 model and a 256x256 model. You can swap-out the LDM VAE, too.

In [10]:
# Set image parameters
image_size = 256
vae_model = "stabilityai/sd-vae-ft-ema"
latent_size = image_size // 8

# Load model
model = DiT_XL_2(input_size=latent_size).to(device)
state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt")
model.load_state_dict(state_dict)
model.eval()  # Important for inference
vae = AutoencoderKL.from_pretrained(vae_model).to(device)

  model = torch.load(local_path, map_location=lambda storage, loc: storage)
Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.


# 2. Sample from Pre-trained DiT Models

You can customize several sampling options. For the full list of ImageNet classes, [check out this](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a).

In [11]:
# Set user inputs
seed = 0
torch.manual_seed(seed)
num_sampling_steps = 250
cfg_scale = 4
class_labels = torch.tensor([207, 360, 387, 974, 88, 979, 417, 279], dtype=torch.int64, device=device)
samples_per_row = 4

# Create diffusion object
diffusion = create_diffusion(str(num_sampling_steps))

# Create sampling noise
n = len(class_labels)
z = torch.randn(n, 4, latent_size, latent_size, dtype=torch.float32, device=device)
y = class_labels.to(device)  # Keep labels as int64 for embeddings

# Setup classifier-free guidance
z = torch.cat([z, z], 0)
y_null = torch.full((n,), 1000, dtype=torch.int64, device=device)  # Null class labels
y = torch.cat([y, y_null], 0)

# Model input dictionary
model_kwargs = dict(y=y, cfg_scale=cfg_scale)

# Sample images
samples = diffusion.p_sample_loop(
    model.forward_with_cfg, z.shape, z, clip_denoised=False, 
    model_kwargs=model_kwargs, progress=True, device=device
)
samples, _ = samples.chunk(2, dim=0)  # Remove null class samples
samples = vae.decode(samples / 0.18215).sample

# Save and display images
save_image(samples, "sample.png", nrow=int(samples_per_row), normalize=True, value_range=(-1, 1))
samples = Image.open("sample.png")
display(samples)

  0%|          | 0/250 [00:33<?, ?it/s]


TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.