In [1]:
import os
import io
import time
import json

import math
import numpy as np
import torch
from pytorch_msssim import ms_ssim

import matplotlib.pyplot as plt
from PIL import Image

In [None]:
img = Image.open("../data/assets/stmalo_fracape.png").convert("RGB")
print(img.size)

In [None]:
def pillow_encode(img, fmt="jpeg", quality=10):
    tmp = io.BytesIO()
    img.save(tmp, format=fmt, quality=quality)
    tmp.seek(0)
    filesize = tmp.getbuffer().nbytes
    bpp = filesize * float(8) / (img.size[0] * img.size[1])
    rec = Image.open(tmp)
    return rec, bpp


def find_closest_bpp(target, img, fmt="jpeg"):
    lower = 0
    upper = 100
    prev_mid = upper

    def _psnr(a, b):
        a = np.asarray(a).astype(np.float32)
        b = np.asarray(b).astype(np.float32)
        mse = np.mean(np.square(a - b))
        try:
            return 20*math.log10(255.) -10. * math.log10(mse + 10e-6)
        except:
            print(mse)
            return -1 * float("inf")
    
    for i in range(10):
        mid = (upper - lower) / 2 + lower
        if int(mid) == int(prev_mid):
            break
        rec, bpp = pillow_encode(img, fmt=fmt, quality=int(mid))
        if bpp > target:
            upper = mid - 1
        else:
            lower = mid
    psnr_val = _psnr(rec, img)
    return rec, bpp, psnr_val

def find_closest_psnr(target, img, fmt="jpeg"):
    lower = 0
    upper = 100
    prev_mid = upper
    
    def _psnr(a, b):
        a = np.asarray(a).astype(np.float32)
        b = np.asarray(b).astype(np.float32)
        mse = np.mean(np.square(a - b))
        try:
            return 20*math.log10(255.) -10. * math.log10(mse + 10e-6)
        except:
            print(mse)
            return -1 * float("inf")
    
    for i in range(10):
        mid = (upper - lower) / 2 + lower
        if int(mid) == int(prev_mid):
            break
        prev_mid = mid
        rec, bpp = pillow_encode(img, fmt=fmt, quality=int(mid))
        psnr_val = _psnr(rec, img)
        if psnr_val > target:
            upper = mid - 1
        else:
            lower = mid
    return rec, bpp, psnr_val

def find_closest_msssim(target, img, fmt="jpeg"):
    lower = 0
    upper = 100
    prev_mid = upper
    
    def _mssim(a, b):
        a = torch.from_numpy(np.asarray(a).astype(np.float32)).permute(2, 0, 1).unsqueeze(0)
        b = torch.from_numpy(np.asarray(b).astype(np.float32)).permute(2, 0, 1).unsqueeze(0)
        return ms_ssim(a, b, data_range=255.).item()

    for i in range(10):
        mid = (upper - lower) / 2 + lower
        if int(mid) == int(prev_mid):
            break
        prev_mid = mid
        rec, bpp = pillow_encode(img, fmt=fmt, quality=int(mid))
        msssim_val = _mssim(rec, img)
        if msssim_val > target:
            upper = mid - 1
        else:
            lower = mid
    return rec, bpp, msssim_val

In [None]:
codecs = ["jpeg", "jpeg2000", "webp", "png"]

bpp_recs = []
bpp_bpps = []
bpp_psnrs = []

psnr_recs = []
psnr_bpps = []
psnr_psnrs = []

target_bpp = 0.66
target_psnr = 34.11

for codec in codecs:
    rec, bpp, psnr = find_closest_bpp(target_bpp, img, codec)
    bpp_recs.append(rec)
    bpp_bpps.append(bpp)
    bpp_psnrs.append(psnr)
    
    rec, bpp, psnr = find_closest_psnr(target_psnr, img, codec)
    psnr_recs.append(rec)
    psnr_bpps.append(bpp)
    psnr_psnrs.append(psnr)

In [None]:
recs = bpp_recs
bpps = bpp_bpps
psnrs = bpp_psnrs

fig, axs = plt.subplots(len(codecs), 2, figsize=(2*6, len(codecs)*5))

for i in range(len(codecs)):
	axs[i, 0].imshow(img)
	axs[i, 0].title.set_text("Original")
	axs[i, 0].axis("off")

	axs[i, 1].imshow(recs[i])
	axs[i, 1].title.set_text(f"{codecs[i]} | PSNR: {psnrs[i]:.2f} | Bit rate: {bpps[i]:.2f} bpp")
	axs[i, 1].axis("off")

plt.tight_layout()
plt.show()

In [None]:
recs = psnr_recs
bpps = psnr_bpps
psnrs = psnr_psnrs

fig, axs = plt.subplots(len(codecs), 2, figsize=(2*6, len(codecs)*5))

for i in range(len(codecs)):
	axs[i, 0].imshow(img)
	axs[i, 0].title.set_text("Original")
	axs[i, 0].axis("off")

	axs[i, 1].imshow(recs[i])
	axs[i, 1].title.set_text(f"{codecs[i]} | PSNR: {psnrs[i]:.2f} | Bit rate: {bpps[i]:.2f} bpp")
	axs[i, 1].axis("off")

plt.tight_layout()
plt.show()

In [3]:
import json
id = "20250308_103738"
with open(f"test_res/{id}/avg_metrics_kodak.json") as f:
	codecs_avg_metrics = json.load(f)

In [4]:
coi = "webp"
moi = "zeus-fps"

min_coi = None
min_moi = 0

max_coi = None
max_moi = 0

for codec, metrics in codecs_avg_metrics.items():
	if codec.split("_")[0] == coi:
		if min_coi is None or min_moi > metrics[moi]:
			min_coi = codec
			min_moi = metrics[moi]

		if max_coi is None or max_moi < metrics[moi]:
			max_coi = codec
			max_moi = metrics[moi]
                        
print(f"[{coi}] Min {moi}: {min_moi} ({min_coi})")
print(f"[{coi}] Max {moi}: {max_moi} ({max_coi})")

[webp] Min zeus-fps: 19.113249308068347 (webp_100_3)
[webp] Max zeus-fps: 39.60535971627532 (webp_0_2)


Webp:
- energy: 0_0, 100_4
- psnr: 0_0, 100_0
- bit rate: 0_0, 100_0
- throughtput: 100_3, 0_2

In [5]:
import json
id = "20250310_134041"
with open(f"../kd_lic_experiments/test_res/{id}/avg_metrics_kodak.json") as f:
	all_avg_metrics = json.load(f)

avg_metrics = all_avg_metrics["proposed"]
pretrained_avg_metrics = all_avg_metrics["pretrained"]

In [6]:
import numpy as np
import matplotlib.pyplot as plt

# Set-up matplotlib
plt.rcParams["axes.prop_cycle"] = plt.rcParams["axes.prop_cycle"][1:]

In [12]:
# Retrieve average metrics as lists
brs = [m["bit-rate"] for _, m in avg_metrics.items()]
pretrained_brs = [m["bit-rate"] for _, m in pretrained_avg_metrics.items()]
jpeg_brs = [m["bit-rate"] for codec, m in codecs_avg_metrics.items() if codec.split("_")[0] == "jpeg"]
jpeg2000_brs = [m["bit-rate"] for codec, m in codecs_avg_metrics.items() if codec.split("_")[0] == "jpeg2000"]
webp_brs = [m["bit-rate"] for codec, m in codecs_avg_metrics.items() if codec.split("_")[0] == "webp"]

psnrs = [m["psnr"] for _, m in avg_metrics.items()]
pretrained_psnrs = [m["psnr"] for _, m in pretrained_avg_metrics.items()]
jpeg_psnrs = [m["psnr"] for codec, m in codecs_avg_metrics.items() if codec.split("_")[0] == "jpeg"]
jpeg2000_psnrs = [m["psnr"] for codec, m in codecs_avg_metrics.items() if codec.split("_")[0] == "jpeg2000"]
webp_psnrs = [m["psnr"] for codec, m in codecs_avg_metrics.items() if codec.split("_")[0] == "webp"]

pretrained_msssim = [-10*np.log10(1-m["ms-ssim"]) for _, m in pretrained_avg_metrics.items()]

zeus_fps = [m["zeus-fps"] for _, m in avg_metrics.items()]
zeus_pretrained_fps = [m["zeus-fps"] for _, m in pretrained_avg_metrics.items()]
zeus_jpeg_fps = [m["zeus-fps"] for codec, m in codecs_avg_metrics.items() if codec.split("_")[0] == "jpeg"]
zeus_jpeg2000_fps = [m["zeus-fps"] for codec, m in codecs_avg_metrics.items() if codec.split("_")[0] == "jpeg2000"]
zeus_webp_fps = [m["zeus-fps"] for codec, m in codecs_avg_metrics.items() if codec.split("_")[0] == "webp"]

zeus_energies = [m["zeus-energy"] for _, m in avg_metrics.items()]
zeus_pretrained_energies = [m["zeus-energy"] for _, m in pretrained_avg_metrics.items()]
zeus_jpeg_energies = [m["zeus-energy"] for codec, m in codecs_avg_metrics.items() if codec.split("_")[0] == "jpeg"]
zeus_jpeg2000_energies = [m["zeus-energy"] for codec, m in codecs_avg_metrics.items() if codec.split("_")[0] == "jpeg2000"]
zeus_webp_energies = [m["zeus-energy"] for codec, m in codecs_avg_metrics.items() if codec.split("_")[0] == "webp"]

In [10]:
# Plot average rate-distortion curves
fig, axs = plt.subplots(1, 2, figsize=(13, 5))

axs[0].plot(pretrained_brs, pretrained_psnrs, "blue", linestyle="--",
            linewidth=1, label="pre-trained")
axs[0].plot(jpeg_brs, jpeg_psnrs, "darkgreen", linestyle="--",
            linewidth=1, label="JPEG")
# axs[0].plot(jpeg2000_psnrs, jpeg2000_psnrs, "purple", linestyle="--",
#             linewidth=1, label="JPEG 2000")
axs[0].plot(webp_brs, webp_psnrs, "darkorange", linestyle="--",
            linewidth=1, label="Webp")

axs[1].plot(pretrained_brs, pretrained_msssim, "blue", linestyle="--",
            linewidth=1, label="pre-trained")

for name, m in pretrained_avg_metrics.items():
    axs[0].plot(m["bit-rate"], m["psnr"], "o", color="blue")
    axs[0].grid(True)
    axs[0].set_ylabel("PSNR [dB]")
    axs[0].set_xlabel("Bit rate [bpp]")
    axs[0].title.set_text("PSNR comparison")

    axs[1].plot(m["bit-rate"], -10*np.log10(1-m["ms-ssim"]), "o", color="blue")
    axs[1].grid(True)
    axs[1].set_ylabel("MS-SSIM [dB]")
    axs[1].set_xlabel("Bit rate [bpp]")
    axs[1].title.set_text("MS-SSIM (log) comparison")

for name, m in avg_metrics.items():
    axs[0].plot(m["bit-rate"], m["psnr"],
                 "s" if name == "teacher" else "o", label=name)
    axs[0].grid(True)
    axs[0].set_ylabel("PSNR [dB]")
    axs[0].set_xlabel("Bit rate [bpp]")
    axs[0].title.set_text("PSNR comparison")

    axs[1].plot(m["bit-rate"], -10*np.log10(1-m["ms-ssim"]),
                 "s" if name == "teacher" else "o", label=name)
    axs[1].grid(True)
    axs[1].set_ylabel("MS-SSIM [dB]")
    axs[1].set_xlabel("Bit rate [bpp]")
    axs[1].title.set_text("MS-SSIM (log) comparison")

axs[0].legend(loc="best")
axs[1].legend(loc="best")

fig.tight_layout()

plt.savefig("codecs_rd.png")
plt.close()

In [11]:
# Plot zeus-energy and RD performance
fig, axs = plt.subplots(1, 2, figsize=(13, 5))

axs[0].plot(zeus_pretrained_energies, pretrained_psnrs, "blue", linestyle="--",
            linewidth=1, label="pre-trained")
axs[0].plot(zeus_jpeg_energies, jpeg_psnrs, "darkgreen", linestyle="--",
            linewidth=1, label="JPEG")
# axs[0].plot(zeus_jpeg2000_energies, jpeg2000_psnrs, "purple", linestyle="--",
#             linewidth=1, label="JPEG 2000")
axs[0].plot(zeus_webp_energies, webp_psnrs, "darkorange", linestyle="--",
            linewidth=1, label="Webp")

axs[1].plot(zeus_pretrained_energies, pretrained_brs, "blue", linestyle="--",
            linewidth=1, label="pre-trained")
axs[1].plot(zeus_jpeg_energies, jpeg_brs, "darkgreen", linestyle="--",
            linewidth=1, label="JPEG")
# axs[1].plot(zeus_jpeg2000_energies, jpeg2000_brs, "purple", linestyle="--",
#             linewidth=1, label="JPEG 2000")
axs[1].plot(zeus_webp_energies, webp_brs, "darkorange", linestyle="--",
            linewidth=1, label="Webp")

for name, m in pretrained_avg_metrics.items():
    axs[0].plot(m["zeus-energy"], m["psnr"], "o", color="blue")
    axs[0].grid(True)
    axs[0].set_ylabel("PSNR [dB]")
    axs[0].set_xlabel("Inference energy [mJ/frame]")
    axs[0].title.set_text("PSNR comparison")

    axs[1].plot(m["zeus-energy"], m["bit-rate"], "o", color="blue")
    axs[1].grid(True)
    axs[1].set_ylabel("Bit rate [bpp]")
    axs[1].set_xlabel("Inference energy [mJ/frame]")
    axs[1].title.set_text("Bit rate comparison")

for name, m in avg_metrics.items():
    axs[0].plot(m["zeus-energy"], m["psnr"],
                 "s" if name == "teacher" else "o", label=name)
    axs[0].grid(True)
    axs[0].set_ylabel("PSNR [dB]")
    axs[0].set_xlabel("Inference energy [mJ/frame]")
    axs[0].title.set_text("PSNR comparison")

    axs[1].plot(m["zeus-energy"], m["bit-rate"],
                 "s" if name == "teacher" else "o", label=name)
    axs[1].grid(True)
    axs[1].set_ylabel("Bit rate [bpp]")
    axs[1].set_xlabel("Inference energy [mJ/frame]")
    axs[1].title.set_text("Bit rate comparison")
    
# for name, m in codecs_avg_metrics.items():
#     axs[0].plot(m["zeus-energy"], m["psnr"], "o") #, label=name.split("_")[0])
#     axs[0].grid(True)
#     axs[0].set_ylabel("PSNR [dB]")
#     axs[0].set_xlabel("Inference energy [mJ/frame]")
#     axs[0].title.set_text("PSNR comparison")

#     axs[1].plot(m["zeus-energy"], m["bit-rate"], "o") #, label=name.split("_")[0])
#     axs[1].grid(True)
#     axs[1].set_ylabel("Bit rate [bpp]")
#     axs[1].set_xlabel("Inference energy [mJ/frame]")
#     axs[1].title.set_text("Bit rate comparison")

axs[0].legend(loc="best")
axs[1].legend(loc="best")

fig.tight_layout()

plt.savefig("codecs_energy.png")
plt.close()

In [13]:
# Plot zeus-fps and RD performance
fig, axs = plt.subplots(1, 2, figsize=(13, 5))

axs[0].plot(zeus_pretrained_fps, pretrained_psnrs, "blue",
            linestyle="--", linewidth=1, label="pre-trained")
axs[0].plot(zeus_jpeg_fps, jpeg_psnrs, "darkgreen", linestyle="--",
            linewidth=1, label="JPEG")
# axs[0].plot(zeus_jpeg2000_fps, jpeg2000_psnrs, "purple", linestyle="--",
#             linewidth=1, label="JPEG 2000")
axs[0].plot(zeus_webp_fps, webp_psnrs, "darkorange", linestyle="--",
            linewidth=1, label="Webp")

axs[1].plot(zeus_pretrained_fps, pretrained_brs, "blue",
            linestyle="--", linewidth=1, label="pre-trained")
axs[1].plot(zeus_jpeg_fps, jpeg_brs, "darkgreen", linestyle="--",
            linewidth=1, label="JPEG")
# axs[1].plot(zeus_jpeg2000_fps, jpeg2000_brs, "purple", linestyle="--",
#             linewidth=1, label="JPEG 2000")
axs[1].plot(zeus_webp_fps, webp_brs, "darkorange", linestyle="--",
            linewidth=1, label="Webp")

for name, m in pretrained_avg_metrics.items():
    axs[0].plot(m["zeus-fps"], m["psnr"], "o", color="blue")
    axs[0].grid(True)
    axs[0].set_ylabel("PSNR [dB]")
    axs[0].set_xlabel("Throughput [FPS]")
    axs[0].title.set_text("PSNR comparison")

    axs[1].plot(m["zeus-fps"], m["bit-rate"], "o", color="blue")
    axs[1].grid(True)
    axs[1].set_ylabel("Bit rate [bpp]")
    axs[1].set_xlabel("Throughput [FPS]")
    axs[1].title.set_text("Bit rate comparison")

for name, m in avg_metrics.items():
    axs[0].plot(m["zeus-fps"], m["psnr"],
                "s" if name == "teacher" else "o", label=name)
    axs[0].grid(True)
    axs[0].set_ylabel("PSNR [dB]")
    axs[0].set_xlabel("Throughput [FPS]")
    axs[0].title.set_text("PSNR comparison")

    axs[1].plot(m["zeus-fps"], m["bit-rate"],
                "s" if name == "teacher" else "o", label=name)
    axs[1].grid(True)
    axs[1].set_ylabel("Bit rate [bpp]")
    axs[1].set_xlabel("Throughput [FPS]")
    axs[1].title.set_text("Bit rate comparison")

axs[0].legend(loc="best")
axs[1].legend(loc="best")

fig.tight_layout()

plt.savefig("codecs_fps.png")
plt.close()