In [None]:
import torch
from torch import nn
from torch.optim import AdamW
from diffusers import StableDiffusionPipeline,StableDiffusionImg2ImgPipeline
from transformers import CLIPTextModel, CLIPTokenizer
from PIL import Image
import numpy as np
from io import BytesIO
import requests


In [3]:
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
    model_id, torch_dtype=torch.float16
).to(device)
pipeline.enable_attention_slicing()

In [5]:
def load_image(image_url):
    response = requests.get(image_url)
    image = Image.open(BytesIO(response.content)).convert("RGB").resize((512, 512))
    image = np.array(image).astype(np.float32) / 255.0
    image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).to(device)
    return image.half()
input_image_url = "https://drive.google.com/uc?export=download&id=123Hn9lNcqlQ_leMygrI2HT5g021fs9sA"
input_image = load_image(input_image_url)
target_text = "Bird sitting on human hand"

In [None]:
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(
    model_id, subfolder="text_encoder", torch_dtype=torch.float16
).to(device)
text_inputs = tokenizer(
    target_text,
    padding="max_length",
    max_length=tokenizer.model_max_length,
    truncation=True,
    return_tensors="pt",
)
input_ids = text_inputs.input_ids.to(device)
with torch.no_grad():
    target_embeddings = text_encoder(input_ids)[0]
target_embeddings = target_embeddings.float()
optimized_embeddings = target_embeddings.clone().detach().requires_grad_(True)
optimizer = AdamW([optimized_embeddings], lr=2e-6, eps=1e-4)
mse_loss = nn.MSELoss()
with torch.no_grad():
    latents = pipeline.vae.encode(input_image).latent_dist.sample() * 0.18215
def embedding_regularization(optimized, original, weight=10):
    return weight * torch.nn.functional.mse_loss(optimized, original)
latents = latents.half()
num_steps = 3
for step in range(num_steps):
    t = torch.randint(
        0, pipeline.scheduler.config.num_train_timesteps, (1,), device=device
    ).long()
    noise = torch.randn_like(latents)
    noisy_latents = pipeline.scheduler.add_noise(latents, noise, t)
    with torch.cuda.amp.autocast():
        noise_pred = pipeline.unet(
            noisy_latents, t, encoder_hidden_states=optimized_embeddings.half()
        ).sample
    noise_pred = noise_pred.float()
    noise = noise.float()
    loss = mse_loss(noise_pred, noise)
    reg_loss = embedding_regularization(optimized_embeddings, target_embeddings)
    total_loss = loss + reg_loss
    if torch.isnan(loss) or torch.isinf(loss):
        print(f"Loss is nan or inf at step {step}")
        break
    optimizer.zero_grad()
    total_loss.backward()
    for name, param in [('optimized_embeddings', optimized_embeddings)]:
        if param.grad is not None:
            if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                print(f"Found nan or inf in gradients of {name} at step {step}")
                break
    optimizer.step()
    if torch.isnan(optimized_embeddings).any() or torch.isinf(optimized_embeddings).any():
        print(f"Found nan or inf in optimized_embeddings at step {step}")
        break
    if step % 20 == 0:
        print(f"Step {step}/{num_steps}, Loss: {total_loss.item()}")
optimized_embeddings = optimized_embeddings.detach()

In [None]:
unet = pipeline.unet
unet.train()
unet.enable_gradient_checkpointing()
for param in pipeline.vae.parameters():
    param.requires_grad = False
for param in text_encoder.parameters():
    param.requires_grad = False
for param in unet.parameters():
    param.data = param.data.half()
    param.requires_grad = True
optimizer_unet = AdamW(unet.parameters(), lr=5e-7, eps=1e-4)
num_finetune_steps = 2
accumulation_steps = 8
for step in range(num_finetune_steps):
    t = torch.randint(
        0, pipeline.scheduler.config.num_train_timesteps, (1,), device=device
    ).long()
    noise = torch.randn_like(latents)
    noisy_latents = pipeline.scheduler.add_noise(latents, noise, t)
    embeddings = optimized_embeddings.to(device).half()
    noise_pred = unet(
        noisy_latents, t, encoder_hidden_states=embeddings
    ).sample
    loss = mse_loss(noise_pred, noise) / accumulation_steps
    if torch.isnan(loss) or torch.isinf(loss):
        print(f"Loss is NaN or Inf at step {step}")
        break
    loss.backward()
    if (step + 1) % accumulation_steps == 0:
        optimizer_unet.step()
        optimizer_unet.zero_grad()
    if step % 20 == 0:
        print(f"Fine-tuning Step {step}/{num_finetune_steps}, Loss: {loss.item()}")
unet.eval()

In [8]:
eta = 0.05
interpolated_embeddings = eta * target_embeddings + (1 - eta) * optimized_embeddings
interpolated_embeddings = interpolated_embeddings.to(device).half()
with torch.no_grad():
    generator = torch.Generator(device=device)
    generator.manual_seed(40)
    latents = torch.randn(
        (1, unet.config.in_channels, 64, 64), device=device, generator=generator
    ).half()
    scheduler = pipeline.scheduler
    scheduler.set_timesteps(50)
    for i, t in enumerate(scheduler.timesteps):
        latent_model_input = torch.cat([latents] * 2)
        uncond_embeddings = torch.zeros_like(interpolated_embeddings)
        embeddings = torch.cat([uncond_embeddings, interpolated_embeddings])
        noise_pred = unet(
            latent_model_input,
            t,
            encoder_hidden_states=embeddings,
        ).sample
        guidance_scale = 4
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        latents = scheduler.step(noise_pred, t, latents).prev_sample
    edited_image = pipeline.vae.decode(1 / 0.18215 * latents).sample
    edited_image = (edited_image / 2 + 0.5).clamp(0, 1)
    edited_image = (
        edited_image.cpu().permute(0, 2, 3, 1).numpy()[0] * 255
    ).astype(np.uint8)
    edited_image = Image.fromarray(edited_image)
edited_image.save("embedding.png")


In [None]:
pipeline_final = StableDiffusionImg2ImgPipeline.from_pretrained(
    model_id, torch_dtype=torch.float16
).to(device)
pipeline_final.enable_attention_slicing()
strength = 0.46
guidance_scale = 19
with torch.no_grad():
    edited_image = pipeline_final(
        prompt=target_text,
        image=input_image,
        strength=strength,
        guidance_scale=guidance_scale,
        num_inference_steps=50,
    ).images[0]

edited_image.save("final_edited_image.png")