# Textual Inversion

# Common

In [None]:
import itertools
import math
import os
import random

import jsonargparse
import numpy as np
import torch
import torch.nn.functional as F
import yaml
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    LMSDiscreteScheduler,
    UNet2DConditionModel,
    get_scheduler,
)
from PIL import Image
from torch import autocast
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms as tfms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device:{torch_device}")

In [None]:
# Configuration
config = jsonargparse.dict_to_namespace(
    yaml.safe_load(
        """
data:
 image_dir: ../../images/kali 
 size: 512
 placeholder_token: <kali-dog>
 initializer_token: dog
 learnable_property: object
optimization:
 learning_rate: 0.000125
 batch_size: 4
 batch_accum: 4
 num_train_steps: 5000
architecture:
 model_name: runwayml/stable-diffusion-v1-5
 scheduler: CompVis/stable-diffusion-v1-4
 train_scheduler_type: DDPMScheduler
 inference_scheduler_type: LMSDiscreteScheduler
embedding_save_path: kali_saved_embedding_{}.bin
"""
    )
)

In [None]:
def add_token(placeholder_token, text_encoder, tokenizer, embed):
    # Add placeholder_token to the tokenizer
    num_added_tokens = tokenizer.add_tokens(config.data.placeholder_token)
    if num_added_tokens == 0:
        raise ValueError(
            f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
            "placeholder_token` that is not already in the tokenizer."
        )

    placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
    text_encoder.resize_token_embeddings(len(tokenizer))
    token_embeds = text_encoder.get_input_embeddings().weight.data
    
    if isinstance(embed, int):
        token_embeds[placeholder_token_id] = token_embeds[embed]
    elif isinstance(embed, torch.Tensor):
        token_embeds[placeholder_token_id] = embed

    return placeholder_token_id


def get_models(config, scheduler_type):
    vae = AutoencoderKL.from_pretrained(config.architecture.model_name, subfolder="vae")
    tokenizer = CLIPTokenizer.from_pretrained(
        config.architecture.model_name, subfolder="tokenizer"
    )
    text_encoder = CLIPTextModel.from_pretrained(
        config.architecture.model_name, subfolder="text_encoder"
    )
    unet = UNet2DConditionModel.from_pretrained(
        config.architecture.model_name, subfolder="unet"
    )
    
    if scheduler_type == "DDPMScheduler":
        scheduler = DDPMScheduler.from_config(
        config.architecture.scheduler, subfolder="scheduler"
        )
    elif scheduler_type == "LMSDiscreteScheduler":
        scheduler = LMSDiscreteScheduler.from_config(
        config.architecture.scheduler, subfolder="scheduler"
        )

    return vae, tokenizer, text_encoder, unet, scheduler


def freeze_params(params):
    for param in params:
        param.requires_grad = False


def generate_image(prompt, neg_prompt, vae, tokenizer, text_encoder, unet, scheduler, config):
    # Some settings
    prompt = [prompt]
    height = config.data.size                        # default height of Stable Diffusion
    width = config.data.size                         # default width of Stable Diffusion
    num_inference_steps = 30            # Number of denoising steps
    guidance_scale = 8                # Scale for classifier-free guidance
    generator = torch.manual_seed(32)   # Seed generator to create the inital latent noise
    batch_size = 1

    # Prep text 
    text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    with torch.no_grad():
        text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer(
    [neg_prompt] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
    )
    
    with torch.no_grad():
        uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] 
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

    # Prep Scheduler
    scheduler.set_timesteps(num_inference_steps)

    # Prep latents
    latents = torch.randn(
    (batch_size, unet.in_channels, height // 8, width // 8),
    generator=generator,
    )
    latents = latents.to(torch_device)
    latents = latents * scheduler.init_noise_sigma # Scaling (previous versions did latents = latents * self.scheduler.sigmas[0]

    # Loop
    with autocast("cuda"):
        for i, t in tqdm(enumerate(scheduler.timesteps)):
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)
            # Scale the latents (preconditioning):
            latent_model_input = scheduler.scale_model_input(latent_model_input, t)

            # predict the noise residual
            with torch.no_grad():
                noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

            # perform guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = scheduler.step(noise_pred, t, latents).prev_sample

    # scale and decode the image latents with vae
    latents = 1 / 0.18215 * latents
    with torch.no_grad():
        image = vae.decode(latents).sample

    # Display
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    return pil_images[0]

# Training

## Dataset

In [None]:
# Python Dataset & Dataloaders
object_templates = [
    "a photo of a {}",
    "a rendering of a {}",
    "a cropped photo of the {}",
    "the photo of a {}",
    "a photo of a clean {}",
    "a photo of a dirty {}",
    "a dark photo of the {}",
    "a photo of my {}",
    "a photo of the cool {}",
    "a close-up photo of a {}",
    "a bright photo of the {}",
    "a cropped photo of a {}",
    "a photo of the {}",
    "a good photo of the {}",
    "a photo of one {}",
    "a close-up photo of the {}",
    "a rendition of the {}",
    "a photo of the clean {}",
    "a rendition of a {}",
    "a photo of a nice {}",
    "a good photo of a {}",
    "a photo of the nice {}",
    "a photo of the small {}",
    "a photo of the weird {}",
    "a photo of the large {}",
    "a photo of a cool {}",
    "a photo of a small {}",
]

style_templates = [
    "a painting in the style of {}",
    "a rendering in the style of {}",
    "a cropped painting in the style of {}",
    "the painting in the style of {}",
    "a clean painting in the style of {}",
    "a dirty painting in the style of {}",
    "a dark painting in the style of {}",
    "a picture in the style of {}",
    "a cool painting in the style of {}",
    "a close-up painting in the style of {}",
    "a bright painting in the style of {}",
    "a cropped painting in the style of {}",
    "a good painting in the style of {}",
    "a close-up painting in the style of {}",
    "a rendition in the style of {}",
    "a nice painting in the style of {}",
    "a small painting in the style of {}",
    "a weird painting in the style of {}",
    "a large painting in the style of {}",
]


class TextualInversionDataset(Dataset):
    def __init__(
        self,
        data_root,
        tokenizer,
        placeholder_token,
        learnable_property="object",
        size=512,
        repeats=100,
        flip_p=0.5,
    ):
        self.data_root = data_root
        self.tokenizer = tokenizer
        self.learnable_property = learnable_property
        self.size = size
        self.placeholder_token = placeholder_token
        self.flip_p = flip_p

        self.image_paths = [
            os.path.join(self.data_root, file_path)
            for file_path in os.listdir(self.data_root)
        ]

        self.num_images = len(self.image_paths)
        self._length = self.num_images * repeats

        self.interpolation = Image.BICUBIC

        self.templates = (
            style_templates if learnable_property == "style" else object_templates
        )
        self.flip_transform = tfms.RandomHorizontalFlip(p=self.flip_p)

    def __len__(self):
        return self._length

    def __getitem__(self, i):
        text = random.choice(self.templates).format(self.placeholder_token)
        text_input = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids[0]

        image = Image.open(self.image_paths[i % self.num_images])
        if not image.mode == "RGB":
            image = image.convert("RGB")
        image = image.resize((self.size, self.size), resample=self.interpolation)
        image = self.flip_transform(image)
        image = np.array(image).astype(np.uint8)
        image = (image / 127.5 - 1.0).astype(np.float32)
        image_input = torch.from_numpy(image).permute(2, 0, 1)

        return {"text_input": text_input, "image_input": image_input}

## Setup Training

In [None]:
# Prepare model artifacts for training
vae, tokenizer, text_encoder, unet, scheduler = get_models(config, config.architecture.train_scheduler_type)

token_ids = tokenizer.encode(config.data.initializer_token, add_special_tokens=False)
if len(token_ids) > 1:
    raise ValueError("The initializer token must be a single token.")
init_token_id = token_ids[0]
placeholder_token_id = add_token(
    config.data.placeholder_token, text_encoder, tokenizer, init_token_id
)

freeze_params(vae.parameters())
freeze_params(unet.parameters())
params_to_freeze = itertools.chain(
    text_encoder.text_model.encoder.parameters(),
    text_encoder.text_model.final_layer_norm.parameters(),
    text_encoder.text_model.embeddings.position_embedding.parameters(),
)
freeze_params(params_to_freeze)

vae = vae.to(torch_device)
text_encoder = text_encoder.to(torch_device)
unet = unet.to(torch_device)

vae.eval()
unet.eval()
text_encoder.train()

optimizer = torch.optim.AdamW(
    text_encoder.get_input_embeddings().parameters(),
    lr=config.optimization.learning_rate * config.optimization.batch_accum * config.optimization.batch_size
)

## Training Loop

In [None]:
train_dataset = TextualInversionDataset(
    data_root=config.data.image_dir,
    tokenizer=tokenizer,
    size=config.data.size,
    placeholder_token=config.data.placeholder_token,
    learnable_property=config.data.learnable_property,
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=config.optimization.batch_size, shuffle=True
)

data_iterator = iter(train_dataloader)
progress_bar = tqdm(range(config.optimization.num_train_steps))
progress_bar.set_description("Steps")

for step in range(config.optimization.num_train_steps * config.optimization.batch_accum):
    try:
        batch = next(data_iterator)
    except StopIteration:
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset, batch_size=config.optimization.batch_size, shuffle=True
        )
    
    latents = (
        vae.encode(batch["image_input"].to(torch_device)).latent_dist.sample().detach()
    )
    latents = latents * 0.18215

    noise = torch.randn(latents.shape).to(latents.device)
    bsz = latents.shape[0]
    timesteps = torch.randint(
        0, scheduler.config.num_train_timesteps, (bsz,), device=latents.device
    ).long()
    
    noisy_latents = scheduler.add_noise(latents, noise, timesteps)
    encoder_hidden_states = text_encoder(batch["text_input"].to(torch_device))[0]
    noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

    loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
    loss = loss / config.optimization.batch_accum
    loss.backward()
    
    grads = text_encoder.get_input_embeddings().weight.grad
    index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
    grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
    
    if (step + 1) % config.optimization.batch_accum == 0:
        optimizer.step()
        optimizer.zero_grad()

        progress_bar.update(1)
        logs = {"loss": loss.detach().item(), "lr": config.optimization.learning_rate * config.optimization.batch_accum * config.optimization.batch_size}
        progress_bar.set_postfix(**logs)

learned_embeds = text_encoder.get_input_embeddings().weight[placeholder_token_id]
torch.save(
    {config.data.placeholder_token: learned_embeds.detach().cpu()}, config.embedding_save_path.format(step)
)

# Inference

## Setup Model

In [None]:
vae, tokenizer, text_encoder, unet, scheduler = get_models(config, config.architecture.inference_scheduler_type)

placeholder_token, embedding = list(torch.load("./kali_saved_embedding_19999.bin").items())[0]
add_token(placeholder_token, text_encoder, tokenizer, embedding)

vae = vae.to(torch_device)
text_encoder = text_encoder.to(torch_device)
unet = unet.to(torch_device)

vae.eval()
unet.eval()
text_encoder.eval()

## Generate Image

In [None]:
generate_image("Oil painting of <kali-dog>", "", vae, tokenizer, text_encoder, unet, scheduler, config)