<span style="color: red;">Requirement when running in Goolge Colab</span>

In [None]:
!pip install diffusers

#  Appendix 1 - Stable Face XL

This is a more recent and significantly faster approach to the one with explained in our chapters, it requires no null prompt optimisation and the premise of this apprach relies on the attentions generated by inverse latents during the inverse steps, so by storing and realigning the latent in each step we achieve a similiar editing capability

One of the key difference here is the use of DiffusionPipeline instead of specific Stable Diffusion which gives the code more generalisation ability and obviously the use Stable Diffusion XL instead of Stable Diffusion 2.1 base used in other chatpers.

In [None]:
import warnings
warnings.filterwarnings("ignore")
from diffusers import DiffusionPipeline, DDIMInverseScheduler, DDIMScheduler
import torch
import matplotlib.pyplot as plt
from typing import Optional
from tqdm import tqdm
from diffusers.models.attention_processor import Attention, AttnProcessor2_0

model_id = "stabilityai/stable-diffusion-xl-base-1.0"
model_guidance_scale = 7.5
default_resolution = 1024

pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")

inverse_scheduler = DDIMInverseScheduler.from_pretrained(model_id, subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")


## Load Image

Load the image, resize it accordingly and encode it with pipeline's variational Auto Encoder, we use float16 models and since with XL models the VAE tempts to overflow, we use the upcast_vae to use the float precision for encoding the latent

In [17]:
import torchvision
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import gc
import requests

url1 = 'https://raw.githubusercontent.com/OutofAi/StableFace/main/photo.png'
filename1 = url1.split('/')[-1]
response1 = requests.get(url1)
with open(filename1, 'wb') as f:
    f.write(response1.content)


img = Image.open('photo.png')

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((default_resolution, default_resolution)),
    torchvision.transforms.ToTensor()
])

loaded_image = transform(img).to("cuda").unsqueeze(0)

if loaded_image.shape[1] == 4:
    loaded_image = loaded_image[:,:3,:,:]

with torch.no_grad():
    pipe.upcast_vae()
    encoded_image = pipe.vae.encode(loaded_image*2 - 1)
    real_image_latents = pipe.vae.config.scaling_factor * encoded_image.latent_dist.sample()

## Inverse Step

As this method doesn't require optimisation we use the default pipeline as it is, without dissecting it with the addition of using callbacks in the pipeline to store latents in each inverse step

In [None]:
num_inference_steps = 10

guidance_scale = 1
inverse_scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = inverse_scheduler.timesteps

latents = real_image_latents.half()

inversed_latents = [latents]

def store_latent(pipe, step, timestep, callback_kwargs):
    latents = callback_kwargs["latents"]

    with torch.no_grad():
        if step != num_inference_steps - 1:
            inversed_latents.append(latents)

    return callback_kwargs

with torch.no_grad():

    # replace_attention_processor(pipe.unet, True)

    prompt = "A photo of a woman, straight hair, light blonde and pink hair, smiling expression, grey background"

    pipe.scheduler = inverse_scheduler
    latents = pipe(prompt=prompt,
        negative_prompt="",
        guidance_scale = guidance_scale,
        output_type="latent",
        return_dict=False,
        num_inference_steps=num_inference_steps,
        latents=latents,
        callback_on_step_end=store_latent,
        callback_on_step_end_tensor_inputs=["latents"],)[0]


real_image_initial_latents = latents

## Configure Attention Replacement Class

In [19]:
class AttnReplaceProcessor(AttnProcessor2_0):

    def __init__(self, replace_all):
        super().__init__()
        self.replace_all = replace_all

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        temb: Optional[torch.FloatTensor] = None,
        *args,
        **kwargs,
    ) -> torch.FloatTensor:

        residual = hidden_states

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_scores = attn.scale * torch.bmm(query, key.transpose(-1, -2))

        #############################################################
        ### The replacing process of attention maps happens here  ###
        #############################################################

        dimension_squared = hidden_states.shape[1]

        skip_dimension_1 = pipe.unet.config.sample_size
        skip_dimension_2 = pipe.unet.config.sample_size // 2

        if self.replace_all and not dimension_squared == skip_dimension_1 * skip_dimension_1 and not dimension_squared == skip_dimension_2 * skip_dimension_2:
            ucond_attn_scores_src, ucond_attn_scores_dst, attn_scores_src, attn_scores_dst = attention_scores.chunk(4)
            attn_scores_dst.copy_(attn_scores_src)
            ucond_attn_scores_dst.copy_(ucond_attn_scores_src)

        #############################################################

        attention_probs = attention_scores.softmax(dim=-1)
        del attention_scores

        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)
        del attention_probs

        hidden_states = attn.to_out[0](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states


def replace_attention_processor(unet, clear=False):
    attention_count = 0

    for name, module in unet.named_modules():
        if "attn1" in name and "to" not in name:
            layer_type = name.split(".")[0].split("_")[0]
            attention_count += 1

            if not clear:
                if layer_type == "down":
                    module.processor = AttnReplaceProcessor(True)
                elif layer_type == "mid":
                    module.processor = AttnReplaceProcessor(True)
                elif layer_type == "up":
                    module.processor = AttnReplaceProcessor(True)

            else:
                module.processor = AttnReplaceProcessor(False)


## Apply New Prompt

In [None]:
new_prompt = "A photo of a woman, curly hair, light blonde and pink hair, smiling expression, grey background"
prompt = "A photo of a woman, straight hair, light blonde and pink hair, smiling expression, grey background"

guidance_scale = model_guidance_scale
num_inference_steps = 10
scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = scheduler.timesteps

initial_latents = torch.cat([real_image_initial_latents] * 2)

def adjust_latent(pipe, step, timestep, callback_kwargs):
    replace_attention_processor(pipe.unet)

    with torch.no_grad():
        callback_kwargs["latents"][1] = callback_kwargs["latents"][1] + (inversed_latents[len(timesteps) - 1 - step].detach() - callback_kwargs["latents"][0])
        callback_kwargs["latents"][0] = inversed_latents[len(timesteps) - 1 - step].detach()

    return callback_kwargs


with torch.no_grad():

    replace_attention_processor(pipe.unet)

    pipe.scheduler = scheduler
    latents = pipe(prompt=[prompt, new_prompt],
        guidance_scale = guidance_scale,
        output_type="latent",
        return_dict=False,
        num_inference_steps=num_inference_steps,
        latents=initial_latents,
        callback_on_step_end=adjust_latent,
        callback_on_step_end_tensor_inputs=["latents"],)[0]

    replace_attention_processor(pipe.unet, True)


## Utility Function

In [21]:
def display_latents(latents):
    with torch.no_grad():
        pipe.upcast_vae()
        image_0 = pipe.vae.decode(latents[0].float().unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]
        image_np_0 = image_0.squeeze(0).float().permute(1, 2, 0).detach().cpu()
        image_np_0 = (image_np_0 / 2 + 0.5).clamp(0, 1)

        image_1 = pipe.vae.decode(latents[1].float().unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]
        image_np_1 = image_1.squeeze(0).float().permute(1, 2, 0).detach().cpu()
        image_np_1 = (image_np_1 / 2 + 0.5).clamp(0, 1)

        fig, axes = plt.subplots(1, 2, figsize=(12, 6))

        axes[0].imshow(image_np_0)
        axes[0].axis('off')
        axes[0].set_title('Latent 0')

        axes[1].imshow(image_np_1)
        axes[1].axis('off')
        axes[1].set_title('Latent 1')

        plt.show()

In [None]:
display_latents(latents)