<span style="color: red;">Requirement when running in Goolge Colab</span>

In [None]:
!pip install diffusers

#  Chapter 6 - Real-world image reconstruction

Since we achieved our goal for synthetic images to apply facial related changes to a generated image, we would like to be able to do the same if any given arbitary image therefore in this chapter we will take a step back from attention layers modifications and focus on whther there is a combination of prompts and latents that can be fed to the model that would lead to the generation of any given image. The apprach is take from the paper, Null-text Inversion for Editing Real Images using Guided Diffusion Models (https://arxiv.org/abs/2211.09794) to precisely and efficiently reconstruct any real images throught the sampling process of Stable Diffusion with the relevant prompt

This part, as before, it's been copied from the previous chapter and you can and run and move to the next cell

In [None]:
import warnings
warnings.filterwarnings("ignore")
from diffusers import StableDiffusionPipeline, DDIMInverseScheduler, DDIMScheduler
import torch
import matplotlib.pyplot as plt
from typing import Optional
from tqdm import tqdm


model_id = "stabilityai/stable-diffusion-2-1-base"

pipe = StableDiffusionPipeline.from_pretrained(model_id)
pipe = pipe.to("cuda")

prompt = "A photo of a woman, straight hair, light blonde and pink hair, smiling expression, grey background"

prompt_embeds = pipe.encode_prompt(prompt=prompt, negative_prompt="", device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=True)

cond_prompt_embeds = prompt_embeds[0]
uncond_prompt_embeds = prompt_embeds[1]

prompt_embeds_combined = torch.cat([uncond_prompt_embeds, cond_prompt_embeds])

initial_latents = torch.randn((1, pipe.unet.in_channels, pipe.unet.config.sample_size, pipe.unet.config.sample_size), generator=torch.Generator().manual_seed(22)).to("cuda")

Before we can start the process we would first need to load a real-world image and convert them to latents space from pixel space through vae

In [None]:
import torchvision
from PIL import Image
import requests

url1 = 'https://raw.githubusercontent.com/OutofAi/StableFace/main/photo.png'
filename1 = url1.split('/')[-1]
response1 = requests.get(url1)
with open(filename1, 'wb') as f:
    f.write(response1.content)


img = Image.open('photo.png')

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((512, 512)),
    torchvision.transforms.ToTensor()
])

loaded_image = transform(img).to("cuda").unsqueeze(0)

if loaded_image.shape[1] == 4:
    loaded_image = loaded_image[:,:3,:,:]
    
with torch.no_grad():
    encoded_image = pipe.vae.encode(loaded_image*2 - 1)
    real_image_latents = pipe.vae.config.scaling_factor * encoded_image.latent_dist.sample()



The idea is quite simple, considering backward and forward path of a scheduler is deterministic, hence adding or removing noise will yield the same results and as the u-net is predicting only the noise, we should be able to apply inverse scheduler instead of scheduler to run a sampling backward and reach from a real image to its relevant latent noise and later by using that latent noise and a relevant prompt to reconstruct the image with Stable Diffusion

So we load the inverse scheduler with the same model

In [None]:
inverse_scheduler = DDIMInverseScheduler.from_pretrained(model_id, subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")


During the inverse process we won't be using CFG, so technically we won't be needing the unconditional state prompt, but for simplicity we keep it and only reduce the guidance_scale to 1 to ignore the CFG

In [None]:
num_inference_steps = 10

# notice we disabled the CFG here by setting guidance scale as 1
guidance_scale = 1
inverse_scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = inverse_scheduler.timesteps

latents = real_image_latents

inversed_latents = []

with torch.no_grad():

    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):

        inversed_latents.append(latents)

        latent_model_input = torch.cat([latents] * 2)

        noise_pred = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds_combined,
            cross_attention_kwargs=None,
            return_dict=False,
        )[0]


        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # using inverser_scheduler instead of scheduler
        latents = inverse_scheduler.step(noise_pred, t, latents, return_dict=False)[0]


# initial state
real_image_initial_latents = latents

Lets display the inverse process latents for each step

In [None]:
import torch
import matplotlib.pyplot as plt

def display_latents(latents):
    with torch.no_grad():
        num_latents = len(latents)
        images_np = []

        for latent in latents:
            image = pipe.vae.decode(latent / pipe.vae.config.scaling_factor, return_dict=False)[0]
            image_np = image.squeeze(0).float().permute(1, 2, 0).detach().cpu().numpy()
            image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())
            images_np.append(image_np)

        # Calculate the figure size based on the number of latents
        fig_width = min(20, 2 * num_latents)  # Max width of 20, 2 inches per image
        fig_height = 2  # Fixed height for all images

        fig, axes = plt.subplots(1, num_latents, figsize=(fig_width, fig_height))

        if num_latents == 1:
            axes = [axes]  # Ensure axes is always iterable

        for i, (ax, image_np) in enumerate(zip(axes, images_np)):
            ax.imshow(image_np)
            ax.axis('off')
            ax.set_title(f'Latent {i}')

        plt.tight_layout()
        plt.show()

In [None]:
display_latents(inversed_latents)

Now Theoretically we should be able to utilise the initial latents and generate a real-world image through the Stable Diffusion pipeline

In [None]:
guidance_scale = 7.5
scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = scheduler.timesteps

latents = real_image_initial_latents

with torch.no_grad():

  for i, t in tqdm(enumerate(timesteps)):

    latent_model_input = torch.cat([latents] * 2)
    noise_pred = pipe.unet(
        latent_model_input,
        t,
        encoder_hidden_states=prompt_embeds_combined,
        cross_attention_kwargs=None,
        return_dict=False,
    )[0]

    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

    latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

  image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
  image_np = image.squeeze(0).float().permute(1,2,0).detach().cpu()
  image_np = image_np - image_np.min()
  image_np = image_np / image_np.max()
  plt.imshow(image_np)

As you noticed the results are less that satisfactory, they resemble the real-world image to a degree but it doesn't capture all the intricate details of the image, the discripancy mainly stems from the CFG which introduces an amount of noise to the process, causing it to deviate from the real-world image. With that in mind we can now benefit from the unconditional state prompt to fix the deviation from the original steps and find a vector that moves our generation in each step to the desired result. In the paper this has been referred to as pivotal tuning

Showed in the (modified) figure below, extracted from the original paper you can see the required process to calculate the mean squared error of each step to move the generation back to it original route. the null-text is our unconditional prompt or negative prompt which is what we trying to replace here

![image.png](assets/image6.png)

((modified) image source: https://arxiv.org/abs/2211.09794)

In [None]:
import torch.nn as nn

W_values = uncond_prompt_embeds.repeat(num_inference_steps, 1, 1)
QT = nn.Parameter(W_values.clone())

In [None]:
import torch.nn.functional as F
import gc

guidance_scale = 7.5
scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = scheduler.timesteps


optimizer = torch.optim.AdamW([QT], lr=0.008)


pipe.vae.eval()
pipe.vae.requires_grad_(False)
pipe.unet.eval()
pipe.unet.requires_grad_(False)

# depending on VRAM you can increase the chunk size to increase training speed
# 10GB VRAM set chunk_size = 1
# 20GB VRAM set chunk_size = 5
# 40GB VRAM set chunk_size = 10
chunk_size = 5
num_chunks = num_inference_steps // chunk_size
residual = num_inference_steps - num_chunks * chunk_size
last_loss = 1

if residual > 0:
   num_chunks += 1

for epoch in range(50):
    gc.collect()
    torch.cuda.empty_cache()

    intermediate_values = real_image_initial_latents.clone().requires_grad_(True)

    if last_loss < 0.02:
      break
    elif last_loss < 0.03:
      for param_group in optimizer.param_groups:
        param_group['lr'] = 0.003
    elif last_loss < 0.035:
      for param_group in optimizer.param_groups:
        param_group['lr'] = 0.006

    for chunk_iter in range(num_chunks):

      accumulated_loss = 0

      range_start = chunk_iter * chunk_size
      range_end = chunk_size + chunk_iter * chunk_size

      if chunk_iter == num_chunks - 1 and residual > 0:
         range_end = range_start + residual

      for i in range(range_start, range_end):

          latents = intermediate_values.detach().clone().requires_grad_(True)

          t = timesteps[i]

          prompt_embeds = torch.cat([QT[i].unsqueeze(0), cond_prompt_embeds.detach()])

          latent_model_input = torch.cat([latents] * 2)

          noise_pred_model = pipe.unet(
              latent_model_input,
              t,
              encoder_hidden_states=prompt_embeds,
              cross_attention_kwargs=None,
              return_dict=False,
          )[0]

          noise_pred_uncond, noise_pred_text = noise_pred_model.chunk(2)
          noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

          intermediate_values = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

          loss = F.mse_loss(inversed_latents[len(timesteps) - 1 - i].detach(), intermediate_values, reduction="mean")
          accumulated_loss += loss
          last_loss = loss

      optimizer.zero_grad()
      accumulated_loss.backward()
      optimizer.step()

      print(f"Average Loss (epoch {epoch} - Step range {range_start}:{range_end}): {accumulated_loss.item() / chunk_size}")
    print(f"Reconstruction Loss (epoch {epoch}): {last_loss.item()}")



In [None]:
guidance_scale = 7.5
scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = scheduler.timesteps

previous_latents = None

with torch.no_grad():
    gc.collect()
    torch.cuda.empty_cache()
    intermediate_values = real_image_initial_latents.clone().requires_grad_(True)

    for i, t in enumerate(timesteps):
        latents_value = intermediate_values.detach().clone().requires_grad_(True)


        prompt_embeds = torch.cat([QT[i].unsqueeze(0), cond_prompt_embeds.detach()])

        latent_model_input = torch.cat([latents_value] * 2)

        # Predict the noise residual
        noise_pred_model = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds,
            cross_attention_kwargs=None,
            return_dict=False,
        )[0]

        noise_pred_uncond, noise_pred_text = noise_pred_model.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        intermediate_values = scheduler.step(noise_pred, t, latents_value, return_dict=False)[0]


    image = pipe.vae.decode(intermediate_values / pipe.vae.config.scaling_factor, return_dict=False)[0]
    image_np = image.squeeze(0).float().permute(1,2,0).detach().cpu()
    image_np = image_np - image_np.min()
    image_np = image_np / image_np.max()
    plt.imshow(image_np)


# Save Training Data
If you want to avoid re-running the training process in the next chapter, you can uncomment this next cell, and save the relevant data and skip the Training Section of the next chapter

In [None]:
combined_data = {
    'initial_latent': real_image_initial_latents,
    'QT': QT
}
torch.save(combined_data, "reconstruction_data.pt")