In [1]:
import torch, io, datasets, PIL.Image,  numpy as np, time
from huggingface_hub import snapshot_download
from types import SimpleNamespace
from piq import LPIPS, DISTS, SSIMLoss
from torchvision.transforms.v2.functional import to_pil_image, pil_to_tensor
from cosmos_tokenizer.image_lib import ImageTokenizer

dataset = datasets.load_dataset("danjacobellis/kodak")
model_path = snapshot_download(repo_id='nvidia/Cosmos-Tokenizer-DI8x8')
encoder = ImageTokenizer(checkpoint_enc=f'{model_path}/encoder.jit')
decoder = ImageTokenizer(checkpoint_dec=f'{model_path}/decoder.jit')

lpips_loss = LPIPS()
dists_loss = DISTS()
ssim_loss = SSIMLoss()
print(sum( p.numel() for p in encoder.parameters())/1e6)

Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]



32.261254


In [2]:
def evaluate_throughput(sample, device='cuda', dtype=torch.bfloat16):
    encoder.to(device)
    decoder.to(device)
    img = sample['image']
    img = img.resize((int(2.5*img.size[0]),int(2.5*img.size[1]))) # 1080p
    x_orig = pil_to_tensor(img).to(device).unsqueeze(0).to(dtype) / 127.5 - 1.0
    orig_size = tuple(x_orig.shape[-2:])
    orig_dim = x_orig.numel() 
    
    # analysis transform
    t0 = time.time()
    with torch.no_grad():
        z = encoder.encode(x_orig)[0]
    analysis_time = time.time() - t0
    
    # entropy coding
    t0 = time.time()
    torch.save(z,'tmp.pth')
    entropy_code_time = time.time() - t0
    
    # entropy decoding
    t0 = time.time()
    z = torch.load('tmp.pth')
    entropy_decode_time = time.time() - t0
    
    # synthesis transform
    t0 = time.time()
    with torch.no_grad():
        x_hat = decoder.decode(z).to(torch.float).clamp(-1,1)
    synthesis_time = time.time() - t0

    return {
        'analysis_time': analysis_time,
        'entropy_code_time': entropy_code_time,
        'entropy_decode_time': entropy_decode_time,
        'synthesis_time': synthesis_time,
    }

In [3]:
for (device,dtype) in [('cuda',torch.bfloat16)]:
    results_dataset = dataset['validation'].map(lambda s: evaluate_throughput(s,device=device, dtype=dtype))
    print("mean\n---")
    for metric in [
        'analysis_time',
        'entropy_code_time',
        'entropy_decode_time',
        'synthesis_time',
    ]:
        μ = 1920*1080e-6/np.mean(results_dataset[metric])
        print(μ)
    print(1920*1080e-6/np.mean(results_dataset['analysis_time']+results_dataset['entropy_code_time']))
    print(1920*1080e-6/np.mean(results_dataset['entropy_decode_time']+results_dataset['synthesis_time']))



Map:   0%|          | 0/24 [00:00<?, ? examples/s]

mean
---
14.866209758584366
4959.735080207195
5801.106402801401
11.253463645003372
29.643566489678058
22.463351032885953
