In [2]:
import torch

from model.titok import TiTok
from base_tokenizers import load_vae

from decord import VideoReader, cpu, bridge
from einops import rearrange
from torchvision.io import write_video
from omegaconf import OmegaConf

In [3]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.manual_seed(0)

device = "cuda:0"
torch_dtype = torch.bfloat16

config =  OmegaConf.load("stage_2/config.yaml")
checkpoint_path = "stage_2/model.ckpt"
vae_path = "wfvae-16"

In [4]:

tokenizer = TiTok(config)

if config.training.torch_compile:
    tokenizer = torch.compile(tokenizer.to(device))

orig_sd = torch.load(checkpoint_path, map_location="cpu", weights_only=False)['state_dict']
model_sd = {}
for k, v in orig_sd.items():
    if not 'disc' in k:
        model_sd[k[6:]] = v
    
tokenizer.load_state_dict(model_sd)
tokenizer.eval().to(device, torch_dtype)
vae = load_vae(vae_name=config.model.vae.type, model_path=vae_path, embed_dim=config.model.vae.latent_channels)
vae.eval().to(device, torch_dtype)

WFVAEWrapper(
  (vae): WFVAEModel(
    (encoder): Encoder(
      (down1): Sequential(
        (0): Conv2d(24, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ResnetBlock2D(
          (norm1): LayerNorm(
            (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
          )
          (conv1): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): LayerNorm(
            (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (2): ResnetBlock2D(
          (norm1): LayerNorm(
            (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
          )
          (conv1): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): LayerNorm(
            (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
    

In [5]:
# Tokenize a video
def tokenize_and_reconstruct(video_path, write_path):
    with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch_dtype):
        bridge.set_bridge('torch')
        video_reader = VideoReader(video_path, ctx=cpu(0))
        fps = video_reader.get_avg_fps()
        video = video_reader.get_batch(list(range(len(video_reader)))).to(device, torch_dtype)
        video = rearrange(video, 't h w c -> 1 c t h w') / 255
        video = (video * 2) - 1
        enc_vid = vae.encode(video)
        z_quant, result_dict = tokenizer.encode(enc_vid)
        
        tokens_list = result_dict['codes'].squeeze(0).cpu().tolist()
        print(f"VIDEO TOKENS ({len(tokens_list)}):\n{tokens_list}")
        z_quant = tokenizer.quantize.indices_to_codes(result_dict['codes'])
        
        recon_video = tokenizer.decode(z_quant.to(torch_dtype))
        
        dec_video = vae.decode(recon_video)
        dec_video = ((dec_video + 1) / 2) * 255
        dec_video = rearrange(dec_video.squeeze(0), 'c t h w -> t h w c').to("cpu", dtype=torch.uint8)
        write_video(write_path, dec_video, fps=fps, options={'crf': '0'})

In [6]:
tokenize_and_reconstruct(f"assets/big_buck_bunny_256x33.mp4", f"assets/bbb_recon_256x33.mp4")

VIDEO TOKENS (128):
[615, 826, 1465, 441, 7711, 9764, 3279, 7746, 2583, 7695, 13888, 1371, 14888, 8872, 12469, 13881, 9751, 1152, 4346, 315, 14856, 14344, 11762, 946, 2510, 479, 7807, 6757, 7744, 13580, 1969, 13328, 4596, 14353, 11902, 2041, 639, 4398, 15340, 12808, 10116, 12755, 5572, 14215, 6598, 11591, 13335, 4197, 6824, 10802, 1156, 13927, 4388, 8336, 12292, 4936, 1595, 10360, 7214, 14022, 14872, 6569, 1326, 313, 12685, 13367, 12325, 7717, 4575, 13344, 5619, 5668, 8321, 6532, 12383, 8761, 10945, 12353, 7255, 8209, 6611, 3141, 13357, 9146, 13507, 5167, 3143, 2885, 7183, 7080, 12410, 4400, 2614, 2063, 11798, 2911, 14599, 1572, 11331, 12310, 14047, 11775, 7495, 8654, 3208, 13456, 3107, 998, 13453, 12295, 1087, 1440, 11839, 11151, 3085, 3142, 14741, 6339, 2846, 2491, 6669, 14857, 6650, 9888, 12568, 7045, 4223, 11324]
