In [1]:
# Description - Compression Comparison Between Models:{Factorized prior balle_variational_2018, Hyperprior balle_variational_2018,
# Hyperprior with a Gaussian mixture model minnen_joint_2018,Joint autoregressive and hyperprior minnen_joint_2018
# Extension from minnen_joint_2018, residual blocks and sub-pixel deconvolution cheng_learned_2020}
!pip install numpy
import tensorly as tl
import math
import io
import torch
from torchvision import transforms
import numpy as np
from pytorch_msssim import ms_ssim
from PIL import Image
from compressai.zoo import (bmshj2018_factorized, bmshj2018_hyperprior, mbt2018_mean, mbt2018, cheng2020_anchor)
import matplotlib.pyplot as plt
from ipywidgets import interact, widgets
import sys
import os
import time
from statistics import mean
from pytorch_msssim import ms_ssim
from tensorly.decomposition import parafac
from tensorly.decomposition import tucker
import VBMF

tl.set_backend('numpy')

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0[0m[39;49m -> [0m[32;49m23.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
def compute_psnr(a, b):
    mse = torch.mean((a - b)**2).item()
    return -10 * math.log10(mse)

def compute_msssim(a, b):
    return ms_ssim(a, b, data_range=1.).item()

def compute_bpp(out_net):
    size = out_net['x_hat'].size()
    num_pixels = size[0] * size[2] * size[3]

    return sum(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)
            for likelihoods in out_net['likelihoods'].values()).item()

In [3]:
def tucker_ranks(weights):

    unfold_0 = tl.base.unfold(weights, 0) 
    unfold_1 = tl.base.unfold(weights, 1)

    unfold_0 = torch.from_numpy(unfold_0)
    unfold_1 = torch.from_numpy(unfold_1)
    
    _, diag_0, _, _ = VBMF.EVBMF(unfold_0)
    _, diag_1, _, _ = VBMF.EVBMF(unfold_1)

    ranks = [diag_0.shape[0], diag_1.shape[1]]
    return ranks

In [4]:
def to_image(tensor):
    """A convenience function to convert from a float dtype back to uint8"""
    im = tl.to_numpy(tensor)
    im -= im.min()
    im /= im.max()
    im *= 255
    return im.astype(np.uint8)

In [5]:
def tucker_decompose(image, device="cuda", random_state=12345):
    imagetensor_for_tucker = tl.tensor(image, dtype='float64')
    ranks = tucker_ranks(imagetensor_for_tucker)
    tucker_rank = ranks
    tucker_rank.append(imagetensor_for_tucker.shape[2])
    core, tucker_factors = tucker(imagetensor_for_tucker, rank=tucker_rank, init='random', tol=10e-5, random_state=random_state)
    tucker_reconstruction = tl.tucker_to_tensor((core, tucker_factors))
    rec_img = to_image(tucker_reconstruction)
    rec_img = Image.fromarray(rec_img, 'RGB').resize((1088,1920))
    x = transforms.ToTensor()(rec_img).unsqueeze(0).to(device)
    return x

In [6]:
def cp_decomposition(image, device="cuda"):
    cp_rank = 25
    imagetensor_for_cp = tl.tensor(image, dtype='float64')
    weights, factors = parafac(imagetensor_for_cp, rank=cp_rank, init='random', tol=10e-6)
    cp_reconstruction = tl.cp_to_tensor((weights, factors))
    rec_image = to_image(cp_reconstruction)
    rec_image = Image.fromarray(rec_image, 'RGB').resize((1088,1920))
    x = transforms.ToTensor()(rec_image).unsqueeze(0).to(device)
    return x

In [10]:
def get_reconstructions(networks, directory, quality, compression="Tucker"):
    reconstructed_images = []
    metric_list_tuckerimage = []
    metric_list_origimage = []
    bpp_list = []
    msssim_list_tuckerimage = []
    msssim_list_origimage = []
    for images in os.listdir(directory):
        img = Image.open(f"{directory}" + "/" + f"{images}").convert("RGB").resize((1088,1920))
        image_for_metrics = transforms.ToTensor()(img).unsqueeze(0).to(device)
        if compression.lower() == "tucker":
            x = tucker_decompose(img)
        elif compression.lower() == "cpd":
            x = cp_decomposition(img)
        else:
            x = image_for_metrics
        outputs = {}
        with torch.no_grad():
            for name, net in networks.items():
                rv = net(x)
                rv['x_hat'].clamp_(0, 1)
                rv['x_hat']
                outputs[name] = rv
        d_npy_xhat = rv['x_hat'].cpu().numpy()

        reconstructions = {name: transforms.ToPILImage()(out['x_hat'].squeeze().cpu())
                           for name, out in outputs.items()}
        metrics = {}
        name_metric = {}
        metrices_dict = []
        
        for name, out in outputs.items():
#             metric_list.append(compute_psnr(x, out["x_hat"]))
#             bpp_list.append(compute_bpp(out))
            metric_list_tuckerimage.append(compute_psnr(x, out["x_hat"]))
    
            metric_list_origimage.append(compute_psnr(image_for_metrics, out["x_hat"]))
    
            msssim_list_tuckerimage.append(compute_msssim(x, out["x_hat"]))
    
            msssim_list_origimage.append(compute_msssim(image_for_metrics, out["x_hat"]))

            bpp_list.append(compute_bpp(out))
    
    print(mean(metric_list_origimage), mean(bpp_list))