In [2]:
import torch
import glob
import ffmpeg
from omegaconf import OmegaConf
from einops import rearrange
from huggingface_hub import snapshot_download
import zipfile
import glob
import re
import os
import shutil
from tqdm import tqdm

from model.titok import TiTok

from torchmetrics.image import PeakSignalNoiseRatio, LearnedPerceptualImagePatchSimilarity, StructuralSimilarityIndexMeasure
from torchmetrics import MetricCollection

from model.metrics.fvd import FVDCalculator

In [3]:
# download MCL_JCV 720p source dataset
val_ds_path = 'mcl_jcv_dataset'

if not os.path.exists(val_ds_path):
    temp_zips = os.path.join(val_ds_path, 'temp')
    # only need first three zips for 720p
    snapshot_download(repo_id="uscmcl/MCL-JCV_Dataset", repo_type='dataset', local_dir=temp_zips, allow_patterns=[f"*00{x}.zip" for x in range(1, 4)])
    for zip_path in glob.glob(os.path.join(temp_zips, '*.zip')):
        with zipfile.ZipFile(zip_path) as zf:
            to_extract = [x for x in zf.namelist() if re.search(r".*videoSRC.._1280x720_..\.yuv", x)]
            for file in to_extract:
                zf.extract(file, path=val_ds_path)
                shutil.move(os.path.join(val_ds_path, file), os.path.join(val_ds_path, os.path.basename(file)))
    shutil.rmtree(temp_zips)
    shutil.rmtree(os.path.join(val_ds_path, 'MCL_JCV'))

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

MCL_JCV-20250507T045226Z-002.zip:   0%|          | 0.00/2.15G [00:00<?, ?B/s]

MCL_JCV-20250507T045226Z-001.zip:   0%|          | 0.00/1.16G [00:00<?, ?B/s]

MCL_JCV-20250507T045226Z-003.zip:   0%|          | 0.00/995M [00:00<?, ?B/s]

In [3]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.manual_seed(0)

device = "cuda"
torch_dtype = torch.bfloat16

config =  OmegaConf.load("configs/tiny_natten_gan.yaml")
checkpoint_path = "out_tiny_w512_exp_2/epoch=0-step=57000.ckpt"

In [4]:
tokenizer = TiTok(config)

if config.training.torch_compile:
    tokenizer = torch.compile(tokenizer.to(device))

orig_sd = torch.load(checkpoint_path, map_location="cpu", weights_only=False)['state_dict']
global_step = torch.load(checkpoint_path, map_location="cpu", weights_only=False)['global_step']

model_sd = {}
for k, v in orig_sd.items():
    if 'loss_module' not in k:
        model_sd[k[6:]] = v
    
tokenizer.load_state_dict(model_sd)
tokenizer.eval().to(device, torch_dtype)

eval_metrics = MetricCollection(
    {
        "psnr": PeakSignalNoiseRatio(),
        "ssim": StructuralSimilarityIndexMeasure(),
        "lpips": LearnedPerceptualImagePatchSimilarity(net_type='vgg').eval(),
    }
).to(device, torch_dtype)

fvd_metric = FVDCalculator()

In [5]:
trg_res = config.dataset.resolution
trg_fps = config.dataset.frames_per_second
trg_frames = config.dataset.num_frames

src_files = glob.glob(os.path.join(val_ds_path, '*.yuv'))
num_eval = 0

for src_file in tqdm(src_files):
    src_fps = src_file.split('_')[-1].replace('.yuv', '')
    width, height = [int(i) for i in src_file.split('_')[-2].split('x')]
    
    x_offset = (width - height) // 2
    out, _ = (
        ffmpeg.input(src_file, format='rawvideo', pix_fmt='yuv420p', s='{}x{}'.format(width, height), framerate=src_fps)
        .crop(x=x_offset, y=0, width='ih', height='ih') # crop to square
        .filter('scale', width=trg_res, height=trg_res) # resize
        .filter('fps', trg_fps)
        .output('pipe:', format='rawvideo', pix_fmt='rgb24', v='error')
        .run(capture_stdout=True)
    )

    video = torch.frombuffer(out, dtype=torch.uint8).reshape([-1, trg_res, trg_res, 3])

    if video.shape[0] >= trg_frames:
        num_chunks = video.shape[0] // trg_frames
        chunked_video = video[:num_chunks*trg_frames].reshape(-1, trg_frames, trg_res, trg_res, 3)
        chunked_video = (chunked_video.permute(0, 4, 1, 2, 3).to(device, torch_dtype) / 255) # BTHWC -> BCTHW, 0-255
        chunked_video = (chunked_video * 2) - 1.0 # -1, 1

        for chunk in chunked_video: # not batching? Only ~30 vids, not worth it?
            with torch.no_grad():
                orig = chunk.unsqueeze(0)
                recon = tokenizer(orig)[0].clamp(-1, 1)

                fvd_metric.update(real=orig, generated=recon)
                eval_metrics.update(rearrange(recon, "1 c t h w -> t c h w"), rearrange(orig, "1 c t h w -> t c h w")) # averages automatically
    
                num_eval += 1

  video = torch.frombuffer(out, dtype=torch.uint8).reshape([-1, trg_res, trg_res, 3])
100%|███████████████████████████████████████████████████████████████████████████████████| 30/30 [00:49<00:00,  1.66s/it]


In [6]:
eval_scores = eval_metrics.compute()
eval_scores['FVD'] = fvd_metric.gather()

eval_metrics.reset()
fvd_metric.reset()

print(f"Num eval: {num_eval}")
print("Eval scores: " + ' | '.join([f"{k.upper()} {v:.2f}" for k, v in eval_scores.items()]))

Num eval: 150
Eval scores: LPIPS 0.52 | PSNR 17.51 | SSIM 0.41 | FVD 793.79
