In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import v2, functional
from IPython.display import Video
import numpy as np
from einops import rearrange

from model.titok import TiTok

from omegaconf import OmegaConf

from decord import VideoReader, cpu, bridge
import imageio

In [2]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.manual_seed(0)

device = "cpu"
torch_dtype = torch.bfloat16

config =  OmegaConf.load("configs/tiny.yaml")
checkpoint_path = "out_tiny_w512_exp/epoch=0-step=51000.ckpt"

In [None]:
tokenizer = TiTok(config)

if config.training.torch_compile:
    tokenizer = torch.compile(tokenizer.to(device))

orig_sd = torch.load(checkpoint_path, map_location="cpu", weights_only=False)['state_dict']
global_step = torch.load(checkpoint_path, map_location="cpu", weights_only=False)['global_step']

model_sd = {}
for k, v in orig_sd.items():
    if 'loss_module' not in k:
        model_sd[k[6:]] = v
    
tokenizer.load_state_dict(model_sd)
tokenizer.eval().to(device, torch_dtype)

In [None]:
bridge.set_bridge('torch')
trg_res = config.dataset.resolution
num_frames = config.dataset.num_frames
dataset_fps = config.dataset.frames_per_second

transforms = v2.Compose([
            v2.Resize(size=trg_res, interpolation=functional.InterpolationMode.BICUBIC, antialias=True),
            v2.CenterCrop(size=trg_res),
            v2.UniformTemporalSubsample(num_frames),
            v2.ToImage(),
            v2.ToDtype(torch_dtype, scale=True),
            v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # [-1, 1]      
])

# Tokenize a video
def tokenize_and_reconstruct(video_path, write_path):
    with torch.no_grad():
        vr = VideoReader(video_path, ctx=cpu(0))
        fps = vr.get_avg_fps()
        video = vr.get_batch(list(range(len(vr)))) # get all frames
        video = video[:int(fps/dataset_fps*num_frames)]
    
        orig = transforms(video.permute(0, 3, 1, 2)).permute(1, 0, 2, 3).to(device, torch_dtype)
        z_quant, result_dict = tokenizer.encode(orig.unsqueeze(0))

        tokens_list = result_dict['codes'].cpu().tolist()[0]
        print(f"VIDEO TOKENS ({len(tokens_list)}):\n{tokens_list}")
        z_quant = tokenizer.quantize.indices_to_codes(result_dict['codes'])
        
        recon = tokenizer.decode(z_quant.to(torch_dtype)).clamp(-1, 1).squeeze(0)

        merged_video = torch.cat((orig, recon), dim=-1).permute(1, 2, 3, 0).cpu().float().numpy() # cth(W) concat -> thwc
        merged_video = ((merged_video + 1) / 2 * 255).astype(np.uint8)
        imageio.mimwrite(write_path, merged_video, fps=fps, quality=8)
        # Video(write_path, width=trg_res*2, height=trg_res, embed=True) # display

In [None]:
tokenize_and_reconstruct(f"assets/orig.mp4", f"assets/recon_{config.logging.run_name}_{global_step}.mp4")

In [6]:
# eval dataset for single multi-chunk-duration video
class EvalReconstructionDataset(torch.utils.data.Dataset): # assumes videos are > target in frames, res, etc.
    def __init__(self, video_path, trg_fps=8, trg_frames=8, trg_res=128):
        self.trg_frames = trg_frames
        self.trg_fps = trg_fps
        self.video_path = video_path
        self.transform = v2.Compose([
            v2.Resize(size=trg_res, interpolation=functional.InterpolationMode.BICUBIC, antialias=True),
            v2.CenterCrop(size=trg_res)
        ])
        bridge.set_bridge('torch')
        self.vr = VideoReader(video_path, ctx=cpu(0), num_threads=0)

        self.orig_fps_chunk_length = int(trg_frames * (self.vr.get_avg_fps() / trg_fps))
        self.num_chunks = len(self.vr) // self.orig_fps_chunk_length
        

    def __len__(self):
        return self.num_chunks

    def __getitem__(self, idx):
        start_idx = idx * self.orig_fps_chunk_length
        end_idx = start_idx + self.orig_fps_chunk_length
    
        chunk_indices = np.linspace(start_idx, end_idx - 1, self.trg_frames, dtype=int).tolist()
        chunk = torch.Tensor(self.vr.get_batch(chunk_indices))
    
        chunk = chunk.permute(0, 3, 1, 2)
        chunk = self.transform(chunk)
        chunk = chunk.permute(1, 0, 2, 3)
    
        chunk = chunk.to(torch_dtype) / 255
        chunk = (chunk * 2) - 1 # [-1, 1]

        return chunk

In [7]:
trg_res = config.dataset.resolution
num_frames = config.dataset.num_frames
dataset_fps = config.dataset.frames_per_second

# Tokenize a video (batched)
def tokenize_and_reconstruct(video_path, write_path, batch_size=1, workers=0): # would desync with many workers?
    with torch.no_grad(), imageio.v3.imopen(write_path, "w", plugin="pyav") as video_writer:
        dataset = EvalReconstructionDataset(video_path, dataset_fps, num_frames, trg_res)
        dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
        video_writer.init_video_stream("libx264", fps=dataset_fps)
        video_writer._video_stream.options = {'crf': '0'} # lossless

        for batch in dataloader:
            recon = tokenizer(batch)[0].clamp(-1, 1)
            
            orig = rearrange(batch, "b c t h w -> (b t) h w c")
            recon = rearrange(recon, "b c t h w -> (b t) h w c")

            merged_video = torch.cat((orig, recon), dim=2).cpu().float().numpy() # th(W)c concat
            merged_video = ((merged_video + 1) / 2 * 255).astype(np.uint8)
            for frame in merged_video:
                video_writer.write_frame(frame)

In [None]:
tokenize_and_reconstruct(f"assets/orig_long.mp4", f"assets/recon_long_{config.logging.run_name}_{global_step}.mp4", batch_size=32)