In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import os
import numpy as np
from tqdm import tqdm
from PIL import Image
import pandas as pd

In [3]:
import logging

# Set up logging for debugging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class CelebADataset(Dataset):
    """
    PyTorch Dataset for the CelebA dataset using .csv files.
    Supports image-only access or conditioning via attributes, landmarks, and bounding boxes.
    
    Args:
        root_dir (str): Path to the directory containing the dataset CSVs and image folder.
        split (str): One of {'train', 'valid', 'test', 'all'}.
        target_types (list): Subset of ['attr', 'landmarks', 'bbox'].
        transform (callable): Optional transform to apply to images.
        target_transform (callable): Optional transform to apply to targets.
        return_dict (bool): If True, returns targets as a dictionary.
    """
    def __init__(
        self,
        root_dir,
        split='all',
        target_types=None,
        transform=None,
        target_transform=None,
        return_dict=False
    ):
        self.root_dir = root_dir
        self.img_folder = os.path.join(root_dir, "img_align_celeba")
        self.split = split
        self.target_types = target_types if target_types else []
        self.transform = transform
        self.target_transform = target_transform
        self.return_dict = return_dict

        if not os.path.isdir(self.img_folder):
            raise FileNotFoundError(f"Image directory not found: {self.img_folder}")

        # Load image split
        split_path = os.path.join(root_dir, "list_eval_partition.csv")
        partition_df = pd.read_csv(split_path)
        if split != 'all':
            split_map = {'train': 0, 'valid': 1, 'test': 2}
            partition_df = partition_df[partition_df['partition'] == split_map[split]]
        self.image_files = partition_df['image_id'].tolist()

        # Load all metadata into dictionaries
        self.target_data = {}
        for target_type in self.target_types:
            df = self._load_csv(target_type)
            # Create a subset of the dataframe containing only images in our split
            df_subset = df[df['image_id'].isin(self.image_files)]
            self.target_data[target_type] = {}
            
            # Process each row individually to avoid type conversion issues
            for _, row in df_subset.iterrows():
                img_id = row['image_id']
                try:
                    self.target_data[target_type][img_id] = self._convert_to_tensor(row, target_type)
                except Exception as e:
                    print(f"Error processing {img_id} for {target_type}: {e}")
                    # Skip problematic entries rather than failing completely

    def _convert_to_tensor(self, row, target_type):
        """Convert row data to appropriate tensor based on target type"""
        if target_type == 'attr':
            # Attributes are -1/1 in the csv, convert to 0/1
            # Access values as a list rather than numpy array to handle mixed types
            values = [int(row[col]) for col in row.index[1:]]  # Skip image_id
            values = [(v + 1) // 2 for v in values]  # Convert -1/1 to 0/1
            return torch.tensor(values, dtype=torch.float32)
        elif target_type == 'landmarks':
            # For landmarks, ensure we're getting numeric values
            values = [float(row[col]) for col in row.index[1:]]
            return torch.tensor(values, dtype=torch.float32)
        else:  # 'bbox'
            # For bbox, convert to float values
            values = [float(row[col]) for col in row.index[1:]]
            return torch.tensor(values, dtype=torch.float32)

    def _load_csv(self, target_type):
        """Load the appropriate CSV file based on target type"""
        if target_type == 'attr':
            path = os.path.join(self.root_dir, 'list_attr_celeba.csv')
            # Convert string values to integers when loading
            df = pd.read_csv(path)
            # Convert all non-image_id columns to numeric
            for col in df.columns[1:]:
                df[col] = pd.to_numeric(df[col], errors='coerce')
            return df
        elif target_type == 'landmarks':
            path = os.path.join(self.root_dir, 'list_landmarks_align_celeba.csv')
            df = pd.read_csv(path)
            # Convert all non-image_id columns to numeric
            for col in df.columns[1:]:
                df[col] = pd.to_numeric(df[col], errors='coerce')
            return df
        elif target_type == 'bbox':
            path = os.path.join(self.root_dir, 'list_bbox_celeba.csv')
            df = pd.read_csv(path)
            # Convert all non-image_id columns to numeric
            for col in df.columns[1:]:
                df[col] = pd.to_numeric(df[col], errors='coerce')
            return df
        else:
            raise ValueError(f"Unsupported target type: {target_type}")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_id = self.image_files[idx]
        img_path = os.path.join(self.img_folder, img_id)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        if not self.target_types:
            return image

        targets = []
        for t in self.target_types:
            # Get target data for this image, empty tensor if not found
            value = self.target_data.get(t, {}).get(img_id, torch.tensor([]))
            targets.append(value)

        if self.target_transform:
            targets = self.target_transform(targets)

        if self.return_dict:
            return image, {k: v for k, v in zip(self.target_types, targets)}
        elif len(targets) == 1:
            return image, targets[0]
        return image, tuple(targets)


def get_celeba_dataloader(
    root_dir='.',
    split='train',
    batch_size=32,
    image_size=64,
    target_types=None,
    transform=None,
    target_transform=None,
    num_workers=4,
    shuffle=True,
    return_dict=False
):
    """
    Create a DataLoader for the CelebA dataset.
    
    Args:
        root_dir (str): Path to the dataset root
        split (str): One of {'train', 'valid', 'test', 'all'}
        batch_size (int): Batch size
        image_size (int): Size to resize images to (square)
        target_types (list): Subset of ['attr', 'landmarks', 'bbox']
        transform (callable): Transform to apply to images
        target_transform (callable): Transform to apply to targets
        num_workers (int): Number of worker processes for data loading
        shuffle (bool): Whether to shuffle the data
        return_dict (bool): If True, targets are returned as dictionary
        
    Returns:
        DataLoader: PyTorch DataLoader for the CelebA dataset
    """
    # Standard transformation pipeline for images
    if split == 'train' and transform is None:
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)), # Resize to square
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize to [-1, 1]
        ])
    
    elif transform is None:
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),  
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  
        ])    

    dataset = CelebADataset(
        root_dir=root_dir,
        split=split,
        target_types=target_types,
        transform=transform,
        return_dict=return_dict
    )

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True
    )

In [4]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, time_embed_dim=None):
        super().__init__()
        self.time_embed_dim = time_embed_dim
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.norm2 = nn.GroupNorm(8, out_channels)
        
        self.act = nn.SiLU()
        
        if time_embed_dim is not None:
            self.time_mlp = nn.Linear(time_embed_dim, out_channels)
        
    def forward(self, x, time_emb=None):
        h = self.act(self.norm1(self.conv1(x)))
        
        if self.time_embed_dim is not None and time_emb is not None:
            time_emb = self.act(self.time_mlp(time_emb))
            h = h + time_emb.unsqueeze(-1).unsqueeze(-1)
            
        h = self.act(self.norm2(self.conv2(h)))
        return h

In [5]:
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = np.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In [6]:
class SimpleUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, time_dim=128, features=[64, 128, 256, 512]):
        super().__init__()
        self.time_dim = time_dim
        
        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_dim),
            nn.Linear(time_dim, time_dim),
            nn.SiLU(),
            nn.Linear(time_dim, time_dim)
        )
        
        # Initial convolution
        self.init_conv = nn.Conv2d(in_channels, features[0], kernel_size=3, padding=1)
        
        # Encoder pathway
        self.downs = nn.ModuleList()
        for i in range(len(features) - 1):
            self.downs.append(nn.ModuleList([
                Block(features[i], features[i], time_dim),
                Block(features[i], features[i+1], time_dim),
                nn.MaxPool2d(kernel_size=2)
            ]))
            
        # Middle blocks
        self.middle = nn.ModuleList([
            Block(features[-1], features[-1], time_dim),
            Block(features[-1], features[-1], time_dim)
        ])
        
        # Decoder pathway
        self.ups = nn.ModuleList()
        for i in reversed(range(len(features) - 1)):
            self.ups.append(nn.ModuleList([
                Block(features[i+1], features[i+1], time_dim),
                Block(features[i+1], features[i], time_dim),
                nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
            ]))
            
        # Final convolution
        self.final_conv = nn.Sequential(
            Block(features[0], features[0], time_dim),
            nn.Conv2d(features[0], out_channels, kernel_size=1)
        )
        
    def forward(self, x, t):
        # Time embedding
        t = self.time_mlp(t)
        
        # Initial convolution
        x = self.init_conv(x)
        
        # Store residual connections
        residuals = []
        
        # Encoder
        for down_block1, down_block2, downsample in self.downs:
            x = down_block1(x, t)
            x = down_block2(x, t)
            residuals.append(x)
            x = downsample(x)
            
        # Middle
        x = self.middle[0](x, t)
        x = self.middle[1](x, t)
        
        # Decoder
        for up_block1, up_block2, upsample in self.ups:
            x = up_block1(x, t)
            x = up_block2(x, t)
            x = upsample(x)
            residual = residuals.pop()
            
            # Handle potential size mismatch
            if x.shape != residual.shape:
                x = F.interpolate(x, size=residual.shape[2:], mode="bilinear", align_corners=False)
                
            x = x + residual  # Skip connection
            
        # Final convolution
        return self.final_conv(x)

In [7]:
def extract(a, t, x_shape):
    """Extract coefficients at specified timesteps t and reshape to match x_shape"""
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

In [8]:
class DiffusionModel:
    def __init__(self, device, num_timesteps=500, beta_start=1e-4, beta_end=0.02):
        self.device = device
        self.num_timesteps = num_timesteps
        
        # Define beta schedule
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps, device=device)
        # Define alphas
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        
        # Pre-calculate diffusion parameters
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
        self.sqrt_recip_alphas = torch.sqrt(1. / self.alphas)
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)

    def q_sample(self, x_start, t, noise=None):
        """Forward diffusion process"""
        if noise is None:
            noise = torch.randn_like(x_start)
            
        sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
        
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def p_losses(self, denoise_model, x_start, t, noise=None):
        """Calculate training loss for the denoising model"""
        if noise is None:
            noise = torch.randn_like(x_start)
            
        # Add noise to the input image according to timestep t
        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
        
        # Predict the noise using the model
        predicted_noise = denoise_model(x_noisy, t)
        
        # Calculate loss
        loss = F.mse_loss(predicted_noise, noise)
        return loss

    @torch.no_grad()
    def p_sample(self, model, x, t, t_index):
        """Sample from the model at timestep t"""
        betas_t = extract(self.betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape)
        sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape)
        
        # Equation 11 in the paper
        # Use model to predict the mean
        model_mean = sqrt_recip_alphas_t * (
            x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
        )
        
        if t_index == 0:
            return model_mean
        else:
            posterior_variance_t = extract(self.posterior_variance, t, x.shape)
            noise = torch.randn_like(x)
            # Algorithm 2 in the paper
            return model_mean + torch.sqrt(posterior_variance_t) * noise

    @torch.no_grad()
    def p_sample_loop(self, model, shape):
        """Generate samples from the model using the sampling loop"""
        device = next(model.parameters()).device
        b = shape[0]
        
        # Start from pure noise
        img = torch.randn(shape, device=device)
        
        # Progressively denoise the image
        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling timesteps', total=self.num_timesteps):
            # Create a batch of the same timestep
            t = torch.full((b,), i, device=device, dtype=torch.long)
            # Sample from p(x_{t-1} | x_t)
            img = self.p_sample(model, img, t, i)
            
        # Samples are in [-1, 1] range
        return img

    @torch.no_grad()
    def sample(self, model, batch_size=16, channels=3, img_size=64):
        """Simple interface for sampling from the model"""
        return self.p_sample_loop(model, shape=(batch_size, channels, img_size, img_size))

In [None]:
def train_diffusion(diffusion, model, dataloader, optimizer, num_epochs=10, device="cpu"):
    """Train the diffusion model"""
    os.makedirs("diffusion_outputs", exist_ok=True)
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
        progress_bar.set_description(f"Epoch {epoch+1}/{num_epochs}")
        
        for i, batch in progress_bar:
            # print(len(batch))
            # print(type(batch))
            # print(batch.shape)
            images = batch
            
            
            images = images.to(device)
            optimizer.zero_grad()
            
            # Sample random timesteps
            t = torch.randint(0, diffusion.num_timesteps, (images.shape[0],), device=device).long()
            
            # Calculate loss
            loss = diffusion.p_losses(model, images, t)
            
            # Backpropagation
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            progress_bar.set_postfix(loss=running_loss/(i+1))
            
        # Save checkpoint
        checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch
        }
        torch.save(checkpoint, f"diffusion_outputs/checkpoint_epoch_{epoch+1}.pt")
            
        # Generate and save samples
        if (epoch + 1) % 2 == 0 or epoch == num_epochs - 1:
            model.eval()
            sample_images = diffusion.sample(model, batch_size=4, img_size=64)
            # Rescale from [-1, 1] to [0, 1]
            sample_images = (sample_images + 1) / 2
            save_image(sample_images, f"diffusion_outputs/samples_epoch_{epoch+1}.png", nrow=2)
            
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}")
    
    # Save the final model
    torch.save(model.state_dict(), "diffusion_outputs/final_model.pt")

In [None]:
test_dataloader = get_celeba_dataloader(
    root_dir='./celeba',
    split='train',
    batch_size=32,
    image_size=128,
    target_types=['attr'],
    return_dict=True
)

for images, targets in test_dataloader:
        print(f"Image batch shape: {images.shape}")
        if isinstance(targets, dict):
            for target_type, target_tensor in targets.items():
                print(f"{target_type} shape: {target_tensor.shape}")
        else:
            print(f"Target shape: {targets.shape}")
        break

# Show the first 5 images in the batch 'images'
from matplotlib import pyplot as plt
imgs = images[:5]

for img in imgs:
    img = img.permute(1, 2, 0)  # Change from (C, H, W) to (H, W, C)
    img = (img + 1) / 2  # Rescale to [0, 1]
    img = img.numpy()
    # img = np.clip(img, 0, 1)  # Ensure values are in [0, 1]
    plt.imshow(img)
    plt.axis('off')
    plt.show()  # Move plt.show() outside the loop to display all images at once

Image batch shape: torch.Size([32, 3, 128, 128])
attr shape: torch.Size([32, 40])


In [10]:
def generate_images(model, diffusion, output_dir, total_images=10000, batch_size=16, device="cpu"):
    """Generate a large number of images efficiently"""
    os.makedirs(output_dir, exist_ok=True)
    
    model.eval()
    
    num_batches = (total_images + batch_size - 1) // batch_size
    image_count = 0
    
    for batch_idx in tqdm(range(num_batches), desc="Generating batches"):
        # For the last batch, adjust batch size if needed
        current_batch_size = min(batch_size, total_images - image_count)
        
        # Generate a batch of images
        with torch.no_grad():
            batch_images = diffusion.sample(model, batch_size=current_batch_size, img_size=64)
            # Rescale from [-1, 1] to [0, 1]
            batch_images = (batch_images + 1) / 2
            
            # Save as a grid for inspection
            if batch_idx % 10 == 0:
                save_image(batch_images, os.path.join(output_dir, f"grid_{batch_idx}.png"), nrow=4)
            
            # Save individual images
            for i in range(current_batch_size):
                filename = os.path.join(output_dir, f"img_{image_count:05d}.png")
                save_image(batch_images[i], filename)
                image_count += 1
                
                if image_count >= total_images:
                    break

In [11]:
import logging

def main():
    # Device setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Create output directories
    os.makedirs("diffusion_outputs", exist_ok=True)
    os.makedirs("diffusion_outputs/generated", exist_ok=True)
    
    # Dataset preparation
    transform = transforms.Compose([
        transforms.CenterCrop(160),
        transforms.Resize(64),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
    ])
    
    # Load dataset
    # try:
    #     image_dir = "celebA/celeba/img_align_celeba"
    #     dataset = datasets.ImageFolder(root=image_dir, transform=transform)
    #     dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
    #     print(f"Loaded dataset with {len(dataset)} images")
    # except Exception as e:
    #     # print(f"Error loading dataset: {e}")
    #     logging.error(f"Error loading dataset: {e}", exc_info=True, stack_info=True)
    #     print("Make sure the CelebA dataset is available at the specified path")
    #     return
    
    dataloader = get_celeba_dataloader(
        root_dir='./celeba',
        split='train',
        batch_size=32,
        image_size=64,
        return_dict=True
    )
    
    # Initialize model
    model = SimpleUNet(in_channels=3, out_channels=3, time_dim=128).to(device)
    
    # Initialize diffusion process (using fewer timesteps for a more practical demonstration)
    diffusion = DiffusionModel(device=device, num_timesteps=500)
    
    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    
    # Train model
    print("Starting training...")
    train_diffusion(diffusion, model, dataloader, optimizer, num_epochs=10, device=device)
    
    # Generate images
    print("Generating 10,000 images...")
    generate_images(model, diffusion, "diffusion_outputs/generated", total_images=10000, batch_size=16, device=device)
    
    print("Done! 10,000 images saved in: diffusion_outputs/generated")

In [14]:
main()

Using device: cuda
Starting training...


Epoch 1/10:   0%|          | 0/5087 [00:00<?, ?it/s]

32
<class 'torch.Tensor'>
torch.Size([32, 3, 64, 64])





ValueError: too many values to unpack (expected 2)

In [15]:
import os

os.getcwd()

'/home/sawyer/01_classwork/genVision-celebA'

In [16]:
from pathlib import Path

test_img_dir = Path("celebA/celeba/img_align_celeba")
test_img_dir.exists()

True