In [None]:
# install
! pip install accelerate diffusers pandas torchvision transformers wandb

In [None]:
# login to wandb
! wandb login

In [None]:
# imports
import os
import shutil

from PIL import Image
from accelerate import Accelerator
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from torch.utils.data import Dataset
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer
import PIL
import numpy as np
import pandas as pd
import safetensors
import torch
import torch.nn.functional as F
import wandb

In [None]:
# configs
seed = 42
model_id = 'runwayml/stable-diffusion-v1-5'
trigger_tokens = ['<stained-glass>', '<abin-thomas>']
init_tokens = ['painting', 'person']
data_dirs = ['', '']
batch_size = 1
num_steps = 1000
val_freq = 100
print_freq = 10
checkpoint_freq = 100
keep_n_checkpoints = 5
validation_prompts = [
    "A <stained-glass> portrait of <abin-thomas>",
    "<abin-thomas> sitting on a chair",
    "A <stained-glass> of a puppy",
]
num_val_per_prompt = 4
output_dir = ''

dataset_repeat = 100  # greater than (num_steps*batch_size)/num_images
wandb_project = 'stained-glass-abin-thomas'

In [None]:
# init and set seed
accelerator = Accelerator(log_with="wandb")
set_seed(seed)

In [None]:
# load pipeline components
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
noise_scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")

In [None]:
# update tokenizer
trigger_token_ids = []
init_token_ids = []
for init_token, trigger_token in zip(init_tokens, trigger_tokens):
    assert len(tokenizer.encode(init_token, add_special_tokens=False)) == 1, "Initializer token must be a single token."

    new_token_count = tokenizer.add_tokens([trigger_token])
    assert new_token_count == 1, "Placeholder token must be a new token."

    text_encoder.resize_token_embeddings(len(tokenizer))
    token_embeds = text_encoder.get_input_embeddings().weight.data
    trigger_token_id = tokenizer.convert_tokens_to_ids(trigger_token)
    init_token_id = tokenizer.convert_tokens_to_ids(init_token)
    trigger_token_ids.append(trigger_token_id)
    init_token_ids.append(init_token_id)
    with torch.no_grad():
        token_embeds[trigger_token_id] = token_embeds[init_token_id].clone()


In [None]:
# load dataset
class ImageCaptionDataset(Dataset):

    _ADJECTIVES = ["", "good", "cropped", "clean", "bright", "cool", "nice", "small", "large", "dark", "weird"]

    def __init__(self, data_dirs, tokenizer, trigger_tokens, repeat=100):
        self.data_dirs = data_dirs
        self.tokenizer = tokenizer
        self.repeat = repeat

        filepath_caption_trigger_tuples = []
        for data_dir, trigger_token in zip(data_dirs, trigger_tokens):
            df = pd.read_csv(os.path.join(data_dir, 'captions.csv'), header=None)
            filepath_caption_trigger_tuples = [(os.path.join(data_dir, 'images', row[0]), row[1].replace('"', ''), trigger_token) for _, row in df.iterrows()]
            filepath_caption_trigger_tuples.extend(filepath_caption_trigger_tuples)
        np.random.shuffle(filepath_caption_trigger_tuples)
        self.filepath_caption_trigger_map = {filepath: (caption, trigger) for filepath, caption, trigger in filepath_caption_trigger_tuples}


    def __len__(self):
        return len(self.filepath_caption_trigger_map) * self.repeat
    
    def __getitem__(self, idx):
        idx = idx % len(self.filepath_caption_trigger_map)
        example = {}

        filepath = list(self.filepath_caption_trigger_map.keys())[idx]
        caption, trigger_token = self.filepath_caption_trigger_map[filepath]
        caption = caption.format(adjective=np.random.choice(self._ADJECTIVES), trigger=trigger_token)

        image = Image.open(filepath)
        if not image.mode == "RGB":
            image = image.convert("RGB")
        image = np.array(image).astype(np.uint8)
        image = Image.fromarray(image)
        image = image.resize((512, 512), resample=PIL.Image.Resampling.BICUBIC)
        image = transforms.RandomHorizontalFlip(p=0.5)(image)
        image = np.array(image).astype(np.uint8)
        image = (image / 127.5 - 1.0).astype(np.float32)
        example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)

        example["input_ids"] = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids[0]

        return example
    
dataset = ImageCaptionDataset(data_dirs, tokenizer, trigger_tokens, dataset_repeat)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
# prepare for training
vae.requires_grad_(False)
unet.requires_grad_(False)

text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)

optimizer = torch.optim.AdamW(text_encoder.get_input_embeddings().parameters())
lr_scheduler = get_scheduler("constant", optimizer=optimizer)

text_encoder.train()

text_encoder, optimizer, dataloader, lr_scheduler = accelerator.prepare(
    text_encoder, optimizer, dataloader, lr_scheduler
)
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()

unet.to(accelerator.device)
vae.to(accelerator.device)

global_step = 0

accelerator.init_trackers(
    wandb_project, config={
        'seed': seed,
        'model_id': model_id,
        'trigger_tokens': ', '.join(trigger_tokens),
        'init_token': ', '.join(init_tokens),
        'batch_size': batch_size,
        'num_steps': num_steps,
        'print_freq': print_freq,
        'val_freq': val_freq,
        'checkpoint_freq': checkpoint_freq,
        'keep_n_checkpoints': keep_n_checkpoints,
        'dataset_repeat': dataset_repeat,
    })

checkpoint_dir = os.path.join(output_dir, wandb_project, wandb.run.name, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
save_dir = os.path.join(output_dir, wandb_project, wandb.run.name, 'embeddings')
os.makedirs(save_dir, exist_ok=True)

In [None]:
# maybe load recent checkpoint
dirs = os.listdir(checkpoint_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None

if path is not None:
    accelerator.load_state(checkpoint_dir)
    global_step = int(path.split("-")[1])
    print(f"Loaded checkpoint from step {global_step}.")

In [None]:
# training loop
data_iter = iter(dataloader)
while global_step < num_steps:
    
    try:
        batch = next(data_iter)
    except StopIteration:
        data_iter = iter(dataloader)
        batch = next(data_iter)

    latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
    latents = latents * vae.config.scaling_factor
    noise = torch.randn_like(latents)
    bsz = latents.shape[0]
    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
    timesteps = timesteps.long()
    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
    encoder_hidden_states = text_encoder(batch["input_ids"])[0]
    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
    target = noise
    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

    accelerator.backward(loss)
    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()

    index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
    for trigger_token_id in trigger_token_ids:
        index_no_updates[trigger_token_id] = False

    with torch.no_grad():
        accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
            index_no_updates
        ] = orig_embeds_params[index_no_updates]

    global_step += 1

    accelerator.log({"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}, step=global_step)

    if global_step % print_freq == 0:
        print(f"Step: {global_step} - Loss: {loss.item()}")
    
    if global_step % checkpoint_freq == 0:
        print(f"Saving checkpoint at step {global_step}.")
        checkpoints = os.listdir(checkpoint_dir)
        checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
        checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
        if len(checkpoints) >= keep_n_checkpoints:
            num_to_remove = len(checkpoints) - keep_n_checkpoints + 1
            removing_checkpoints = checkpoints[0:num_to_remove]
            for removing_checkpoint in removing_checkpoints:
                removing_checkpoint = os.path.join(checkpoint_dir, removing_checkpoint)
                shutil.rmtree(removing_checkpoint)
        save_path = os.path.join(checkpoint_dir, f"checkpoint-{global_step}")
        accelerator.save_state(save_path)

    if global_step % val_freq == 0:
        print(f"Evaluating at step {global_step}.")
        pipeline = DiffusionPipeline.from_pretrained(
            model_id,
            text_encoder=accelerator.unwrap_model(text_encoder),
            tokenizer=tokenizer,
            unet=unet,
            vae=vae,
            safety_checker=None,
        )
        pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
        pipeline = pipeline.to(accelerator.device)
        pipeline.set_progress_bar_config(disable=True)
        
        generator = torch.Generator(device=accelerator.device).manual_seed(seed)
        images = []
        for prompt in validation_prompts:
            for i in range(num_val_per_prompt):
                image = pipeline(prompt, num_inference_steps=25, generator=generator).images[0]
                images.append((prompt, i, image))
        tracker = accelerator.trackers[0]
        tracker.log({"validation": [wandb.Image(image, caption=f"{i}-{prompt}") for prompt, i, image in images]})
        del pipeline
        torch.cuda.empty_cache()

In [None]:
# save learned embeddings
weight_name = f"learned_embeds.safetensors"
save_path = os.path.join(save_dir, weight_name)
learned_embeds_dict = {}
for trigger_token_id, trigger_token in zip(trigger_token_ids, trigger_tokens):
    learned_embeds = (
        accelerator.unwrap_model(text_encoder)
        .get_input_embeddings()
        .weight[trigger_token_id:(trigger_token_id + 1)]
    )
    _learned_embeds_dict = {trigger_token: learned_embeds.detach().cpu()}
    learned_embeds_dict.update(_learned_embeds_dict)
safetensors.torch.save_file(learned_embeds_dict, save_path, metadata={"format": "pt"})

accelerator.end_training()