# üîß System Diagnostic - Run This First!

Before training, we need to verify that the GPU is properly accessible. This cell will check:
1. Slurm job allocation
2. GPU hardware detection
3. CUDA environment variables
4. PyTorch CUDA compatibility
5. Common issues and solutions

**Run the diagnostic cell below FIRST before proceeding with training!**

In [1]:
#!/usr/bin/env python3
"""Comprehensive GPU Diagnostic for HPC Cluster"""

import subprocess
import sys
import os

print("="*80)
print(" üîç COMPREHENSIVE GPU DIAGNOSTIC FOR HPC CLUSTER")
print("="*80)
print()

# ============================================================================
# 1. SLURM JOB INFORMATION
# ============================================================================
print("1Ô∏è‚É£  SLURM JOB ALLOCATION")
print("-" * 80)

slurm_vars = {
    'SLURM_JOB_ID': 'Job ID',
    'SLURM_JOB_NODELIST': 'Assigned Node(s)',
    'SLURM_NODEID': 'Node ID',
    'SLURM_GPUS': 'Total GPUs Allocated',
    'SLURM_GPUS_ON_NODE': 'GPUs on This Node',
    'SLURM_JOB_GPUS': 'GPU IDs Allocated',
    'SLURM_CPUS_ON_NODE': 'CPUs on Node',
    'SLURM_MEM_PER_NODE': 'Memory per Node',
}

slurm_allocated = False
for var, desc in slurm_vars.items():
    value = os.environ.get(var, 'NOT SET')
    print(f"  {desc:25s}: {value}")
    if var in ['SLURM_GPUS', 'SLURM_GPUS_ON_NODE', 'SLURM_JOB_GPUS']:
        if value != 'NOT SET' and value != '0' and value != '':
            slurm_allocated = True

print()
if slurm_allocated:
    print("  ‚úÖ Slurm has allocated GPU(s) to this job")
else:
    print("  ‚ö†Ô∏è  WARNING: No GPU allocation detected by Slurm!")
    print("     This job may not have requested GPU resources.")
    print()

# ============================================================================
# 2. CUDA ENVIRONMENT VARIABLES
# ============================================================================
print()
print("2Ô∏è‚É£  CUDA ENVIRONMENT VARIABLES")
print("-" * 80)

cuda_vars = {
    'CUDA_VISIBLE_DEVICES': 'Which GPUs are visible to CUDA',
    'CUDA_HOME': 'CUDA installation directory',
    'CUDA_PATH': 'CUDA path',
    'CUDA_ROOT': 'CUDA root directory',
    'LD_LIBRARY_PATH': 'Library path (includes CUDA libs)',
}

cuda_env_ok = False
for var, desc in cuda_vars.items():
    value = os.environ.get(var, 'NOT SET')
    if var == 'LD_LIBRARY_PATH' and value != 'NOT SET':
        cuda_parts = [p for p in value.split(':') if 'cuda' in p.lower() or 'CUDA' in p]
        if cuda_parts:
            print(f"  {var:25s}: {cuda_parts[0]} (and {len(cuda_parts)-1} more)")
        else:
            print(f"  {var:25s}: (no CUDA paths found)")
    else:
        print(f"  {var:25s}: {value}")
    
    if var == 'CUDA_VISIBLE_DEVICES' and value != 'NOT SET':
        cuda_env_ok = True

print()
if cuda_env_ok:
    print("  ‚úÖ CUDA_VISIBLE_DEVICES is set")
else:
    print("  ‚ö†Ô∏è  WARNING: CUDA_VISIBLE_DEVICES not set!")
    print("     GPUs may not be visible to applications.")
    print()

# ============================================================================
# 3. GPU HARDWARE DETECTION
# ============================================================================
print()
print("3Ô∏è‚É£  GPU HARDWARE DETECTION (nvidia-smi)")
print("-" * 80)

try:
    result = subprocess.run(
        ['nvidia-smi', '--query-gpu=index,name,driver_version,memory.total,memory.used,memory.free',
         '--format=csv'],
        capture_output=True, text=True, timeout=5
    )
    
    if result.returncode == 0:
        print(result.stdout)
        print("  ‚úÖ GPU hardware detected successfully")
        hardware_ok = True
    else:
        print(f"  ‚ùå nvidia-smi failed with error:\n{result.stderr}")
        hardware_ok = False
except FileNotFoundError:
    print("  ‚ùå nvidia-smi command not found!")
    hardware_ok = False
except Exception as e:
    print(f"  ‚ùå Error running nvidia-smi: {e}")
    hardware_ok = False

print()

# ============================================================================
# 4. PYTORCH CUDA DETECTION
# ============================================================================
print("4Ô∏è‚É£  PYTORCH CUDA DETECTION")
print("-" * 80)

try:
    import torch
    print(f"  PyTorch Version: {torch.__version__}")
    print(f"  PyTorch Built with CUDA: {torch.version.cuda}")
    print(f"  CUDA Available: {torch.cuda.is_available()}")
    
    if torch.cuda.is_available():
        print(f"  CUDA Device Count: {torch.cuda.device_count()}")
        for i in range(torch.cuda.device_count()):
            print(f"    GPU {i}: {torch.cuda.get_device_name(i)}")
            props = torch.cuda.get_device_properties(i)
            print(f"      Total Memory: {props.total_memory / 1024**3:.2f} GB")
        print()
        print("  ‚úÖ PyTorch can access GPU(s)!")
        pytorch_ok = True
    else:
        print()
        print("  ‚ùå PyTorch CANNOT access GPU!")
        pytorch_ok = False
except ImportError:
    print("  ‚ùå PyTorch is not installed!")
    pytorch_ok = False

print()

# ============================================================================
# 5. SUMMARY
# ============================================================================
print("="*80)
print(" üìã SUMMARY")
print("="*80)
print()

all_checks = {
    'Slurm GPU Allocation': slurm_allocated,
    'CUDA Environment': cuda_env_ok,
    'GPU Hardware (nvidia-smi)': hardware_ok,
    'PyTorch CUDA Access': pytorch_ok,
}

for check, status in all_checks.items():
    status_icon = "‚úÖ" if status else "‚ùå"
    print(f"  {status_icon} {check}")

print()

if all(all_checks.values()):
    print("üéâ ALL CHECKS PASSED! GPU is ready for training.")
else:
    print("‚ö†Ô∏è  ISSUES DETECTED. Review the diagnostic output above.")

print()
print("="*80)

 üîç COMPREHENSIVE GPU DIAGNOSTIC FOR HPC CLUSTER

1Ô∏è‚É£  SLURM JOB ALLOCATION
--------------------------------------------------------------------------------
  Job ID                   : 3030862
  Assigned Node(s)         : gpu8
  Node ID                  : 0
  Total GPUs Allocated     : NOT SET
  GPUs on This Node        : 1
  GPU IDs Allocated        : 0
  CPUs on Node             : 2
  Memory per Node          : 16384

  ‚úÖ Slurm has allocated GPU(s) to this job

2Ô∏è‚É£  CUDA ENVIRONMENT VARIABLES
--------------------------------------------------------------------------------
  CUDA_VISIBLE_DEVICES     : 0
  CUDA_HOME                : /prefix/software/CUDA/11.8.0
  CUDA_PATH                : /prefix/software/CUDA/11.8.0
  CUDA_ROOT                : /prefix/software/CUDA/11.8.0
  LD_LIBRARY_PATH          : /prefix/software/CUDA/11.8.0/nvvm/lib64 (and 2 more)

  ‚úÖ CUDA_VISIBLE_DEVICES is set

3Ô∏è‚É£  GPU HARDWARE DETECTION (nvidia-smi)
--------------------------------------

# Train Conditional UNet2D on MS-COCO with Text Prompts + VAE

This notebook trains a **text-conditional** diffusion model on MS-COCO using:
- **Pretrained VAE** (AutoencoderKL from Stable Diffusion) to encode images to latents
- **Pretrained CLIP** text encoder for text embeddings
- **UNet2DConditionModel** (with cross-attention for text conditioning)
- **DDPM scheduler** for training and inference

## Key Design Choices
- MS-COCO images resized to 256x256
- Images normalized to [-1,1] for VAE compatibility
- VAE encodes to 32x32x4 latents
- CLIP encodes text captions to 77x768 embeddings
- Only the UNet is trained; VAE and CLIP are frozen
- **Text-conditional generation** - generate images from text prompts

## Memory Optimizations for GPU
- **Batch size**: 16 (adjustable based on GPU memory)
- **Image size**: 256x256
- **Mixed precision**: Enabled (FP16)
- **Cache clearing**: Periodic GPU cache clearing to prevent fragmentation
- **DataLoader**: num_workers=2, pin_memory=True for efficiency
- **Streaming dataset**: Load COCO on-the-fly to save disk space

# üé® MS-COCO Text-to-Image Conditional Diffusion Training

## What's Different from Unconditional Training?

1. **Dataset**: MS-COCO (120K images with text captions) instead of LSUN Churches
2. **Model Type**: **UNet2DConditionModel** (with cross-attention layers)
3. **Text Encoder**: CLIP text encoder for processing captions
4. **Text Conditioning**: Cross-attention with text embeddings
5. **Captions**: 5 captions per image (human-written descriptions)
6. **Generation**: Text-to-image - generate from text prompts!

## Quick Start

1. Run the diagnostic cell to verify GPU access
2. Configure training parameters (adjust batch size for your GPU)
3. Start training - checkpoints and sample images saved every 2000 steps
4. Monitor progress through text-to-image samples
5. Generate images from custom text prompts!

## About MS-COCO

- **Dataset**: Microsoft Common Objects in Context (COCO)
- **Size**: ~120,000 training images
- **Captions**: 5 detailed human-written captions per image
- **Content**: Diverse everyday scenes, objects, people, animals
- **Download**: Automatically streamed via HuggingFace (no manual download needed)
- **Quality**: High-quality images and detailed natural language descriptions

## 1. Imports and Setup

In [2]:
from __future__ import annotations

import os
import math
from dataclasses import dataclass
from typing import List, Optional

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
import matplotlib.pyplot as plt

from diffusers import DDPMScheduler
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer

  from .autonotebook import tqdm as notebook_tqdm


## 2. Helper Functions

In [3]:
def get_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


def build_transforms(image_size: int = 256) -> transforms.Compose:
    """Transform COCO images: resize to square, normalize to [-1,1] for VAE."""
    return transforms.Compose(
        [
            transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ]
    )


def seed_everything(seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def encode_text(text: str, tokenizer, text_encoder, device):
    """Encode text prompt into embeddings using CLIP."""
    text_inputs = tokenizer(
        text,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids.to(device)
    
    with torch.no_grad():
        text_embeddings = text_encoder(text_input_ids)[0]
    
    return text_embeddings

## 3. Configuration

In [4]:
@dataclass
class TrainConfig:
    dataset_root: str
    output_dir: str = "./outputs/train10_coco_text2img"
    batch_size: int = 16
    num_epochs: int = 500
    lr: float = 1e-4
    num_train_timesteps: int = 1000
    image_size: int = 256
    seed: int = 42
    mixed_precision: bool = True
    checkpoint_interval: int = 2000  # Save checkpoint every N steps
    # UNet size for COCO (text-conditional)
    unet_block_out_channels: tuple[int, ...] = (128, 256, 512, 512)
    layers_per_block: int = 2
    # Text conditioning
    use_streaming: bool = True  # Stream COCO dataset to save disk space
    classifier_free_guidance_prob: float = 0.1  # 10% unconditional training for CFG


# Configure training parameters - OPTIMIZED FOR COCO TEXT-TO-IMAGE
config = TrainConfig(
    dataset_root="../../datasets",
    output_dir="./outputs/train10_coco_text2img",
    batch_size=16,  # Adjust based on GPU memory
    num_epochs=500,  # Train for 500 epochs
    lr=1e-4,
    image_size=256,  # Full resolution for high-quality images
    mixed_precision=True,  # Enable mixed precision to save memory
    checkpoint_interval=2000,  # Save checkpoint every 2000 steps
    use_streaming=True,  # Stream dataset to save disk space
    classifier_free_guidance_prob=0.1,  # Enable classifier-free guidance
)

print(f"Device: {get_device()}")
print(f"Batch size: {config.batch_size}")
print(f"Epochs: {config.num_epochs}")
print(f"Learning rate: {config.lr}")
print(f"Image size: {config.image_size}")
print(f"Output directory: {config.output_dir}")
print(f"Mixed precision: {config.mixed_precision}")
print(f"Checkpoint interval: {config.checkpoint_interval} steps")
print(f"Classifier-free guidance: {config.classifier_free_guidance_prob * 100}% unconditional")
print(f"Streaming dataset: {config.use_streaming}")
print(f"\nDataset: MS-COCO (Text-Conditional Image Generation)")
print(f"Total training steps: ~{120000 // config.batch_size * config.num_epochs:,} (approx)")

Device: cuda
Batch size: 16
Epochs: 500
Learning rate: 0.0001
Image size: 256
Output directory: ./outputs/train10_coco_text2img
Mixed precision: True
Checkpoint interval: 2000 steps
Classifier-free guidance: 10.0% unconditional
Streaming dataset: True

Dataset: MS-COCO (Text-Conditional Image Generation)
Total training steps: ~3,750,000 (approx)


## 4. Load Models

In [5]:
def create_models(device: torch.device, config: TrainConfig):
    # Pretrained VAE (frozen)
    print("Loading pretrained VAE...")
    vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
    vae.requires_grad_(False)
    vae.eval()
    vae.to(device)

    # Pretrained CLIP text encoder (frozen)
    print("Loading pretrained CLIP text encoder...")
    text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
    text_encoder.requires_grad_(False)
    text_encoder.eval()
    text_encoder.to(device)
    
    # CLIP tokenizer
    print("Loading CLIP tokenizer...")
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

    print("Creating UNet2DConditionModel (text-conditional)...")
    # Conditional UNet with cross-attention for text conditioning
    unet = UNet2DConditionModel(
        sample_size=config.image_size // 8,  # 32 for 256x256
        in_channels=4,
        out_channels=4,
        layers_per_block=config.layers_per_block,
        block_out_channels=config.unet_block_out_channels,
        down_block_types=(
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "DownBlock2D",
        ),
        up_block_types=(
            "UpBlock2D",
            "CrossAttnUpBlock2D",
            "CrossAttnUpBlock2D",
            "CrossAttnUpBlock2D",
        ),
        cross_attention_dim=512,  # CLIP hidden size
    ).to(device)

    num_params = sum(p.numel() for p in unet.parameters() if p.requires_grad)
    print(f"UNet trainable parameters: {num_params:,}")

    return vae, text_encoder, tokenizer, unet


# Initialize models
device = get_device()
seed_everything(config.seed)
vae, text_encoder, tokenizer, unet = create_models(device, config)

Loading pretrained VAE...
Loading pretrained CLIP text encoder...
Loading CLIP tokenizer...
Creating UNet2DConditionModel (text-conditional)...
UNet trainable parameters: 139,680,132


## 5. Prepare Dataset

## 5.1. Load MS-COCO Dataset with Captions

In [6]:
def make_dataloader(config: TrainConfig) -> DataLoader:
    """Load MS-COCO dataset with captions using local cache."""
    from datasets import load_dataset
    from torch.utils.data import Dataset
    import random
    
    print("Loading MS-COCO 2017 Captions dataset...")
    
    # Use cache directory to avoid re-downloading
    cache_dir = os.path.abspath("../../dataset_cache")
    print(f"üìÅ Dataset download/cache directory (absolute path): {cache_dir}")
    
    try:
        # Try shunk031/MSCOCO with cache
        print(f"Attempting to load shunk031/MSCOCO dataset...")
        ds = load_dataset(
            "shunk031/MSCOCO",
            year=2017,
            coco_task="captions",
            split="train",
            cache_dir=cache_dir,
        )
        print(f"‚úÖ Loaded shunk031/MSCOCO dataset")
    except Exception as e:
        print(f"‚ö†Ô∏è  Could not load shunk031/MSCOCO: {e}")
        print("Trying to use pre-downloaded HuggingFaceM4/COCO from cache...")
        
        # Fallback: use the already downloaded HuggingFaceM4/COCO
        ds = load_dataset(
            "HuggingFaceM4/COCO",
            split="train",
            cache_dir=cache_dir,
        )
        print(f"‚úÖ Loaded HuggingFaceM4/COCO from cache")
    
    print(f"Dataset size: {len(ds)}")
    
    tfms = build_transforms(config.image_size)
    
    # Dataset wrapper for COCO with text captions
    class COCOTextImageDataset(Dataset):
        def __init__(self, hf_dataset, transform, cfg_prob=0.1):
            self.dataset = hf_dataset
            self.transform = transform
            self.cfg_prob = cfg_prob
            
        def __len__(self):
            return len(self.dataset)
            
        def __getitem__(self, idx):
            max_retries = 5
            for retry in range(max_retries):
                try:
                    example = self.dataset[idx]
                    
                    # Get image
                    image = example.get('image', None)
                    if image is None:
                        idx = (idx + 1) % len(self.dataset)
                        continue
                        
                    if image.mode != 'RGB':
                        image = image.convert('RGB')
                    image = self.transform(image)
                    
                    # Get caption - handle different dataset formats
                    caption = ""
                    
                    # shunk031/MSCOCO format
                    if 'caption' in example:
                        caption_data = example['caption']
                        if isinstance(caption_data, dict) and 'caption' in caption_data:
                            caption = caption_data['caption']
                        elif isinstance(caption_data, str):
                            caption = caption_data
                    
                    # HuggingFaceM4/COCO format (sentences with 'raw' field)
                    elif 'sentences' in example:
                        captions = example['sentences']
                        if captions and len(captions) > 0:
                            caption_data = random.choice(captions)
                            if isinstance(caption_data, dict):
                                caption = caption_data.get('raw', '')
                            else:
                                caption = str(caption_data)
                    
                    # For classifier-free guidance: randomly drop text
                    if random.random() < self.cfg_prob:
                        caption = ""
                    
                    return {'image': image, 'caption': caption}
                    
                except Exception as e:
                    print(f"Error loading example {idx}: {e}, retrying...")
                    idx = (idx + 1) % len(self.dataset)
                    if retry == max_retries - 1:
                        raise
    
    dataset = COCOTextImageDataset(ds, tfms, config.classifier_free_guidance_prob)
    
    print(f"Wrapped dataset size: {len(dataset)}")
    
    # Use multiple workers for faster data loading
    return DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        drop_last=True,  # Drop incomplete batches
    )


# Create dataloader
print("="*80)
dataloader = make_dataloader(config)
print("="*80)
print("‚úÖ Dataset loaded successfully")
print(f"Batch size: {config.batch_size}")
print(f"Number of batches: ~{len(dataloader)}")


Loading MS-COCO 2017 Captions dataset...
üìÅ Dataset download/cache directory (absolute path): /RG/rg-miray/doshlom4/final_project/notebooks/dataset_cache
Attempting to load shunk031/MSCOCO dataset...


Downloading data:  30%|‚ñà‚ñà‚ñâ       | 5.79G/19.3G [05:00<12:01, 18.8MB/s]  

‚ö†Ô∏è  Could not load shunk031/MSCOCO: 
Trying to use pre-downloaded HuggingFaceM4/COCO from cache...


The repository for HuggingFaceM4/COCO contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/HuggingFaceM4/COCO.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N]  y



Downloading data:   0%|          | 0.00/36.7M [00:00<?, ?B/s][A
Downloading data:   0%|          | 16.4k/36.7M [00:00<07:07, 85.9kB/s][A
Downloading data:   0%|          | 49.2k/36.7M [00:00<04:30, 136kB/s] [A
Downloading data:   0%|          | 115k/36.7M [00:00<02:39, 229kB/s] [A
Downloading data:   0%|          | 164k/36.7M [00:00<02:32, 240kB/s][A
Downloading data:   1%|          | 344k/36.7M [00:00<01:14, 491kB/s][A
Downloading data:   1%|‚ñè         | 508k/36.7M [00:01<00:59, 613kB/s][A
Downloading data:   2%|‚ñè         | 852k/36.7M [00:01<00:36, 997kB/s][A
Downloading data:   3%|‚ñé         | 1.20M/36.7M [00:01<00:28, 1.25MB/s][A
Downloading data:   5%|‚ñå         | 1.87M/36.7M [00:01<00:17, 1.95MB/s][A
Downloading data:   7%|‚ñã         | 2.54M/36.7M [00:01<00:14, 2.42MB/s][A
Downloading data:  11%|‚ñà         | 3.87M/36.7M [00:02<00:08, 3.78MB/s][A
Downloading data:  14%|‚ñà‚ñç        | 5.23M/36.7M [00:02<00:06, 4.78MB/s][A
Downloading data:  19%|‚ñà‚ñâ        | 

FSTimeoutError: 

## 5.2. Dataset Statistics

In [None]:
# Print dataset statistics
print("="*60)
print(" MS-COCO TEXT-TO-IMAGE DATASET STATISTICS")
print("="*60)
print(f"Approximate number of images: ~120,000")
print(f"Captions per image: 5 (human-written)")
print(f"Batch size: {config.batch_size}")
print(f"Estimated steps per epoch: ~{120000 // config.batch_size:,}")
print(f"Total epochs: {config.num_epochs}")
print(f"Estimated total training steps: ~{(120000 // config.batch_size) * config.num_epochs:,}")
print(f"Checkpoint interval: every {config.checkpoint_interval:,} steps")
print(f"Expected checkpoints per epoch: ~{(120000 // config.batch_size) // config.checkpoint_interval}")
print(f"Classifier-free guidance: {config.classifier_free_guidance_prob * 100}% unconditional")
print("="*60)


## 6. Visualize Sample Data

In [None]:
# Visualize a batch with captions
print("Loading sample batch...")
sample_batch = next(iter(dataloader))
sample_images = sample_batch['image']
sample_captions = sample_batch['caption']

print(f"Batch shape: {sample_images.shape}")
print(f"\nSample captions:")
for i, caption in enumerate(sample_captions[:4]):
    print(f"  {i+1}. {caption}")

fig, axes = plt.subplots(2, 8, figsize=(18, 5))
for i, ax in enumerate(axes.flat):
    if i < len(sample_images):
        img = sample_images[i].permute(1, 2, 0).cpu().numpy()
        img = (img + 1) / 2  # Denormalize from [-1,1] to [0,1]
        ax.imshow(img)
        ax.axis('off')
        # Add caption as title (truncated)
        if i < len(sample_captions):
            caption = sample_captions[i][:50] + "..." if len(sample_captions[i]) > 50 else sample_captions[i]
            ax.set_title(caption, fontsize=7, wrap=True)
            
plt.suptitle("Sample MS-COCO Images with Captions", fontsize=16, y=1.02)
plt.tight_layout()
plt.show()

del sample_batch, sample_images, sample_captions
import gc
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()


## 7. Checkpoint Management Functions

In [None]:
def find_latest_checkpoint(output_dir: str):
    """Find the latest checkpoint in the output directory."""
    if not os.path.exists(output_dir):
        return None
    
    # Look for checkpoint files
    checkpoint_files = []
    for filename in os.listdir(output_dir):
        if filename.endswith('.pt') and (filename.startswith('unet_step_') or filename.startswith('unet_epoch_')):
            filepath = os.path.join(output_dir, filename)
            checkpoint_files.append(filepath)
    
    if not checkpoint_files:
        return None
    
    # Get the most recently modified checkpoint
    latest_checkpoint = max(checkpoint_files, key=os.path.getmtime)
    return latest_checkpoint


def load_checkpoint(checkpoint_path: str, unet, optimizer=None):
    """Load checkpoint and return metadata."""
    print(f"Loading checkpoint from: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    # Handle different checkpoint formats
    if isinstance(checkpoint, dict):
        if 'unet' in checkpoint:
            unet.load_state_dict(checkpoint['unet'])
        else:
            unet.load_state_dict(checkpoint)
        
        # Load optimizer state if available
        if optimizer is not None and 'optimizer' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
        
        # Extract metadata
        metadata = {
            'global_step': checkpoint.get('global_step', checkpoint.get('step', 0)),
            'epoch': checkpoint.get('epoch', 0),
            'batch_losses': checkpoint.get('batch_losses', []),
            'epoch_losses': checkpoint.get('epoch_losses', []),
        }
    else:
        unet.load_state_dict(checkpoint)
        metadata = {'global_step': 0, 'epoch': 0, 'batch_losses': [], 'epoch_losses': []}
    
    print(f"Resumed from step {metadata['global_step']}, epoch {metadata['epoch']}")
    return metadata


def save_checkpoint(path: str, unet, optimizer, global_step: int, epoch: int, 
                   batch_losses: List[float], epoch_losses: List[float]):
    """Save checkpoint with complete metadata."""
    checkpoint = {
        'unet': unet.state_dict(),
        'optimizer': optimizer.state_dict(),
        'global_step': global_step,
        'epoch': epoch,
        'batch_losses': batch_losses,
        'epoch_losses': epoch_losses,
    }
    torch.save(checkpoint, path)


# Check for existing checkpoints
latest_checkpoint = find_latest_checkpoint(config.output_dir)
if latest_checkpoint:
    print(f"‚úÖ Found existing checkpoint: {latest_checkpoint}")
    print("=" * 80)
    print("üîç VALIDATING CHECKPOINT BEFORE RESUMING TRAINING")
    print("=" * 80)
    
    # Load checkpoint temporarily to validate
    print("\nLoading checkpoint for validation...")
    temp_metadata = load_checkpoint(latest_checkpoint, unet, None)
    
    print(f"\nüìä Checkpoint Information:")
    print(f"   Global Step: {temp_metadata['global_step']:,}")
    print(f"   Epoch: {temp_metadata['epoch']}")
    print(f"   Batch Losses Recorded: {len(temp_metadata['batch_losses']):,}")
    print(f"   Epoch Losses Recorded: {len(temp_metadata['epoch_losses'])}")
    
    if temp_metadata['batch_losses']:
        recent_losses = temp_metadata['batch_losses'][-100:]
        print(f"   Recent Avg Loss (last 100 batches): {sum(recent_losses)/len(recent_losses):.4f}")
    
    # Generate test samples to verify the checkpoint works
    print(f"\nüé® Generating test samples from checkpoint...")
    validation_samples_path = os.path.join(config.output_dir, f"checkpoint_validation_step_{temp_metadata['global_step']}.png")
    
    generate_checkpoint_samples(
        unet, vae, text_encoder, tokenizer, device, config,
        save_path=validation_samples_path,
        global_step=temp_metadata['global_step'],
        num_samples=16,
        num_inference_steps=50
    )
    
    print(f"‚úÖ Checkpoint validation complete!")
    print(f"   Validation samples saved: {validation_samples_path}")
    print(f"\n‚ñ∂Ô∏è  Training will resume from step {temp_metadata['global_step']}, epoch {temp_metadata['epoch']}")
    print(f"   Next checkpoint will be saved at step {((temp_metadata['global_step'] // config.checkpoint_interval) + 1) * config.checkpoint_interval}")
    print("=" * 80)
else:
    print("No existing checkpoints found. Starting training from scratch.")


### Checkpoint System Features

The checkpoint system provides **complete automatic checkpoint management**:

1. **Automatic Detection & Loading**: 
   - On startup, automatically searches for the latest checkpoint
   - Loads checkpoint and validates it by generating test images
   - Displays checkpoint metadata (step, epoch, loss history)
   - Seamlessly resumes training from the exact point where it stopped

2. **Complete State Preservation**: Each checkpoint stores:
   - ‚úÖ UNet model weights (all parameters)
   - ‚úÖ Optimizer state (Adam momentum, learning rate schedule)
   - ‚úÖ Global step count (exact training progress)
   - ‚úÖ Current epoch number
   - ‚úÖ Complete batch loss history (all losses since start)
   - ‚úÖ Complete epoch loss history

3. **Training Visualization**: Saved with every checkpoint:
   - üìä Training loss plots (batch + epoch)
   - üé® Text-to-image sample generations (16 diverse prompts)

4. **Multiple Checkpoint Types**:
   - **Step checkpoints**: `unet_step_2000.pt`, `unet_step_4000.pt`, ... (every 2000 steps)
   - **Epoch checkpoints**: `unet_epoch_1.pt`, `unet_epoch_2.pt`, ... (after each epoch)
   - **Final checkpoint**: `unet_final.pt` (at completion of all 500 epochs)
   - **Sample images**: `samples_step_2000.png`, `samples_epoch_1.png`, `samples_final.png`

5. **Validation Before Resume**:
   - Generates 16 test images from the checkpoint to verify model quality
   - Shows recent average loss (last 100 batches)
   - Calculates next checkpoint step
   - Saves validation images as `checkpoint_validation_step_XXXX.png`

**To Resume Training**: Simply re-run the training cell (section 11) - the system automatically:
- Finds the latest checkpoint
- Validates it
- Resumes from the exact step and epoch
- Continues with the same optimizer state and loss history

**No manual intervention needed!** üöÄ

## 7b. Test Checkpoint System (Optional)

**Run this cell to verify the checkpoint system is working correctly**

This will test:
- ‚úÖ Checkpoint directory creation
- ‚úÖ Checkpoint saving
- ‚úÖ Checkpoint loading
- ‚úÖ Metadata preservation

In [None]:
# Test checkpoint system
print("="*80)
print(" üß™ TESTING CHECKPOINT SYSTEM")
print("="*80)

# Test 1: Directory creation
test_dir = os.path.join(config.output_dir, "checkpoint_test")
os.makedirs(test_dir, exist_ok=True)
print("‚úÖ Test 1: Directory creation - PASSED")

# Test 2: Save test checkpoint
test_ckpt_path = os.path.join(test_dir, "test_checkpoint.pt")
test_batch_losses = [0.5, 0.4, 0.3]
test_epoch_losses = [0.4]
save_checkpoint(
    test_ckpt_path, unet, None, 
    global_step=100, epoch=1,
    batch_losses=test_batch_losses,
    epoch_losses=test_epoch_losses
)
print(f"‚úÖ Test 2: Checkpoint save - PASSED")
print(f"   Saved to: {test_ckpt_path}")

# Test 3: Load test checkpoint
if os.path.exists(test_ckpt_path):
    loaded_metadata = load_checkpoint(test_ckpt_path, unet, None)
    assert loaded_metadata['global_step'] == 100, "Global step mismatch!"
    assert loaded_metadata['epoch'] == 1, "Epoch mismatch!"
    assert len(loaded_metadata['batch_losses']) == 3, "Batch losses mismatch!"
    assert len(loaded_metadata['epoch_losses']) == 1, "Epoch losses mismatch!"
    print("‚úÖ Test 3: Checkpoint load - PASSED")
    print(f"   Step: {loaded_metadata['global_step']}")
    print(f"   Epoch: {loaded_metadata['epoch']}")
    print(f"   Losses: {len(loaded_metadata['batch_losses'])} batch, {len(loaded_metadata['epoch_losses'])} epoch")
else:
    print("‚ùå Test 3: Checkpoint load - FAILED (file not found)")

# Test 4: Find latest checkpoint
test_ckpt2_path = os.path.join(test_dir, "unet_step_200.pt")
save_checkpoint(test_ckpt2_path, unet, None, 200, 2, [0.3, 0.2], [0.25])
latest = find_latest_checkpoint(test_dir)
if latest and "200" in latest:
    print("‚úÖ Test 4: Find latest checkpoint - PASSED")
    print(f"   Latest: {os.path.basename(latest)}")
else:
    print(f"‚ùå Test 4: Find latest checkpoint - FAILED")
    print(f"   Found: {latest}")

# Cleanup test files
import shutil
if os.path.exists(test_dir):
    shutil.rmtree(test_dir)
    print("\nüßπ Cleanup: Test files removed")

print("\n" + "="*80)
print(" ‚úÖ ALL CHECKPOINT TESTS PASSED!")
print("="*80)
print("\nThe checkpoint system is working correctly and ready for training.")


## 8. Sample Generation Function

In [None]:
def generate_checkpoint_samples(
    unet, vae, text_encoder, tokenizer, device, config, 
    save_path: str, global_step: int, num_samples: int = 16, num_inference_steps: int = 50
):
    """Generate sample images from text prompts during training checkpoints."""
    print(f"\nGenerating {num_samples} text-to-image samples at step {global_step}...")
    
    # Predefined interesting prompts for visualization
    prompts = [
        "A cat sitting on a couch",
        "A person riding a bicycle",
        "A bowl of fruit on a table",
        "A dog playing in a park",
        "A bird flying in the sky",
        "A car parked on the street",
        "A pizza on a plate",
        "People walking on a beach",
        "A train at a station",
        "A horse in a field",
        "A laptop computer on a desk",
        "A vase with flowers",
        "A person skiing down a mountain",
        "An airplane in the sky",
        "A teddy bear on a bed",
        "A clock on the wall",
    ]
    
    # Create scheduler for sampling
    scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", num_train_timesteps=1000)
    scheduler.set_timesteps(num_inference_steps)
    
    # Set models to eval mode
    unet.eval()
    
    fig, axes = plt.subplots(2, 8, figsize=(20, 6))
    
    with torch.no_grad():
        for i, ax in enumerate(axes.flat):
            if i >= num_samples:
                ax.axis('off')
                continue
            
            # Get prompt and encode
            prompt = prompts[i % len(prompts)]
            text_embeddings = encode_text(prompt, tokenizer, text_encoder, device)
            
            # Set seed for reproducibility
            torch.manual_seed(42 + i)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(42 + i)
            
            # Init random latents
            latents = torch.randn((1, 4, config.image_size // 8, config.image_size // 8), device=device)
            
            # Denoising loop with text conditioning
            for t in scheduler.timesteps:
                latent_model_input = scheduler.scale_model_input(latents, t)
                
                # Predict noise with text conditioning
                noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
                
                # Step
                latents = scheduler.step(noise_pred, t, latents).prev_sample
            
            # Decode latents to image
            latents = latents / 0.18215
            image = vae.decode(latents).sample
            image = (image / 2 + 0.5).clamp(0, 1)
            image = image.detach().cpu()
            
            # Display image with prompt as title
            img = image[0].permute(1, 2, 0).numpy()
            ax.imshow(img)
            ax.axis('off')
            ax.set_title(prompt[:40], fontsize=7)
    
    plt.suptitle(f"Text-to-Image Samples at Step {global_step}", fontsize=16, y=0.98)
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close(fig)
    print(f"Sample images saved: {save_path}")
    
    # Set UNet back to train mode
    unet.train()


## 9. Training Loop

In [None]:
def train(config: TrainConfig, vae, text_encoder, tokenizer, unet, dataloader, device, resume_from_checkpoint: str = None):
    os.makedirs(config.output_dir, exist_ok=True)

    # Noise scheduler for training
    noise_scheduler = DDPMScheduler(num_train_timesteps=config.num_train_timesteps, beta_schedule="squaredcos_cap_v2")

    optimizer = torch.optim.AdamW(unet.parameters(), lr=config.lr)
    scaler = torch.cuda.amp.GradScaler(enabled=config.mixed_precision and device.type == "cuda")

    # Loss tracking and training state
    batch_losses = []
    epoch_losses = []
    global_step = 0
    start_epoch = 0
    
    # Resume from checkpoint if available
    if resume_from_checkpoint:
        metadata = load_checkpoint(resume_from_checkpoint, unet, optimizer)
        global_step = metadata['global_step']
        start_epoch = metadata['epoch']
        batch_losses = metadata['batch_losses']
        epoch_losses = metadata['epoch_losses']
        print(f"Resuming training from step {global_step}, epoch {start_epoch}")

    unet.train()
    
    for epoch in range(start_epoch, config.num_epochs):
        epoch_loss_sum = 0.0
        epoch_batch_count = 0
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{config.num_epochs}")
        
        for batch_idx, batch in enumerate(pbar):
            # Get images and captions from batch
            images = batch['image'].to(device, non_blocking=True)
            captions = batch['caption']

            with torch.no_grad():
                # Encode images to latents using frozen VAE
                latents = vae.encode(images).latent_dist.sample() * 0.18215
                
                # Encode text prompts using frozen CLIP
                text_embeddings = []
                for caption in captions:
                    emb = encode_text(caption, tokenizer, text_encoder, device)
                    text_embeddings.append(emb)
                text_embeddings = torch.cat(text_embeddings, dim=0)

            # Sample noise and timestep; add noise
            noise = torch.randn_like(latents)
            bsz = latents.shape[0]
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=device).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Predict noise with text conditioning
            with torch.autocast(
                device_type=device.type,
                dtype=torch.float16 if (config.mixed_precision and device.type == "cuda") else torch.float32,
                enabled=config.mixed_precision
            ):
                noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
                loss = nn.functional.mse_loss(noise_pred, noise)

            optimizer.zero_grad(set_to_none=True)
            if scaler.is_enabled():
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()

            # Record loss
            loss_value = loss.item()
            batch_losses.append(loss_value)
            epoch_loss_sum += loss_value
            epoch_batch_count += 1

            global_step += 1
            pbar.set_postfix({"loss": f"{loss_value:.4f}", "step": global_step})

            # Clear cache every 50 batches to prevent fragmentation
            if batch_idx % 50 == 0 and torch.cuda.is_available():
                torch.cuda.empty_cache()

            # Save periodic checkpoints
            if global_step % config.checkpoint_interval == 0:
                ckpt_path = os.path.join(config.output_dir, f"unet_step_{global_step}.pt")
                save_checkpoint(ckpt_path, unet, optimizer, global_step, epoch, batch_losses, epoch_losses)
                print(f"\nCheckpoint saved: {ckpt_path}")
                
                # Save training plots with the checkpoint
                plot_path = os.path.join(config.output_dir, f"training_loss_step_{global_step}.png")
                save_loss_plot(batch_losses, epoch_losses, plot_path)
                print(f"Training plot saved: {plot_path}")
                
                # Generate and save sample images
                samples_path = os.path.join(config.output_dir, f"samples_step_{global_step}.png")
                generate_checkpoint_samples(
                    unet, vae, text_encoder, tokenizer, device, config,
                    save_path=samples_path, global_step=global_step
                )
                
                # Clear cache after sampling
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

        # Record epoch average loss
        avg_epoch_loss = epoch_loss_sum / epoch_batch_count if epoch_batch_count > 0 else 0.0
        epoch_losses.append(avg_epoch_loss)
        print(f"Epoch {epoch+1}/{config.num_epochs} - Average Loss: {avg_epoch_loss:.4f}")

        # Save per-epoch checkpoint
        ckpt_path = os.path.join(config.output_dir, f"unet_epoch_{epoch+1}.pt")
        save_checkpoint(ckpt_path, unet, optimizer, global_step, epoch + 1, batch_losses, epoch_losses)
        
        # Save training plots with epoch checkpoint
        plot_path = os.path.join(config.output_dir, f"training_loss_epoch_{epoch+1}.png")
        save_loss_plot(batch_losses, epoch_losses, plot_path)
        print(f"Training plot saved: {plot_path}")
        
        # Generate and save sample images at end of epoch
        samples_path = os.path.join(config.output_dir, f"samples_epoch_{epoch+1}.png")
        generate_checkpoint_samples(
            unet, vae, text_encoder, tokenizer, device, config,
            save_path=samples_path, global_step=global_step
        )
        
        # Clear cache after each epoch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    # Save final
    final_path = os.path.join(config.output_dir, "unet_final.pt")
    save_checkpoint(final_path, unet, optimizer, global_step, config.num_epochs, batch_losses, epoch_losses)
    print(f"\nFinal model saved: {final_path}")
    
    # Save final training plots
    final_plot_path = os.path.join(config.output_dir, "training_loss_final.png")
    save_loss_plot(batch_losses, epoch_losses, final_plot_path)
    print(f"Final training plot saved: {final_plot_path}")
    
    # Generate final sample images
    final_samples_path = os.path.join(config.output_dir, "samples_final.png")
    generate_checkpoint_samples(
        unet, vae, text_encoder, tokenizer, device, config,
        save_path=final_samples_path, global_step=global_step
    )

    return batch_losses, epoch_losses


def save_loss_plot(batch_losses: List[float], epoch_losses: List[float], save_path: str):
    """Helper function to save loss plots without displaying them."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Per-batch loss
    if len(batch_losses) > 0:
        axes[0].plot(batch_losses, linewidth=0.8, alpha=0.7)
        axes[0].set_xlabel("Batch")
        axes[0].set_ylabel("Loss")
        axes[0].set_title("Training Loss per Batch")
        axes[0].grid(True, alpha=0.3)

    # Per-epoch loss
    if len(epoch_losses) > 0:
        axes[1].plot(range(1, len(epoch_losses) + 1), epoch_losses, marker='o', linewidth=2)
        axes[1].set_xlabel("Epoch")
        axes[1].set_ylabel("Average Loss")
        axes[1].set_title("Training Loss per Epoch")
        axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close(fig)  # Close the figure to free memory


## 10. Memory Optimization (Clear GPU Cache)

In [None]:
# Clear GPU cache before training to maximize available memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    # Set memory allocation configuration for better fragmentation handling
    import os
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
    
    # Print current memory status
    print(f"GPU Memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU Memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
    print("Memory cache cleared and optimized for training")

## 11. Run Training

**‚ö†Ô∏è Important Notes:**

### Checkpoint Resumption
- **Automatic**: The training will automatically resume from the latest checkpoint if one exists
- **Validation**: Before resuming, the checkpoint is validated by generating test images
- **State Preservation**: All training state is preserved (step count, epoch, optimizer, losses)

### Training Duration
- **Total epochs**: 500
- **Steps per epoch**: ~7,500
- **Total steps**: ~3,750,000
- **Checkpoint frequency**: Every 2,000 steps
- **Expected checkpoints**: ~1,875 step checkpoints + 500 epoch checkpoints

### Monitoring Progress
- Check `./outputs/train10_coco_text2img/` for:
  - `samples_step_XXXX.png` - Generated text-to-image samples
  - `training_loss_step_XXXX.png` - Training loss plots
  - `unet_step_XXXX.pt` - Model checkpoints

### To Resume After Interruption
Simply **re-run this cell** - the system will:
1. Detect the latest checkpoint automatically
2. Validate it by generating test images
3. Resume training from the exact step where it stopped

**No configuration needed - just run the cell!** üöÄ

In [None]:
# Check for existing checkpoints and resume if available
latest_checkpoint = find_latest_checkpoint(config.output_dir)

# Train the model (will resume from checkpoint if found)
batch_losses, epoch_losses = train(
    config, vae, text_encoder, tokenizer, unet, dataloader, device, 
    resume_from_checkpoint=latest_checkpoint
)


## 12. Visualize Training Loss

In [None]:
def plot_losses(batch_losses: List[float], epoch_losses: List[float], output_dir: str = None):
    """Generate and save loss plots for per-batch and per-epoch losses."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Per-batch loss
    axes[0].plot(batch_losses, linewidth=0.8, alpha=0.7)
    axes[0].set_xlabel("Batch")
    axes[0].set_ylabel("Loss")
    axes[0].set_title("Training Loss per Batch")
    axes[0].grid(True, alpha=0.3)

    # Per-epoch loss
    if len(epoch_losses) > 0:
        axes[1].plot(range(1, len(epoch_losses) + 1), epoch_losses, marker='o', linewidth=2)
        axes[1].set_xlabel("Epoch")
        axes[1].set_ylabel("Average Loss")
        axes[1].set_title("Training Loss per Epoch")
        axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    
    if output_dir:
        loss_plot_path = os.path.join(output_dir, "training_loss.png")
        plt.savefig(loss_plot_path, dpi=150)
        print(f"Loss plots saved to {loss_plot_path}")
    
    plt.show()


# Plot the losses
plot_losses(batch_losses, epoch_losses, config.output_dir)

## 13. Text-to-Image Sampling Function

In [None]:
@torch.no_grad()
def sample(
    prompt: str,
    num_inference_steps: int = 50,
    guidance_scale: float = 7.5,
    seed: Optional[int] = None,
):
    """Generate an image from a text prompt.
    
    Args:
        prompt: Text description of the image to generate
        num_inference_steps: Number of denoising steps
        guidance_scale: Classifier-free guidance scale (higher = more prompt adherence)
        seed: Random seed for reproducibility
    """
    if seed is not None:
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", num_train_timesteps=1000)
    scheduler.set_timesteps(num_inference_steps)

    # Encode text prompt
    text_embeddings = encode_text(prompt, tokenizer, text_encoder, device)
    
    # For classifier-free guidance, also encode empty prompt
    if guidance_scale > 1.0:
        uncond_embeddings = encode_text("", tokenizer, text_encoder, device)
        # Concatenate for batch processing
        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

    # Init random latents in latent space
    latents = torch.randn((1, 4, config.image_size // 8, config.image_size // 8), device=device)

    unet.eval()
    for t in tqdm(scheduler.timesteps, desc=f"Generating '{prompt}'"):
        # Prepare latent input
        latent_model_input = latents
        if guidance_scale > 1.0:
            latent_model_input = torch.cat([latents] * 2)
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)

        # Predict noise with text conditioning
        noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # Perform classifier-free guidance
        if guidance_scale > 1.0:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # Step
        latents = scheduler.step(noise_pred, t, latents).prev_sample

    # Decode latents to image
    latents = latents / 0.18215
    image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu()

    return image


## 14. Generate Text-to-Image Samples

In [None]:
# Generate images from various text prompts
test_prompts = [
    "A cat sitting on a couch",
    "A person riding a bicycle",
    "A bowl of fruit on a table",
    "A dog playing in a park",
    "A bird flying in the sky",
    "A car parked on the street",
    "A pizza on a plate",
    "People walking on a beach",
    "A train at a station",
    "A horse in a field",
    "A laptop computer on a desk",
    "A vase with flowers",
    "A person skiing down a mountain",
    "An airplane in the sky",
    "A teddy bear on a bed",
    "A clock on the wall",
]

num_inference_steps = 50
guidance_scale = 7.5

fig, axes = plt.subplots(4, 4, figsize=(18, 18))
for i, ax in enumerate(axes.flat):
    if i < len(test_prompts):
        prompt = test_prompts[i]
        print(f"Generating {i+1}/{len(test_prompts)}: '{prompt}'")
        generated_image = sample(
            prompt=prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            seed=42 + i,
        )
        
        # Display image
        img = generated_image[0].permute(1, 2, 0).numpy()
        ax.imshow(img)
        ax.axis('off')
        ax.set_title(prompt, fontsize=9, wrap=True)
    else:
        ax.axis('off')

plt.suptitle("Text-to-Image Samples from MS-COCO Model", fontsize=16, y=0.995)
plt.tight_layout()
plt.savefig(os.path.join(config.output_dir, "text_to_image_samples.png"), dpi=150)
plt.show()


## 15. Generate Custom Text-to-Image Sample

In [None]:
# Generate and save a single high-quality sample from a custom prompt
from torchvision.utils import save_image

# CUSTOMIZE YOUR PROMPT HERE
custom_prompt = "A golden retriever dog playing with a ball in a sunny park"

print(f"Generating image for prompt: '{custom_prompt}'")

generated_image = sample(
    prompt=custom_prompt,
    num_inference_steps=50,
    guidance_scale=7.5,
    seed=42,
)

# Save to file
output_path = os.path.join(config.output_dir, "custom_text_to_image.png")
save_image(generated_image, output_path)
print(f"Sample saved to: {output_path}")

# Display
plt.figure(figsize=(8, 8))
img = generated_image[0].permute(1, 2, 0).numpy()
plt.imshow(img)
plt.title(f"Generated: '{custom_prompt}'", fontsize=12, wrap=True)
plt.axis('off')
plt.tight_layout()
plt.show()



Downloading data:  42%|‚ñà‚ñà‚ñà‚ñà‚ñè     | 5.63G/13.5G [05:19<07:09, 18.4MB/s][A