In [1]:
import torch

from model.titok import TiTok
from base_tokenizer.tokenizer_wrapper import WFVAEWrapper

from decord import VideoReader, cpu
from einops import rearrange
from torchvision.io import write_video
from omegaconf import OmegaConf

In [None]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.manual_seed(0)

config =  OmegaConf.load("configs/small.yaml")
checkpoint_path = "model.pt"
vae_path = 'base_tokenizer/pretrained_model'

In [3]:
tokenizer = TiTok(config)

if config.training.torch_compile:
    tokenizer = torch.compile(tokenizer.to('cuda:0'))

tokenizer.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True))
tokenizer.eval()
tokenizer.requires_grad_(False)

TiTok(
  (quantize): FSQ(
    (project_in): Identity()
    (project_out): Identity()
  )
  (encoder): TiTokEncoder(
    (patch_embed): Conv3d(16, 512, kernel_size=(4, 8, 8), stride=(4, 8, 8))
    (ln_pre): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (model_layers): ModuleList(
      (0-7): 8 x ResidualAttentionBlock(
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (c_fc): Linear(in_features=512, out_features=2048, bias=True)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=2048, out_features=512, bias=True)
        )
      )
    )
    (ln_post): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (linear_out): Linear(in_features=512, out_features=5, bias=True)
  )

In [None]:
device = "cuda"
torch_dtype = torch.float32

tokenizer = tokenizer.to(device, dtype=torch_dtype)
vae = WFVAEWrapper(model_path=vae_path, dtype=torch_dtype)
vae.eval()
vae.to(device)

WFVAEWrapper(
  (vae): WFVAEModel(
    (encoder): Encoder(
      (down1): Sequential(
        (0): Conv2d(24, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ResnetBlock2D(
          (norm1): LayerNorm(
            (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
          )
          (conv1): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): LayerNorm(
            (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (2): ResnetBlock2D(
          (norm1): LayerNorm(
            (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
          )
          (conv1): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm2): LayerNorm(
            (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
    

In [8]:
# Tokenize a video
def tokenize_and_reconstruct(video_path, write_path):
    with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch_dtype):
        video_reader = VideoReader(video_path, ctx=cpu(0))
        fps = video_reader.get_avg_fps()
        video = video_reader.get_batch(list(range(len(video_reader)))).asnumpy()
        video = torch.from_numpy(video).to(device, dtype=torch_dtype)
        video = rearrange(video, 't h w c -> 1 c t h w') / 255
        video = (video * 2) - 1
        enc_vid = vae.encode(video)
        z_quant, result_dict = tokenizer.encode(enc_vid)
        
        print(f"VIDEO TOKENS:\n{result_dict['codes'].squeeze(0).cpu().tolist()}")

        z_quant = tokenizer.codes_to_quantized(result_dict['codes'])
        
        recon_video = tokenizer.decode(z_quant)
        
        dec_video = vae.decode(recon_video)
        
        dec_video = (dec_video + 1) / 2
        dec_video = torch.clamp(dec_video, 0.0, 1.0) * 255
        dec_video = rearrange(dec_video.squeeze(0), 'c t h w -> t h w c').to("cpu", dtype=torch.uint8)
        write_video(write_path, dec_video, fps=fps, options={'crf': '0'})

In [9]:
tokenize_and_reconstruct(f"assets/big_buck_bunny.mp4", f"assets/bbb_recon.mp4")

VIDEO TOKENS:
[3026, 1318, 6509, 6487, 457, 6682, 5921, 803, 1288, 5490, 5781, 1473, 6986, 3252, 2476, 51, 2253, 5252, 7273, 7849, 3108, 7307, 3097, 2379, 4304, 1381, 3495, 6878, 2762, 2770, 1706, 5556, 7443, 5552, 3684, 38, 6757, 5960, 3621, 1395, 5196, 3708, 1419, 750, 2643, 5168, 4609, 4954, 5592, 7704, 3921, 4684, 1418, 3059, 2956, 1553, 6279, 1826, 7539, 1278, 5217, 3096, 4657, 1505]
