## This is the unofficial implementation of paper Prompt Tuning Inversion for Text-Driven Image Editing Using Diffusion Models
https://arxiv.org/abs/2305.04441

Use ddim

## First download the Diffusion Model

In [None]:
from diffusers import StableDiffusionPipeline
from diffusers import DDIMScheduler
import torch

model_id = "CompVis/stable-diffusion-v1-4"
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")

ldm_stable = StableDiffusionPipeline.from_pretrained(model_id).to(device)

## Load your own image

In [2]:
from PIL import Image
import numpy as np

def load_image(img_path):
    img = Image.open(img_path)
    img = np.array(img)

    img_tensor = torch.from_numpy(img).float() / 127.5 - 1
    img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device_0)
    return img_tensor

img_path = "your_image_path"
image = load_image(img_path)
x_0 = (ldm_stable.vae.encode(image).latent_dist.mode() * 0.18215).float()

## DDIM Inversion

In [None]:
from tqdm import tqdm
from typing import Union

scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
ldm_stable.scheduler = scheduler
ldm_stable.scheduler.set_timesteps(50)

def encode_text(model, prompts):
    text_input = model.tokenizer(
        prompts,
        padding="max_length",
        max_length=model.tokenizer.model_max_length, 
        truncation=True,
        return_tensors="pt",
    )
    with torch.no_grad():
        text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0]
    return text_encoding

def next_step(model, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
    timestep, next_timestep = min(timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps, 999), timestep
    alpha_prod_t = model.scheduler.alphas_cumprod[timestep] if timestep >= 0 else model.scheduler.final_alpha_cumprod
    alpha_prod_t_next = model.scheduler.alphas_cumprod[next_timestep]
    beta_prod_t = 1 - alpha_prod_t
    next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
    next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
    next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
    return next_sample

def get_noise_pred(model, latent, t, context, cfg_scale):
    latents_input = torch.cat([latent] * 2)
    noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
    noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
    noise_pred = noise_pred_uncond + cfg_scale * (noise_prediction_text - noise_pred_uncond)
    # latents = next_step(model, noise_pred, t, latent)
    return noise_pred

@torch.no_grad()
def ddim_inversion(model, w0, cfg_scale):
    # uncond_embeddings, cond_embeddings = self.context.chunk(2)
    # all_latent = [latent]
    uncond_embedding = encode_text(model, "")
    context = torch.cat([uncond_embedding, uncond_embedding])
    latent = w0.clone().detach()
    all_latent = []
    for i in tqdm(range(model.scheduler.num_inference_steps)):
        t = model.scheduler.timesteps[len(model.scheduler.timesteps) - i - 1]
        noise_pred = get_noise_pred(model, latent, t, context, cfg_scale)
        latent = next_step(model, noise_pred, t, latent)
        all_latent.append((t, latent))
    return all_latent

xts = ddim_inversion(ldm_stable, x_0, 0)

## Pormpt Tuning process
(Don't forget to give the target prompt)

In [None]:
from torch.optim import AdamW

def ddim_forward(model, noise, t, latent):
    latent = model.scheduler.step(noise, t, latent)["prev_sample"]
    return latent

def prompt_tuning(model, xts, target_prompt, cfg_scale, beta, N):
    c_list = []
    target_embedding = encode_text(model, target_prompt)
    c_T = torch.concat([target_embedding, target_embedding])
    c_list.append(c_T)
    opt = AdamW([c_T])
    z_prd_t = xts[-1][1]
    c_T.requires_grad_(True)
    for i in range(model.scheduler.num_inference_steps - 1):
        t = model.scheduler.timesteps[i]
        z_t_minus_1 = xts[len(xts) - i - 2][1]
        z_t_minus_1.requires_grad_(True)
        for j in range(0, N):
            print(f'iter {j}')
            noise_prd = get_noise_pred(model, z_prd_t, t, c_T, cfg_scale)
            z_prd_t_minus_1 = ddim_forward(model, noise_prd, t, z_prd_t)
            z_prd_t_minus_1.requires_grad_(True)
            opt.zero_grad()

            differnce = torch.norm(z_prd_t_minus_1 - z_t_minus_1, p=2) ** 2 
            differnce.backward()
            opt.step()
            
            z_t_minus_1 = z_t_minus_1.detach()
            z_prd_t_minus_1 = z_prd_t_minus_1.detach()
        print(f'done')

        c_list.append(c_T.clone())
        
        noise_prd_again = get_noise_pred(model, z_prd_t, t, c_T, cfg_scale)
        z_prd_t_minus_1 = ddim_forward(model, noise_prd, t, z_prd_t)
        z_prd_t = z_prd_t_minus_1.detach()
    
    return c_list

target_prompt = "your own prompt"
beta = 0.1
c_list = prompt_tuning(ldm_stable, xts, target_prompt, cfg_scale=5, beta=beta, N=1)


## Editing process

In [None]:
@torch.no_grad()
def editing(model, xts, c_list, target_prompt, eta):
    z_t = xts[-1][1]
    target_embedding = encode_text(model, target_prompt)
    c_target = torch.concat([target_embedding, target_embedding])
    for i in range(model.scheduler.num_inference_steps):
        t = model.scheduler.timesteps[i]
        c_t = c_list[i]
        c_t = (1 - eta) * c_t + eta * c_target
        noise_prd = get_noise_pred(model, z_t, t, c_t, cfg_scale=1)
        z_t = ddim_forward(model, noise_prd, t, z_t)
    return z_t

z_0 = editing(ldm_stable, xts, c_list, target_prompt, eta=0.9)

## Decode the latent and save the image

In [8]:
from PIL import Image, ImageDraw ,ImageFont
import torchvision.transforms as T

def tensor_to_pil(tensor_imgs):
    if type(tensor_imgs) == list:
        tensor_imgs = torch.cat(tensor_imgs)
    tensor_imgs = (tensor_imgs / 2 + 0.5).clamp(0, 1)
    to_pil = T.ToPILImage()
    pil_imgs = [to_pil(img) for img in tensor_imgs]    
    return pil_imgs

def add_margin(pil_img, top = 0, right = 0, bottom = 0, 
                    left = 0, color = (255,255,255)):
    width, height = pil_img.size
    new_width = width + right + left
    new_height = height + top + bottom
    result = Image.new(pil_img.mode, (new_width, new_height), color)
    
    result.paste(pil_img, (left, top))
    return result

def image_grid(imgs, rows = 1, cols = None, 
                    size = None,
                   titles = None, text_pos = (0, 0)):
    if type(imgs) == list and type(imgs[0]) == torch.Tensor:
        imgs = torch.cat(imgs)
    if type(imgs) == torch.Tensor:
        imgs = tensor_to_pil(imgs)
        
    if not size is None:
        imgs = [img.resize((size,size)) for img in imgs]
    if cols is None:
        cols = len(imgs)
    assert len(imgs) >= rows*cols
    
    top=20
    w, h = imgs[0].size
    delta = 0
    if len(imgs)> 1 and not imgs[1].size[1] == h:
        delta = top
        h = imgs[1].size[1]
    if not titles is  None:
        font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeMono.ttf", 
                                    size = 20, encoding="unic")
        h = top + h 
    grid = Image.new('RGB', size=(cols*w, rows*h+delta))    
    for i, img in enumerate(imgs):
        
        if not titles is  None:
            img = add_margin(img, top = top, bottom = 0,left=0)
            draw = ImageDraw.Draw(img)
            draw.text(text_pos, titles[i],(0,0,0), 
            font = font)
        if not delta == 0 and i > 0:
           grid.paste(img, box=(i%cols*w, i//cols*h+delta))
        else:
            grid.paste(img, box=(i%cols*w, i//cols*h))
        
    return grid    


z_decode = ldm_stable.vae.decode(1 / 0.18215 * z_0).sample
if z_decode.dim() < 4:
    z_decode = z_decode[None, :, :, :]
img = image_grid(z_decode)
save_path = "your save path"
img.save(save_path)
