# The notebook for training the text to image model

## Package Preparation

### Import packages

In [1]:
# !pip install -q datasets
# !pip install -q transformers
# !pip install -q accelerate
# !pip install -q git+https://github.com/huggingface/diffusers

In [2]:
import logging
import math
import os
import random
import glob
from pathlib import Path

import datasets
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset, Dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available

### Check diffuser version & Save model card

In [3]:
check_min_version("0.16.0.dev0")

logger = get_logger(__name__, log_level="INFO")

In [4]:
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
    img_str = ""
    for i, image in enumerate(images):
        image.save(os.path.join(repo_folder, f"image_{i}.png"))
        img_str += f"![img_{i}](./image_{i}.png)\n"

    yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- lora
inference: true
---
    """
    model_card = f"""
# LoRA text2image fine-tuning - {repo_id}
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
{img_str}
"""
    with open(os.path.join(repo_folder, "README.md"), "w") as f:
        f.write(yaml + model_card)

## Set Basic Arguments

### Saving Directory

In [5]:
#@markdown If model weights should be saved directly in google drive (takes around 4-5 GB).
save_to_gdrive = False #@param {type:"boolean"}
if save_to_gdrive:
    from google.colab import drive
    drive.mount('/content/drive')

#@markdown Name/Path of the initial model.
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5" #@param {type:"string"}

#@markdown Enter the directory name to save model at.

# output_dir = "ml_stable_diffusion_weights/lora" #@param {type:"string"}
# if save_to_gdrive:
#     output_dir = "/content/drive/MyDrive/" + output_dir
# else:
#     output_dir = "/content/" + output_dir

# print(f"[*] Weights will be saved at {output_dir}")

# !mkdir -p $output_dir
output_dir = "lora_output"

### Configure Accelerator

In [6]:
logging_dir = os.path.join(output_dir, "logs")
accelerator_project_config = ProjectConfiguration(total_limit=None)

accelerator = Accelerator(
        gradient_accumulation_steps=1,
        mixed_precision="fp16",
        log_with="tensorboard",
        logging_dir=logging_dir,
        project_config=accelerator_project_config,
    )



### Handle Repository Creation

In [7]:
if accelerator.is_main_process:
        if output_dir is not None:
            os.makedirs(output_dir, exist_ok=True)

### Load scheduler, tokenizer, models

In [8]:
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(
    pretrained_model_name_or_path, subfolder="tokenizer", revision=None
)
text_encoder = CLIPTextModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="text_encoder", revision=None
)
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=None)
unet = UNet2DConditionModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="unet", revision=None
)
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)

text_encoder.requires_grad_(False)

CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e

In [9]:
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16

### Move unet, vae, text_encoder to device

In [10]:
print(accelerator.device)

cuda


In [11]:
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)

CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e

## Start adding LoRA weights to attention layers

    # It's important to realize here how many attention weights will be added and of which sizes
    # The sizes of the attention layers consist only of two different variables:
    # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
    # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.

    # Let's first see how many attention processors we will have to set.
    # For Stable Diffusion, it should be equal to:
    # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
    # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
    # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
    # => 32 layers

### Set correct lora layers

In [12]:
lora_attn_procs = {}
for name in unet.attn_processors.keys():
  cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
  # print(name)
  if name.startswith("mid_block"):
    # print(unet.config.block_out_channels)
    hidden_size = unet.config.block_out_channels[-1]
  elif name.startswith("up_blocks"):
    block_id = int(name[len("up_blocks.")])
    hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
    # print(hidden_size)
  elif name.startswith("down_blocks"):
    block_id = int(name[len("down_blocks.")])
    hidden_size = unet.config.block_out_channels[block_id]
    # print(hidden_size)

  lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)

unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)


### Initalize optimizers

In [13]:
#@markdown Parameters for adamW

optimizer_cls = torch.optim.AdamW
learning_rate = 1e-4 #@param {type:"number"}
adam_beta1 = 0.9 #@param {type:"number"}
adam_beta2 = 0.999 #@param {type:"number"}
adam_weight_decay = 1e-2 #@param {type:"number"}
adam_epsilon = 1e-08 #@param {type:"number"}

In [14]:
optimizer = optimizer_cls(
    lora_layers.parameters(),
    lr=learning_rate,
    betas=(adam_beta1,adam_beta2),
    weight_decay=adam_weight_decay,
    eps=adam_epsilon,
)

### Load Quickdraw Dataset

#### Read the class name

In [15]:
# !wget 'https://raw.githubusercontent.com/zaidalyafeai/zaidalyafeai.github.io/master/sketcher/mini_classes.txt'

--2023-05-02 06:49:20--  https://raw.githubusercontent.com/zaidalyafeai/zaidalyafeai.github.io/master/sketcher/mini_classes.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 760 [text/plain]
Saving to: 'mini_classes.txt.1'


2023-05-02 06:49:20 (50.1 MB/s) - 'mini_classes.txt.1' saved [760/760]



In [16]:
f = open("mini_classes.txt","r")
# And for reading use
classes = f.readlines()
f.close()

In [17]:
classes = [c.replace('\n','').replace(' ','_') for c in classes]
print(len(classes))

100


Download Data

In [18]:
# !mkdir data

In [19]:
import urllib.request
from tqdm.auto import tqdm
def download():
    base = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'
    for c in tqdm(classes):        
        cls_url = c.replace('_', '%20')
        path = base+cls_url+'.npy'
        # print(path)
        urllib.request.urlretrieve(path, 'data/'+c+'.npy')

In [20]:
# download()

load the data

In [21]:
# def load_data(root, vfold_ratio=0.2, max_items_per_class= 4000 ):
#     all_files = glob.glob(os.path.join(root, '*.npy'))

#     #initialize variables 
#     x = np.empty([0, 784])
#     y = np.empty([0])
#     class_names = []

#     #load each data file 
#     for idx, file in enumerate(all_files):
#         data = np.load(file)
#         data = data[0: max_items_per_class, :]
#         labels = np.full(data.shape[0], idx)

#         x = np.concatenate((x, data), axis=0)
#         y = np.append(y, labels)

#         class_name, ext = os.path.splitext(os.path.basename(file))
#         class_names.append(class_name)

#     data = None
#     labels = None
    
#     #randomize the dataset 
#     permutation = np.random.permutation(y.shape[0])
#     x = x[permutation, :]
#     y = y[permutation]

#     #separate into training and testing 
#     vfold_size = int(x.shape[0]/100*(vfold_ratio*100))

#     x_test = x[0:vfold_size, :]
#     y_test = y[0:vfold_size]

#     x_train = x[vfold_size:x.shape[0], :]
#     y_train = y[vfold_size:y.shape[0]]
#     return x_train, y_train, x_test, y_test, class_names

In [22]:
# x_train, y_train, x_test, y_test, class_names = load_data('data')
# num_classes = len(class_names)
# image_size = 28

In [23]:
# print(len(x_train))

In [24]:
# import matplotlib.pyplot as plt
# from random import randint
# %matplotlib inline  
# idx = randint(0, len(x_train))
# plt.imshow(x_train[idx].reshape(28,28)) 
# print(class_names[int(y_train[idx].item())])

In [25]:
# print(len(classes))
# print(len(y_train))

In [26]:
def load_data_for_diffusion(root, max_items_per_class= 4000 ):
    all_files = glob.glob(os.path.join(root, '*.npy'))

    #initialize variables
    imgs = np.empty([0, 784])
    labels = []

    for idx, file in enumerate(all_files):
      data = np.load(file)
      data = data[0: max_items_per_class, :]

      class_name, ext = os.path.splitext(os.path.basename(file))
      labels.extend(["a scribble of" + class_name for i in range(data.shape[0])])

      imgs = np.concatenate((imgs, data), axis=0)


    return imgs, labels
    

In [27]:
imgs, labels = load_data_for_diffusion('data')

In [28]:
print(imgs.shape)
print(len(labels))

(400000, 784)
400000


In [29]:
labeled_imgs = pd.DataFrame(imgs.T, columns=labels)
labeled_imgs

Unnamed: 0,a scribble oflollipop,a scribble oflollipop.1,a scribble oflollipop.2,a scribble oflollipop.3,a scribble oflollipop.4,a scribble oflollipop.5,a scribble oflollipop.6,a scribble oflollipop.7,a scribble oflollipop.8,a scribble oflollipop.9,...,a scribble ofbeard,a scribble ofbeard.1,a scribble ofbeard.2,a scribble ofbeard.3,a scribble ofbeard.4,a scribble ofbeard.5,a scribble ofbeard.6,a scribble ofbeard.7,a scribble ofbeard.8,a scribble ofbeard.9
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
779,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
780,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
781,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
782,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


#### Tokenize labels

In [30]:
def tokenize_captions(labels):
    captions = []
    for caption in labels:
        if isinstance(caption, str):
            captions.append(caption)
        elif isinstance(caption, (list, np.ndarray)):
            # take a random caption if there are multiple
            captions.append(random.choice(caption) if is_train else caption[0])
            
    inputs = tokenizer(
            captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )
    return inputs.input_ids

In [31]:
input_ids = tokenize_captions(labels=labels)
input_ids

tensor([[49406,   320,  2139,  ..., 49407, 49407, 49407],
        [49406,   320,  2139,  ..., 49407, 49407, 49407],
        [49406,   320,  2139,  ..., 49407, 49407, 49407],
        ...,
        [49406,   320,  2139,  ..., 49407, 49407, 49407],
        [49406,   320,  2139,  ..., 49407, 49407, 49407],
        [49406,   320,  2139,  ..., 49407, 49407, 49407]])

In [34]:
input_ids.shape

torch.Size([400000, 77])

#### Preprocess Images

In [39]:
from PIL import Image

imgs_RGB = [Image.fromarray(img.reshape(28,28)).convert("RGB") for img in imgs]


In [43]:
train_transforms = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]
)

In [46]:
pixel_values = [train_transforms(img) for img in imgs_RGB]

#### DataLoaders Creation

In [49]:
def collate_fn(data_dic):
    pixel_values = torch.stack([pixel for pixel in data_dic['pixel_values']])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    input_ids = torch.stack([pixel for pixel in data_dic['input_ids']])
    return {"pixel_values": pixel_values, "input_ids": input_ids}
    

In [62]:
train_batch_size=16
train_dataset_dic = {"input_ids": input_ids, "pixel_values": pixel_values}
train_dataloader = torch.utils.data.DataLoader(
    train_dataset_dic,
    shuffle=True,
    collate_fn=collate_fn,
    batch_size=train_batch_size,
    num_workers=0,
)

### Training Preparations

#### Scheduler and math around the number of training steps

In [63]:
gradient_accumulation_steps = 1
num_train_epochs = 100
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
max_train_steps = num_train_epochs * num_update_steps_per_epoch
lr_warmup_steps = 500

In [64]:
lr_scheduler = "constant"
lr_scheduler = get_scheduler(
    lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
    num_training_steps=max_train_steps * gradient_accumulation_steps,
)

#### prepare everything with accelerator

In [65]:
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        lora_layers, optimizer, train_dataloader, lr_scheduler
    )

#### Recalculate our total training steps as the size of the training dataloader may have changed

In [66]:
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
max_train_steps = num_train_epochs * num_update_steps_per_epoch

### Train

In [68]:
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f"  Num examples = {len(labels)}")
logger.info(f"  Num Epochs = {num_train_epochs}")
logger.info(f"  Instantaneous batch size per device = {train_batch_size}")
logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
logger.info(f"  Total optimization steps = {max_train_steps}")

global_step = 0
first_epoch = 0

In [71]:
progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")

checkpointing_steps = 500

for epoch in range(first_epoch, num_train_epochs):
    unet.train()
    train_loss = 0.0
    for step, batch in enumverate(train_dataloader):
        with accelerator.accumulate(unet):
            # convert images to latent space
            latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
            latents = latents * vae.config.scaling_factor
            
            # sample noise that we'll add to the latents
            noise = torch.randn_like(latents)
            bsz = latents.shape[0]
            
            # Sample a random timestep for each image
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
            timesteps = timesteps.long()
            
            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            
            # Get the text embedding for conditioning
            encoder_hidden_states = text_encoder(batch["input_ids"])[0]
            
            # Get the target for loss depending on the prediction type
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
                
            
            # Predict the noise residual and compute loss
            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
            
            # Gather the losses across all processes
            avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
            train_loss += avg_loss.item() / args.gradient_accumulation_steps
            
            # Backpropagate
            accelerator.backward(loss)
            if accelerator.sync_gradients:
                params_to_clip = lora_layers.parameters()
                accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            
            # update procress bar
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1
                accelerator.log({"train_loss":  train_loss}, step=global_step)
                train_loss = 0.0
                
                if global_step % checkpointing_steps == 0:
                    if accelerator.is_main_process:
                        save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
                        accelerator.save_state(save_path)
                        logger.info(f"Saved state to {save_path}")
                        
            logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
            progress_bar.set_postfix(**logs)
            
            if global_step >= args.max_train_steps:
                break
        
        if accelerator.is_main_process:
            validation_prompt = "a scribble of horse"

  0%|          | 0/100 [00:00<?, ?it/s]