In [None]:
import os
import io
import time

import math
import numpy as np
import torch
from pytorch_msssim import ms_ssim

import matplotlib.pyplot as plt
from PIL import Image

In [None]:
img = Image.open("../data/assets/stmalo_fracape.png").convert("RGB")
print(img.size)

In [None]:
def pillow_encode(img, fmt="jpeg", quality=10):
    tmp = io.BytesIO()
    img.save(tmp, format=fmt, quality=quality)
    tmp.seek(0)
    filesize = tmp.getbuffer().nbytes
    bpp = filesize * float(8) / (img.size[0] * img.size[1])
    rec = Image.open(tmp)
    return rec, bpp


def find_closest_bpp(target, img, fmt="jpeg"):
    lower = 0
    upper = 100
    prev_mid = upper

    def _psnr(a, b):
        a = np.asarray(a).astype(np.float32)
        b = np.asarray(b).astype(np.float32)
        mse = np.mean(np.square(a - b))
        try:
            return 20*math.log10(255.) -10. * math.log10(mse + 10e-6)
        except:
            print(mse)
            return -1 * float("inf")
    
    for i in range(10):
        mid = (upper - lower) / 2 + lower
        if int(mid) == int(prev_mid):
            break
        rec, bpp = pillow_encode(img, fmt=fmt, quality=int(mid))
        if bpp > target:
            upper = mid - 1
        else:
            lower = mid
    psnr_val = _psnr(rec, img)
    return rec, bpp, psnr_val

def find_closest_psnr(target, img, fmt="jpeg"):
    lower = 0
    upper = 100
    prev_mid = upper
    
    def _psnr(a, b):
        a = np.asarray(a).astype(np.float32)
        b = np.asarray(b).astype(np.float32)
        mse = np.mean(np.square(a - b))
        try:
            return 20*math.log10(255.) -10. * math.log10(mse + 10e-6)
        except:
            print(mse)
            return -1 * float("inf")
    
    for i in range(10):
        mid = (upper - lower) / 2 + lower
        if int(mid) == int(prev_mid):
            break
        prev_mid = mid
        rec, bpp = pillow_encode(img, fmt=fmt, quality=int(mid))
        psnr_val = _psnr(rec, img)
        if psnr_val > target:
            upper = mid - 1
        else:
            lower = mid
    return rec, bpp, psnr_val

def find_closest_msssim(target, img, fmt="jpeg"):
    lower = 0
    upper = 100
    prev_mid = upper
    
    def _mssim(a, b):
        a = torch.from_numpy(np.asarray(a).astype(np.float32)).permute(2, 0, 1).unsqueeze(0)
        b = torch.from_numpy(np.asarray(b).astype(np.float32)).permute(2, 0, 1).unsqueeze(0)
        return ms_ssim(a, b, data_range=255.).item()

    for i in range(10):
        mid = (upper - lower) / 2 + lower
        if int(mid) == int(prev_mid):
            break
        prev_mid = mid
        rec, bpp = pillow_encode(img, fmt=fmt, quality=int(mid))
        msssim_val = _mssim(rec, img)
        if msssim_val > target:
            upper = mid - 1
        else:
            lower = mid
    return rec, bpp, msssim_val

In [None]:
rec, bpp = pillow_encode(img, "jpeg")
print(type(rec))

In [None]:
codecs = ["jpeg", "jpeg2000", "webp", "png"]

bpp_recs = []
bpp_bpps = []
bpp_psnrs = []

psnr_recs = []
psnr_bpps = []
psnr_psnrs = []

target_bpp = 0.66
target_psnr = 34.11

for codec in codecs:
    rec, bpp, psnr = find_closest_bpp(target_bpp, img, codec)
    bpp_recs.append(rec)
    bpp_bpps.append(bpp)
    bpp_psnrs.append(psnr)
    
    rec, bpp, psnr = find_closest_psnr(target_psnr, img, codec)
    psnr_recs.append(rec)
    psnr_bpps.append(bpp)
    psnr_psnrs.append(psnr)

In [None]:
recs = bpp_recs
bpps = bpp_bpps
psnrs = bpp_psnrs

fig, axs = plt.subplots(len(codecs), 2, figsize=(2*6, len(codecs)*5))

for i in range(len(codecs)):
	axs[i, 0].imshow(img)
	axs[i, 0].title.set_text("Original")
	axs[i, 0].axis("off")

	axs[i, 1].imshow(recs[i])
	axs[i, 1].title.set_text(f"{codecs[i]} | PSNR: {psnrs[i]:.2f} | Bit rate: {bpps[i]:.2f} bpp")
	axs[i, 1].axis("off")

plt.tight_layout()
plt.show()

In [None]:
recs = psnr_recs
bpps = psnr_bpps
psnrs = psnr_psnrs

fig, axs = plt.subplots(len(codecs), 2, figsize=(2*6, len(codecs)*5))

for i in range(len(codecs)):
	axs[i, 0].imshow(img)
	axs[i, 0].title.set_text("Original")
	axs[i, 0].axis("off")

	axs[i, 1].imshow(recs[i])
	axs[i, 1].title.set_text(f"{codecs[i]} | PSNR: {psnrs[i]:.2f} | Bit rate: {bpps[i]:.2f} bpp")
	axs[i, 1].axis("off")

plt.tight_layout()
plt.show()