In [1]:
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
from diffusers import AutoencoderKL
from diffusers.image_processor import VaeImageProcessor
import torch



In [2]:
model_id = "CompVis/stable-diffusion-v1-4"
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)

vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)

In [None]:
model = Unet(
    dim = 64,
    dim_mults = (1, 2, 4, 8),
    flash_attn = True,
    channels = 4
).to(torch.float32)

diffusion = GaussianDiffusion(
    model,
    image_size = 128,
    timesteps = 400,           # number of steps
    sampling_timesteps = 250    # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
)

trainer = Trainer(
    diffusion,
    './data/train',
    train_batch_size = 16,
    train_lr = 8e-5,
    train_num_steps = 200000,         # total training steps
    gradient_accumulate_every = 2,    # gradient accumulation steps
    ema_decay = 0.995,                # exponential moving average decay
    amp = False,                       # turn on mixed precision
    calculate_fid = False,              # whether to calculate fid during training
    save_best_and_latest_only=False,
    crop_size = 128,
    vae_scale_factor = 1,
    vae = vae
)

trainer.train()

  0%|          | 0/200000 [00:00<?, ?it/s]

  self.gen = func(*args, **kwds)
