In [4]:
import torch
import random
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler
import math
import glob
import random
import json
def seed_everything(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)
# Sinusoidal timestep embedding for diffusion steps
def get_timestep_embedding(timesteps, embedding_dim):
    half_dim = embedding_dim // 2
    emb = torch.exp(
        torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) *
        -(torch.log(torch.tensor(10000.0)) / half_dim)
    )
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # Handle odd embedding dimensions
        emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
    return emb

# Residual block with time and context embeddings
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, context_dim=None):
        super().__init__()
        self.norm1 = nn.GroupNorm(min(32, in_channels), in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.norm2 = nn.GroupNorm(min(32, out_channels), out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        self.context_proj = nn.Linear(context_dim, out_channels) if context_dim else None
        self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, t_emb, context=None):
        h = self.norm1(x)
        h = F.silu(h)
        h = self.conv1(h)

        # Add time embedding
        t_proj = self.time_mlp(t_emb)[:, :, None, None]
        h = h + t_proj

        # Add context embedding if available
        if self.context_proj is not None and context is not None:
            context_pooled = context.mean(dim=1)  # [batch, context_dim]
            context_proj = self.context_proj(context_pooled)[:, :, None, None]
            h = h + context_proj

        h = self.norm2(h)
        h = F.silu(h)
        h = self.conv2(h)

        return h + self.shortcut(x)

# Cross-attention to integrate text embeddings
class CrossAttention(nn.Module):
    def __init__(self, channels, context_dim):
        super().__init__()
        self.channels = channels
        self.query = nn.Linear(channels, channels)
        self.key = nn.Linear(context_dim, channels)
        self.value = nn.Linear(context_dim, channels)
        self.out = nn.Linear(channels, channels)
        self.norm = nn.LayerNorm(channels)

    def forward(self, x, context):
        if context is None:
            return x

        B, C, H, W = x.shape
        x_flat = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
        x_norm = self.norm(x_flat)

        q = self.query(x_norm)  # [B, H*W, C]
        k = self.key(context)   # [B, seq_len, C]
        v = self.value(context) # [B, seq_len, C]

        scale = (C ** -0.5)
        attn_weights = torch.bmm(q, k.transpose(1, 2)) * scale
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_out = torch.bmm(attn_weights, v)
        attn_out = self.out(attn_out)

        attn_out = attn_out.reshape(B, H, W, C).permute(0, 3, 1, 2)
        return x + attn_out

# Self-attention block for image features
class AttentionBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.norm = nn.GroupNorm(min(32, channels), channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.proj = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.norm(x)
        qkv = self.qkv(h).reshape(B, 3, C, H * W)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]

        scale = (C ** -0.5)
        attn = torch.bmm(q.transpose(1, 2), k) * scale
        attn = F.softmax(attn, dim=-1)

        out = torch.bmm(v, attn.transpose(1, 2))
        out = out.reshape(B, C, H, W)
        return self.proj(out) + x

# U-Net model updated for 256x256 latents
class UNetConditional(nn.Module):
    def __init__(self, in_channels=4, base_channels=128, context_dim=768):
        super().__init__()
        self.time_emb_dim = base_channels * 4
        from types import SimpleNamespace
        self.config = SimpleNamespace()
        self.config._diffusers_version = "0.34.0"
        self.config.in_channels = in_channels
        self.config.out_channels = in_channels
        self.config.sample_size = 256  # Updated for 256x256 latents
        self.config.layers_per_block = 2
        self.config.block_out_channels = [base_channels, base_channels * 2, base_channels * 4, base_channels * 8]
        self.config.attention_head_dim = 8
        self.config.cross_attention_dim = context_dim

        # Time embedding MLP
        self.time_mlp = nn.Sequential(
            nn.Linear(base_channels, self.time_emb_dim),
            nn.SiLU(),
            nn.Linear(self.time_emb_dim, self.time_emb_dim),
        )

        # Input projection
        self.input_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        # Encoder
        self.down1 = ResidualBlock(base_channels, base_channels * 2, self.time_emb_dim, context_dim)
        self.downsample1 = nn.Conv2d(base_channels * 2, base_channels * 2, 3, stride=2, padding=1)
        self.cross1 = CrossAttention(base_channels * 2, context_dim)

        self.down2 = ResidualBlock(base_channels * 2, base_channels * 4, self.time_emb_dim, context_dim)
        self.downsample2 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, stride=2, padding=1)
        self.cross2 = CrossAttention(base_channels * 4, context_dim)

        self.down3 = ResidualBlock(base_channels * 4, base_channels * 8, self.time_emb_dim, context_dim)
        self.downsample3 = nn.Conv2d(base_channels * 8, base_channels * 8, 3, stride=2, padding=1)
        self.cross3 = CrossAttention(base_channels * 8, context_dim)

        # Middle
        self.middle1 = ResidualBlock(base_channels * 8, base_channels * 8, self.time_emb_dim, context_dim)
        self.middle_attn = AttentionBlock(base_channels * 8)
        self.middle2 = ResidualBlock(base_channels * 8, base_channels * 8, self.time_emb_dim, context_dim)

        # Decoder
        self.up3 = ResidualBlock(base_channels * 16, base_channels * 4, self.time_emb_dim, context_dim)
        self.upsample3 = nn.ConvTranspose2d(base_channels * 4, base_channels * 4, 4, stride=2, padding=1)
        self.cross_up3 = CrossAttention(base_channels * 4, context_dim)

        self.up2 = ResidualBlock(base_channels * 8, base_channels * 2, self.time_emb_dim, context_dim)
        self.upsample2 = nn.ConvTranspose2d(base_channels * 2, base_channels * 2, 4, stride=2, padding=1)
        self.cross_up2 = CrossAttention(base_channels * 2, context_dim)

        self.up1 = ResidualBlock(base_channels * 4, base_channels, self.time_emb_dim, context_dim)
        self.upsample1 = nn.ConvTranspose2d(base_channels, base_channels, 4, stride=2, padding=1)

        # Output
        self.output_conv = nn.Sequential(
            nn.GroupNorm(min(32, base_channels), base_channels),
            nn.SiLU(),
            nn.Conv2d(base_channels, in_channels, 3, padding=1)
        )

    def forward(self, x, t, context, cfg_scale=1.0):
        t_emb = get_timestep_embedding(t, self.time_emb_dim // 4)
        t_emb = self.time_mlp(t_emb)

        def denoise(x, t_emb, context):
            h = self.input_conv(x)

            # Encoder
            h1 = self.down1(h, t_emb, context)
            h1_cross = self.cross1(h1, context)
            h1_down = self.downsample1(h1_cross)

            h2 = self.down2(h1_down, t_emb, context)
            h2_cross = self.cross2(h2, context)
            h2_down = self.downsample2(h2_cross)

            h3 = self.down3(h2_down, t_emb, context)
            h3_cross = self.cross3(h3, context)
            h3_down = self.downsample3(h3_cross)

            # Middle
            h_mid = self.middle1(h3_down, t_emb, context)
            h_mid = self.middle_attn(h_mid)
            h_mid = self.middle2(h_mid, t_emb, context)

            # Decoder
            h3_cross_resized = F.interpolate(h3_cross, size=h_mid.shape[-2:], mode='nearest')
            h = self.up3(torch.cat([h_mid, h3_cross_resized], dim=1), t_emb, context)
            h = self.upsample3(h)
            h = self.cross_up3(h, context)

            h2_cross_resized = F.interpolate(h2_cross, size=h.shape[-2:], mode='nearest')
            h = self.up2(torch.cat([h, h2_cross_resized], dim=1), t_emb, context)
            h = self.upsample2(h)
            h = self.cross_up2(h, context)

            h1_cross_resized = F.interpolate(h1_cross, size=h.shape[-2:], mode='nearest')
            h = self.up1(torch.cat([h, h1_cross_resized], dim=1), t_emb, context)
            h = self.upsample1(h)

            return self.output_conv(h)

        if cfg_scale == 1.0 or context is None:
            return denoise(x, t_emb, context)

        uncond = denoise(x, t_emb, context=None)
        cond = denoise(x, t_emb, context)
        return uncond + cfg_scale * (cond - uncond)
import torch
from diffusers import AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from PIL import Image
import numpy as np
from tqdm import tqdm
import argparse
import sys
import os

# This script is uploaded to Hugging Face for users to download and use with KahabMinGenT2Im-v1.pt
# Author: Mohammed Kahab K
# It contains the custom UNetConditional class and pipeline for 256x256 image generation



def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

def generate_images_direct(unet_path="KahabMinGenT2Im-v1.pt", device="cuda", output_dir="output", prompt=None, num_inference_steps=1000):
    """Generate 256x256 images with a custom UNet and user-specified text prompt"""
    seed_everything(42)
    print(f"Using device: {device}")

    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Load components
    print("Loading VAE...")
    vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device).eval().requires_grad_(False)

    print("Loading tokenizer and text encoder...")
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
    text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device).eval().requires_grad_(False)

    print("Loading trained UNet...")
    unet = UNetConditional(in_channels=4, base_channels=128, context_dim=768)
    checkpoint = torch.load(unet_path, map_location=device, weights_only=True)
    unet.load_state_dict(checkpoint)
    unet = unet.to(device).eval()

    # Create scheduler
    scheduler = DDPMScheduler(num_train_timesteps=1000)

    # Get prompt from user if not provided
    if prompt is None:
        # Check if running in Jupyter
        if 'ipykernel' in sys.modules:
            prompt = input("Enter your text prompt (e.g., 'A futuristic smartphone'): ").strip()
        else:
            prompt = ""  # Will be handled by argparse default or user input
        if not prompt:
            prompt = "A futuristic smartphone"  # Default prompt if empty

    test_prompts = [prompt]

    print(f"üé® Generating 256x256 images with {num_inference_steps} inference steps...")
    for i, prompt in enumerate(test_prompts):
        print(f"Generating: {prompt}")
        try:
            with torch.no_grad():
                # Encode prompt
                inputs = tokenizer(
                    prompt,
                    padding="max_length",
                    truncation=True,
                    max_length=77,
                    return_tensors="pt"
                )
                inputs = {k: v.to(device) for k, v in inputs.items()}
                text_embeddings = text_encoder(**inputs).last_hidden_state
                print(f"Text embeddings shape: {text_embeddings.shape}, device: {text_embeddings.device}")

                # Create random latents for 256x256 output (256/8 = 32 due to VAE scaling)
                latents = torch.randn(1, 4, 32, 32, device=device, dtype=torch.float32)
                print(f"Initial latents shape: {latents.shape}, device: {latents.device}")

                # Set timesteps
                scheduler.set_timesteps(num_inference_steps)

                # Denoising loop
                for t in tqdm(scheduler.timesteps, desc=f"Denoising {prompt}"):
                    t_tensor = torch.tensor([t], device=device, dtype=torch.long)
                    noise_pred = unet(latents, t_tensor, context=text_embeddings)
                    latents = scheduler.step(noise_pred, t, latents).prev_sample

                print(f"Final latents shape: {latents.shape}")

                # Decode latents to image
                latents = latents / 0.18215
                images = vae.decode(latents).sample
                images = (images / 2 + 0.5).clamp(0, 1)  # Denormalize
                images = images.cpu().permute(0, 2, 3, 1).numpy()
                image = Image.fromarray((images[0] * 255).astype(np.uint8))

                # Save
                filename = f"{output_dir}/generated_256_{i+1}_{prompt.replace(' ', '_')}.png"
                image.save(filename)
                print(f"‚úÖ Saved: {filename}")

        except Exception as e:
            print(f"‚ùå Error generating '{prompt}': {e}")
            print(f"Error type: {type(e).__name__}")
            continue

def main():
    # Check if running in Jupyter
    if 'ipykernel' in sys.modules:
        generate_images_direct(
            unet_path="KahabMinGenT2Im-v1.pt",
            device="cuda" if torch.cuda.is_available() else "cpu",
            output_dir="output",
            prompt=None
        )
    else:
        parser = argparse.ArgumentParser(description="Generate images with custom UNet and text prompt")
        parser.add_argument("--unet_path", type=str, default="KahabMinGenT2Im-v1.pt", help="Path to UNet checkpoint")
        parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use (cuda or cpu)")
        parser.add_argument("--output_dir", type=str, default="output", help="Output directory for generated images")
        parser.add_argument("--prompt", type=str, default=None, help="Text prompt for image generation")
        args = parser.parse_args()

        generate_images_direct(
            unet_path=args.unet_path,
            device=args.device,
            output_dir=args.output_dir,
            prompt=args.prompt
        )

if __name__ == "__main__":
    main()

Using device: cuda
Loading VAE...
Loading tokenizer and text encoder...
Loading trained UNet...


RuntimeError: Error(s) in loading state_dict for UNetConditional:
	Missing key(s) in state_dict: "time_mlp.0.weight", "time_mlp.0.bias", "time_mlp.2.weight", "time_mlp.2.bias", "input_conv.weight", "input_conv.bias", "down1.norm1.weight", "down1.norm1.bias", "down1.conv1.weight", "down1.conv1.bias", "down1.norm2.weight", "down1.norm2.bias", "down1.conv2.weight", "down1.conv2.bias", "down1.time_mlp.weight", "down1.time_mlp.bias", "down1.context_proj.weight", "down1.context_proj.bias", "down1.shortcut.weight", "down1.shortcut.bias", "downsample1.weight", "downsample1.bias", "cross1.query.weight", "cross1.query.bias", "cross1.key.weight", "cross1.key.bias", "cross1.value.weight", "cross1.value.bias", "cross1.out.weight", "cross1.out.bias", "cross1.norm.weight", "cross1.norm.bias", "down2.norm1.weight", "down2.norm1.bias", "down2.conv1.weight", "down2.conv1.bias", "down2.norm2.weight", "down2.norm2.bias", "down2.conv2.weight", "down2.conv2.bias", "down2.time_mlp.weight", "down2.time_mlp.bias", "down2.context_proj.weight", "down2.context_proj.bias", "down2.shortcut.weight", "down2.shortcut.bias", "downsample2.weight", "downsample2.bias", "cross2.query.weight", "cross2.query.bias", "cross2.key.weight", "cross2.key.bias", "cross2.value.weight", "cross2.value.bias", "cross2.out.weight", "cross2.out.bias", "cross2.norm.weight", "cross2.norm.bias", "down3.norm1.weight", "down3.norm1.bias", "down3.conv1.weight", "down3.conv1.bias", "down3.norm2.weight", "down3.norm2.bias", "down3.conv2.weight", "down3.conv2.bias", "down3.time_mlp.weight", "down3.time_mlp.bias", "down3.context_proj.weight", "down3.context_proj.bias", "down3.shortcut.weight", "down3.shortcut.bias", "downsample3.weight", "downsample3.bias", "cross3.query.weight", "cross3.query.bias", "cross3.key.weight", "cross3.key.bias", "cross3.value.weight", "cross3.value.bias", "cross3.out.weight", "cross3.out.bias", "cross3.norm.weight", "cross3.norm.bias", "middle1.norm1.weight", "middle1.norm1.bias", "middle1.conv1.weight", "middle1.conv1.bias", "middle1.norm2.weight", "middle1.norm2.bias", "middle1.conv2.weight", "middle1.conv2.bias", "middle1.time_mlp.weight", "middle1.time_mlp.bias", "middle1.context_proj.weight", "middle1.context_proj.bias", "middle_attn.norm.weight", "middle_attn.norm.bias", "middle_attn.qkv.weight", "middle_attn.qkv.bias", "middle_attn.proj.weight", "middle_attn.proj.bias", "middle2.norm1.weight", "middle2.norm1.bias", "middle2.conv1.weight", "middle2.conv1.bias", "middle2.norm2.weight", "middle2.norm2.bias", "middle2.conv2.weight", "middle2.conv2.bias", "middle2.time_mlp.weight", "middle2.time_mlp.bias", "middle2.context_proj.weight", "middle2.context_proj.bias", "up3.norm1.weight", "up3.norm1.bias", "up3.conv1.weight", "up3.conv1.bias", "up3.norm2.weight", "up3.norm2.bias", "up3.conv2.weight", "up3.conv2.bias", "up3.time_mlp.weight", "up3.time_mlp.bias", "up3.context_proj.weight", "up3.context_proj.bias", "up3.shortcut.weight", "up3.shortcut.bias", "upsample3.weight", "upsample3.bias", "cross_up3.query.weight", "cross_up3.query.bias", "cross_up3.key.weight", "cross_up3.key.bias", "cross_up3.value.weight", "cross_up3.value.bias", "cross_up3.out.weight", "cross_up3.out.bias", "cross_up3.norm.weight", "cross_up3.norm.bias", "up2.norm1.weight", "up2.norm1.bias", "up2.conv1.weight", "up2.conv1.bias", "up2.norm2.weight", "up2.norm2.bias", "up2.conv2.weight", "up2.conv2.bias", "up2.time_mlp.weight", "up2.time_mlp.bias", "up2.context_proj.weight", "up2.context_proj.bias", "up2.shortcut.weight", "up2.shortcut.bias", "upsample2.weight", "upsample2.bias", "cross_up2.query.weight", "cross_up2.query.bias", "cross_up2.key.weight", "cross_up2.key.bias", "cross_up2.value.weight", "cross_up2.value.bias", "cross_up2.out.weight", "cross_up2.out.bias", "cross_up2.norm.weight", "cross_up2.norm.bias", "up1.norm1.weight", "up1.norm1.bias", "up1.conv1.weight", "up1.conv1.bias", "up1.norm2.weight", "up1.norm2.bias", "up1.conv2.weight", "up1.conv2.bias", "up1.time_mlp.weight", "up1.time_mlp.bias", "up1.context_proj.weight", "up1.context_proj.bias", "up1.shortcut.weight", "up1.shortcut.bias", "upsample1.weight", "upsample1.bias", "output_conv.0.weight", "output_conv.0.bias", "output_conv.2.weight", "output_conv.2.bias". 
	Unexpected key(s) in state_dict: "epoch", "model_state_dict", "optimizer_state_dict", "scheduler_state_dict", "loss". 