In [1]:
import torch, io, datasets, PIL.Image,  numpy as np, json
from throughput.image import wallclock
from huggingface_hub import hf_hub_download
from types import SimpleNamespace
from walloc import walloc
from walloc.walloc import latent_to_pil, pil_to_latent
from torchvision.transforms.v2.functional import to_pil_image, pil_to_tensor, resize

dataset = datasets.load_dataset("danjacobellis/kodak")
config_file = hf_hub_download(
    repo_id="danjacobellis/walloc",
    filename="RGB_4x.json"
)
codec_config = SimpleNamespace(**json.load(open(config_file)))
checkpoint_file = hf_hub_download(
    repo_id="danjacobellis/walloc",
    filename="RGB_4x.pth"
)
checkpoint = torch.load(checkpoint_file, map_location="cpu",weights_only=False)
codec = walloc.Codec2D(
    channels = codec_config.channels,
    J = codec_config.J,
    Ne = codec_config.Ne,
    Nd = codec_config.Nd,
    latent_dim = codec_config.latent_dim,
    latent_bits = codec_config.latent_bits,
    lightweight_encode = codec_config.lightweight_encode
)
codec.load_state_dict(checkpoint['model_state_dict'])
dtype = torch.float

  from pkg_resources import resource_stream


In [None]:
for device in ['cpu','xpu']:
    wallclock.reset()
    codec.to(device)


    for _ in range(5):
        for sample in dataset['validation']:
            
            img = sample['image']
            x = pil_to_tensor(img).unsqueeze(0).to(dtype) / 127.5 - 1.0
            x = x.to(device)
            
            with wallclock('encode'):
                with wallclock('analysis'):
                    with torch.inference_mode():
                        x = codec.encoder(codec.wavelet_analysis(x,J=codec.J))

                with wallclock('transfer'):
                    x = x.cpu()
                    
                with wallclock('store'):
                    x = latent_to_pil(x, n_bits=8, C=3)[0]
                    buff = io.BytesIO()
                    x.save(buff, format='WEBP', lossless=True)
                
            with wallclock('decode'):
                x = pil_to_latent([PIL.Image.open(buff)], N=codec_config.latent_dim, n_bits=8, C=3).to(device).to(dtype)        
                x = x.to(device)
                with torch.no_grad():
                    xhat = codec.wavelet_synthesis(codec.decoder(x),J=codec.J).clamp(-1,1)
            break
            
    display(to_pil_image(x[0]/2+0.5))
    wallclock.summary(pixels=512*768)