In [37]:
import torch
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
from transformers import CLIPTokenizer, CLIPTextModel
from PIL import Image
import torchvision.transforms as transforms
from torchvision.utils import save_image
import os

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

def load_models(model_path):
    """Load all required models from the saved path"""
    print("Loading models...")
    
    # Load models from saved weights
    unet = UNet2DConditionModel.from_pretrained(model_path,low_cpu_mem_usage=False).to(device)
    text_encoder = CLIPTextModel.from_pretrained(model_path).to(device)
    autoencoder = AutoencoderKL.from_pretrained(model_path).to(device)
    
    # Load tokenizer (this is usually loaded from original CLIP)
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
    
    # Load scheduler configuration
    scheduler = DDPMScheduler.from_config(os.path.join(model_path, "scheduler_config.json"))
    
    return unet, text_encoder, autoencoder, tokenizer, scheduler

Using device: cuda


In [38]:
def generate_image(prompt, num_inference_steps=50, image_size=32, guidance_scale=7.5):
    """Generate an image from a text prompt"""
    
    # Load models
    model_path = "./latent_diffusion_model"
    unet, text_encoder, autoencoder, tokenizer, scheduler = load_models(model_path)
    
    # Set models to evaluation mode
    unet.eval()
    text_encoder.eval()
    autoencoder.eval()
    
    # Encode prompt
    text_input = tokenizer(
        prompt,
        padding="max_length",
        max_length=77,
        truncation=True,
        return_tensors="pt"
    ).to(device)
    
    with torch.no_grad():
        text_embeddings = text_encoder(text_input.input_ids)[0]
    
    # Initialize random latents
    latents = torch.randn(
        (1, 4, image_size // 8, image_size // 8),
        device=device
    )
    
    # Set timesteps
    scheduler.set_timesteps(num_inference_steps)
    latents = latents * scheduler.init_noise_sigma
    
    # Denoising loop
    for t in scheduler.timesteps:
        print(f"Denoising step {t}", end="\r")
        
        # Prepare latent input
        latent_model_input = scheduler.scale_model_input(latents, timestep=t)
        
        # Predict noise
        with torch.no_grad():
            noise_pred = unet(
                latent_model_input,
                t,
                encoder_hidden_states=text_embeddings
            ).sample
        
        # Guidance scale (not necessary but can improve results)
        # latents = scheduler.step(noise_pred, t, latents, guidance_scale=guidance_scale).prev_sample
        latents = scheduler.step(noise_pred, t, latents).prev_sample
    
    # Decode latents to image
    with torch.no_grad():
        latents = 1 / 0.18215 * latents
        image = autoencoder.decode(latents).sample
        
    # Normalize image
    image = (image / 2 + 0.5).clamp(0, 1)
    
    return image

In [39]:
# Example usage
prompt = "a beautiful sunset over mountains"
print(f"Generating image for prompt: {prompt}")
    
# Generate image
generated_image = generate_image(prompt)
    
# Save the generated image
save_image(generated_image, "generated_image.png")
print("\nImage saved as 'generated_image.png'")

Generating image for prompt: a beautiful sunset over mountains
Loading models...


Some weights of the model checkpoint at ./latent_diffusion_model were not used when initializing UNet2DConditionModel: ['decoder.up_blocks.1.resnets.0.norm1.weight', 'encoder.down_blocks.2.resnets.1.conv1.bias', 'encoder.down_blocks.1.resnets.0.conv1.bias', 'encoder.mid_block.attentions.0.group_norm.bias', 'decoder.up_blocks.1.resnets.1.conv2.bias', 'encoder.mid_block.attentions.0.group_norm.weight', 'encoder.mid_block.attentions.0.to_q.weight', 'decoder.up_blocks.0.resnets.1.conv2.weight', 'decoder.up_blocks.1.resnets.0.conv2.weight', 'decoder.mid_block.attentions.0.group_norm.bias', 'encoder.down_blocks.0.resnets.1.conv2.weight', 'encoder.down_blocks.1.resnets.0.norm2.weight', 'decoder.up_blocks.2.resnets.2.conv1.weight', 'encoder.down_blocks.3.resnets.1.conv1.weight', 'encoder.down_blocks.1.resnets.1.norm2.bias', 'quant_conv.weight', 'decoder.mid_block.resnets.0.norm2.weight', 'decoder.mid_block.resnets.0.norm1.weight', 'decoder.conv_norm_out.weight', 'encoder.mid_block.resnets.0.co

Denoising step 980

RuntimeError: Given groups=1, weight of size [128, 3, 3, 3], expected input[1, 4, 4, 4] to have 3 channels, but got 4 channels instead