In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import v2, functional
import numpy as np
from einops import rearrange
from omegaconf import OmegaConf
from decord import VideoReader, cpu, bridge
import imageio

from model.titok import TiTok

In [2]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.manual_seed(0)

device = "cuda"
torch_dtype = torch.bfloat16

config = OmegaConf.load("configs/large.yaml")
checkpoint_path = "model.ckpt" # hf download NilanE/TiTok-Video-VariableComp-V1 model.ckpt --local-dir .
use_ema = True # False


In [3]:
tokenizer = TiTok(config)

if config.training.main.torch_compile:
    tokenizer = torch.compile(tokenizer.to(device))

orig_sd = torch.load(checkpoint_path, map_location="cpu", weights_only=False)['state_dict']
global_step = torch.load(checkpoint_path, map_location="cpu", weights_only=False)['global_step']

model_sd = {}
model_key = 'ema_model' if (use_ema and any([k.startswith('ema_model') for k in orig_sd.keys()])) else 'model'

for k, v in orig_sd.items():
    if k.startswith(model_key):
        model_sd[k[(len(model_key)+1):]] = v
    
tokenizer.load_state_dict(model_sd)
tokenizer.eval().to(device, torch_dtype)

TiTok(
  (encoder): TiTokEncoder(
    (rope): RoPE(
      (pos_emb): Lumina2RotaryPosEmbed()
    )
    (proj_in): Linear(in_features=768, out_features=1024, bias=True)
    (model_layers): ResidualAttentionBlock(
      (attn_layer): Sequential(
        (0): Attn(
          (pre_ln): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (to_qkv): Linear(in_features=1024, out_features=1536, bias=False)
          (q_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (k_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=False)
        )
        (1): Attn(
          (pre_ln): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (to_qkv): Linear(in_features=1024, out_features=1536, bias=False)
          (q_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (k_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (out_proj): Linear(in_feat

In [4]:
# eval dataset for single multi-chunk-duration video
class EvalReconstructionDataset(torch.utils.data.Dataset): # assumes videos are > target in frames, res, etc.
    def __init__(self, video_path, trg_fps=8, out_grid=(16, 128, 128)):
        self.out_grid = out_grid
        self.video_path = video_path
        self.transform = v2.Compose([
            v2.Resize(size=max(out_grid[1:]), interpolation=functional.InterpolationMode.BICUBIC, antialias=True),
            v2.CenterCrop(size=out_grid[1:])
        ])
        bridge.set_bridge('torch')
        self.vr = VideoReader(video_path, ctx=cpu(0), num_threads=0)

        self.orig_fps_chunk_length = int(out_grid[0] * (self.vr.get_avg_fps() / trg_fps))
        self.num_chunks = len(self.vr) // self.orig_fps_chunk_length
        

    def __len__(self):
        return self.num_chunks

    def __getitem__(self, idx):
        start_idx = idx * self.orig_fps_chunk_length
        end_idx = start_idx + self.orig_fps_chunk_length
    
        chunk_indices = np.linspace(start_idx, end_idx - 1, self.out_grid[0], dtype=int).tolist()
        chunk = torch.Tensor(self.vr.get_batch(chunk_indices))
    
        chunk = chunk.permute(0, 3, 1, 2)
        chunk = self.transform(chunk)
        chunk = chunk.permute(1, 0, 2, 3)
    
        chunk = chunk.to(device, torch_dtype) / 255
        chunk = (chunk * 2) - 1 # [-1, 1]

        return chunk

In [None]:
# Tokenize a video (batched)
def tokenize_and_reconstruct(video_path, write_path, grid, num_tokens, ds_fps, batch_size=1):
    with torch.no_grad(), imageio.v3.imopen(write_path, "w", plugin="pyav") as video_writer:
        dataset = EvalReconstructionDataset(video_path, ds_fps, grid)
        dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)

        video_writer.init_video_stream("h264", fps=ds_fps)
        video_writer._video_stream.options = {'crf': '0'} # lossless
        video_writer._video_stream.codec_context.time_base = video_writer._video_stream.time_base

        for batch in dataloader:
            recon = torch.stack(tokenizer(batch.unbind(0), [num_tokens]*batch_size)[0], dim=0).clamp(-1, 1)

            # ### same effect as the above line, but more verbose
            # compressed_tokens = tokenizer.encode(batch.unbind(0), [num_tokens]*batch_size)[1]['indices']
            # compressed_tokens = torch.cat(compressed_tokens, dim=0)
            # print(f"VIDEO TOKENS ({compressed_tokens.shape[0]}):\n{compressed_tokens.tolist()}")

            # recon = tokenizer.decode_indices(compressed_tokens, [num_tokens]*batch_size, [grid]*batch_size)
            # recon = torch.stack(recon, dim=0).clamp(-1, 1)
            # ###
            
            orig = rearrange(batch, "b c t h w -> (b t) h w c")
            recon = rearrange(recon, "b c t h w -> (b t) h w c")

            merged_video = torch.cat((orig, recon), dim=2).detach().cpu().float().numpy() # th(W)c concat
            merged_video = ((merged_video + 1) / 2 * 255).astype(np.uint8)
            
            for frame in merged_video:
                video_writer.write_frame(frame)

In [None]:
grid = (20, 192, 192) # frames, H-res, W-res per encode
num_tokens = 224 # 128, 224, 1024, etc.
fps = 8 # sampling FPS

in_video = "big_buck_bunny_480p_h264.mov" # https://download.blender.org/peach/bigbuckbunny_movies/big_buck_bunny_480p_h264.mov
out_video = '.'.join(in_video.split('.')[:-1]) + f"_TL{num_tokens}{'_EMA' if use_ema else ''}.mp4"

tokenize_and_reconstruct(in_video, out_video, grid, num_tokens, fps, batch_size=1)