In [1]:
import torch, timm, datasets, fastprogress, io, PIL.Image, pillow_jpls
import numpy as np, pandas as pd, matplotlib.pyplot as plt
from types import SimpleNamespace
from torchvision.transforms.v2.functional import pil_to_tensor, to_pil_image
from piq import LPIPS, DISTS, SSIMLoss
from huggingface_hub import hf_hub_download
from torchvision.transforms.v2 import Resize
from livecodec.codec import AutoCodecND, latent_to_pil, pil_to_latent

In [2]:
codec_device = 'cuda:1'
device = 'cuda:0'
ssim_loss = SSIMLoss().to(device)
lpips_loss = LPIPS().to(device)
dists_loss = DISTS().to(device)
psnr_db = lambda x01, xhat01: -10*torch.nn.functional.mse_loss(x01,xhat01).log10().item()
ssim_01 = lambda x01, xhat01: 1 - ssim_loss(x01,xhat01).item()
lpips_db = lambda x01, xhat01: -10*lpips_loss(x01,xhat01).log10().item()
dists_db = lambda x01, xhat01: -10*dists_loss(x01,xhat01).log10().item()    

checkpoint_file = hf_hub_download(
    repo_id="danjacobellis/liveaction",
    filename="lsdir_f16c48_lambdap1.pth"
)
checkpoint = torch.load(checkpoint_file, map_location="cpu",weights_only=False)
cconfig = checkpoint['config']
codec = AutoCodecND(
    dim=2,
    input_channels=cconfig.input_channels,
    J = int(np.log2(cconfig.F)),
    latent_dim=cconfig.latent_dim,
    encoder_depth = cconfig.encoder_depth,
    encoder_kernel_size = cconfig.encoder_kernel_size,
    decoder_depth = cconfig.decoder_depth,
    lightweight_encode = cconfig.lightweight_encode,
    lightweight_decode = cconfig.lightweight_decode,
).to(codec_device)
codec.load_state_dict(checkpoint['state_dict'])
codec.eval();

config = SimpleNamespace()
config.batch_size = 16
config.num_workers = 8
config.inference_size = 224

model = timm.create_model('timm/eva_giant_patch14_224.clip_ft_in1k',pretrained=True).to(device)
dataset = datasets.load_dataset('danjacobellis/imagenet_1k_val_224',split='validation')



In [3]:
def collate_fn(batch, resize):
    B = len(batch)
    y = []
    x = []
    xr = []
    for sample in batch:
        y.append(sample['cls'])
        img = sample['crop224']
        x.append(pil_to_tensor(img).unsqueeze(0))
        img = img.resize((resize,resize), resample=PIL.Image.Resampling.BICUBIC)
        xr.append(pil_to_tensor(img).unsqueeze(0))
    x = torch.cat(x)
    xr = torch.cat(xr)
    return x, xr, torch.tensor(y,dtype=torch.long)

In [None]:
G_A = lambda x: codec.quantize.compand(codec.encode(x)).round()
resize_settings = [224]

correct_matrix = []
psnr_matrix = []
ssim_matrix = []
lpips_matrix = []
dists_matrix = []
size_matrix = []

mb = fastprogress.master_bar(resize_settings)
for i_r, resize in enumerate(mb):
    dataloader = torch.utils.data.dataloader.DataLoader(
        dataset=dataset,
        num_workers=config.num_workers,
        collate_fn=lambda batch: collate_fn(batch, resize),
        batch_size=config.batch_size,
        drop_last=False,
    )
    preds_c = []
    size_bytes = []
    psnr = []
    ssim = []
    lpips = []
    dists = []
    correct = 0
    total = 0
    pb = fastprogress.progress_bar(dataloader,parent=mb)
    
    for i_batch, (x, xr, y) in enumerate(pb):
        y = y.to(device)
        x = x.to(torch.float).to(device)/127.5 - 1.0
        xr = xr.to(torch.float).to(codec_device)/127.5 - 1.0

        with torch.inference_mode():
            z = G_A(xr)
            compressed = latent_to_pil(z.cpu(), n_bits=8, C=3)
            sb = []
            for i_sample in range(config.batch_size):
                buff = io.BytesIO()
                compressed[i_sample].save(buff, format='JPEG-LS')
                sb.append(len(buff.getbuffer()))
                del buff
            size_bytes.append(sb)
            xhat = codec.decode(z).to(device)
            xhat = Resize(config.inference_size,interpolation=PIL.Image.Resampling.BICUBIC)(xhat).clamp(-1,1)

        with torch.inference_mode(): 
            pred = model(xhat).argmax(dim=1)
            is_correct = (pred == y)
            preds_c.append(is_correct.cpu())
            correct += is_correct.sum().item()
            total += x.size(0)
            acc = correct / total
            pb.comment = f'Acc: {acc:.4f}'

            for xi, xihat in zip(x,xhat):
                x01 = xi.unsqueeze(0) / 2 + 0.5
                xhat01 = xihat.unsqueeze(0) / 2 + 0.5
                psnr.append(psnr_db(x01,xhat01))
                ssim.append(ssim_01(x01,xhat01))
                lpips.append(lpips_db(x01,xhat01))
                dists.append(dists_db(x01,xhat01))
    print(acc)
    size_matrix.append(torch.tensor(sum(size_bytes,[])).unsqueeze(0))
    correct_matrix.append(torch.cat(preds_c).unsqueeze(0))
    psnr_matrix.append(torch.tensor(psnr).unsqueeze(0))
    ssim_matrix.append(torch.tensor(ssim).unsqueeze(0))
    lpips_matrix.append(torch.tensor(lpips).unsqueeze(0))
    dists_matrix.append(torch.tensor(dists).unsqueeze(0))
size_matrix = torch.cat(size_matrix)
correct_matrix = torch.cat(correct_matrix)
psnr_matrix = torch.cat(psnr_matrix)
ssim_matrix = torch.cat(ssim_matrix)
lpips_matrix = torch.cat(lpips_matrix)
dists_matrix = torch.cat(dists_matrix)

In [None]:
def drop_nan(tensor: torch.Tensor) -> torch.Tensor:
    if tensor.dim() != 2:
        raise ValueError("Input must be a 2D tensor.")
    nan_mask = torch.isnan(tensor).any(dim=0)
    inf_mask = torch.isinf(tensor).any(dim=0)
    bad_mask = nan_mask | inf_mask
    return tensor[:, ~bad_mask]

cr = 224*224*3/(size_matrix.float()).mean(dim=1)    
psnr = drop_nan(psnr_matrix).mean(dim=1)
lpips = drop_nan(lpips_matrix).mean(dim=1)
ssim = ssim_matrix.mean(dim=1)
dists = drop_nan(dists_matrix).mean(dim=1)
acc = correct_matrix.float().mean(dim=1)
plt.plot(psnr,acc,marker='.',label='PSNR')
plt.plot(30*ssim+3,acc,marker='.',label='SSIM')
plt.plot(lpips+17,acc,marker='.',label='LPIPS')
plt.plot(dists+15,acc,marker='.',label='DISTS')
plt.xlabel('Quality [various units]')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
print(f'CR:{cr}')
print(f'PSNR:{psnr}')
print(f'SSIM:{ssim}')
print(f'LPIPS:{lpips}')
print(f'DISTS:{dists}')

In [None]:
size_np = size_matrix.cpu().numpy()
psnr_np = psnr_matrix.cpu().numpy()
ssim_np = ssim_matrix.cpu().numpy()
lpips_np = lpips_matrix.cpu().numpy()
dists_np = dists_matrix.cpu().numpy()
correct_np = correct_matrix.cpu().numpy().astype(int)
dfs = []
for i, resize in enumerate(resize_settings):
    temp_df = pd.DataFrame({
        'resize': resize,
        'size_bytes': size_np[i],
        'psnr': psnr_np[i],
        'ssim': ssim_np[i],
        'lpips': lpips_np[i],
        'dists': dists_np[i],
        'correct': correct_np[i]
    })
    dfs.append(temp_df)
df = pd.concat(dfs, ignore_index=True)
df.head()

In [None]:
results = datasets.Dataset.from_pandas(df)
results.push_to_hub('danjacobellis/imagenet_224_mpq2_liveaction_f16c48_lambdap1',split='validation')

---