In [10]:
import numpy as np
import matplotlib.pyplot as plt
import os

import torch
import torch.utils
import torch.distributions
import torchvision

from diffusers.models import AutoencoderKL

In [None]:
IMAGE_FOLDER = 'data/DATASET'
LATENT_IMAGE_FOLDER = 'data/LATENT_DATASET'
TARGET_SIZE = 256

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def init_vae():
    # https://huggingface.co/stabilityai/sd-vae-ft-mse
    model: AutoencoderKL = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse').to(device)
    # model = torch.compile(model) # TODO we should do this on linux for optimization
    model = model.eval()
    model.train = False
    for param in model.parameters():
        param.requires_grad = False
    return model

vae = init_vae()
scale_factor=0.18215 # scale_factor follows DiT and stable diffusion.

@torch.no_grad()
def encode(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 
    posterior = vae.encode(x, return_dict=False)[0].parameters
    return torch.chunk(posterior, 2, dim=1)    

@torch.no_grad()
def sample(mean: torch.FloatTensor, logvar: torch.FloatTensor) -> torch.FloatTensor:
    std = torch.exp(0.5 * logvar)
    z = torch.randn_like(mean)
    z = mean + z * std
    return z * scale_factor

@torch.no_grad()
def decode(z) -> torch.Tensor:
    x = vae.decode(z / scale_factor, return_dict=False)[0]
    x = ((x + 1.0) * 127.5).clamp(0, 255).to(torch.uint8)
    return x

In [12]:

def list_files_recursively(data_dir):
    results = []
    for entry in sorted(os.listdir(data_dir)):
        full_path = os.path.join(data_dir, entry)
        if(entry.endswith('.jpg')):
            results.append(full_path)
        elif os.path.isdir(full_path):
            results.extend(list_files_recursively(full_path))
    return results

def center_crop_square(image: torch.Tensor):
    # Crop center of image to be square using pytorch
    _, width, height = image.shape
    new_size = min(width, height)

    left = (width - new_size) // 2
    top = (height - new_size) // 2
    right = (width + new_size) // 2
    bottom = (height + new_size) // 2

    return image[:, left:right, top:bottom]

def resize_square(image: torch.Tensor, size: int): 
  return torchvision.transforms.functional.resize(image, size, interpolation=torchvision.transforms.InterpolationMode.BICUBIC)

image_files = list_files_recursively(IMAGE_FOLDER)

# For each image, load it, resize it, and save it   

for image_file in image_files:
  image = torchvision.io.read_image(image_file).to(device).to(torch.float32)

  image = center_crop_square(image)
  image = resize_square(image, TARGET_SIZE)
  image = (image / 127.5) - 1
  image = image.unsqueeze(0)

  mean, logvar = encode(image)

  mean = mean.cpu().numpy().astype(np.float16)
  logvar = logvar.cpu().numpy().astype(np.float16)

  np.savez_compressed(image_file.replace(IMAGE_FOLDER, LATENT_IMAGE_FOLDER), mean=mean, logvar=logvar)

#   # Load

#   loaded = np.load(image_file.replace(IMAGE_FOLDER, LATENT_IMAGE_FOLDER) + '.npz')

#   mean = loaded['mean']
#   logvar = loaded['logvar']

#   mean = torch.tensor(mean).to(device).to(torch.float32)
#   logvar = torch.tensor(logvar).to(device).to(torch.float32)

#   # Decode

#   z = sample(mean, logvar)
#   x = decode(z).cpu()
#   x = x.squeeze(0)

#   # Save

#   torchvision.io.write_jpeg(x, image_file.replace(IMAGE_FOLDER, LATENT_IMAGE_FOLDER), quality=100)


  
  




