In [30]:
import torch
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
from transformers import CLIPTokenizer, CLIPTextModel
from PIL import Image
import torchvision.transforms as transforms
from torchvision.utils import save_image
import os

In [31]:
class ShapeDebugger:
    @staticmethod
    def check_shape(tensor, expected_shape, layer_name):
        if tensor.shape != expected_shape:
            raise ValueError(
                f"Shape mismatch in {layer_name}:\n"
                f"Expected shape: {expected_shape}\n"
                f"Got shape: {tensor.shape}"
            )
        print(f"✓ {layer_name} shape: {tensor.shape}")


In [32]:
def load_models(model_path):
    """Load all required models from the saved path"""
    print("Loading models...")
    
    # Load models from saved weights
    unet = UNet2DConditionModel.from_pretrained(model_path, low_cpu_mem_usage=False).to(device)
    text_encoder = CLIPTextModel.from_pretrained(model_path).to(device)
    
    # Load autoencoder from original Stable Diffusion weights instead of local path
    autoencoder = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
    
    # Load tokenizer (this is usually loaded from original CLIP)
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
    
    # Load scheduler configuration
    scheduler = DDPMScheduler.from_config(os.path.join(model_path, "scheduler_config.json"))
    
    return unet, text_encoder, autoencoder, tokenizer, scheduler

In [33]:
# Example usage
device = "cuda" if torch.cuda.is_available() else "cpu"

def generate_image(prompt, num_inference_steps=50, image_size=512, guidance_scale=7.5):
    """Generate an image from a text prompt with shape debugging"""
    
    debug = ShapeDebugger()
    model_path = "./latent_diffusion_model"
    unet, text_encoder, autoencoder, tokenizer, scheduler = load_models(model_path)
    
    # Set evaluation mode
    unet.eval()
    text_encoder.eval()
    autoencoder.eval()
    
    # Text embedding pipeline with shape checking
    text_input = tokenizer(
        prompt,
        padding="max_length",
        max_length=77,
        truncation=True,
        return_tensors="pt"
    ).to(device)
    
    debug.check_shape(
        text_input.input_ids, 
        torch.Size([1, 77]), 
        "Tokenizer output"
    )
    
    with torch.no_grad():
        text_embeddings = text_encoder(text_input.input_ids)[0]
        debug.check_shape(
            text_embeddings,
            torch.Size([1, 77, text_encoder.config.hidden_size]),
            "Text encoder output"
        )
    
    # Latent initialization with shape checking
    # Changed the latent shape to match VAE expectations
    latents = torch.randn(
        (1, 4, image_size // 8, image_size // 8),  # Changed to maintain aspect ratio
        device=device
    )
    debug.check_shape(
        latents,
        torch.Size([1, 4, image_size // 8, image_size // 8]),
        "Initial latents"
    )
    
    # Denoising loop with shape checking
    scheduler.set_timesteps(num_inference_steps)
    latents = latents * scheduler.init_noise_sigma
    
    for t in scheduler.timesteps:
        print(f"\nDenoising step {t}")
        
        latent_model_input = scheduler.scale_model_input(latents, timestep=t)
        debug.check_shape(
            latent_model_input,
            torch.Size([1, 4, image_size // 8, image_size // 8]),
            "Scaled model input"
        )
        
        with torch.no_grad():
            try:
                noise_pred = unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=text_embeddings
                ).sample
                
                debug.check_shape(
                    noise_pred,
                    torch.Size([1, 4, image_size // 8, image_size // 8]),
                    "UNet output"
                )
            except Exception as e:
                print("\nError in UNet forward pass:")
                print(f"Input shapes:")
                print(f"- latent_model_input: {latent_model_input.shape}")
                print(f"- timestep: {t.shape if isinstance(t, torch.Tensor) else 'scalar'}")
                print(f"- text_embeddings: {text_embeddings.shape}")
                raise e
        
        latents = scheduler.step(noise_pred, t, latents).prev_sample
        debug.check_shape(
            latents,
            torch.Size([1, 4, image_size // 8, image_size // 8]),
            "Updated latents"
        )
    
    # Decoding with shape checking
    with torch.no_grad():
        latents = 1 / 0.18215 * latents
        debug.check_shape(
            latents,
            torch.Size([1, 4, image_size // 8, image_size // 8]),
            "Scaled latents for decoder"
        )
        
        try:
            image = autoencoder.decode(latents).sample
            debug.check_shape(
                image,
                torch.Size([1, 3, image_size, image_size]),
                "Decoded image"
            )
        except Exception as e:
            print("\nError in VAE decoding:")
            print(f"Input latents shape: {latents.shape}")
            raise e
    
    # Normalize image
    image = (image / 2 + 0.5).clamp(0, 1)
    
    return image

In [34]:


print(f"Using device: {device}")
    
try:
        prompt = "a beautiful sunset over mountains"
        print(f"\nGenerating image for prompt: '{prompt}'")
        generated_image = generate_image(prompt)
        
        from torchvision.utils import save_image
        save_image(generated_image, "generated_image.png")
        print("\nImage saved successfully!")
        
except Exception as e:
        print(f"\n❌ Error during generation: {str(e)}")

Using device: cuda

Generating image for prompt: 'a beautiful sunset over mountains'
Loading models...


Some weights of the model checkpoint at ./latent_diffusion_model were not used when initializing UNet2DConditionModel: ['encoder.mid_block.resnets.1.norm2.weight', 'encoder.conv_norm_out.weight', 'encoder.mid_block.resnets.0.norm1.bias', 'encoder.down_blocks.0.resnets.1.conv1.bias', 'decoder.mid_block.resnets.0.norm2.weight', 'encoder.down_blocks.3.resnets.0.conv1.weight', 'decoder.mid_block.attentions.0.to_out.0.bias', 'decoder.conv_norm_out.weight', 'decoder.up_blocks.3.resnets.2.norm2.weight', 'encoder.down_blocks.1.resnets.0.conv2.weight', 'encoder.down_blocks.3.resnets.0.conv2.bias', 'encoder.mid_block.attentions.0.to_k.bias', 'encoder.mid_block.resnets.1.conv2.weight', 'decoder.mid_block.attentions.0.group_norm.bias', 'encoder.down_blocks.1.resnets.0.conv_shortcut.bias', 'encoder.down_blocks.1.resnets.0.norm2.bias', 'encoder.mid_block.resnets.1.norm2.bias', 'decoder.up_blocks.1.upsamplers.0.conv.weight', 'decoder.up_blocks.2.resnets.0.conv_shortcut.weight', 'encoder.down_blocks.3

✓ Tokenizer output shape: torch.Size([1, 77])
✓ Text encoder output shape: torch.Size([1, 77, 512])
✓ Initial latents shape: torch.Size([1, 4, 64, 64])

Denoising step 980
✓ Scaled model input shape: torch.Size([1, 4, 64, 64])

Error in UNet forward pass:
Input shapes:
- latent_model_input: torch.Size([1, 4, 64, 64])
- timestep: torch.Size([])
- text_embeddings: torch.Size([1, 77, 512])

❌ Error during generation: Given groups=1, weight of size [128, 3, 3, 3], expected input[1, 4, 64, 64] to have 3 channels, but got 4 channels instead
