# LTX-2 VAE Test Notebook

**VAE Specifications:**
- Spatial compression: 32x | Temporal compression: 8x | Latent channels: 128
- Frame count: must be `1 + 8*k` | Resolution: must be divisible by 32

**Tests:**
1. Basic Redecode - encode/decode reconstruction quality
2. Checkerboard Latent Blend Sweep - varying block sizes from large to 1x1 XOXOXO
3. Latent Quantization Sweep - quantize to 256, 128, 64, 32, 16, 4, 2 levels

## Setup

In [None]:
import sys
import rp
import torch
import numpy as np
from einops import rearrange

# Global config
IN_NOTEBOOK = rp.running_in_jupyter_notebook()
top_dir = rp.get_git_working_dir('.')
ltx_src = rp.path_join(top_dir, 'LTX2', 'src')

# Add LTX packages to path
for pkg in ['ltx-core', 'ltx-trainer', 'ltx-pipelines']:
    sys.path.insert(0, rp.path_join(ltx_src, 'packages', pkg, 'src'))

from ltx_trainer.model_loader import load_video_vae_encoder, load_video_vae_decoder

print(f"Running in notebook: {IN_NOTEBOOK}")
print(f"Top directory: {top_dir}")

## Load VAE

In [None]:
models_dir = rp.path_join(top_dir, 'LTX2', 'models')
vae_path = rp.path_join(models_dir, 'ltx-2-19b-distilled.safetensors')
device = rp.select_torch_device(prefer_used=True, reserve=True)
dtype = torch.bfloat16

print(f"Device: {device}, dtype: {dtype}")
print("Loading VAE...")
vae_encoder = load_video_vae_encoder(vae_path, device=device, dtype=dtype)
vae_decoder = load_video_vae_decoder(vae_path, device=device, dtype=dtype)
print("VAE loaded!")

## Helper Functions

In [None]:
# 10x higher than 'high' (which is 10^7)
VIDEO_BITRATE = 100000000


def show_video(video, name, framerate=30):
    """Display in notebook or save to file."""
    if IN_NOTEBOOK:
        rp.display_video(video, framerate=framerate)
    else:
        path = rp.path_join(top_dir, 'Notebooks', f'{name}.mp4')
        print(f"Saving: {path}")
        rp.save_video_mp4(video, path, framerate=framerate, video_bitrate=VIDEO_BITRATE)
        return path


def encode_video(rgb_video):
    """Encode RGB video (T,H,W,3) to latent (LT,LC,LH,LW)."""
    numpy_video = rp.as_rgb_images(rgb_video, copy=False)
    torch_video = rp.as_torch_images(numpy_video, device=device, dtype=dtype, copy=False)
    torch_video = rearrange(torch_video, 'T C H W -> 1 C T H W') * 2 - 1
    with torch.inference_mode():
        latent = vae_encoder(torch_video)
    return rearrange(latent, '1 LC LT LH LW -> LT LC LH LW')


def decode_video(latent_video):
    """Decode latent (LT,LC,LH,LW) to RGB video (T,H,W,3)."""
    latent = rearrange(latent_video, 'LT LC LH LW -> 1 LC LT LH LW')
    with torch.inference_mode():
        decoded = vae_decoder(latent)
    decoded = rearrange(decoded, '1 C T H W -> T H W C')
    decoded = ((decoded + 1) / 2).clamp(0, 1) * 255
    return decoded.to(torch.uint8).cpu().numpy()


def redecode_video(rgb_video):
    """Encode then decode."""
    return decode_video(encode_video(rgb_video))


def make_comparison(videos, labels):
    """Create labeled side-by-side comparison video."""
    return rp.horizontally_concatenated_videos(
        rp.resize_lists_to_min_len(
            rp.resize_videos_to_min_size(
                rp.labeled_videos(videos, labels, font='R:Futura')
            )
        )
    )

## Load Test Video

In [None]:
video_path = rp.download_to_cache('https://www.pexels.com/download/video/5291434/')
numpy_video_boat = rp.load_video(video_path, length=150, show_progress=True)

# Resize and crop to VAE requirements
numpy_video_boat = rp.resize_images_to_fit(numpy_video_boat, height=512, width=768, allow_growth=False)
numpy_video_boat = rp.as_numpy_array(numpy_video_boat)
H, W = numpy_video_boat.shape[1:3]
numpy_video_boat = rp.crop_images(numpy_video_boat, (H//32)*32, (W//32)*32, origin='center')
numpy_video_boat = rp.as_numpy_array(numpy_video_boat)
T = len(numpy_video_boat)
numpy_video_boat = numpy_video_boat[:1 + 8*((T-1)//8)]
print(f"Boat video: {numpy_video_boat.shape}")

# Create green video
T, H, W, _ = numpy_video_boat.shape
numpy_video_green = np.zeros((T, H, W, 3), dtype=np.uint8)
numpy_video_green[:, :, :, 1] = 255
print(f"Green video: {numpy_video_green.shape}")

---
# Test 1: Basic Redecode
Encode to latent space, decode back. Should be nearly identical to input.

In [None]:
rp.tic()
video_redecoded = redecode_video(numpy_video_boat)
rp.ptoc('Redecode')

show_video(make_comparison(
    [numpy_video_boat, video_redecoded],
    ['Input', 'Redecoded']
), 'test1_redecode')

---
# Test 2: Checkerboard Latent Blend Sweep
Blend boat and green videos in **latent space** using checkerboard masks of varying sizes, from large blocks down to 1x1 XOXOXO pattern.

In [None]:
rp.tic()
latent_boat = encode_video(numpy_video_boat)
latent_green = encode_video(numpy_video_green)
rp.ptoc('Encode both')

LT, LC, LH, LW = latent_boat.shape
print(f"Latent shape: LT={LT}, LC={LC}, LH={LH}, LW={LW}")


def make_checkerboard_mask(block_h, block_w, LT, LC, LH, LW, device):
    """Create checkerboard mask with given block size."""
    mask = torch.zeros((LH, LW), dtype=torch.bool, device=device)
    for h in range(LH):
        for w in range(LW):
            block_row = h // block_h
            block_col = w // block_w
            if (block_row + block_col) % 2 == 0:
                mask[h, w] = True
    return mask.unsqueeze(0).unsqueeze(0).expand(LT, LC, LH, LW)


# Sweep from large blocks down to 1x1 (XOXOXO pattern)
block_sizes = [
    (LH // 2, LW // 2),  # 2x2 quadrants
    (LH // 4, LW // 4),  # 4x4 blocks
    (LH // 8, LW // 8),  # 8x8 blocks
    (2, 2),              # 2x2 latent blocks
    (1, 1),              # 1x1 XOXOXO
]

checker_videos = []
checker_labels = []

for block_h, block_w in block_sizes:
    block_h = max(1, block_h)
    block_w = max(1, block_w)
    label = f'{block_h}x{block_w}'
    print(f"Checkerboard {label}...")

    mask = make_checkerboard_mask(block_h, block_w, LT, LC, LH, LW, device)
    latent_blend = torch.where(mask, latent_boat, latent_green)
    video_checker = decode_video(latent_blend)

    checker_videos.append(video_checker)
    checker_labels.append(label)

rp.ptoc('All checkerboards')

# Create tiled grid
grid_video = rp.tiled_videos(
    rp.resize_lists_to_min_len(
        rp.resize_videos_to_min_size(
            rp.labeled_videos(checker_videos, checker_labels, font='R:Futura')
        )
    )
)

show_video(grid_video, 'test2_checkerboard_sweep')

---
# Test 3: Latent Quantization Sweep
Quantize latent values to various discrete levels to test compression tolerance.

In [None]:
rp.tic()
latent = encode_video(numpy_video_boat)
lat_min, lat_max = latent.min(), latent.max()
print(f"Latent range: [{lat_min:.4f}, {lat_max:.4f}]")


def quantize_latent(latent, n_levels):
    """Quantize latent to n discrete levels."""
    latent_norm = (latent - lat_min) / (lat_max - lat_min)
    latent_quant = (latent_norm * (n_levels - 1)).round() / (n_levels - 1)
    return latent_quant * (lat_max - lat_min) + lat_min


# Generate videos for each quantization level
quant_levels = [256, 128, 64, 32, 16, 4, 2]
videos = [numpy_video_boat, video_redecoded]  # Input and Unquantized
labels = ['Input', 'Unquantized']

for n in quant_levels:
    print(f"Quantizing to {n} levels...")
    latent_q = quantize_latent(latent, n)
    video_q = decode_video(latent_q)
    videos.append(video_q)
    labels.append(f'N={n}')

rp.ptoc('All quantizations')

# Create tiled grid comparison
grid_video = rp.tiled_videos(
    rp.resize_lists_to_min_len(
        rp.resize_videos_to_min_size(
            rp.labeled_videos(videos, labels, font='R:Futura')
        )
    )
)

show_video(grid_video, 'test3_quantization_sweep')

---
# All Tests Complete!

In [None]:
print("ALL TESTS COMPLETE!")