In [1]:
import torch
import glob
import ffmpeg
from omegaconf import OmegaConf
from einops import rearrange
import gdown
import os
from tqdm import tqdm

from model.titok import TiTok
from base_tokenizers import load_vae

from torchmetrics.image import PeakSignalNoiseRatio, LearnedPerceptualImagePatchSimilarity, StructuralSimilarityIndexMeasure
from torchmetrics import MetricCollection

In [2]:
# download MCL_JCV 720p source dataset
gdrive_ds = 'https://drive.google.com/drive/folders/12A8gk07j3OppdcY5JbxSg-ozw_BXzcIQ'
val_ds_path = 'val_dataset'
os.makedirs(val_ds_path, exist_ok=True)
gdown.download_folder(gdrive_ds, output=val_ds_path, resume=True)

Retrieving folder contents


Processing file 1b0oxiZ3udGJ9Re-CeUTYg-EUorNX3Y7J videoSRC01_1280x720_30.yuv
Processing file 1DHBUXCHroqC7MXR1bL3P8mfEGlUDjRl0 videoSRC02_1280x720_30.yuv
Processing file 1EVj4S4MheQ8L3voaQ6LKMf0cySTh7gLK videoSRC03_1280x720_30.yuv
Processing file 1K4w5pUbSFNr46FQnssmeTpgiRTETqRiN videoSRC04_1280x720_30.yuv
Processing file 1j8Oh0Agdfo-BFtN2i2hgytM8ArjKHx5q videoSRC05_1280x720_25.yuv
Processing file 1EeSFFcs16QHkz4JZOAtG42K0qt-6XDaq videoSRC06_1280x720_25.yuv
Processing file 1aNdQmPGuEpxUp8JDXZfzy18PmKUrSfyu videoSRC07_1280x720_25.yuv
Processing file 1l-AZ1CS6bS84qixmD7L2Go_g48b736HN videoSRC08_1280x720_25.yuv
Processing file 1BdVeZ7GIWcUfNSgCgWxZOom7LjO_Z6FM videoSRC09_1280x720_25.yuv
Processing file 1U3q2ttP1SSg4I-1ZjqrhVe0tbvRud2G_ videoSRC10_1280x720_30.yuv
Processing file 1_KZp7H5LdLMVTQEDSg7x2wzZc0PrViLX videoSRC11_1280x720_30.yuv
Processing file 1gMX-z0x9jxaplh7-sv8LL7tihmxUebol videoSRC12_1280x720_30.yuv
Processing file 1WL2kCcvpqy-CZXtlp_cE9kK-2OOJR_7j videoSRC13_1280x720_30.yuv

Retrieving folder contents completed
Building directory structure
Building directory structure completed
Skipping already downloaded file val_dataset/videoSRC01_1280x720_30.yuv
Skipping already downloaded file val_dataset/videoSRC02_1280x720_30.yuv
Skipping already downloaded file val_dataset/videoSRC03_1280x720_30.yuv
Skipping already downloaded file val_dataset/videoSRC04_1280x720_30.yuv
Skipping already downloaded file val_dataset/videoSRC05_1280x720_25.yuv
Skipping already downloaded file val_dataset/videoSRC06_1280x720_25.yuv
Skipping already downloaded file val_dataset/videoSRC07_1280x720_25.yuv
Skipping already downloaded file val_dataset/videoSRC08_1280x720_25.yuv
Skipping already downloaded file val_dataset/videoSRC09_1280x720_25.yuv
Skipping already downloaded file val_dataset/videoSRC10_1280x720_30.yuv
Skipping already downloaded file val_dataset/videoSRC11_1280x720_30.yuv
Skipping already downloaded file val_dataset/videoSRC12_1280x720_30.yuv
Skipping already downloaded fil

['val_dataset/videoSRC01_1280x720_30.yuv',
 'val_dataset/videoSRC02_1280x720_30.yuv',
 'val_dataset/videoSRC03_1280x720_30.yuv',
 'val_dataset/videoSRC04_1280x720_30.yuv',
 'val_dataset/videoSRC05_1280x720_25.yuv',
 'val_dataset/videoSRC06_1280x720_25.yuv',
 'val_dataset/videoSRC07_1280x720_25.yuv',
 'val_dataset/videoSRC08_1280x720_25.yuv',
 'val_dataset/videoSRC09_1280x720_25.yuv',
 'val_dataset/videoSRC10_1280x720_30.yuv',
 'val_dataset/videoSRC11_1280x720_30.yuv',
 'val_dataset/videoSRC12_1280x720_30.yuv',
 'val_dataset/videoSRC13_1280x720_30.yuv',
 'val_dataset/videoSRC14_1280x720_30.yuv',
 'val_dataset/videoSRC15_1280x720_30.yuv',
 'val_dataset/videoSRC16_1280x720_30.yuv',
 'val_dataset/videoSRC17_1280x720_24.yuv',
 'val_dataset/videoSRC18_1280x720_25.yuv',
 'val_dataset/videoSRC19_1280x720_30.yuv',
 'val_dataset/videoSRC20_1280x720_25.yuv',
 'val_dataset/videoSRC21_1280x720_24.yuv',
 'val_dataset/videoSRC22_1280x720_24.yuv',
 'val_dataset/videoSRC23_1280x720_24.yuv',
 'val_datas

In [3]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.manual_seed(0)

device = "cuda"
torch_dtype = torch.bfloat16

config =  OmegaConf.load("stage_2/config.yaml")
checkpoint_path = "stage_2/model.ckpt"
vae_path = "wfvae-16"

In [4]:
tokenizer = TiTok(config)

if config.training.torch_compile:
    tokenizer = torch.compile(tokenizer.to(device))

orig_sd = torch.load(checkpoint_path, map_location="cpu", weights_only=False)['state_dict']
model_sd = {}
for k, v in orig_sd.items():
    if not 'disc' in k:
        model_sd[k[6:]] = v
    
tokenizer.load_state_dict(model_sd)
tokenizer.eval().to(device, torch_dtype)
vae = load_vae(vae_name=config.model.vae.type, model_path=vae_path, embed_dim=config.model.vae.latent_channels)
vae.eval().to(device, torch_dtype)

vae_metrics = MetricCollection(
    {
        "psnr": PeakSignalNoiseRatio(),
        "ssim": StructuralSimilarityIndexMeasure(),
        "lpips": LearnedPerceptualImagePatchSimilarity(net_type='vgg').eval(),
    }
).to(device, torch_dtype)

titok_metrics = vae_metrics.clone()

In [5]:
trg_res = config.dataset.resolution
trg_fps = config.dataset.frames_per_second
trg_frames = config.dataset.num_frames

src_files = glob.glob(os.path.join(val_ds_path, '*.yuv'))
num_eval = 0

for src_file in tqdm(src_files):
    src_fps = src_file.split('_')[-1].replace('.yuv', '')
    width, height = [int(i) for i in src_file.split('_')[-2].split('x')]
    
    x_offset = (width - height) // 2
    out, _ = (
        ffmpeg.input(src_file, format='rawvideo', pix_fmt='yuv420p', s='{}x{}'.format(width, height), framerate=src_fps)
        .crop(x=x_offset, y=0, width='ih', height='ih') # crop to square
        .filter('scale', width=trg_res, height=trg_res) # resize
        .filter('fps', trg_fps)
        .output('pipe:', format='rawvideo', pix_fmt='rgb24', v='error')
        .run(capture_stdout=True)
    )

    video = torch.frombuffer(out, dtype=torch.uint8).reshape([-1, trg_res, trg_res, 3])

    if video.shape[0] >= trg_frames:
        num_chunks = video.shape[0] // trg_frames
        chunked_video = video[:num_chunks*trg_frames].reshape(-1, trg_frames, trg_res, trg_res, 3)
        chunked_video = (chunked_video.permute(0, 4, 1, 2, 3).to('cuda:0', torch.bfloat16) / 255) # BTHWC -> BCTHW, 0-255
        chunked_video = (chunked_video * 2) - 1.0 # -1, 1

        for chunk in chunked_video: # not batching? Only ~30 vids, not worth it?
            with torch.no_grad():
                vae_encoded = vae.encode(chunk.unsqueeze(0))
                titok_encoded, _ = tokenizer(vae_encoded)
                titok_decoded = vae.decode(titok_encoded)
                vae_decoded = vae.decode(vae_encoded)

                recon_titok = rearrange(titok_decoded.squeeze(0), "c t h w -> t c h w")
                recon_vae = rearrange(vae_decoded.squeeze(0), "c t h w -> t c h w")
                orig = rearrange(chunk, "c t h w -> t c h w")

                titok_metrics.update(recon_titok, orig)
                vae_metrics.update(recon_vae, orig) # averages automatically
    
                num_eval += 1

  video = torch.frombuffer(out, dtype=torch.uint8).reshape([-1, trg_res, trg_res, 3])
100%|███████████████████████████████████████████████████████████████████████████████████| 30/30 [01:47<00:00,  3.58s/it]


In [16]:
vae_scores = vae_metrics.compute()
titok_scores = titok_metrics.compute()

print(f"Num eval: {num_eval}")
print("VAE-only scores:    " + ' | '.join([f"{k.upper()} {v:.2f}" for k, v in vae_scores.items()]))
print("TiTok-Video scores: " + ' | '.join([f"{k.upper()} {v:.2f}" for k, v in titok_scores.items()]))

Num eval: 60
VAE-only scores:    LPIPS 0.22 | PSNR 24.02 | SSIM 0.81
TiTok-Video scores: LPIPS 0.67 | PSNR 14.93 | SSIM 0.36
