In [None]:
#Level 1
# !pip install --upgrade diffusers
# !pip install --upgrade transformers
# !pip install --upgrade tokenizers
# !pip install --upgrade datasets
from diffusers import StableDiffusionPipeline

In [None]:
access_token = "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"
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=access_token).to("cuda") #use revision='fp16' and torch_dtype=torch.float16 for low memory

In [None]:
prompt = "a photo of a horse riding an astronaut on Mars"
image = pipe(prompt).images[0]
image.save("./images/horse_rides_astronaut.png")

In [None]:
# level 2
from torch import autocast
from PIL import Image, ImageDraw

In [None]:
def dummy(images, **kwargs):
    return images, False

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols
    
    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

In [None]:
pipe.safety_checker = dummy
n_images = 3
prompts = [
    "masterpiece, best quality, a photo of a horse riding an astronaut, trending on artstation, photorealistic, qhd, rtx on, 8k"
] * n_images
with autocast("cuda"):
    images = pipe(prompts, num_inference_steps=28).images
image_grid(images, rows=1, cols=3)

In [18]:
#level 3
from diffusers import UNet2DConditionModel, StableDiffusionPipeline, AutoencoderKL, LMSDiscreteScheduler
from transformers import CLIPTextModel, CLIPTokenizer
import torch
from torch.nn import functional as F
from torch import autocast
import numpy as np

from tqdm.auto import tqdm

In [None]:
# vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder='vae', use_auth_token=access_token)
# vae.save_pretrained('./models/vae')
vae = AutoencoderKL.from_pretrained('./models/vae/').to("cuda")

# tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
# tokenizer.save_pretrained('./tokenizers/')
tokenizer = CLIPTokenizer.from_pretrained('./tokenizers/')
# text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda")
# text_encoder.save_pretrained('./models/text_encoder')
text_encoder = CLIPTextModel.from_pretrained('./models/text_encoder/').to("cuda")

# model = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder='unet', use_auth_token=access_token).to("cuda")
# model.save_pretrained('./models/sd_v1-5')
model = UNet2DConditionModel.from_pretrained('./models/sd_v1-5/').to("cuda")

scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule='scaled_linear', num_train_timesteps=1000)

In [None]:
model

In [None]:
model.config

In [None]:
def get_text_embeds(prompt):
  # Tokenize text and get embeddings
  text_input = tokenizer(
      prompt, padding='max_length', max_length=tokenizer.model_max_length,
      truncation=True, return_tensors='pt')
  with torch.no_grad():
    text_embeddings = text_encoder(text_input.input_ids.to("cuda"))[0]

  # Do the same for unconditional embeddings
  uncond_input = tokenizer(
      [''] * len(prompt), padding='max_length',
      max_length=tokenizer.model_max_length, return_tensors='pt')
  with torch.no_grad():
    uncond_embeddings = text_encoder(uncond_input.input_ids.to("cuda"))[0]

  # Cat for final embeddings
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
  return text_embeddings

test_embeds = get_text_embeds(['an amazingly cool anime character'])
print(test_embeds)
print(test_embeds.shape)

In [None]:
def produce_latents(text_embeddings, height=512, width=512,
                    num_inference_steps=28, guidance_scale=11, latents=None,
                    return_all_latents=False):
  if latents is None:
    latents = torch.randn((text_embeddings.shape[0] // 2, model.in_channels, \
                           height // 8, width // 8))
  latents = latents.to("cuda")

  scheduler.set_timesteps(num_inference_steps)
  latents = latents * scheduler.sigmas[0]

  latent_hist = [latents]
  with autocast('cuda'):
    for i, t in tqdm(enumerate(scheduler.timesteps)):
      # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
      latent_model_input = torch.cat([latents] * 2)
      sigma = scheduler.sigmas[i]
      latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)

      # predict the noise residual
      with torch.no_grad():
        noise_pred = model(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']

      # perform guidance
      noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
      noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

      # compute the previous noisy sample x_t -> x_t-1
      latents = scheduler.step(noise_pred, i, latents)['prev_sample']
      latent_hist.append(latents)
  
  if not return_all_latents:
    return latents

  all_latents = torch.cat(latent_hist, dim=0)
  return all_latents

test_latents = produce_latents(test_embeds)
print(test_latents)
print(test_latents.shape)

In [None]:
def decode_img_latents(latents):
  latents = 1 / 0.18215 * latents

  with torch.no_grad():
    imgs = vae.decode(latents)['sample']

  imgs = (imgs / 2 + 0.5).clamp(0, 1)
  imgs = imgs.detach().cpu().permute(0, 2, 3, 1)
  imgs = (imgs + 1.0) * 127.5
  imgs = imgs.numpy().astype(np.uint8)
  pil_images = [Image.fromarray(image) for image in imgs]
  return pil_images

imgs = decode_img_latents(test_latents)
imgs[0]

In [None]:
def prompt_to_img(prompts, height=512, width=512, num_inference_steps=28, guidance_scale=11, latents=None):
    if isinstance(prompts, str):
        prompts = [prompts]
        
    #Prompts -> text embeddings
    text_embeds = get_text_embeds(prompts)
    
    #Text embeddings -> img latents
    latents = produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale)
    
    #Img latents -> imgs
    imgs = decode_img_latents(latents)
    
    return imgs

In [None]:
imgs = prompt_to_img(['Super cool fantasty knight, intricate armor, 8k']*4, 512, 512, 28, 11)

In [None]:
image_grid(imgs, rows=2, cols=2)

In [None]:
# level 3.5 - similar images and img2img
from diffusers import DDIMScheduler

def prompt_to_img(prompts, height=512, width=512, num_inference_steps=50,
                  guidance_scale=7.5, latents=None, return_all_latents=False,
                  batch_size=2):
  if isinstance(prompts, str):
    prompts = [prompts]

  # Prompts -> text embeds
  text_embeds = get_text_embeds(prompts)

  # Text embeds -> img latents
  latents = produce_latents(
      text_embeds, height=height, width=width, latents=latents,
      num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
      return_all_latents=return_all_latents)
  
  # Img latents -> imgs
  all_imgs = []
  for i in tqdm(range(0, len(latents), batch_size)):
    imgs = decode_img_latents(latents[i:i+batch_size])
    all_imgs.extend(imgs)

  return all_imgs

In [None]:
prompt = 'Steampunk airship bursting through the clouds, cyberpunk art'
latents = torch.randn((1, model.in_channels, 512 // 8, 512 // 8))
img = prompt_to_img(prompt, num_inference_steps=20, latents=latents)[0]
img

In [None]:
def perturb_latents(latents, scale=0.1):
  noise = torch.randn_like(latents)
  new_latents = (1 - scale) * latents + scale * noise
  return (new_latents - new_latents.mean()) / new_latents.std()

In [None]:
new_latents = perturb_latents(latents, 0.4)
img = prompt_to_img(prompt, num_inference_steps=20, latents=new_latents)[0]
img

In [None]:
prompt = 'Upright squid'
img = prompt_to_img(prompt, num_inference_steps=30)[0]
img

In [None]:
def encode_img_latents(imgs):
  if not isinstance(imgs, list):
    imgs = [imgs]

  img_arr = np.stack([np.array(img) for img in imgs], axis=0)
  img_arr = img_arr / 255.0
  img_arr = torch.from_numpy(img_arr).float().permute(0, 3, 1, 2)
  img_arr = 2 * (img_arr - 0.5)

  latent_dists = vae.encode(img_arr.to("cuda"))
  latent_samples = latent_dists.latent_dist.sample()
  latent_samples *= 0.18215

  return latent_samples

In [None]:
img_latents = encode_img_latents([img])
dec_img = decode_img_latents(img_latents)[0]
dec_img

In [None]:
# New scheduler for img-to-img
scheduler = DDIMScheduler(
    beta_start=0.00085, beta_end=0.012,
    beta_schedule='scaled_linear', num_train_timesteps=1000)

In [None]:
def produce_latents(text_embeddings, height=512, width=512,
                    num_inference_steps=50, guidance_scale=7.5, latents=None,
                    return_all_latents=False, start_step=10):
  if latents is None:
    latents = torch.randn((text_embeddings.shape[0] // 2, model.in_channels, \
                           height // 8, width // 8))
  latents = latents.to("cuda")

  scheduler.set_timesteps(num_inference_steps)
  if start_step > 0:
    start_timestep = scheduler.timesteps[start_step]
    start_timesteps = start_timestep.repeat(latents.shape[0]).long()

    noise = torch.randn_like(latents)
    latents = scheduler.add_noise(latents, noise, start_timesteps)

  latent_hist = [latents]
  with autocast('cuda'):
    for i, t in tqdm(enumerate(scheduler.timesteps[start_step:])):
      # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
      latent_model_input = torch.cat([latents] * 2)

      # predict the noise residual
      with torch.no_grad():
        noise_pred = model(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']

      # perform guidance
      noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
      noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

      # compute the previous noisy sample x_t -> x_t-1
      latents = scheduler.step(noise_pred, t, latents)['prev_sample']
      latent_hist.append(latents)
  
  if not return_all_latents:
    return latents

  all_latents = torch.cat(latent_hist, dim=0)
  return all_latents

def prompt_to_img(prompts, height=512, width=512, num_inference_steps=50,
                  guidance_scale=7.5, latents=None, return_all_latents=False,
                  batch_size=2, start_step=0):
  if isinstance(prompts, str):
    prompts = [prompts]

  # Prompts -> text embeds
  text_embeds = get_text_embeds(prompts)

  # Text embeds -> img latents
  latents = produce_latents(
      text_embeds, height=height, width=width, latents=latents,
      num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
      return_all_latents=return_all_latents, start_step=start_step)
  
  # Img latents -> imgs
  all_imgs = []
  for i in tqdm(range(0, len(latents), batch_size)):
    imgs = decode_img_latents(latents[i:i+batch_size])
    all_imgs.extend(imgs)

  return all_imgs

In [None]:
prompt = 'Squidward'
img = prompt_to_img(prompt, num_inference_steps=30, latents=img_latents,
                    start_step=20)[0]
img

In [None]:
# Level 4 - AUTOMATIC1111
#!git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui
#https://rentry.org/voldy

In [None]:
#Level 5 - Deforum, Vid2Vid, Textual Inversion Dreambooth, Negative Prompts, Fine-Tuning, etc.
from accelerate import Accelerator
from pathlib import Path
import os
from torch.utils.data import Dataset
import hashlib
import itertools
import random
from torchvision import transforms
import math
from diffusers.optimization import get_scheduler
from contextlib import nullcontext

In [None]:
#Textual Inversion Locally Add "perfect prompts"

accelerator = Accelerator(
    gradient_accumulation_steps=4,
    mixed_precision="no",
)

In [None]:
if accelerator.is_main_process:
    os.make_dirs("./textual_inversion_outputs/", exist_ok=True)
    
tokenizer = CLIPTokenizer.from_pretrained('./tokenizers/')

placeholder = "<pokemon-sprite>"
initializer = "sprite"
num_added_tokens = tokenizer.add_tokens(placeholder)
if num_added_tokens == 0:
    raise ValueError(
        f"The tokenizer already contains the token {placeholder}."
        "Please pass a different placeholder that isn't already in"
        "the tokenizer."
    )

token_ids = tokenizer.encode(initializer, add_special_tokens=False)
if len(token_ids) > 1:
    raise ValueError(
        "The initializer token must be a single token,"
        "try something shorter."
    )
initializer_token_id = token_ids[0]
placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder)

In [None]:
text_encoder = CLIPTextModel.from_pretrained('./models/text_encoder/')
vae = AutoencoderKL.from_pretrained('./models/vae/')
model = UNet2DConditionModel.from_pretrained('./models/sd_v1-5/')

In [None]:
text_encoder.resize_token_embeddings(len(tokenizer))

token_embeds = text_encoder.get_input_embeddings().weight.data
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]

In [None]:
def freeze_params(params):
    for param in params:
        param.requires_grad = False
        
freeze_params(vae.parameters())
freeze_params(model.parameters())

params_to_freeze = itertools.chain(
    text_encoder.text_model.encoder.parameters(),
    text_encoder.text_model.final_layer_norm.parameters(),
    text_encoder.text_model.embeddings.position_embedding.parameters()
)
freeze_params(params_to_freeze)

In [None]:
learning_rate = 5e-04 * 4 * 1 * accelerator.num_processes #lr * gradient accumulation steps * training_batch_size * number of processes

optimizer = torch.optim.AdamW(
    text_encoder.get_input_embeddings().parameters(), #Only optimize the embeddings
    lr=learning_rate,
    betas=(0.9, 0.999),
    weight_decay=1e-2,
    eps=1e-08
)

noise_scheduler = DDPMScheduler(
    beta_start=0.00085, 
    beta_end=0.012, 
    beta_schedule='scaled_linear', 
    num_train_timesteps=1000, 
    set_alpha_to_one=False, 
    skip_prk_steps=True, 
    steps_offset=1, 
    clip_sample=False
)

In [None]:
imagenet_style_templates_small = [
    "a painting in the style of {}",
    "a rendering in the style of {}",
    "a cropped painting in the style of {}",
    "the painting in the style of {}",
    "a clean painting in the style of {}",
    "a dirty painting in the style of {}",
    "a dark painting in the style of {}",
    "a picture in the style of {}",
    "a cool painting in the style of {}",
    "a close-up painting in the style of {}",
    "a bright painting in the style of {}",
    "a cropped painting in the style of {}",
    "a good painting in the style of {}",
    "a close-up painting in the style of {}",
    "a rendition in the style of {}",
    "a nice painting in the style of {}",
    "a small painting in the style of {}",
    "a weird painting in the style of {}",
    "a large painting in the style of {}",
]

class TextualInversionDataset(Dataset):
    def __init__(
        self,
        data_root,
        tokenizer,
        learnable_property="object",  # [object, style]
        size=512,
        repeats=100,
        interpolation="bicubic",
        flip_p=0.5,
        set="train",
        placeholder_token="*",
        center_crop=False,
    ):
        self.data_root = data_root
        self.tokenizer = tokenizer
        self.learnable_property = learnable_property
        self.size = size
        self.placeholder_token = placeholder_token
        self.center_crop = center_crop
        self.flip_p = flip_p

        self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]

        self.num_images = len(self.image_paths)
        self._length = self.num_images

        if set == "train":
            self._length = self.num_images * repeats

        self.interpolation = {
            "linear": PIL.Image.LINEAR,
            "bilinear": PIL.Image.BILINEAR,
            "bicubic": PIL.Image.BICUBIC,
            "lanczos": PIL.Image.LANCZOS,
        }[interpolation]

        self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
        self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)

    def __len__(self):
        return self._length

    def __getitem__(self, i):
        example = {}
        image = Image.open(self.image_paths[i % self.num_images])

        if not image.mode == "RGB":
            image = image.convert("RGB")

        placeholder_string = self.placeholder_token
        text = random.choice(self.templates).format(placeholder_string)

        example["input_ids"] = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids[0]

        # default to score-sde preprocessing
        img = np.array(image).astype(np.uint8)

        if self.center_crop:
            crop = min(img.shape[0], img.shape[1])
            h, w, = (
                img.shape[0],
                img.shape[1],
            )
            img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]

        image = Image.fromarray(img)
        image = image.resize((self.size, self.size), resample=self.interpolation)

        image = self.flip_transform(image)
        image = np.array(image).astype(np.uint8)
        image = (image / 127.5 - 1.0).astype(np.float32)

        example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
        return example

train_dataset = TextualInversionDataset(
        data_root="./training/",
        tokenizer=tokenizer,
        size=512,
        placeholder_token=placeholder,
        repeats=100,
        learnable_property="object",
        center_crop=False,
        set="train",
    )
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)

In [None]:
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / 4)

lr_scheduler = get_scheduler(
    "constant",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=3000 * 4,
)

text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    text_encoder, optimizer, train_dataloader, lr_scheduler
)

vae.to(accelerator.device)
model.to(accelerator.device)

vae.eval()
model.eval()

num_update_steps_per_epoch = math.ceil(len(train_dataloader) / 4)
num_train_epochs = math.ceil(3000 / num_update_steps_per_epoch)

# Train!
total_batch_size = 1 * accelerator.num_processes * 4

In [None]:
progress_bar = tqdm(range(3000), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
global_step = 0

for epoch in range(num_train_epochs):
    text_encoder.train()
    for step, batch in enumerate(train_dataloader):
        with accelerator.accumulate(text_encoder):
            
            latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
            latents = latents * 0.18215

            noise = torch.randn(latents.shape).to(latents.device)
            bsz = latents.shape[0]
            
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
            ).long()

            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Get the text embedding for conditioning
            encoder_hidden_states = text_encoder(batch["input_ids"])[0]

            # Predict the noise residual
            noise_pred = model(noisy_latents, timesteps, encoder_hidden_states).sample

            loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
            accelerator.backward(loss)

            if accelerator.num_processes > 1:
                grads = text_encoder.module.get_input_embeddings().weight.grad
            else:
                grads = text_encoder.get_input_embeddings().weight.grad

            index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
            grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        # Checks if the accelerator has performed an optimization step behind the scenes
        if accelerator.sync_gradients:
            progress_bar.update(1)
            global_step += 1
            if global_step % 500 == 0:
                learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
                learned_embeds_dict = {placeholder: learned_embeds.detach().cpu()}
                torch.save(learned_embeds_dict, os.path.join("./textual_inversion_outputs/", "learned_embeds.bin"))

        logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
        progress_bar.set_postfix(**logs)

        if global_step >= 3000:
            break

    accelerator.wait_for_everyone()

# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
    pipeline = StableDiffusionPipeline(
        text_encoder=accelerator.unwrap_model(text_encoder),
        vae=vae,
        unet=model,
        tokenizer=tokenizer,
        scheduler=PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule='scaled_linear', num_train_timesteps=1000, set_alpha_to_one=False, skip_prk_steps=True, steps_offset=1, clip_sample=False),
        safety_checker=None,
        feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
    )
    pipeline.save_pretrained("./textual_inversion_outputs/")
    
    learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
    learned_embeds_dict = {placeholder: learned_embeds.detach().cpu()}
    torch.save(learned_embeds_dict, os.path.join("./textual_inversion_outputs/", "learned_embeds.bin"))

accelerator.end_training()

In [None]:
    """Compare this script to this:
    !git clone https://github.com/justinpinkney/stable-diffusion.git
    %cd stable-diffusion
    !pip install --upgrade pip
    !pip install -r requirements.txt

    !pip install --upgrade keras # on lambda stack we need to upgrade keras
    !pip uninstall -y torchtext # on colab we need to remove torchtext
    
    from datasets import load_dataset
    ds = load_dataset("lambdalabs/pokemon-blip-captions", split="train")
    sample = ds[0]
    display(sample["image"].resize((256, 256)))
    print(sample["text"])
    
    BATCH_SIZE = 4
    N_GPUS = 1
    ACCUMULATE_BATCHES = 1

    gpu_list = ",".join((str(x) for x in range(N_GPUS))) + ","
    print(f"Using GPUs: {gpu_list}")
    
    !(python main.py \
        -t \
        --base configs/stable-diffusion/pokemon.yaml \
        --gpus "$gpu_list" \
        --scale_lr False \
        --num_nodes 1 \
        --check_val_every_n_epoch 10 \
        --finetune_from "$ckpt_path" \
        data.params.batch_size="$BATCH_SIZE" \
        lightning.trainer.accumulate_grad_batches="$ACCUMULATE_BATCHES" \
        data.params.validation.params.n_gpus="$NUM_GPUS" \
    )
    !(python scripts/txt2img.py \
        --prompt 'robotic cat with wings' \
        --outdir 'outputs/generated_pokemon' \
        --H 512 --W 512 \
        --n_samples 4 \
        --config 'configs/stable-diffusion/pokemon.yaml' \
        --ckpt 'path/to/your/checkpoint')
    """

In [3]:
#Dreambooth Locally Fine-Tune stable diffusion
accelerator = Accelerator(
    gradient_accumulation_steps=1,
    mixed_precision="fp16",
)

Defaulting to user installation because normal site-packages is not writeable
Collecting pip
  Using cached pip-22.3.1-py3-none-any.whl (2.1 MB)
Installing collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 22.3
    Uninstalling pip-22.3:
      Successfully uninstalled pip-22.3
Successfully installed pip-22.3.1


In [4]:
concepts_list = [
    {
        "instance_prompt": "photo of a zwx pokemon",
        "class_prompt": "photo of a pokemon",
        "instance_data_dir": "training/pokemon",
        "class_data_dir": "classes/pokemon"
    }
]

In [None]:
class_image_dir = Path(concepts_list[0]["class_data_dir"])
curr_class_images = len(list(class_image_dir.iterdir()))

In [None]:
##Only run this if you have new images that you're adding to a dataset
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
pipeline = StableDiffusionPipeline.from_pretrained(
    "./models/sd_v1-5/",
    vae=AutoencoderKL.from_pretrained(
        "./models/vae/"
    ),
    torch_dtype=torch_dtype,
    safety_checker=None
)
pipeline.to(accelerator.device)


class PromptDataset(Dataset):
    "A simple dataset to prepare the prompts to generate class images on multiple GPUs."

    def __init__(self, prompt, num_samples):
        self.prompt = prompt
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        example = {}
        example["prompt"] = self.prompt
        example["index"] = index
        return example
    
sample_dataset = PromptDataset(concepts_list[0]["class_prompt"], num_samples=0) #Tell it the number of new images
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=1)

sample_dataloader = accelerator.prepare(sample_dataloader)


with torch.autocast("cuda"), torch.inference_mode():
    for example in tqdm(
        sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
    ):
        images = pipeline(example["prompt"]).images

        for i, image in enumerate(images):
            hash_image = hashlib.sha1(image.tobytes()).hexdigest()
            image_filename = class_image_dir / f"{example['index'][i] + curr_class_images}-{hash_image}.jpg"
            image.save(image_filename)
            
del pipeline
if torch.cuda.is_available():
    torch.cuda.empty_cache()

In [None]:
vae = AutoencoderKL.from_pretrained('./models/vae/').to("cuda")
vae.requires_grad_(False)
tokenizer = CLIPTokenizer.from_pretrained('./tokenizers/')
text_encoder = CLIPTextModel.from_pretrained('./models/text_encoder/').to("cuda")
model = UNet2DConditionModel.from_pretrained('./models/sd_v1-5/').to("cuda")

In [None]:
# If you want to use 8-bit Adam
try:
    import bitsandbytes as bnb
except ImportError:
    print("sorry dude")
    
if not bnb:
    optimizer_class = torch.optim.AdamW

In [None]:
params_to_optimize = (itertools.chain(model.parameters(), text_encoder.parameters()))

optimizer = optimizer_class(
    params_to_optimize,
    lr=1e-6, #If you're training your text encoder too, else 5e-6
    betas=(0.9, 0.999),
    weight_decay=1e-2,
    eps=1e-08
)

noise_scheduler = DDPMScheduler(
    beta_start=0.00085, 
    beta_end=0.012, 
    beta_schedule='scaled_linear', 
    num_train_timesteps=1000, 
    set_alpha_to_one=False, 
    skip_prk_steps=True, 
    steps_offset=1, 
    clip_sample=False
)

In [None]:
class DreamBoothDataset(Dataset):
    """
    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
    It pre-processes the images and the tokenizes prompts.
    """

    def __init__(
        self,
        concepts_list,
        tokenizer,
        with_prior_preservation=True,
        size=512,
        center_crop=False,
        num_class_images=None,
        pad_tokens=False,
        hflip=False
    ):
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer
        self.with_prior_preservation = with_prior_preservation
        self.pad_tokens = pad_tokens

        self.instance_images_path = []
        self.class_images_path = []

        for concept in concepts_list:
            inst_img_path = [(x, concept["instance_prompt"]) for x in Path(concept["instance_data_dir"]).iterdir() if x.is_file()]
            self.instance_images_path.extend(inst_img_path)

            if with_prior_preservation:
                class_img_path = [(x, concept["class_prompt"]) for x in Path(concept["class_data_dir"]).iterdir() if x.is_file()]
                self.class_images_path.extend(class_img_path[:num_class_images])

        random.shuffle(self.instance_images_path)
        self.num_instance_images = len(self.instance_images_path)
        self.num_class_images = len(self.class_images_path)
        self._length = max(self.num_class_images, self.num_instance_images)

        self.image_transforms = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(0.5 * hflip),
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_path, instance_prompt = self.instance_images_path[index % self.num_instance_images]
        instance_image = Image.open(instance_path)
        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB")
        example["instance_images"] = self.image_transforms(instance_image)
        example["instance_prompt_ids"] = self.tokenizer(
            instance_prompt,
            padding="max_length" if self.pad_tokens else "do_not_pad",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
        ).input_ids

        if self.with_prior_preservation:
            class_path, class_prompt = self.class_images_path[index % self.num_class_images]
            class_image = Image.open(class_path)
            if not class_image.mode == "RGB":
                class_image = class_image.convert("RGB")
            example["class_images"] = self.image_transforms(class_image)
            example["class_prompt_ids"] = self.tokenizer(
                class_prompt,
                padding="max_length" if self.pad_tokens else "do_not_pad",
                truncation=True,
                max_length=self.tokenizer.model_max_length,
            ).input_ids

        return example
    
train_dataset = DreamBoothDataset(
    concepts_list=concepts_list,
    tokenizer=tokenizer,
    with_prior_preservation=True,
    size=512,
    center_crop=False,
    num_class_images=50,
    pad_tokens=True,
    hflip=False
)

In [None]:
def collate(examples):
    input_ids = [example["instance_prompt_ids"] for example in examples]
    pixel_values = [example["instance_images"] for example in examples]
    
    input_ids += [example["class_prompt_ids"] for example in examples]
    pixel_values += [example["class_prompt_ids"] for example in examples]
    
    pixel_values = torch.stack(pixel_values)
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    
    input_ids = tokenizer.pad(
        {"input_ids": input_ids},
        padding=True,
        return_tensors="pt",
    ).input_ids
    
    batch = {
        "input_ids": input_ids,
        "pixel_values": pixel_values
    }
    return batch

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=1, shuffle=True, collate_fn=collate, pin_memory=True
)

In [None]:
weight_type = "fp16"
vae.to(accelerator.device, dtype=weight_type)

latents_cache = []
text_encoder_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
    with torch.no_grad():
        batch['pixel_values'] = batch['pixel_values'].to(accelerator.device, non_blocking=True, dtype=weight_dtype)
        batch['input_ids'] = batch['input_ids'].to(accelerator.device, non_blocking=True)
        latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
        
        text_encoder_cache.append(batch['input_ids'])
        
class LatentsDataset(Dataset):
    def __init__(self, latents_cache, text_encoder_cache):
        self.latents_cache = latents_cache
        self.text_encoder_cache = text_encoder_cache

    def __len__(self):
        return len(self.latents_cache)

    def __getitem__(self, index):
        return self.latents_cache[index], self.text_encoder_cache[index]
    
train_dataset = LatentsDataset(latents_cache, text_encoder_cache)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=True
)

del vae
if torch.cuda.is_available():
    torch.cuda.empty_cache()

In [None]:
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / 1)

lr_scheduler = get_scheduler(
    "constant",
    optimizer=optimizer,
    num_warmup_steps = 0,
    num_training_steps = 800
)

In [None]:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(model, text_encoder, optimizer, train_dataloader, lr_scheduler)

In [None]:
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / 1)
num_train_epochs = math.ceil(800 / num_update_steps_per_epoch)

total_batch_size = 1 * accelerator.num_processes * 1

In [None]:
def save_weights(step):
    if accelerator.is_main_process:
        text_enc_model = accelerator.unwrap_model(text_encoder)
        scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
        pipeline = StableDiffusionPipeline.from_pretrained(
            unet=accelerator.unwrap_model(unet),
            text_encoder=text_enc_model,
            vae=AutoencoderKL.from_pretrained(
                './models/vae/'
            ),
            safety_checker=None,
            scheduler=scheduler,
            torch_dtype=torch.float16,
        )
        save_dir = "./dreambooth_outputs/"
        pipeline.save_pretrained(save_dir)
        
        pipeline = pipeline.to(accelerator.device)
        g_cuda = torch.Generator(device=accelerator.device).manual_seed(8855) #This is arbitrary, I just like this one.
        pipeline.set_progress_bar_config(disable=True)
        sample_dir = os.path.join(save_dir, "samples")
        os.makedirs(sample_dir, exist_ok=True)
        with torch.autocast("cuda"), torch.inference_mode():
            for i in tqdm(range(4)): #This number is how many samples of the new model to save
                images = pipeline(
                    "photo of a zwx pokemon",
                    negative_prompt="", #Add negative prompts
                    guidance_scale=7.5, #default
                    num_inference_steps=50, #default
                    generator=g_cuda
                ).images
                images[0].save(os.path.join(sample_dir, f"{i}.png"))
        del pipeline
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    print(f"[*] Weights saved at {save_dir}")
    
class AverageMeter:
    def __init__(self, name=None):
        self.name = name
        self.reset()

    def reset(self):
        self.sum = self.count = self.avg = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
    
progress_bar = tqdm(range(800), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Train_steps")
global_step = 0
loss_avg = AverageMeter()
text_enc_context = nullcontext()

In [None]:
for epoch in range(num_train_epochs):
    unet.train()
    text_encoder.train()
    for step, batch in enumerate(train_dataloader):
        with accelerator.accumulate(unet):
            with torch.no_grad():
                latent_dist = batch[0][0]
                latents = latent_dist.sample() * 0.18215
                
            noise = torch.randn_like(latents)
            bsz = latents.shape[0]
            
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
            timesteps = timesteps.long()
            
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            
            with text_enc_context:
                encoder_hidden_states = text_encoder(batch[0][1])[0]
                
            noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
            
            noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
            noise, noise_prior = torch.chunk(noise, 2, dim=0)
            
            loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1,2,3]).mean()
            
            prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
            
            loss = loss + 1.0*prior_loss
            
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad(set_to_none=True)
            loss_avg.update(loss.detach(), bsz)
            
        if not global_step % 10:
            logs = {"loss": loss_avg.avg.item(), "lr": lr_scheduler.get_last_lr()[0]}
            progress_bar.set_postfix(**logs)
            
        if global_step > 0 and not global_step % 800:
            save_weights(global_step)
            
        progress_bar.update(1)
        global_step += 1
        
        if global_step >= 800:
            break
        
    accelerator.wait_for_everyone()
    
save_weights(global_step)

accelerator.end_training()