In [None]:
import numpy as np
import os
from tqdm import tqdm

import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms

from diffusers.models import AutoencoderKL

In [None]:
use_cuda = torch.cuda.is_available()
gpu_indx  = 0
device = torch.device(gpu_indx if use_cuda else "cpu")

In [None]:
image_size = 256
assert image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."

batch_size = 32
dataset_dir = ".""
latent_save_dir = "."
os.makedirs(latent_save_dir, exist_ok=True)

In [None]:
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(device)

In [None]:
# Setup data:
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])

dataset = ImageFolder(dataset_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

In [None]:
img_index = 0
with torch.no_grad():
    for x, y in tqdm(data_loader, leave=False):
        x = x.to(device)
        y = y.to(device)
        with torch.cuda.amp.autocast():
            # Map input images to latent space + normalize latents:
            latent_features = vae.encode(x).latent_dist.sample().mul_(0.18215)
            latent_features = latent_features.detach().cpu()  # (bs, 4, image_size//8, image_size//8)

        for latent in latent_features.split(1, 0):
            np.save(latent_save_dir + f'/{img_index}.npy', latent.squeeze(0).numpy())
            img_index += 1