In [1]:
import matplotlib.pyplot as plt
import tensorly as tl
import numpy as np
from scipy.misc import face
from scipy.ndimage import zoom
from tensorly.decomposition import parafac
from tensorly.decomposition import tucker
from math import ceil
from PIL import Image
import torch
import math
from pytorch_msssim import ms_ssim
import VBMF
from torchvision import transforms

In [2]:
def compute_psnr(a, b):
    mse = torch.mean((a - b)**2).item()
    return -10 * math.log10(mse)

def compute_msssim(a, b):
    return ms_ssim(a, b, data_range=1.).item()

def compute_bpp(out_net):
    size = out_net['x_hat'].size()
    num_pixels = size[0] * size[2] * size[3]

    return sum(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)
            for likelihoods in out_net['likelihoods'].values()).item()

In [3]:
def tucker_ranks(weights):

#     weights = layer.weight.data.numpy()

    unfold_0 = tl.base.unfold(weights, 0) 
    unfold_1 = tl.base.unfold(weights, 1)

    unfold_0 = torch.from_numpy(unfold_0)
    unfold_1 = torch.from_numpy(unfold_1)
    
    _, diag_0, _, _ = VBMF.EVBMF(unfold_0)
    _, diag_1, _, _ = VBMF.EVBMF(unfold_1)

    ranks = [diag_0.shape[0], diag_1.shape[1]]
    return ranks

In [5]:
device = "cpu"
directory = "/trinity/home/a.jha/scripts/Image_Compression/kodak_images/kodim02.png"
image = Image.open(directory).convert("RGB").resize((1088,1920))
image_for_metrics = transforms.ToTensor()(image).unsqueeze(0).to(device)


image_for_tucker = tl.tensor(image, dtype='float64')

ranks = tucker_ranks(image_for_tucker)

In [None]:
random_state = 12345

def to_image(tensor):
    """A convenience function to convert from a float dtype back to uint8"""
    im = tl.to_numpy(tensor)
    im -= im.min()
    im /= im.max()
    im *= 255
    return im.astype(np.uint8)


tucker_rank = ranks
tucker_rank.append(image_for_tucker.shape[2])


# Tucker decomposition
core, tucker_factors = tucker(image_for_tucker, rank=tucker_rank, init='random', tol=10e-5, random_state=random_state)

tucker_reconstruction = tl.tucker_to_tensor((core, tucker_factors))

# Plotting the original and reconstruction from the decompositions
fig = plt.figure()
ax = fig.add_subplot(1, 3, 1)
ax.set_axis_off()
ax.imshow(to_image(image_for_tucker))
ax.set_title('original')


ax = fig.add_subplot(1, 3, 3)
ax.set_axis_off()
ax.imshow(to_image(tucker_reconstruction))
ax.set_title('Tucker')

plt.tight_layout()
plt.show()

In [None]:
core.shape, tucker_factors[0].shape, tucker_factors[1].shape

In [None]:
from compressai.zoo import (bmshj2018_factorized, bmshj2018_hyperprior, mbt2018_mean, mbt2018, cheng2020_anchor)

In [None]:
quality = 4
device = "cuda"
networks = {
        'cheng2020-anchor': cheng2020_anchor(quality=quality, pretrained=True).eval().to(device),
#         'bmshj2018-factorized': bmshj2018_factorized(quality=quality, pretrained=True).eval().to(device),
#         'bmshj2018-hyperprior': bmshj2018_hyperprior(quality=quality, pretrained=True).eval().to(device),
#         'mbt2018-mean': mbt2018_mean(quality=quality, pretrained=True).eval().to(device),
#         'mbt2018': mbt2018(quality=quality, pretrained=True).eval().to(device),
    }

In [None]:
from torchvision import transforms

In [None]:
rec_img = to_image(tucker_reconstruction)
rec_img = Image.fromarray(rec_img, 'RGB').resize((1088,1920))

x = transforms.ToTensor()(rec_img).unsqueeze(0).to(device)
outputs = {}

In [None]:
with torch.no_grad():
    for name, net in networks.items():
        rv = net(x)
        rv['x_hat'].clamp_(0, 1)
        rv['x_hat']
        outputs[name] = rv
d_npy_xhat = rv['x_hat'].cpu().numpy()

reconstructions = {name: transforms.ToPILImage()(out['x_hat'].squeeze().cpu())
                    for name, out in outputs.items()}

In [None]:
metric_list_tuckerimage = []
metric_list_origimage = []
bpp_list = []
msssim_list_tuckerimage = []
msssim_list_origimage = []

for name, out in outputs.items():
    metric_list_tuckerimage.append(compute_psnr(x, out["x_hat"]))
    
    metric_list_origimage.append(compute_psnr(image_for_metrics, out["x_hat"]))
    
    msssim_list_tuckerimage.append(compute_msssim(x, out["x_hat"]))
    
    msssim_list_origimage.append(compute_msssim(image_for_metrics, out["x_hat"]))

    bpp_list.append(compute_bpp(out))

In [None]:
##PSNR
metric_list_tuckerimage, bpp_list, metric_list_origimage

In [None]:
##MS-SSIM
msssim_list_tuckerimage, msssim_list_origimage