In [1]:
from datasets import load_dataset
import datasets
from diffusers import AutoencoderKL
import torch
import torchvision.transforms as transforms

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
datasets.config.HF_HUB_OFFLINE = 1 # Comment this out if you havent downloaded the dataset yet

In [4]:
train_ds = load_dataset("tpremoli/CelebA-attrs", cache_dir="../../datasets/CelebA-attrs", split="train")
validation_ds = load_dataset("tpremoli/CelebA-attrs", cache_dir="../../datasets/CelebA-attrs", split="validation")
test_ds = load_dataset("tpremoli/CelebA-attrs", cache_dir="../../datasets/CelebA-attrs", split="test")

Using the latest cached version of the dataset since tpremoli/CelebA-attrs couldn't be found on the Hugging Face Hub (offline mode is enabled).
Found the latest cached dataset configuration 'default' at ..\..\datasets\CelebA-attrs\tpremoli___celeb_a-attrs\default\0.0.0\ed9021d2871ceddbd3cf0fb642544bd7c60c5152 (last modified on Fri Oct  4 14:25:53 2024).
Using the latest cached version of the dataset since tpremoli/CelebA-attrs couldn't be found on the Hugging Face Hub (offline mode is enabled).
Found the latest cached dataset configuration 'default' at ..\..\datasets\CelebA-attrs\tpremoli___celeb_a-attrs\default\0.0.0\ed9021d2871ceddbd3cf0fb642544bd7c60c5152 (last modified on Fri Oct  4 14:25:53 2024).
Using the latest cached version of the dataset since tpremoli/CelebA-attrs couldn't be found on the Hugging Face Hub (offline mode is enabled).
Found the latest cached dataset configuration 'default' at ..\..\datasets\CelebA-attrs\tpremoli___celeb_a-attrs\default\0.0.0\ed9021d2871ceddbd3

In [5]:
print(train_ds[0]["image"].size)

(178, 218)


In [6]:
transform = transforms.Compose([
    transforms.Resize((224, 176)),  # Resize to 176x224 (Height x Width)
    transforms.ToTensor(),           # Convert to tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Scale to [-1, 1]
])

In [7]:
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", cache_dir="../../models/vae")
vae = vae.to(device)

In [8]:
transform2 = lambda x: vae.encode(x.to(device).unsqueeze(0)).latent_dist.sample().squeeze(0).cpu()

In [9]:
import json
import os
from tqdm import tqdm

SAVE_PATH = "../../datasets/CelebA-attrs-latents"

# def process_ds(ds, set):
#     dict = {}
#     for i in range(len(ds)):
#         sample = ds[i]
#         latents = transform2(transform(sample["image"])).to(torch.float16)
#         dict[i] = sample["prompt_string"]
#         # Pad filename to 8 digits
#         filename = f"{i:08d}.pt"
#         torch.save(latents, f"{SAVE_PATH}/{set}/latents/{filename}")

#     with open(f"{SAVE_PATH}/{set}/metadata.json", "w") as f:
#         json.dump(dict, f)    

def process_ds(ds, set_name, batch_size=32):
    """
    Processes the dataset in batches, encodes images using the VAE,
    saves latent vectors, and records metadata.

    Args:
        ds (Dataset): The dataset to process.
        set_name (str): The name of the dataset split (e.g., 'train', 'validation', 'test').
        batch_size (int, optional): Number of samples to process in each batch. Defaults to 32.
    """
    metadata = {}
    num_samples = len(ds)
    SAVE_LATENTS_DIR = f"{SAVE_PATH}/{set_name}/latents"

    # Ensure the save directory exists
    os.makedirs(SAVE_LATENTS_DIR, exist_ok=True)

    # Process the dataset in batches
    for start_idx in tqdm(range(0, num_samples, batch_size), desc=f"Processing {set_name}"):
        end_idx = min(start_idx + batch_size, num_samples)
        batch_indices = range(start_idx, end_idx)
        
        # Load and transform images
        images = [transform(ds[i]["image"]) for i in batch_indices]
        prompts = [ds[i]["prompt_string"] for i in batch_indices]
        
        # Stack images into a batch tensor
        batch_tensor = torch.stack(images).to(device)  # Shape: (batch_size, 3, 216, 176)
        
        # Encode the batch using the VAE
        with torch.no_grad():
            # Encode the batch and sample latent vectors
            encoded = vae.encode(batch_tensor)
            latents = encoded.latent_dist.sample().cpu().half()  # Shape: (batch_size, latent_dim)
        
        # Save each latent vector and update metadata
        for i, latent in enumerate(latents):
            idx = start_idx + i
            metadata[idx] = prompts[i]
            filename = f"{idx:08d}.pt"
            torch.save(latent.clone().detach(), f"{SAVE_LATENTS_DIR}/{filename}")
    
    # Save metadata to JSON
    metadata_path = f"{SAVE_PATH}/{set_name}/metadata.json"
    with open(metadata_path, "w") as f:
        json.dump(metadata, f, indent=4)

process_ds(train_ds, "train")
process_ds(validation_ds, "validation")
process_ds(test_ds, "test")

  hidden_states = F.scaled_dot_product_attention(
Processing train: 100%|██████████| 5087/5087 [25:31<00:00,  3.32it/s]
Processing validation: 100%|██████████| 624/624 [03:04<00:00,  3.37it/s]
Processing test: 100%|██████████| 621/621 [03:03<00:00,  3.38it/s]
