# Scalable Diffusion Models with Transformer (DiT)

This notebook samples from pre-trained DiT models. DiTs are class-conditional latent diffusion models trained on ImageNet that use transformers in place of U-Nets as the DDPM backbone. DiT outperforms all prior diffusion models on the ImageNet benchmarks.

[Project Page](https://www.wpeebles.com/DiT) | [HuggingFace Space](https://huggingface.co/spaces/wpeebles/DiT) | [Paper](http://arxiv.org/abs/2212.09748) | [GitHub](github.com/facebookresearch/DiT)

# 1. Setup / Установка нужных библиотек

We recommend using GPUs (Runtime > Change runtime type > Hardware accelerator > GPU)

Рекомендуется использовать GPU, в противном случае все будет ужасно медленным.

In [None]:
# the requirements file contains a list of libraries that I use on my home PC
# в файле requirements представлен список библиотек, который используя я у себя на домашнем пк 

# !pip install -r requirements.txt

In [6]:
#for collab 

!pip install diffusers==0.18.2
!pip install accelerate
!pip install torchinfo

Collecting torchinfo
  Obtaining dependency information for torchinfo from https://files.pythonhosted.org/packages/72/25/973bd6128381951b23cdcd8a9870c6dcfc5606cb864df8eabd82e529f9c1/torchinfo-1.8.0-py3-none-any.whl.metadata
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0



[notice] A new release of pip is available: 23.2 -> 23.2.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [9]:
cd ../

e:\Programming\diffusion_transformer


In [10]:
#most of these libraries are needed only in case of experiments in a laptop, all important steps are done by accessing the .py files through the terminal, where the necessary imports are implemented
#большинство из этих библиотек нужны только в случае экспериментов в ноутбуке, все важные шаги осуществляются путем обращения к .py файлам через терминал, где нужные импорты реалзиованы

import DiT, os
os.chdir(r'DiT')

import torch
from torchvision.utils import save_image
import torchinfo
from diffusion import create_diffusion 
from diffusers.models import AutoencoderKL
from download import find_model  
from models import DiT_XL_2 
import shutil 
from PIL import Image
from IPython.display import display
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("GPU not found. Using CPU instead.")

In [11]:
#show model arch. how much params

image_size = 256
latent_size = int(image_size) // 8
model = DiT_XL_2(input_size=latent_size)

torchinfo.summary(model)

Layer (type:depth-idx)                   Param #
DiT                                      294,912
├─PatchEmbed: 1-1                        --
│    └─Conv2d: 2-1                       19,584
│    └─Identity: 2-2                     --
├─TimestepEmbedder: 1-2                  --
│    └─Sequential: 2-3                   --
│    │    └─Linear: 3-1                  296,064
│    │    └─SiLU: 3-2                    --
│    │    └─Linear: 3-3                  1,328,256
├─LabelEmbedder: 1-3                     --
│    └─Embedding: 2-4                    1,153,152
├─ModuleList: 1-4                        --
│    └─DiTBlock: 2-5                     --
│    │    └─LayerNorm: 3-4               --
│    │    └─Attention: 3-5               5,313,024
│    │    └─LayerNorm: 3-6               --
│    │    └─Mlp: 3-7                     10,622,592
│    │    └─Sequential: 3-8              7,969,536
│    └─DiTBlock: 2-6                     --
│    │    └─LayerNorm: 3-9               --
│    │    └─Attention

# 2. Loading the dataset. Here is an example with an already existing drone dataset from /
# Загрузка датасета. Здесь представлен пример с уже имеющимся датасетом дронов из

https://www.kaggle.com/datasets/dasmehdixtr/drone-dataset-uav

In [None]:
#saving only pictures in a separate folder, when creating a folder, you need to remember about the hierarchy, each class has a separate folder
#сохранение в отдельную папку только картинок, при создании папки нужно помнить об иерархии, каждому классу отдельная папка 
source = '...drone_dataset_yolo\dataset_txt'
for i in os.listdir(source):
    if 'jpg' in i:
        shutil.copy(f'{source}\{i}', '...\drones_for_dit\drones')

# 3. Extracting features from dataset /
# Извлечение фич из данных

Создание фич с помощью энкодера из AutoencoderKL

In [12]:
#vae arch
vae_model = "stabilityai/sd-vae-ft-ema"
vae = AutoencoderKL.from_pretrained(vae_model) #.to(device)

In [13]:
# this part of the model is not trained, first the encoder is taken to extract the features, and at the end the decoder is taken to get the pictures again
# это часть модели не обучается, сначалаб берется жнкодер для извленчение фич, а в конце берется декодер, чтобы снова получить картинки 
torchinfo.summary(vae)

Layer (type:depth-idx)                             Param #
AutoencoderKL                                      --
├─Encoder: 1-1                                     --
│    └─Conv2d: 2-1                                 3,584
│    └─ModuleList: 2-2                             --
│    │    └─DownEncoderBlock2D: 3-1                738,944
│    │    └─DownEncoderBlock2D: 3-2                2,690,304
│    │    └─DownEncoderBlock2D: 3-3                10,754,560
│    │    └─DownEncoderBlock2D: 3-4                9,443,328
│    └─UNetMidBlock2D: 2-3                         --
│    │    └─ModuleList: 3-5                        1,051,648
│    │    └─ModuleList: 3-6                        9,443,328
│    └─GroupNorm: 2-4                              1,024
│    └─SiLU: 2-5                                   --
│    └─Conv2d: 2-6                                 36,872
├─Decoder: 1-2                                     --
│    └─Conv2d: 2-7                                 18,944
│    └─ModuleList: 2-8

In [None]:
!torchrun \
--nnodes=1 \
--nproc_per_node=1 extract_features.py \
--data-path ...\drones_for_dit \
--features-path ...\dit_features2 \
--image-size 256 \
--global-batch-size 256 \
--global-seed 0 \
--vae ema \
--num-workers 0

# 4.Training the model on the received features from the encoder
# Тренировка модели на полученных фичах из энкодера

In [None]:
!accelerate launch \
--mixed_precision fp16 train.py \
--feature-path ...\dit_features \
--results-dir ...\dit_weights \
--model DiT-S/2 \
--image-size 256 \
--num-classes 1 \
--epochs=10 \
--global-batch-size=16 \
--global-seed 0 \
--num-workers 0 \
--log-every=25 \
--ckpt-every=100

# 5.Sampling

In [None]:
#sampling several imgs

!python sample.py \
--model DiT-XL/2 \
--vae ema \
--image-size=256 \
--num-classes=1000 \
--cfg-scale 8 \
--num-sampling-steps=256 \
--seed 0 \
--ckpt ...\DiT-XL-2-256x256.pt

In [None]:
#sampling many imgs, very useful if u have mote than one gpu

!torchrun sample_ddp.py \
--model DiT-XL/2 \
--vae ema \
--sample-dir /path/to/generated/data \
--per-proc-batch-size 32 \
--num-fid-sample 400 \
--image-size 265 \
--num-classes 1 \
--cfg-scale 7 \
--num-sampling-steps 250 \
--global-seed 42 \
--ckpt /path/to/your/trained/models/weights
