In [1]:
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from typing import Tuple, Optional, Any
import math
from torch import Tensor
from einops import rearrange
from diffusers import DDPMScheduler
from torchvision import transforms, datasets
from tqdm import tqdm
from IPython.display import clear_output
import einops
from torch.utils.data.dataset import Dataset
from dataclasses import dataclass
from typing import List, Union, Dict, Any, Optional, BinaryIO
from PIL import Image, ImageColor, ImageDraw, ImageFont
import abc  # Abstract base class
import numpy as np
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import DDPMScheduler, DDPMPipeline
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from accelerate import Accelerator
from accelerate.utils import set_seed

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Check PyTorch installation
def check_torch_cuda():
    import torch
    import sys
    print("Current PyTorch installation:")
    print(f"Python version: {sys.version}")
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA version: {torch.version.cuda}")
        return True
    else:
        print("\nTo install PyTorch with CUDA support, run:")
        print("pip uninstall torch torchvision -y")
        if sys.platform.startswith('win'):
            print("pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121")
        else:
            print("pip install torch torchvision")
        return False

has_cuda = check_torch_cuda()

Current PyTorch installation:
Python version: 3.11.8 (tags/v3.11.8:db85d51, Feb  6 2024, 22:03:32) [MSC v.1937 64 bit (AMD64)]
PyTorch version: 2.8.0+cu126
CUDA available: True
CUDA version: 12.6


In [3]:
# Install PyTorch with CUDA support
import sys
import subprocess

def install_torch_cuda():
    subprocess.check_call([sys.executable, "-m", "pip", "install", 
                          "torch", "torchvision", 
                          "--index-url", "https://download.pytorch.org/whl/cu121"])
    print("Please restart the kernel to use the newly installed PyTorch version.")

In [4]:
# Detailed GPU/CPU diagnostics
print("\n=== System Information ===")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

# Try to get CUDA version and GPU info
cuda_available = torch.cuda.is_available()
if cuda_available:
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU count: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("\nNo CUDA-capable GPU found. Checking system information...")
    import subprocess
    try:
        # Try to run nvidia-smi
        nvidia_smi = subprocess.check_output("nvidia-smi", shell=True).decode()
        print("\nnvidia-smi output:")
        print(nvidia_smi)
    except:
        print("\nNVIDIA driver not found. System details:")
        import platform
        print(f"OS: {platform.platform()}")
        print(f"Python version: {platform.python_version()}")

# Initialize accelerator based on available hardware
use_cpu = not cuda_available
if use_cpu:
    print("\nWARNING: No GPU detected. Using CPU mode (this will be much slower).")
    print("For optimal performance, please ensure:")
    print("1. You have an NVIDIA GPU installed")
    print("2. NVIDIA drivers are properly installed")
    print("3. PyTorch is installed with CUDA support")
    print("\nInstallation commands for PyTorch with CUDA:")
    print("pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121")
    quit()

accelerator = Accelerator(
    gradient_accumulation_steps=1,
    mixed_precision='no',  # Disable mixed precision since we're handling FP16 manually
    cpu=use_cpu
)

print(f"\n=== Accelerator Configuration ===")
print(f"Using device: {accelerator.device}")
print(f"Mixed precision type: {accelerator.mixed_precision}")
print(f"Distributed type: {accelerator.distributed_type}")
print(f"Num processes: {accelerator.num_processes}")

if not use_cpu:
    # Print GPU memory info if available
    print(f"\n=== GPU Memory Usage ===")
    print(f"Allocated: {torch.cuda.memory_allocated() // 1024 // 1024}MB")
    print(f"Cached: {torch.cuda.memory_reserved() // 1024 // 1024}MB")


=== System Information ===
PyTorch version: 2.8.0+cu126
CUDA available: True
CUDA version: 12.6
GPU count: 1
GPU 0: NVIDIA GeForce RTX 3070

=== Accelerator Configuration ===
Using device: cuda
Mixed precision type: fp16
Distributed type: DistributedType.NO
Num processes: 1

=== GPU Memory Usage ===
Allocated: 0MB
Cached: 0MB


In [5]:
# Set seed for reproducibility
set_seed(42)

# Custom UNet2DConditionModel with reduced parameters
class CustomUNet2DConditionModel(UNet2DConditionModel):
    def __init__(self, **kwargs):
        super().__init__(
            sample_size=32,  # CIFAR-10 images are 32x32
            in_channels=3,   # RGB images have 3 channels
            out_channels=3,  # Output also needs 3 channels
            layers_per_block=2,  # Reduced from default
            block_out_channels=(32, 64, 64, 32),  # Reduced from default
            down_block_types=(
                "DownBlock2D",
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "DownBlock2D",
            ),
            up_block_types=(
                "UpBlock2D",
                "CrossAttnUpBlock2D",
                "CrossAttnUpBlock2D",
                "UpBlock2D",
            ),
            cross_attention_dim=512,  # CLIP embedding dimension
            **kwargs
        )

# Load models with explicit dtype
unet = CustomUNet2DConditionModel().to(dtype=torch.float16)

In [6]:
# Load text encoder and tokenizer with safetensors
text_encoder = CLIPTextModel.from_pretrained(
    "openai/clip-vit-base-patch32",
    use_safetensors=True,
    torch_dtype=torch.float16  # Use float16 for mixed precision
)
text_encoder.to(accelerator.device)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

# Freeze text encoder parameters
text_encoder.requires_grad_(False)

`torch_dtype` is deprecated! Use `dtype` instead!


CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e

In [7]:
tokenizer.model_max_length ## 77

77

In [8]:
TOKENIZER_MAX_LENGTH = 8
captions_sample = ['"A handwritten digit 1']
text_input = tokenizer(captions_sample, padding="max_length", max_length=TOKENIZER_MAX_LENGTH, truncation=True, return_tensors="pt")
# Move input_ids tensor to the correct device
text_input['input_ids'] = text_input['input_ids'].to(accelerator.device)
text_embeddings = text_encoder(text_input.input_ids)[0]
print(f"Text embedding shape: {text_embeddings.shape}")  # Print shape for verification

Text embedding shape: torch.Size([1, 8, 512])


In [9]:
num_of_params = sum([p.numel() for p in unet.parameters()])
print("Number of trainable parameters in the model: " + str(f"{num_of_params:,}"))

Number of trainable parameters in the model: 3,141,539


In [10]:
path_to_dataset = "../../datasets"

# CIFAR-10 classes for captions
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Load CIFAR-10 dataset
cifar10_dataset = torchvision.datasets.CIFAR10(
    root=path_to_dataset, 
    download=True, 
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
    ])
)

noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")
train_dataloader = DataLoader(cifar10_dataset, batch_size=256, shuffle=True)
next(iter(train_dataloader))[0].shape

torch.Size([256, 3, 32, 32])

In [None]:
# Training parameters
num_epochs = 1
learning_rate = 1e-4  # Reduced learning rate for stability
optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)

# Prepare for distributed training
unet, optimizer, train_dataloader, text_encoder = accelerator.prepare(
    unet, optimizer, train_dataloader, text_encoder
)

# Training loop
progress_bar = tqdm(range(num_epochs * len(train_dataloader)), desc="Training")

for epoch in range(num_epochs):
    unet.train()
    for batch in train_dataloader:
        with accelerator.accumulate(unet):
            images, labels = batch
            
            # Create text captions from labels using CIFAR-10 class names
            captions = [f"An image of a {cifar10_classes[label]}" for label in labels]

            # Use images directly as latents since we're not using VAE
            latents = images.to(device=accelerator.device, dtype=torch.float16)

            # Encode text
            text_input = tokenizer(captions, padding="max_length", max_length=TOKENIZER_MAX_LENGTH, truncation=True, return_tensors="pt")
            text_input = text_input.to(accelerator.device)
            text_embeddings = text_encoder(text_input.input_ids)[0].to(dtype=torch.float16)

            # Add noise to latents
            noise = torch.randn_like(latents, dtype=torch.float16)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device)
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Predict noise residual
            noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample

            # Calculate loss (use float32 for better numerical stability in loss computation)
            loss = F.mse_loss(noise_pred.float(), noise.float())

            # Backpropagate and optimize
            accelerator.backward(loss)
            
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(unet.parameters(), 1.0)
            
            optimizer.step()
            optimizer.zero_grad()

        progress_bar.update(1)
        progress_bar.set_postfix(loss=loss.detach().item())

    if accelerator.is_main_process:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

Training:   0%|          | 0/196 [00:00<?, ?it/s]

ValueError: Attempting to unscale FP16 gradients.

In [None]:
num_epochs=5
steps = 0

# Training loop with accelerator
progress_bar = tqdm(range(num_epochs * len(train_dataloader)), desc="Training")

for epoch in range(num_epochs):
    unet.train()
    for batch in train_dataloader:
        with accelerator.accumulate(unet):
            images, labels = batch
            
            # Create text captions from labels
            captions = [f"A handwritten digit {label}" for label in labels]

            # Use images directly as latents
            latents = images.to(dtype=torch.float16)

            # Encode text
            text_input = tokenizer(captions, padding="max_length", max_length=TOKENIZER_MAX_LENGTH, truncation=True, return_tensors="pt")
            text_input = text_input.to(accelerator.device)
            text_embeddings = text_encoder(text_input.input_ids)[0]

            # Add noise to latents
            noise = torch.randn_like(latents, dtype=torch.float16)
            timesteps = torch.randint(0, 1000, (latents.shape[0],), device=accelerator.device).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Predict noise
            noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample

            # Calculate loss
            loss = F.mse_loss(noise_pred.float(), noise.float())  # Convert to float32 for loss calculation

            # Backpropagate and optimize
            accelerator.backward(loss)
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(unet.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

            steps += 1
            if steps % 25 == 0 and accelerator.is_main_process:
                print(f"steps {steps}, Loss: {loss.item()}")

        progress_bar.update(1)
        if accelerator.is_main_process:
            progress_bar.set_postfix(loss=loss.detach().item())

    if accelerator.is_main_process:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

In [None]:
# Set models to evaluation mode
text_encoder.eval()
unet.eval()

# Inference parameters
num_inference_steps = 50

@torch.no_grad()
def generate_image(prompt):
    if accelerator.is_main_process:
        # Tokenize and encode the text prompt
        text_input = tokenizer(prompt, padding="max_length", max_length=TOKENIZER_MAX_LENGTH, truncation=True, return_tensors="pt")
        text_input = text_input.to(accelerator.device)
        text_embeddings = text_encoder(text_input.input_ids)[0].to(dtype=torch.float16)

        # Initialize latents / create noisy image
        torch.manual_seed(42)
        latents = torch.randn((1, unet.config.in_channels, 32, 32), device=accelerator.device, dtype=torch.float16)

        # Denoise scheduler
        scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", num_train_timesteps=1000)
        scheduler.set_timesteps(num_inference_steps)

        # Denoising loop
        for t in tqdm(scheduler.timesteps):
            # Prepare latent model input
            latent_model_input = scheduler.scale_model_input(latents, t)
            latent_model_input = latent_model_input.to(dtype=torch.float16)

            # Predict noise residual
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

            # Compute previous noisy sample
            latents = scheduler.step(noise_pred, t, latents).prev_sample
            latents = latents.to(dtype=torch.float16)  # Ensure latents stay in float16

        # Post-process image
        image = latents
        image = (image / 2 + 0.5).clamp(0, 1)
        
        return image
    return None

# Generate an image
if accelerator.is_main_process:
    prompt = "An image of a cat"
    generated_image = generate_image(prompt)

    # Display or save the image
    if generated_image is not None:
        image = generated_image[0]
        image = image.permute(1, 2, 0)  # Change from CxHxW to HxWxC for display
        image = image.cpu().numpy()
        image = (image * 255).round().astype("uint8")
        
        plt.figure(figsize=(6, 6))
        plt.imshow(image)
        plt.axis('off')
        plt.show()

In [None]:
## save unet and text_encoder
# Save the trained model
# torch.save(unet.state_dict(), "./unet.pth")


In [None]:
# Set models to evaluation mode
text_encoder.eval()
unet.eval()

# Inference parameters
num_inference_steps = 50
guidance_scale = 8

@torch.no_grad()
def generate_image(prompt, num_images=1):
    if accelerator.is_main_process:
        # Tokenize and encode the text prompt
        text_input = tokenizer(prompt, padding="max_length", max_length=TOKENIZER_MAX_LENGTH, truncation=True, return_tensors="pt")
        text_input = text_input.to(accelerator.device)
        text_embeddings = text_encoder(text_input.input_ids)[0].to(dtype=torch.float16)

        # Prepare unconditioned embeddings for classifier free guidance
        uncond_input = tokenizer([""] * num_images, padding="max_length", max_length=TOKENIZER_MAX_LENGTH, return_tensors="pt")
        uncond_input = uncond_input.to(accelerator.device)
        uncond_embeddings = text_encoder(uncond_input.input_ids)[0].to(dtype=torch.float16)

        # Concatenate text embeddings with unconditional embeddings
        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

        # Initialize latents
        latents = torch.randn((num_images, unet.config.in_channels, 32, 32), device=accelerator.device, dtype=torch.float16)

        # Denoise scheduler
        scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", num_train_timesteps=1000)
        scheduler.set_timesteps(num_inference_steps)

        for t in tqdm(scheduler.timesteps):
            # Expand latents for classifier free guidance
            latent_model_input = torch.cat([latents] * 2)
            latent_model_input = scheduler.scale_model_input(latent_model_input, t)
            latent_model_input = latent_model_input.to(dtype=torch.float16)
            
            # Predict noise residual
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

            # Perform guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # Compute previous noisy sample
            latents = scheduler.step(noise_pred, t, latents).prev_sample
            latents = latents.to(dtype=torch.float16)  # Ensure latents stay in float16

        # Post-process image
        image = latents
        image = (image / 2 + 0.5).clamp(0, 1)
        return image
    return None

# Generate an image
if accelerator.is_main_process:
    prompt = "An image of a cat"
    generated_images = generate_image(prompt)

    # Display or save the image
    if generated_images is not None:
        image = generated_images[0]
        image = image.permute(1, 2, 0)  # Change from CxHxW to HxWxC for display
        image = image.cpu().numpy()
        image = (image * 255).round().astype("uint8")
        
        plt.figure(figsize=(6, 6))
        plt.imshow(image)
        plt.axis('off')
        plt.show()

In [None]:
# Define the guidance scale values
guidance_scales = [0, 5, 10, 20, 50, 100]

# Initialize a figure with a grid of subplots
fig, axs = plt.subplots(nrows=1, ncols=len(guidance_scales), figsize=(20, 4))

# Generate an image for each guidance scale value
for i, guidance_scale in enumerate(guidance_scales):
    # Set the guidance scale
    # global guidance_scale
    # guidance_scale = guidance_scale

    # Generate an image
    prompt = "A handwritten digit 0"
    generated_image = generate_image(prompt)

    # Display or save the image
    image = generated_image[0]  # Remove the extra dimension
    image = image.squeeze(0)
    image = image.detach().cpu().numpy()  # Remove batch dimension
    image = (image * 255).round().astype("uint8")

    # Display the image in the current subplot
    axs[i].imshow(image, cmap='gray')  # Display as grayscale image
    axs[i].set_title(f'Guidance Scale: {guidance_scale}')
    axs[i].axis('off')

# Layout so plots do not overlap
fig.tight_layout()

# Display the plot
plt.show()