## Copyright 2022 Google LLC. Double-click for license information.

In [10]:
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Invert images using Null-text inversion

In [11]:
import os
# os.system("pip install rtpt")
from typing import Union
from tqdm import tqdm
import torch

from diffusers import StableDiffusionPipeline, DDIMScheduler
import torch.nn.functional as nnf
import numpy as np
import ptp_utils
from torch.optim.adam import Adam
from PIL import Image
from huggingface_hub import login
from PIL import Image
import pandas as pd
from rtpt import RTPT

from prompt_engineering import get_precise_celeba_prompts, get_ff_prompts

In [12]:
def get_random_only_white_ff_sample(label_dir, sample_size):
    labels = pd.read_csv(label_dir)["Race"].to_numpy()
    only_white = np.where(labels=="White")[0] + 1
    sample = np.random.choice(only_white, size=sample_size, replace=False)
    return sample

In [13]:
is_celeba = False
ff_label_dir = "fairface/dataset/labels/fairface_label_train.csv"


# first_image = 1
# last_image = 12
image_numbers = [16, 24, 27] #get_random_only_white_ff_sample(ff_label_dir, 400) #[i for i in range(first_image, last_image+1)]

debug_mode = False
if is_celeba:
    input_dir = "CelebA/cropped"
    inversion_dir = "CelebA/latents_precise_prompts"
    original_prompts = get_precise_celeba_prompts(image_numbers)
else:
    input_dir = "fairface/dataset/fairface-img-margin125-trainval/train"
    inversion_dir = "fairface/dataset/latents/only_white"
    original_prompts = get_ff_prompts(image_numbers, attr_path=ff_label_dir, random_race_desc=False)
print(original_prompts)
device = "cuda:15"

{16: 'a photo of a White woman', 24: 'a photo of a White woman', 27: 'a photo of a White woman in her fifties'}


For loading the Stable Diffusion using Diffusers, follow the instuctions https://huggingface.co/blog/stable_diffusion and update MY_TOKEN with your token.

In [14]:
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
# get your account token from https://huggingface.co/settings/tokens
# MY_TOKEN = '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' # read
# login(token=MY_TOKEN)

LOW_RESOURCE = False 
NUM_DDIM_STEPS = 50
GUIDANCE_SCALE = 7.5
MAX_NUM_WORDS = 77

device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')

model_name = "/workspaces/PromptToPrompt/StableDiffusion/src/models/stable-diffusion-v1-4"
ldm_stable = StableDiffusionPipeline.from_pretrained(model_name, scheduler=scheduler).to(device) 
# use_auth_token=MY_TOKEN,"runwayml/stable-diffusion-v1-5"

# try:
#     ldm_stable.disable_xformers_memory_efficient_attention()
# except AttributeError:
#     print("Attribute disable_xformers_memory_efficient_attention() is missing")
tokenizer = ldm_stable.tokenizer

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


## Null Text Inversion code

In [15]:
def load_512(image_path, left=0, right=0, top=0, bottom=0):
    if type(image_path) is str:
        image = np.array(Image.open(image_path))[:, :, :3]
    else:
        image = image_path
    h, w, c = image.shape
    left = min(left, w-1)
    right = min(right, w - left - 1)
    top = min(top, h - left - 1)
    bottom = min(bottom, h - top - 1)
    image = image[top:h-bottom, left:w-right]
    h, w, c = image.shape
    if h < w:
        offset = (w - h) // 2
        image = image[:, offset:offset + h]
    elif w < h:
        offset = (h - w) // 2
        image = image[offset:offset + w]
    image = np.array(Image.fromarray(image).resize((512, 512)))
    return image


class NullInversion:
    
    def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
        prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
        alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
        beta_prod_t = 1 - alpha_prod_t
        pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
        pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
        prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
        return prev_sample
    
    def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
        timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep
        alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
        alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
        beta_prod_t = 1 - alpha_prod_t
        next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
        next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
        next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
        return next_sample
    
    def get_noise_pred_single(self, latents, t, context):
        noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"]
        return noise_pred

    def get_noise_pred(self, latents, t, is_forward=True, context=None):
        latents_input = torch.cat([latents] * 2)
        if context is None:
            context = self.context
        guidance_scale = 1 if is_forward else GUIDANCE_SCALE
        noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
        noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
        if is_forward:
            latents = self.next_step(noise_pred, t, latents)
        else:
            latents = self.prev_step(noise_pred, t, latents)
        return latents

    @torch.no_grad()
    def latent2image(self, latents, return_type='np'):
        latents = 1 / 0.18215 * latents.detach()
        image = self.model.vae.decode(latents)['sample']
        if return_type == 'np':
            image = (image / 2 + 0.5).clamp(0, 1)
            image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
            image = (image * 255).astype(np.uint8)
        return image

    @torch.no_grad()
    def image2latent(self, image):
        with torch.no_grad():
            if type(image) is Image:
                image = np.array(image)
            if type(image) is torch.Tensor and image.dim() == 4:
                latents = image
            else:
                image = torch.from_numpy(image).float() / 127.5 - 1
                image = image.permute(2, 0, 1).unsqueeze(0).to(device)
                latents = self.model.vae.encode(image)['latent_dist'].mean
                latents = latents * 0.18215
        return latents

    @torch.no_grad()
    def init_prompt(self, prompt: str):
        uncond_input = self.model.tokenizer(
            [""], padding="max_length", max_length=self.model.tokenizer.model_max_length,
            return_tensors="pt"
        )
        uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0]
        text_input = self.model.tokenizer(
            [prompt],
            padding="max_length",
            max_length=self.model.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0]
        self.context = torch.cat([uncond_embeddings, text_embeddings])
        self.prompt = prompt

    @torch.no_grad()
    def ddim_loop(self, latent):
        uncond_embeddings, cond_embeddings = self.context.chunk(2)
        all_latent = [latent]
        latent = latent.clone().detach()
        for i in range(NUM_DDIM_STEPS):
            t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
            noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings)
            latent = self.next_step(noise_pred, t, latent)
            all_latent.append(latent)
        return all_latent

    @property
    def scheduler(self):
        return self.model.scheduler

    @torch.no_grad()
    def ddim_inversion(self, image):
        latent = self.image2latent(image)
        image_rec = self.latent2image(latent)
        ddim_latents = self.ddim_loop(latent)
        return image_rec, ddim_latents

    def null_optimization(self, latents, num_inner_steps, epsilon, verbose):
        uncond_embeddings, cond_embeddings = self.context.chunk(2)
        uncond_embeddings_list = []
        latent_cur = latents[-1]
        bar = tqdm(total=num_inner_steps * NUM_DDIM_STEPS, disable=not verbose)
        for i in range(NUM_DDIM_STEPS):
            uncond_embeddings = uncond_embeddings.clone().detach()
            uncond_embeddings.requires_grad = True
            optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
            latent_prev = latents[len(latents) - i - 2]
            t = self.model.scheduler.timesteps[i]
            with torch.no_grad():
                noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings)
            for j in range(num_inner_steps):
                noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)
                noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_cond - noise_pred_uncond)
                latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
                loss = nnf.mse_loss(latents_prev_rec, latent_prev)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                loss_item = loss.item()
                bar.update()
                if loss_item < epsilon + i * 2e-5:
                    break
            for j in range(j + 1, num_inner_steps):
                bar.update()
            uncond_embeddings_list.append(uncond_embeddings[:1].detach())
            with torch.no_grad():
                context = torch.cat([uncond_embeddings, cond_embeddings])
                latent_cur = self.get_noise_pred(latent_cur, t, False, context)
        bar.close()
        return uncond_embeddings_list
    
    def invert(self, image_path: str, prompt: str, offsets=(0,0,0,0), num_inner_steps=10, early_stop_epsilon=1e-5, verbose=False):
        self.init_prompt(prompt)
        ptp_utils.register_attention_control(self.model, None)
        image_gt = load_512(image_path, *offsets)
        if verbose:
            print("DDIM inversion...")
        image_rec, ddim_latents = self.ddim_inversion(image_gt)
        if verbose:
            print("Null-text optimization...")
        uncond_embeddings = self.null_optimization(ddim_latents, num_inner_steps, early_stop_epsilon, verbose)
        return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings
        
    
    def __init__(self, model):
        self.model = model
        self.tokenizer = self.model.tokenizer
        self.model.scheduler.set_timesteps(NUM_DDIM_STEPS)
        self.prompt = None
        self.context = None

null_inversion = NullInversion(ldm_stable)

In [16]:
print("Starting inversion...")
print(f"image interval: {(image_numbers[0], image_numbers[-1])}")
print(f"device in use: {device}\n")


rtpt = RTPT('DR', 'Null-Text_Inversion', len(image_numbers))
rtpt.start()

for img_nmb in tqdm(image_numbers):
    img_number = f"{img_nmb:06}" if is_celeba else str(img_nmb)
    inv_output_file = img_number + '.pt'

    # skip image if it is already existing
    if os.path.isfile(os.path.join(inversion_dir, inv_output_file)):
        continue

    img_path = f"{input_dir}/{img_number}.jpg"
    if debug_mode:
        print(f"current image: {img_number}")

    original_prompt = original_prompts[img_nmb]

    if debug_mode:
        print(original_prompt)

    # do null text inversion
    _, latent, uncond_embeddings = null_inversion.invert(img_path, original_prompt, offsets=(0,0,0,0), verbose=debug_mode) # left, right, top, bottom
    
    # save result
    torch.save(
        {
            'prompt': original_prompt,
            'guidance_scale': GUIDANCE_SCALE,
            'num_inference_steps': NUM_DDIM_STEPS,
            'latents': latent,
            'uncond_embeddings': uncond_embeddings
        }, os.path.join(inversion_dir, inv_output_file))
    rtpt.step()

Starting inversion...
image interval: (16, 27)
device in use: cuda:15



 33%|███▎      | 1/3 [01:56<03:52, 116.10s/it]


KeyboardInterrupt: 

In [None]:
# bb_path = "./datasets/celeba/list_bbox_celeba.txt"

# # helper method for get_square_detection, source: https://note.nkmk.me/en/python-pillow-add-margin-expand-canvas/
# def expand2square(pil_img):
#     width, height = pil_img.size
#     if width == height:
#         return pil_img
#     elif width > height:
#         result = Image.new(pil_img.mode, (width, width), (0, 0, 0))
#         result.paste(pil_img, (0, (width - height) // 2))
#         return result
#     else:
#         result = Image.new(pil_img.mode, (height, height), (0, 0, 0))
#         result.paste(pil_img, ((height - width) // 2, 0))
#         return result


# # get square detection as input for prediction
# def get_square_detection(img_nmb, img_path, bboxes):
#     bbox = bboxes[img_nmb + 1].split()[1:]
#     bbox = [int(value) for value in bbox]

#     # get offsets from bbox
#     left, top, width, height = bbox
#     print("left", left, "top", top, "width", width, "height", height)
#     right = left + width
#     bottom = top + height
#     offsets = (left, top, right, bottom)

#     # prepare image for prediction
#     img = Image.open(img_path)
#     img = img.crop(offsets)
#     # print(offsets)
#     img.show()
#     img.thumbnail((224, 224), Image.Resampling.LANCZOS)
#     img = expand2square(img)
#     # print(img)
#     # img.show()
#     return img

# # load bounding boxes
# bboxes = []
# with open(bb_path, "r") as f:
#     bboxes = f.readlines()

# # get face detections to determine gender of the depicted persons 
# for img_nmb in [1727]: #tqdm(range(1, 202599 + 1)): # 101283: no detection (lying women)
#     img_number = f"{img_nmb:06}"
#     img_path = f"./CelebA/{img_number}.jpg"
#     face_detection = get_square_detection(img_nmb, img_path, bboxes)
#     # face_detection.show()
#     # face_detection.save(f"fairface/detected_faces_CelebA/{img_number}.jpg")
    
