<span style="color: red;">Requirement when running in Goolge Colab</span>

In [None]:
!pip install diffusers

#  Chapter 7 - Stable Face

Now that we have the ability to reconstruct real-world images in Stable Diffusion and also learnt how to replace attention layers, we reach the season finale, to combine these methods to modify any facial features arbitarily.

## Set up the Stable Diffusion pipeline

In [None]:
import warnings
warnings.filterwarnings("ignore")
from diffusers import StableDiffusionPipeline, DDIMInverseScheduler, DDIMScheduler
import torch
import matplotlib.pyplot as plt
from typing import Optional
from tqdm import tqdm


model_id = "stabilityai/stable-diffusion-2-1-base"

pipe = StableDiffusionPipeline.from_pretrained(model_id)
pipe = pipe.to("cuda")

inverse_scheduler = DDIMInverseScheduler.from_pretrained(model_id, subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")

prompt = "A photo of a woman, straight hair, light blonde and pink hair, smiling expression, grey background"

cond_prompt_embeds = pipe.encode_prompt(prompt=prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
uncond_prompt_embeds = pipe.encode_prompt(prompt="", device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]

prompt_embeds_combined = torch.cat([uncond_prompt_embeds, cond_prompt_embeds])

initial_latents = torch.randn((1, pipe.unet.in_channels, pipe.unet.config.sample_size, pipe.unet.config.sample_size), generator=torch.Generator().manual_seed(22)).to("cuda")

## Load Training Data

if you saved the training data from previous chapter you can uncomment this code and run this cell and skip the Training Section

In [None]:
# loaded_data = torch.load("reconstruction_data.pt")

# real_image_initial_latents = loaded_data['initial_latent']
# QT = loaded_data['QT']

## Training Section
#### [ Skip to Configure Attention Replacement Function section if you loaded the data ]

In [None]:
import torchvision
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import gc
import requests

url1 = 'https://raw.githubusercontent.com/OutofAi/StableFace/main/photo.png'
filename1 = url1.split('/')[-1]
response1 = requests.get(url1)
with open(filename1, 'wb') as f:
    f.write(response1.content)


img = Image.open('photo.png')

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((512, 512)),
    torchvision.transforms.ToTensor()
])

loaded_image = transform(img).to("cuda").unsqueeze(0)

if loaded_image.shape[1] == 4:
    loaded_image = loaded_image[:,:3,:,:]

with torch.no_grad():
    encoded_image = pipe.vae.encode(loaded_image*2 - 1)
    real_image_latents = pipe.vae.config.scaling_factor * encoded_image.latent_dist.sample()

num_inference_steps = 10

guidance_scale = 1
inverse_scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = inverse_scheduler.timesteps

latents = real_image_latents

inversed_latents = []

with torch.no_grad():

    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):

        inversed_latents.append(latents)

        latent_model_input = torch.cat([latents] * 2)

        noise_pred = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds_combined,
            cross_attention_kwargs=None,
            return_dict=False,
        )[0]


        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        latents = inverse_scheduler.step(noise_pred, t, latents, return_dict=False)[0]


# initial state
real_image_initial_latents = latents


W_values = uncond_prompt_embeds.repeat(num_inference_steps, 1, 1)
QT = nn.Parameter(W_values.clone())


guidance_scale = 7.5
scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = scheduler.timesteps

optimizer = torch.optim.AdamW([QT], lr=0.008)

pipe.vae.eval()
pipe.vae.requires_grad_(False)
pipe.unet.eval()
pipe.unet.requires_grad_(False)

last_loss = 1

for epoch in range(50):
    gc.collect()
    torch.cuda.empty_cache()

    intermediate_values = real_image_initial_latents.clone()

    if last_loss < 0.02:
        break
    elif last_loss < 0.03:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.003
    elif last_loss < 0.035:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.006

    for i in range(num_inference_steps):
        latents = intermediate_values.detach().clone()

        t = timesteps[i]

        prompt_embeds = torch.cat([QT[i].unsqueeze(0), cond_prompt_embeds.detach()])

        latent_model_input = torch.cat([latents] * 2)

        noise_pred_model = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds,
            cross_attention_kwargs=None,
            return_dict=False,
        )[0]

        noise_pred_uncond, noise_pred_text = noise_pred_model.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        intermediate_values = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

        loss = F.mse_loss(inversed_latents[len(timesteps) - 1 - i].detach(), intermediate_values, reduction="mean")
        last_loss = loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss (epoch {epoch} - Step {i}): {loss.item()}")
    print(f"Reconstruction Loss (epoch {epoch}): {last_loss.item()}")


## Configure Attention Replacement Function

In [None]:
def contextual_forward(self, skip_dimension, should_replace = False):

    def forward_modified(
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        temb: Optional[torch.FloatTensor] = None,
    ) -> torch.FloatTensor:

            residual = hidden_states

            is_cross = not encoder_hidden_states is None

            input_ndim = hidden_states.ndim

            if input_ndim == 4:
                batch_size, channel, height, width = hidden_states.shape
                hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

            batch_size, _, _ = (
                hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
            )

            if self.group_norm is not None:
                hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

            query = self.to_q(hidden_states)

            if encoder_hidden_states is None:
                encoder_hidden_states = hidden_states
            elif self.norm_cross:
                encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)

            key = self.to_k(encoder_hidden_states)
            value = self.to_v(encoder_hidden_states)

            query = self.head_to_batch_dim(query)
            key = self.head_to_batch_dim(key)
            value = self.head_to_batch_dim(value)

            attention_scores = self.scale * torch.bmm(query, key.transpose(-1, -2))

            #############################################################
            ### The replacing process of attention maps happens here    ###
            #############################################################

            dimension_squared = hidden_states.shape[1]

            # our experiement showed that this is the combination granted the best results when it comes
            # to facial related changes, hence this specific configuration for our replacement
            if not is_cross and (should_replace or not dimension_squared == skip_dimension * skip_dimension):
                
                ucond_attn_scores_src, ucond_attn_scores_dst, attn_scores_src, attn_scores_dst = attention_scores.chunk(4)
                attn_scores_dst.copy_(attn_scores_src)
                ucond_attn_scores_dst.copy_(ucond_attn_scores_src)
            #############################################################
            attention_probs = attention_scores.softmax(dim=-1)
            del attention_scores

            hidden_states = torch.bmm(attention_probs, value)
            hidden_states = self.batch_to_head_dim(hidden_states)
            del attention_probs

            hidden_states = self.to_out[0](hidden_states)

            if input_ndim == 4:
                hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

            if self.residual_connection:
                hidden_states = hidden_states + residual

            hidden_states = hidden_states / self.rescale_output_factor

            return hidden_states

    return forward_modified

def apply_forward_function(unet, child = None, should_replace= False):
    if child == None:
        children = unet.named_children()
        for child in children:
            apply_forward_function(unet, child[1], should_replace)
    else:
        if child.__class__.__name__ == 'Attention':
            child.forward = contextual_forward(child, pipe.unet.config.sample_size, should_replace)
        elif hasattr(child, 'children'):
            for sub_child in child.children():
                block_name = child.__class__.__name__

                if "Down" in block_name:
                    should_replace = True
                elif "Up" in block_name:
                    should_replace = False
                elif "Mid" in block_name:
                    should_replace = True
                apply_forward_function(unet, sub_child, should_replace)

## Create a Display Function

In [None]:
def display_latents(latents):
    with torch.no_grad():
        image_0 = pipe.vae.decode(latents[0].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]
        image_np_0 = image_0.squeeze(0).float().permute(1, 2, 0).detach().cpu()
        image_np_0 = (image_np_0 / 2 + 0.5).clamp(0, 1)

        image_1 = pipe.vae.decode(latents[1].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]
        image_np_1 = image_1.squeeze(0).float().permute(1, 2, 0).detach().cpu()
        image_np_1 = (image_np_1 / 2 + 0.5).clamp(0, 1)

        fig, axes = plt.subplots(1, 2, figsize=(12, 6))

        axes[0].imshow(image_np_0)
        axes[0].axis('off')
        axes[0].set_title('Latent 0')

        axes[1].imshow(image_np_1)
        axes[1].axis('off')
        axes[1].set_title('Latent 1')

        plt.show()

# Apply Attention Replacement

In [None]:
new_prompt = "A photo of a woman, curly hair, light blonde and pink hair, smiling expression, grey background"

new_prompt_embeds = pipe.encode_prompt(prompt=new_prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]

guidance_scale = 7.5
num_inference_steps = 10
scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = scheduler.timesteps

latents = torch.cat([real_image_initial_latents] * 2)

with torch.no_grad():

    apply_forward_function(pipe.unet)

    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):

        modified_prompt_embeds = torch.cat([QT[i].unsqueeze(0), QT[i].unsqueeze(0), cond_prompt_embeds, new_prompt_embeds])
        latent_model_input = torch.cat([latents] * 2)

        noise_pred = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=modified_prompt_embeds,
            cross_attention_kwargs=None,
            return_dict=False,
        )[0]


        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

    display_latents(latents)


## Other prompts

In [None]:
new_prompt = "A photo of a male, straight hair, light blonde and pink hair, smiling expression, grey background"

new_prompt_embeds = pipe.encode_prompt(prompt=new_prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]

guidance_scale = 7.5
num_inference_steps = 10
scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = scheduler.timesteps

latents = torch.cat([real_image_initial_latents] * 2)

with torch.no_grad():

    apply_forward_function(pipe.unet)

    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):

        modified_prompt_embeds = torch.cat([QT[i].unsqueeze(0), QT[i].unsqueeze(0), cond_prompt_embeds, new_prompt_embeds])
        latent_model_input = torch.cat([latents] * 2)

        noise_pred = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=modified_prompt_embeds,
            cross_attention_kwargs=None,
            return_dict=False,
        )[0]


        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

    display_latents(latents)

In [None]:
new_prompt = "A photo of a woman, wavy hair, light blonde and pink hair, smiling expression, grey background, closed eyes"

new_prompt_embeds = pipe.encode_prompt(prompt=new_prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]

guidance_scale = 7.5
num_inference_steps = 10
scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = scheduler.timesteps

latents = torch.cat([real_image_initial_latents] * 2)

with torch.no_grad():

    apply_forward_function(pipe.unet)

    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):

        modified_prompt_embeds = torch.cat([QT[i].unsqueeze(0), QT[i].unsqueeze(0), cond_prompt_embeds, new_prompt_embeds])
        latent_model_input = torch.cat([latents] * 2)

        noise_pred = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=modified_prompt_embeds,
            cross_attention_kwargs=None,
            return_dict=False,
        )[0]


        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

    display_latents(latents)

In [None]:
new_prompt = "A photo of a woman, wavy hair, dark blonde and pink hair, smiling expression, grey background, bangs"

new_prompt_embeds = pipe.encode_prompt(prompt=new_prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]

guidance_scale = 7.5
num_inference_steps = 10
scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = scheduler.timesteps

latents = torch.cat([real_image_initial_latents] * 2)

with torch.no_grad():

    apply_forward_function(pipe.unet)

    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):

        modified_prompt_embeds = torch.cat([QT[i].unsqueeze(0), QT[i].unsqueeze(0), cond_prompt_embeds, new_prompt_embeds])
        latent_model_input = torch.cat([latents] * 2)

        noise_pred = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=modified_prompt_embeds,
            cross_attention_kwargs=None,
            return_dict=False,
        )[0]


        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

    display_latents(latents)