In [29]:
import os
print(os.getcwd())
import torch
import io
from torchvision import transforms
import numpy as np
from PIL import Image
from compressai.zoo import bmshj2018_factorized
from ipywidgets import interact, widgets
from functools import partial
import math
from pytorch_msssim import ms_ssim

device="cuda" if torch.cuda.is_available() else "cpu"
print("device:",device)


#get the image and load the image to get the size of the image
path="./assets/drone.png"
img=Image.open(path).convert("RGB")
disc=os.path.getsize(path)/1000000
print("size on disc:",disc,"MB")

#transform the image
x=transforms.ToTensor()(img).unsqueeze(0).to(device)

#load the model
net=bmshj2018_factorized(quality=1,pretrained=True).eval().to(device)
#print(net.parameters)

#run the network
with torch.no_grad():
   out_net=net.forward(x)
out_net['x_hat'].clamp(0,1)
print(out_net.keys())

def pillow_encode(img,fmt="jpeg",quality=10):
    tmp=io.BytesIO()
    img.save(tmp,format=fmt,quality=quality)
    filesize=tmp.getbuffer().nbytes
    bpp=filesize*float(8)/(img.size[0]*img.size[1])
    rec=Image.open(tmp)
    return rec,bpp


#metrics
def compute_psnr(a, b):
    mse = torch.mean((a - b)**2).item()
    return -10 * math.log10(mse)

def compute_msssim(a, b):
    return ms_ssim(a, b, data_range=1.).item()

def compute_bpp(out_net):
    size = out_net['x_hat'].size()
    print(size)
    num_pixels = size[0] * size[2] * size[3]
    return sum(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)
              for likelihoods in out_net['likelihoods'].values()).item()

#metrics results
print(f'PSNR: {compute_psnr(x, out_net["x_hat"]):.2f}dB')
print(f'MS-SSIM: {compute_msssim(x, out_net["x_hat"]):.4f}')
print(f'Bit-rate: {compute_bpp(out_net):.3f} bpp')


# finding the closest bpp , and the quality metrics 
def find_closest_bpp(target, img, fmt='jpeg'):
    lower = 0
    upper = 100
    prev_mid = upper
    for i in range(10):
        mid = (upper - lower) / 2 + lower
        if int(mid) == int(prev_mid):
            #print("break",mid)
            break
        rec, bpp = pillow_encode(img, fmt=fmt, quality=int(mid))
        #print("quality:",mid,"bpp:",bpp)
        if bpp > target:
            upper = mid - 1
        else:
            lower = mid
    print("final quality:",mid,"format:",fmt)
    print("Final bpp:",bpp)
    return rec, bpp

def find_closest_psnr(target, img, fmt='jpeg'):
    lower = 0
    upper = 100
    prev_mid = upper
    
    def _psnr(a, b):
        a = np.asarray(a).astype(np.float32)
        b = np.asarray(b).astype(np.float32)
        mse = np.mean(np.square(a - b))
        return 20*math.log10(255.) -10. * math.log10(mse)
    
    for i in range(10):
        mid = (upper - lower) / 2 + lower
        if int(mid) == int(prev_mid):
            break
        prev_mid = mid
        rec, bpp = pillow_encode(img, fmt=fmt, quality=int(mid))
        psnr_val = _psnr(rec, img)
        if psnr_val > target:
            upper = mid - 1
        else:
            lower = mid

    print("Closest psnr finding function")
    print("rec:",rec,"bpp:",bpp,"psnr_Val:",psnr_val)
    return rec, bpp, psnr_val

def find_closest_msssim(target, img, fmt='jpeg'):
    lower = 0
    upper = 100
    prev_mid = upper
    
    def _mssim(a, b):
        a = torch.from_numpy(np.asarray(a).astype(np.float32)).permute(2, 0, 1).unsqueeze(0)
        b = torch.from_numpy(np.asarray(b).astype(np.float32)).permute(2, 0, 1).unsqueeze(0)
        return ms_ssim(a, b, data_range=255.).item()

    for i in range(10):
        mid = (upper - lower) / 2 + lower
        if int(mid) == int(prev_mid):
            break
        prev_mid = mid
        rec, bpp = pillow_encode(img, fmt=fmt, quality=int(mid))
        msssim_val = _mssim(rec, img)
        if msssim_val > target:
            upper = mid - 1
        else:
            lower = mid
    return rec, bpp, msssim_val


#target bpp
target_bpp=compute_bpp(out_net)
rec_jpeg,bpp_jpeg=find_closest_bpp(target_bpp,img)


C:\Swapnil\Narrowband_DRONE\Image_compression_code\Method_7_compress_ai\compressai\examples
device: cuda
size on disc: 14.754594 MB
dict_keys(['x_hat', 'likelihoods'])
PSNR: 25.42dB
MS-SSIM: 0.9184
torch.Size([1, 3, 2464, 3280])
Bit-rate: 0.161 bpp
torch.Size([1, 3, 2464, 3280])
final quality: -0.900390625 format: jpeg
Final bpp: 0.18458386126069054
