In [11]:
import numpy as np
import os

import PIL

import torch
import torch.nn as nn
import torch.utils
import torch.distributions
import torchvision
import torch.nn.functional as F

from diffusers.models import AutoencoderKL

from latent_dataset import LatentImageDataset

from guassian_noise import GaussianDiffusion, get_named_beta_schedule


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [12]:
DATASET_PATH = 'data/LATENT_DATASET/LATENT_DATASET'

def init_vae():
    # https://huggingface.co/stabilityai/sd-vae-ft-mse
    model: AutoencoderKL = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse').to(device)
    torch.compile
    model = model.eval()
    model.train = False
    for param in model.parameters():
        param.requires_grad = False
    return model

vae = init_vae()
scale_factor=0.18215 # scale_factor follows DiT and stable diffusion.

@torch.no_grad()
def encode(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 
    posterior = vae.encode(x, return_dict=False)[0].parameters
    return torch.chunk(posterior, 2, dim=1)    

@torch.no_grad()
def sample(mean: torch.FloatTensor, logvar: torch.FloatTensor) -> torch.FloatTensor:
    std = torch.exp(0.5 * logvar)
    z = torch.randn_like(mean)
    z = mean + z * std
    return z * scale_factor

@torch.no_grad()
def decode(z) -> torch.Tensor:
    x = vae.decode(z / scale_factor, return_dict=False)[0]
    x = ((x + 1.0) * 127.5).clamp(0, 255).to(torch.uint8)
    return x

In [13]:
dataset = LatentImageDataset(DATASET_PATH)

dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=True)

def save_image(img, name):
  img = img[0]
  img = img.cpu()
  img = img.permute(1, 2, 0)

  PIL.Image.fromarray(img.numpy()).save(f"{name}.png")

In [14]:

timesteps = 1000
betas = get_named_beta_schedule("linear", timesteps)

diffusion = GaussianDiffusion(
    betas=betas,
    model_mean_type=None, 
    model_var_type=None,
    loss_type=None
)

In [15]:
# Obtain next latent image in the dataset

mean, logvar = next(iter(dataloader))

mean = mean.to(device).to(torch.float32)
logvar = logvar.to(device).to(torch.float32)

# Sample from the latent space

x_0 = sample(mean, logvar)

# Save the initial image

save_image(decode(x_0), "sample_0")

# Generate images at different timesteps

for t in range(0, 1000, 50):

  tt = torch.tensor([t]).to(device) # timestep in torch tensor

  # Generate noise

  noise = torch.randn(1, 4, 32, 32).to(device)

  # Generate image at timestep t

  x_t = diffusion.q_sample(x_0, tt, noise)

  # Decode the image

  decoded = decode(x_t)

  # Save the image

  save_image(decoded, f"sample_{t + 50}")