In [2]:
import torch
import yaml
import numpy as np
import imageio.v2 as imageio
import os
from tqdm import tqdm

import matplotlib.pyplot as plt
plt.rcParams['text.usetex'] = False

from utils.dataset_helper import create_dataloaders
from utils.image_utils import render
from utils.diffusion_data_helper import denormalize_data
from utils.gaussian_file_helper import load_gaussians
from vae_model import GaussianVAE

device = "cuda" if torch.cuda.is_available() else "cpu"
config_path = "config/vae_training.yaml"
checkpoint_path = "best_gaussian_vae.pth"

with open(config_path, "r") as f:
        cfg = yaml.safe_load(f)

In [3]:
model = GaussianVAE(
            num_gaussians=cfg["model"]["num_gaussians"],
            input_dim=cfg["model"]["input_dim"],
            latent_dim=cfg["model"]["model_dim"],
            decoder_layers=cfg["model"].get("decoder_transformer_layers", 6),
            decoder_heads=cfg["model"].get("decoder_transformer_heads", 8)
        ).to(device)

data_loader, _ = create_dataloaders(
    "./data/FFHQ",
    batch_size=1, 
    shuffle=True, 
    augment=False,
    is_distributed=False
)

model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()

Found 22084 files in ./data/FFHQ


GaussianVAE(
  (sa1): PointNetSetAbstractionMsg(
    (conv_blocks): ModuleList(
      (0): ModuleList(
        (0): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
        (2): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
      )
      (1): ModuleList(
        (0): Conv2d(8, 64, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
        (2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
      )
      (2): ModuleList(
        (0): Conv2d(8, 64, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))
        (2): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (bn_blocks): ModuleList(
      (0): ModuleList(
        (0-1): 2 x BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
   

In [4]:
# Video parameters
fps = 30
duration = 5  # seconds
num_frames = fps * duration

# Output path
output_dir = "output"
os.makedirs(output_dir, exist_ok=True)
video_path = os.path.join(output_dir, "latent_interpolation.mp4")

print(f"Generating {num_frames} frames for {duration}s video at {fps} fps...")

# Get two random samples from the dataset
data_iter = iter(data_loader)

Generating 150 frames for 5s video at 30 fps...


In [None]:
data1 = next(data_iter).to(device)
data2 = next(data_iter).to(device)

# Encode to get latent means
with torch.no_grad():
    mu1, logvar1 = model.encode(data1)
    mu2, logvar2 = model.encode(data2)

# Create video writer (requires imageio-ffmpeg backend)
with imageio.get_writer(video_path, fps=fps, format="FFMPEG", codec="libx264", quality=8) as writer:
    for i in tqdm(range(num_frames)):
        # Calculate alpha (0 to 1)
        alpha = i / (num_frames - 1)
        
        # Interpolate in latent space
        latent_interp = (1 - alpha) * mu1 + alpha * mu2
        
        # Decode and render
        with torch.no_grad():
            decoded = model.decode(latent_interp)
        
        xy, scale, rot, feat = denormalize_data(decoded[:, :, 0:2], decoded[:, :, 2:4], 
                                                decoded[:, :, 4:5], decoded[:, :, 5:8])
        
        xy = xy.squeeze(0).contiguous().float()
        scale = scale.squeeze(0).contiguous().float()
        rot = rot.squeeze(0).contiguous().float()
        feat = feat.squeeze(0).contiguous().float()
        
        img_size = (int(480), int(640))
        image = render(xy, scale, rot, feat, img_size=img_size)
        image_np = image.cpu().detach().permute(1, 2, 0).numpy()
        
        # Convert to uint8 (0-255 range)
        image_uint8 = (np.clip(image_np, 0, 1) * 255).astype(np.uint8)
        writer.append_data(image_uint8)

print(f"Video saved to {video_path}")

In [None]:
# Let's also save an interpolation as a series of images for reference, using the 2 extremes 3 frames in between
image_output_dir = os.path.join(output_dir, "interpolation_frames")
os.makedirs(image_output_dir, exist_ok=True)
num_image_frames = 5
for i in range(num_image_frames):
    alpha = i / (num_image_frames - 1)
    
    latent_interp = (1 - alpha) * mu1 + alpha * mu2
    
    with torch.no_grad():
        decoded = model.decode(latent_interp)
    
    xy, scale, rot, feat = denormalize_data(decoded[:, :, 0:2], decoded[:, :, 2:4], 
                                            decoded[:, :, 4:5], decoded[:, :, 5:8])
    
    xy = xy.squeeze(0).contiguous().float()
    scale = scale.squeeze(0).contiguous().float()
    rot = rot.squeeze(0).contiguous().float()
    feat = feat.squeeze(0).contiguous().float()
    
    img_size = (int(480), int(640))
    image = render(xy, scale, rot, feat, img_size=img_size)
    image_np = image.cpu().detach().permute(1, 2, 0).numpy()
    
    image_uint8 = (np.clip(image_np, 0, 1) * 255).astype(np.uint8)
    imageio.imwrite(os.path.join(image_output_dir, f"frame_{i:02d}.png"), image_uint8)

In [None]:
# I would like to do a giant render of a single sample, it should be at least 10k by 10k pixels
data_sample = next(data_iter).to(device)
with torch.no_grad():
    mu, logvar = model.encode(data_sample)
    decoded = model.decode(mu)

    xy, scale, rot, feat = denormalize_data(decoded[:, :, 0:2], decoded[:, :, 2:4],
                                            decoded[:, :, 4:5], decoded[:, :, 5:8])
    
    xy = xy.squeeze(0).contiguous().float()
    scale = scale.squeeze(0).contiguous().float()
    rot = rot.squeeze(0).contiguous().float()
    feat = feat.squeeze(0).contiguous().float()

    img_size = (4800, 6400)
    large_image = render(xy, scale * 10.0, rot, feat, img_size=img_size)
    large_image_np = large_image.cpu().detach().permute(1, 2, 0).numpy()
    large_image_uint8 = (np.clip(large_image_np, 0, 1) * 255).astype(np.uint8)
    large_image_path = os.path.join(output_dir, "large_render.png")
    imageio.imwrite(large_image_path, large_image_uint8)

    img_size = (480, 640)
    small_image = render(xy, scale, rot, feat, img_size=img_size)
    small_image_np = small_image.cpu().detach().permute(1, 2, 0).numpy()
    small_image_uint8 = (np.clip(small_image_np, 0, 1) * 255).astype(np.uint8)
    small_image_path = os.path.join(output_dir, "small_render.png")
    imageio.imwrite(small_image_path, small_image_uint8)

    #Let's print out which inputs we used
    print("Used the following two samples for interpolation:")
    print(data_sample)
    print(f"Small render saved to {small_image_path}")
    print(f"Giant render saved to {large_image_path}")

In [5]:
#Let's render every single npz in the data folder and save the images to a folder
render_output_dir = os.path.join(output_dir, "all_renders")
os.makedirs(render_output_dir, exist_ok=True)
data_filenames = sorted([f for f in os.listdir("./data/FFHQ") if f.endswith(".npz")])
print(f"Rendering {len(data_filenames)} files to {render_output_dir}...")
print(data_filenames[0])

Rendering 22084 files to output/all_renders...
00000.png.npz


In [None]:
for filename in tqdm(data_filenames):
    data_path = os.path.join("./data/FFHQ", filename)
    if os.path.exists(os.path.join(render_output_dir, "gt",f"{os.path.splitext(filename)[0]}_render_gt.png")):
        continue  #Skip already rendered files

    data =  load_gaussians(data_path)
    xy = data['xy'].unsqueeze(0).to(device)
    scale = data['scale'].unsqueeze(0).to(device) 
    rot = data['rot'].unsqueeze(0).to(device)
    feat = data['feat'].unsqueeze(0).to(device)
    input_data = torch.cat([xy, scale, rot, feat], dim=2)

    #Render ground truth
    image_gt = render(xy.squeeze(0), scale.squeeze(0)/2, rot.squeeze(0), feat.squeeze(0), img_size=(256, 256))
    image_gt_np = image_gt.cpu().detach().permute(1, 2, 0).numpy()
    image_gt_uint8 = (np.clip(image_gt_np, 0, 1) * 255).astype(np.uint8)
    
    output_image_path_gt = os.path.join(render_output_dir, "gt",f"{os.path.splitext(filename)[0]}_render_gt.png")
    imageio.imwrite(output_image_path_gt, image_gt_uint8)

100%|██████████| 22084/22084 [04:27<00:00, 82.48it/s]   


In [None]:
#Render VAE reconstruction
from utils.dataset_helper import GaussianSplatDataset
from torch.utils.data import DataLoader

# Create dataset ensuring order matches data_filenames
file_paths = [os.path.join("./data/FFHQ", f) for f in data_filenames]
dataset = GaussianSplatDataset("./data/FFHQ", file_paths=file_paths, augment=False)
loader = DataLoader(dataset, batch_size=1, shuffle=False)

for (filename, input_data) in tqdm(zip(data_filenames, loader), total=len(data_filenames)):
    
    if os.path.exists(os.path.join(render_output_dir, "vae",f"{os.path.splitext(filename)[0]}_render_vae.png")):
        continue  #Skip already rendered files

    with torch.no_grad():
        input_data = input_data.to(device)
        mu, logvar = model.encode(input_data)
        decoded = model.decode(mu)

        xy_dec, scale_dec, rot_dec, feat_dec = denormalize_data(decoded[:, :, 0:2], decoded[:, :, 2:4],
                                                decoded[:, :, 4:5], decoded[:, :, 5:8])
        
        xy_dec = xy_dec.squeeze(0).contiguous().float()
        scale_dec = scale_dec.squeeze(0).contiguous().float()
        rot_dec = rot_dec.squeeze(0).contiguous().float()
        feat_dec = feat_dec.squeeze(0).contiguous().float()

        image_vae = render(xy_dec, scale_dec/2, rot_dec, feat_dec, img_size=(256, 256))
        image_vae_np = image_vae.cpu().detach().permute(1, 2, 0).numpy()
        image_vae_uint8 = (np.clip(image_vae_np, 0, 1) * 255).astype(np.uint8)
        
        output_image_path_vae = os.path.join(render_output_dir, "vae", f"{os.path.splitext(filename)[0]}_render_vae.png")
        imageio.imwrite(output_image_path_vae, image_vae_uint8)

In [None]:
# Let's calculate the FID and KID between the GT renders and the VAE renders
import os
import time
# First we split the images into two folders
gt = 0
os.makedirs(os.path.join(render_output_dir, "gt_renders"), exist_ok=True)
vae = 0
os.makedirs(os.path.join(render_output_dir, "vae_renders"), exist_ok=True)


for i in os.listdir(render_output_dir):
    if i.endswith("_render_gt.png"):
        os.rename(os.path.join(render_output_dir, i), os.path.join(render_output_dir, "gt_renders", i))
        gt += 1
    elif i.endswith("_render_vae.png"):
        os.rename(os.path.join(render_output_dir, i), os.path.join(render_output_dir, "vae_renders", i))
        vae += 1
    else:
        print(f"Unknown file in render output dir: {i}")
        continue

print(f"Calculating FID and KID between {gt} GT renders and {vae} VAE renders...")

from cleanfid import fid
print("Computing FID and KID, this may take a while...")
start_time =  time.time()
fid_value = fid.compute_fid(os.path.join(render_output_dir, "gt_renders"),
                            os.path.join(render_output_dir, "vae_renders"),
                            mode="clean",
                            num_workers=16)
end_time = time.time()
print(f"FID computation took {end_time - start_time:.2f} seconds.")
print(f"FID: {fid_value}")

start_time =  time.time()
kid_value = fid.compute_kid(os.path.join(render_output_dir, "gt_renders"),
                            os.path.join(render_output_dir, "vae_renders"),
                            mode="clean",
                            num_workers=16)
end_time = time.time()
print(f"KID computation took {end_time - start_time:.2f} seconds.")
print(f"KID: {kid_value}")



Unknown file in render output dir: gt_renders
Unknown file in render output dir: vae_renders
Calculating FID and KID between 0 GT renders and 0 VAE renders...
Computing FID and KID, this may take a while...
compute FID between two folders
Found 22083 images in the folder output/all_renders/gt_renders


FID gt_renders : 100%|██████████| 691/691 [00:54<00:00, 12.67it/s]


Found 22084 images in the folder output/all_renders/vae_renders


FID vae_renders : 100%|██████████| 691/691 [00:54<00:00, 12.67it/s]


FID computation took 118.58 seconds.
FID: 95.78872547084262
compute KID between two folders
Found 22083 images in the folder output/all_renders/gt_renders


KID gt_renders : 100%|██████████| 691/691 [00:55<00:00, 12.45it/s]


Found 22084 images in the folder output/all_renders/vae_renders


KID vae_renders : 100%|██████████| 691/691 [00:54<00:00, 12.78it/s]


KID computation took 118.08 seconds.


In [8]:
print(f"KID: {kid_value}")

KID: 0.09401272982358932
