In [1]:
import math
import io
import shutil
import os
import sys
from os import path
import csv

import numpy as np
import clip
import cv2
from PIL import Image, ImageChops, ImageDraw, ImageFilter
import matplotlib.pyplot as plt
from skimage.transform import resize

import torch
from torchvision import transforms
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchvision.transforms import InterpolationMode

from compressai.zoo import cheng2020_attn
from torchvision import transforms
from pytorch_msssim import ms_ssim

device = 'cuda' if torch.cuda.is_available() else 'cpu'
checkpoint_epoch = ""

In [10]:
BASE_IMG_DIM = 512
VAE_QUALITY = 1
VAE_IMG_DIM = int(512/1)

In [11]:
# If using pre-trained
VAE_model = cheng2020_attn(quality=VAE_QUALITY, pretrained=True).eval().to(device)

Downloading: "https://compressai.s3.amazonaws.com/models/v1/cheng2020_attn-mse-1-465f2b64.pth.tar" to C:\Users\CrashedBboy/.cache\torch\hub\checkpoints\cheng2020_attn-mse-1-465f2b64.pth.tar
100%|█████████████████████████████████████████████████████████████████████████████| 54.3M/54.3M [00:01<00:00, 39.9MB/s]


In [29]:
# If using fine-tuned
VAE_model = cheng2020_attn(quality=6, pretrained=False).eval().to(device)

checkpoint_epoch = f"{299}"
checkpoint_path = path.join('CompressAI', f'{checkpoint_epoch}_checkpoint.pth.tar')
checkpoint = torch.load(checkpoint_path)
VAE_model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [4]:
def compute_bpp(out_net):
    size = out_net['x_hat'].size()
    num_pixels = size[0] * size[2] * size[3]
    return sum(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels) for likelihoods in out_net['likelihoods'].values()).item()
def export_bpp(bpp_table, path):
    assert 'ID' in bpp_table.keys()
    
    headers = bpp_table.keys()
    # Open a new CSV file for writing
    with open(path, 'w', newline='') as csvfile:
        # Create a CSV writer object
        csvwriter = csv.writer(csvfile)
        # Write the headers (keys)
        csvwriter.writerow(headers)
        # Write the rows (values)
        for i in range(len(bpp_table['ID'])):
            csvwriter.writerow([bpp_table[key][i] for key in headers])

In [30]:
dataset = 'valid'
if checkpoint_epoch != "":
    OUTPUT_TAG = f"ft_e{checkpoint_epoch}"
else:
    OUTPUT_TAG = f'pixpat_q{VAE_QUALITY}'

input_dir = path.join('dataset', 'clic2020-professional', 'preprocessed', dataset)
output_dir = path.join('dataset', 'clic2020-professional', 'preprocessed', f'{dataset}_{OUTPUT_TAG}_d{VAE_IMG_DIM}')

if not path.exists(output_dir):
    os.makedirs(output_dir)

files = [f for f in os.listdir(input_dir) if path.isfile(path.join(input_dir, f))]
bpp_table = {
    "ID": [],
    "BPP": []
}
for f in files:
    if ".png" in f:
        img_id = f.split('.png')[0]
        img = Image.open(path.join(input_dir, f)).convert('RGB')
        if VAE_IMG_DIM is not BASE_IMG_DIM:
            img = img.resize((VAE_IMG_DIM, VAE_IMG_DIM))
        x = transforms.ToTensor()(img).unsqueeze(0).to(device)
        output = None
        with torch.no_grad():
            output = VAE_model(x)
            output['x_hat'].clamp_(0, 1)

        reconstruction = transforms.ToPILImage()(output['x_hat'].squeeze())
        reconstruction.save(path.join(output_dir, f'{img_id}.png'))
        bpp = compute_bpp(output)
        bpp_table["ID"].append(img_id)
        bpp_table["BPP"].append(bpp)
        print(f"{img_id}: {bpp} BPP")

average_bpp = sum(bpp_table["BPP"])/len(bpp_table["BPP"])
export_bpp(bpp_table, path.join(output_dir, f'bpp_avg_{average_bpp}.csv'))
print(f'Quality: {VAE_QUALITY}, Dimension: {VAE_IMG_DIM}, Average: {average_bpp}')

alberto-montalesi-176097: 0.05644606426358223 BPP
alejandro-escamilla-6: 0.06585147976875305 BPP
ales-krivec-15949: 0.09637469798326492 BPP
alexander-shustov-73: 0.1207105815410614 BPP
allef-vinicius-109434: 0.04451465979218483 BPP
amy-zhang-15940: 0.05188096687197685 BPP
andrew-ruiz-376: 0.10786150395870209 BPP
benjamin-sloth-lindgreen-705: 0.05217687040567398 BPP
casey-fyfe-999: 0.020797638222575188 BPP
clem-onojeghuo-33741: 0.071669802069664 BPP
daniel-robert-405: 0.08775965869426727 BPP
davide-ragusa-716: 0.02797914668917656 BPP
dogancan-ozturan-395: 0.10308688879013062 BPP
felix-russell-saw-140699: 0.0433218777179718 BPP
gian-reto-tarnutzer-45212: 0.06157468259334564 BPP
jared-erondu-21325: 0.038278691470623016 BPP
jason-briscoe-149782: 0.04717516154050827 BPP
jeremy-cai-1174: 0.12954120337963104 BPP
juskteez-vu-1041: 0.022250622510910034 BPP
kazuend-28556: 0.03801722452044487 BPP
lobostudio-hamburg-75377: 0.030657319352030754 BPP
martin-wessely-211: 0.08249946683645248 BPP
martyn