In [1]:
import torch
import yaml
import numpy as np
import imageio.v2 as imageio
import os
from tqdm import tqdm

import matplotlib.pyplot as plt
plt.rcParams['text.usetex'] = False

from utils.dataset_helper import create_dataloaders, create_train_val_dataloaders
from vae_model import GaussianVAE, vae_loss_sinkhorn
from utils.training_utils import get_warmup_cosine_scheduler
from utils.vae_utils import sample_from_latent, save_target_visualization, visualize_reconstruction
from utils.image_utils import render, render_and_save
from utils.diffusion_data_helper import denormalize_data

device = "cuda" if torch.cuda.is_available() else "cpu"
config_path = "config/vae_training.yaml"
checkpoint_path = "best_gaussian_vae.pth"

with open(config_path, "r") as f:
        cfg = yaml.safe_load(f)

In [2]:
model = GaussianVAE(
            num_gaussians=cfg["model"]["num_gaussians"],
            input_dim=cfg["model"]["input_dim"],
            latent_dim=cfg["model"]["model_dim"],
            decoder_layers=cfg["model"].get("decoder_transformer_layers", 6),
            decoder_heads=cfg["model"].get("decoder_transformer_heads", 8)
        ).to(device)

data_loader, _ = create_dataloaders(
    "./data/FFHQ",
    batch_size=1, 
    shuffle=True, 
    augment=False,
    is_distributed=False
)

model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()

Found 22084 files in ./data/FFHQ


GaussianVAE(
  (sa1): PointNetSetAbstractionMsg(
    (conv_blocks): ModuleList(
      (0): ModuleList(
        (0): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
        (2): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
      )
      (1): ModuleList(
        (0): Conv2d(8, 64, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
        (2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
      )
      (2): ModuleList(
        (0): Conv2d(8, 64, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(64, 96, kernel_size=(1, 1), stride=(1, 1))
        (2): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    (bn_blocks): ModuleList(
      (0): ModuleList(
        (0-1): 2 x BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
   

In [None]:
# Video parameters
fps = 30
duration = 5  # seconds
num_frames = fps * duration

# Output path
output_dir = "output"
os.makedirs(output_dir, exist_ok=True)
video_path = os.path.join(output_dir, "latent_interpolation.mp4")

print(f"Generating {num_frames} frames for {duration}s video at {fps} fps...")

# Get two random samples from the dataset
data_iter = iter(data_loader)
data1 = next(data_iter)[0].to(device)
data2 = next(data_iter)[0].to(device)

# Encode to get latent means
with torch.no_grad():
    mu1, logvar1 = model.encode(data1)
    mu2, logvar2 = model.encode(data2)

# Create video writer (requires imageio-ffmpeg backend)
with imageio.get_writer(video_path, fps=fps, format="FFMPEG", codec="libx264", quality=8) as writer:
    for i in tqdm(range(num_frames)):
        # Calculate alpha (0 to 1)
        alpha = i / (num_frames - 1)
        
        # Interpolate in latent space
        latent_interp = (1 - alpha) * mu1 + alpha * mu2
        
        # Decode and render
        with torch.no_grad():
            decoded = model.decode(latent_interp)
        
        xy, scale, rot, feat = denormalize_data(decoded[:, :, 0:2], decoded[:, :, 2:4], 
                                                decoded[:, :, 4:5], decoded[:, :, 5:8])
        
        xy = xy.squeeze(0).contiguous().float()
        scale = scale.squeeze(0).contiguous().float()
        rot = rot.squeeze(0).contiguous().float()
        feat = feat.squeeze(0).contiguous().float()
        
        img_size = (int(480), int(640))
        image = render(xy, scale, rot, feat, img_size=img_size)
        image_np = image.cpu().detach().permute(1, 2, 0).numpy()
        
        # Convert to uint8 (0-255 range)
        image_uint8 = (np.clip(image_np, 0, 1) * 255).astype(np.uint8)
        writer.append_data(image_uint8)

print(f"Video saved to {video_path}")

Generating 150 frames for 5s video at 30 fps...


  0%|          | 0/150 [00:00<?, ?it/s]


NameError: name 'mu1' is not defined