In [5]:
import torch
from PIL import Image
import numpy as np
from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
from transformers import CLIPTextModel, CLIPTokenizer
import os


device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "stabilityai/stable-diffusion-2-base"
torch_dtype = torch.float16


vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float16).to(device)
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=torch.float16).to(device)
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.float16).to(device)
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
print("All components loaded successfully.")


prompt = "Generate an image of top only with following description: Floral drees suitable for party"
negative_prompt = "picture containing pants, jeans trounsers, any type of background beside white"
guidance_scale = 8.0
num_inference_steps = 50
seed = 42
generator = torch.Generator(device=device).manual_seed(seed)


text_input_ids = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
conditional_embeddings = text_encoder(text_input_ids)[0]
print(f"Conditional text embeddings shape: {conditional_embeddings.shape}")

uncond_input_ids = tokenizer(negative_prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids.to(device)
unconditional_embeddings = text_encoder(uncond_input_ids)[0]
print(f"Unconditional text embeddings shape: {unconditional_embeddings.shape}")

text_embeddings = torch.cat([unconditional_embeddings, conditional_embeddings])
print(f"Combined text embeddings shape (for CFG): {text_embeddings.shape}")


latents = torch.randn(1, vae.config.latent_channels, unet.config.sample_size, unet.config.sample_size, generator=generator, device=device, dtype=torch.float16)
print(f"Initial latents shape: {latents.shape}")
print(f"Initial latents min: {latents.min().item():.4f}, max: {latents.max().item():.4f}, mean: {latents.mean().item():.4f}, std: {latents.std().item():.4f}")


scheduler.set_timesteps(num_inference_steps, device=device)
print(f"Scheduler timesteps set: {scheduler.timesteps.shape[0]} steps.")
print(f"Scheduler timesteps: {scheduler.timesteps.tolist()}")

latents = latents * scheduler.init_noise_sigma
print(f"Latents after initial setup (with init_noise_sigma scaling): {latents.min().item():.4f}, max: {latents.max().item():.4f}, mean: {latents.mean().item():.4f}, std: {latents.std().item():.4f}")

print("Starting denoising process...")
with torch.no_grad():
    for i, t in enumerate(scheduler.timesteps):

        latent_model_input = torch.cat([latents] * 2)
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)

        noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
        print(f"  Step {i+1} (Timestep {t.item()}): Noise Pred Min: {noise_pred.min().item():.4f}, Max: {noise_pred.max().item():.4f}, Mean: {noise_pred.mean().item():.4f}")

        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        latents = scheduler.step(noise_pred, t, latents).prev_sample
        print(f"  Step {i+1} (Timestep {t.item()}): Latents Min: {latents.min().item():.4f}, Max: {latents.max().item():.4f}, Mean: {latents.mean().item():.4f}, Std: {latents.std().item():.4f}")

print("Denoising process completed.")


latents = 1 / vae.config.scaling_factor * latents
print(f"Latents (before VAE decode) min: {latents.min().item():.4f}, max: {latents.max().item():.4f}, mean: {latents.mean().item():.4f}, std: {latents.std().item():.4f}")

print("Decoding latents to image...")
image = vae.decode(latents).sample
print("Image decoded.")
print(f"Decoded image min: {image.min().item():.4f}, max: {image.max().item():.4f}, mean: {image.mean().item():.4f}")


image = torch.clamp(image, -1, 1)
image = (image / 2 + 0.5)
image = torch.clamp(image, 0, 1)
img = image.squeeze(0)
img = (img * 255).byte()
img = img.permute(1, 2, 0).detach().cpu().numpy()
img = Image.fromarray(img)
print("Image prepared for display.")

All components loaded successfully.
Conditional text embeddings shape: torch.Size([1, 77, 1024])
Unconditional text embeddings shape: torch.Size([1, 77, 1024])
Combined text embeddings shape (for CFG): torch.Size([2, 77, 1024])
Initial latents shape: torch.Size([1, 4, 64, 64])
Initial latents min: -3.9277, max: 3.9316, mean: 0.0031, std: 0.9971
Scheduler timesteps set: 50 steps.
Scheduler timesteps: [999, 979, 959, 939, 919, 899, 879, 859, 839, 819, 799, 779, 759, 739, 719, 699, 679, 659, 639, 619, 599, 579, 559, 539, 519, 500, 480, 460, 440, 420, 400, 380, 360, 340, 320, 300, 280, 260, 240, 220, 200, 180, 160, 140, 120, 100, 80, 60, 40, 20]
Latents after initial setup (with init_noise_sigma scaling): -3.9277, max: 3.9316, mean: 0.0031, std: 0.9971
Starting denoising process...
  Step 1 (Timestep 999): Noise Pred Min: -3.9238, Max: 3.9336, Mean: 0.0030
  Step 1 (Timestep 999): Latents Min: -3.9355, Max: 3.9277, Mean: 0.0052, Std: 0.9966
  Step 2 (Timestep 979): Noise Pred Min: -3.9297,

In [None]:
img