In [1]:
import torch

from model.titok import TiTok
from base_tokenizers import load_vae

from decord import VideoReader, cpu, bridge
from einops import rearrange
from torchvision.io import write_video
from omegaconf import OmegaConf

In [2]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.manual_seed(0)

device = "cuda"
torch_dtype = torch.float32

config =  OmegaConf.load("configs/tiny.yaml")
checkpoint_path = "out_tiny/epoch=187-step=59000.ckpt"

In [None]:

tokenizer = TiTok(config)

if config.training.torch_compile:
    tokenizer = torch.compile(tokenizer.to(device))

orig_sd = torch.load(checkpoint_path, map_location="cpu", weights_only=False)['state_dict']
model_sd = {}
for k, v in orig_sd.items():
    if not 'disc' in k:
        model_sd[k[6:]] = v # won't work for torch-compiled models?
    
tokenizer.load_state_dict(model_sd)
tokenizer.eval().to(device, torch_dtype)
vae = load_vae(vae_name=config.model.vae.type, model_path=config.model.vae.path, embed_dim=config.model.vae.latent_channels)
vae.eval().to(device, torch_dtype)

In [None]:
# Tokenize a video
def tokenize_and_reconstruct(video_path, write_path):
    with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch_dtype):
        bridge.set_bridge('torch')
        video_reader = VideoReader(video_path, ctx=cpu(0))
        fps = video_reader.get_avg_fps()
        video = video_reader.get_batch(list(range(len(video_reader)))).to(device, torch_dtype)
        video = rearrange(video, 't h w c -> 1 c t h w') / 255
        video = (video * 2) - 1
        enc_vid = vae.encode(video)
        z_quant, result_dict = tokenizer.encode(enc_vid)
        
        # if not in VAE mode
        tokens_list = result_dict['codes'].squeeze(0).cpu().tolist()
        print(f"VIDEO TOKENS ({len(tokens_list)}):\n{tokens_list}")
        z_quant = tokenizer.quantize.indices_to_codes(result_dict['codes'])
        
        recon_video = tokenizer.decode(z_quant.to(torch_dtype))
        
        dec_video = vae.decode(recon_video)
        dec_video = ((dec_video + 1) / 2) * 255
        dec_video = rearrange(dec_video.squeeze(0), 'c t h w -> t h w c').to("cpu", dtype=torch.uint8)
        write_video(write_path, dec_video, fps=fps, options={'crf': '0'})

In [None]:
tokenize_and_reconstruct(f"assets/big_buck_bunny.mp4", f"assets/bbb_recon.mp4")