In [1]:
import io
import os
import copy
import yaml
import math
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from pytorch_msssim import ms_ssim

In [2]:
import sys
import torch
sys.path.append("..")
from models.nic_cvt import NIC
from utils import get_config

#from .pretrained import load_pretrained
__all__ = [
    "nic_tic",
]

model_architectures = {
    "nic": NIC,
}

models = {
    "nic": NIC,
}

root_url = "./ckpts/Lu2022"
model_urls = {
    "nic": {
        "mse": {
            1: f"{root_url}/mse/nic_mse_1.pt",
            2: f"{root_url}/mse/nic_mse_2.pt",
            3: f"{root_url}/mse/nic_mse_3.pt",
            4: f"{root_url}/mse/nic_mse_4.pt",
            5: f"{root_url}/mse/nic_mse_5.pt",
            6: f"{root_url}/mse/nic_mse_6.pt",
            7: f"{root_url}/mse/nic_mse_7.pt",
            8: f"{root_url}/mse/nic_mse_8.pt",
        },
        "ms-ssim": {
            1: f"{root_url}/ms-ssim/nic_ms-ssim_1.pt",
            2: f"{root_url}/ms-ssim/nic_ms-ssim_2.pt",
            3: f"{root_url}/ms-ssim/nic_ms-ssim_3.pt",
            4: f"{root_url}/ms-ssim/nic_ms-ssim_4.pt",
            5: f"{root_url}/ms-ssim/nic_ms-ssim_5.pt",
            6: f"{root_url}/ms-ssim/nic_ms-ssim_6.pt",
            7: f"{root_url}/ms-ssim/nic_ms-ssim_7.pt",
            8: f"{root_url}/ms-ssim/nic_ms-ssim_8.pt",
        },
    },
}

cfgs = {
    "nic": {
        1: (128, 192),
        2: (128, 192),
        3: (128, 192),
        4: (128, 192),
        5: (192, 320),
        6: (192, 320),
        7: (192, 320),
        8: (192, 320),
    },
}

def _load_model(
    architecture, metric, quality, pretrained=False, progress=True, **kwargs
):
    if architecture not in model_architectures:
        raise ValueError(f'Invalid architecture name "{architecture}"')

    if quality not in cfgs[architecture]:
        raise ValueError(f'Invalid quality value "{quality}"')

    if pretrained:
        if (
            architecture not in model_urls
            or metric not in model_urls[architecture]
            or quality not in model_urls[architecture][metric]
        ):
            raise RuntimeError("Pre-trained model not yet available")

        url = model_urls[architecture][metric][quality]
        print("Loading Ckpts From:", url)
        # state_dict = load_state_dict_from_url(url, progress=progress)
        state_dict = torch.load(url, map_location=torch.device(device))
        # state_dict = load_pretrained(state_dict)

        config = get_config("config.yaml")
        config['embed_dim'] = cfgs[architecture][quality][0]
        config['latent_dim'] = cfgs[architecture][quality][1]
        model = model_architectures[architecture](config)
        model.load_state_dict(state_dict['model'])
        
        # TODO: should be put in traning loop
        model.update()
        
        # model = model_architectures[architecture].from_state_dict(state_dict)
        return model

    # model = model_architectures[architecture](*cfgs[architecture][quality], **kwargs)
    # return model

def nic_tic(quality, metric="mse", pretrained=False, progress=True, **kwargs):
    r"""
        Neural image compression framework from Lu, Ming and Guo, Peiyao and Shi, Huiqing and Cao, Chuntong and Ma, Zhan: 
        `"Transformer-based Image Compression" <https://arxiv.org/abs/2111.06707>`, (DCC 2022).
    Args:
        quality (int): Quality levels (1: lowest, highest: 8)
        metric (str): Optimized metric, choose from ('mse')
        pretrained (bool): If True, returns a pre-trained model
    """
    if metric not in ("mse", "ms-ssim"):
        raise ValueError(f'Invalid metric "{metric}"')

    if quality < 1 or quality > 8:
        raise ValueError(f'Invalid quality "{quality}", should be between (1, 8)')

    return _load_model("nic", metric, quality, pretrained, progress, **kwargs)

In [3]:
from quantization import *

In [4]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'

device

'cuda'

In [5]:
from pytorch_msssim import ms_ssim

def compute_psnr(a, b):
    mse = torch.mean((a - b)**2).item()
    return -10 * math.log10(mse)

def compute_msssim(a, b):
    msssim = ms_ssim(a, b, data_range=1.).item()
    return -10 * math.log10(1-msssim)

def compute_bpp(out_net):
    size = out_net['x_hat'].size()
    num_pixels = size[0] * size[2] * size[3]
    return sum(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)
              for likelihoods in out_net['likelihoods'].values()).item()

def compute_loss(x, rec, out_net, lamda, mode='mse'):
    if mode == 'mse':
        mse = torch.mean((x - rec)**2).item()
        rate = compute_bpp(out_net)
        loss = rate + lamda * 255 *255 * mse
    else:
        msssim = ms_ssim(x, rec, data_range=1.).item()
        rate = compute_bpp(out_net)
        loss = rate + lamda * (1-msssim)
    return loss


In [6]:
def pad(x, p=2 ** 6):
    h, w = x.size(2), x.size(3)
    H = (h + p - 1) // p * p
    W = (w + p - 1) // p * p
    padding_left = (W - w) // 2
    padding_right = W - w - padding_left
    padding_top = (H - h) // 2
    padding_bottom = H - h - padding_top
    return F.pad(
        x,
        (padding_left, padding_right, padding_top, padding_bottom),
        mode="constant",
        value=0,
    )


def crop(x, size):
    H, W = x.size(2), x.size(3)
    h, w = size
    padding_left = (W - w) // 2
    padding_right = W - w - padding_left
    padding_top = (H - h) // 2
    padding_bottom = H - h - padding_top
    return F.pad(
        x,
        (-padding_left, -padding_right, -padding_top, -padding_bottom),
        mode="constant",
        value=0,
    )

In [7]:
def Test_kodak(model=None,  mode='mse'):
    
    testset_path = './datasets/kodak24'
    #device = next(model.parameters()).device
    
    if model is None:
        model = nic_tic(6, mode, pretrained=True).cuda().eval()
    
    psnr_sum = 0.0
    msssim_sum = 0.0
    bit_sum = 0.0

    for i in range(len(os.listdir(testset_path))):
        
        img = Image.open(testset_path+'/kodim'+str(i+1).zfill(2)+'.png').convert('RGB')
        x = transforms.ToTensor()(img).unsqueeze(0).to(device)
        p = 256  # maximum 6 strides of 2, and window size 4 for the smallest latent fmap: 4*2^6=256
        h, w = x.size(2), x.size(3)
        x_pad = pad(x, p)

        with torch.no_grad():
            out = model.forward(x_pad)

        rec = crop(out["x_hat"], (h,w))
        rec.clamp_(0, 1)      
        
#         print(i)
#         print(f'PSNR: {compute_psnr(x, rec):.4f}dB')
#         print(f'MS-SSIM: {compute_msssim(x, rec):.4f}dB')
#         print(f'Bit-rate: {compute_bpp(out):.4f} bpp')
        
        psnr_sum += compute_psnr(x, rec)
        msssim_sum += compute_msssim(x, rec)
        bit_sum += compute_bpp(out)

    print(f'AVG PSNR: {psnr_sum/len(os.listdir(testset_path)):.2f}dB')
    print(f'AVG MS-SSIM: {msssim_sum/len(os.listdir(testset_path)):.2f}dB')
    print(f'AVG Bit-rate: {bit_sum/len(os.listdir(testset_path)):.4f} bpp')

In [8]:
def Test_tecnick(model=None, mode='mse'):
    
    testset_path = './datasets/tecnick100/tecnick/RGB_OR_1200x1200'
    #device = next(model.parameters()).device

    if model is None:
        model = nic_tic(6, mode, pretrained=True).cuda().eval()
    
    psnr_sum = 0.0
    msssim_sum = 0.0
    bit_sum = 0.0
    for i in range(len(os.listdir(testset_path))):
        
        img = Image.open(testset_path+'/RGB_OR_1200x1200_'+str(i+1).zfill(3)+'.png').convert('RGB')
        x = transforms.ToTensor()(img).unsqueeze(0).to(device)
        p = 256  # maximum 6 strides of 2, and window size 4 for the smallest latent fmap: 4*2^6=256
        h, w = x.size(2), x.size(3)
        x_pad = pad(x, p)

        with torch.no_grad():
            out = model.forward(x_pad)

        rec = crop(out["x_hat"], (h,w))
        rec.clamp_(0, 1)      
        
#         print(i)
#         print(f'PSNR: {compute_psnr(x, rec):.4f}dB')
#         print(f'MS-SSIM: {compute_msssim(x, rec):.4f}dB')
#         print(f'Bit-rate: {compute_bpp(out):.4f} bpp')
        
        
        psnr_sum += compute_psnr(x, rec)
        msssim_sum += compute_msssim(x, rec)
        bit_sum += compute_bpp(out)

    print(f'AVG PSNR: {psnr_sum/len(os.listdir(testset_path)):.2f}dB')
    print(f'AVG MS-SSIM: {msssim_sum/len(os.listdir(testset_path)):.2f}dB')
    print(f'AVG Bit-rate: {bit_sum/len(os.listdir(testset_path)):.4f} bpp')

In [9]:
def Test_clic(model=None, mode='mse'):
    
    testset_path = './datasets/clic41'
    #device = next(model.parameters()).device
    
    if model is None:
        model = nic_tic(6, mode, pretrained=True).to(device).eval()
    
    psnr_sum = 0.0
    msssim_sum = 0.0
    bit_sum = 0.0
    for i in range(len(os.listdir(testset_path))):
        
        img = Image.open(testset_path+'/'+str(i+1).zfill(3)+'.png').convert('RGB')
        x = transforms.ToTensor()(img).unsqueeze(0).to(device)
        p = 256  # maximum 6 strides of 2, and window size 4 for the smallest latent fmap: 4*2^6=256
        h, w = x.size(2), x.size(3)
        x_pad = pad(x, p)

        with torch.no_grad():
            out = model.forward(x_pad)

        rec = crop(out["x_hat"], (h,w))
        rec.clamp_(0, 1)      
        
#         print(i)
#         print(f'PSNR: {compute_psnr(x, rec):.4f}dB')
#         print(f'MS-SSIM: {compute_msssim(x, rec):.4f}dB')
#         print(f'Bit-rate: {compute_bpp(out):.4f} bpp')
        
        psnr_sum += compute_psnr(x, rec)
        msssim_sum += compute_msssim(x, rec)
        bit_sum += compute_bpp(out)

    print(f'AVG PSNR: {psnr_sum/len(os.listdir(testset_path)):.2f}dB')
    print(f'AVG MS-SSIM: {msssim_sum/len(os.listdir(testset_path)):.2f}dB')
    print(f'AVG Bit-rate: {bit_sum/len(os.listdir(testset_path)):.4f} bpp')

# Note

### We currently adopt 8-bit channel-wise dynamic activation quantization, which is easy for optimization and implementation. However, it is time-consuming. You can use 16-bit (or higher) layer-wise dynamic quantization to reduce time significantly.

# Lu2022 MSE Test


In [10]:
model = torch.load("./results/Lu2022/mse/6/outputs/Lu2022_Q6_W8A8_prob0.5_task2.0_max-init_clic41_CW.pth").cuda().eval()


# model.model.g_s7.set_quant_state(True, False)
# set_train(model, True)
print("=====================================Kodak Test=============================================")
print("====================== FP32 ======================")
model.set_quant_state(False, False)
Test_kodak(model)
print("======================= W8 =======================")
model.set_quant_state(True, False)
Test_kodak(model)
print("====================== W8A8 ======================")
model.set_quant_state(True, True)
model.model.g_s7.set_quant_state(True, False)
Test_kodak(model)

AVG PSNR: 37.33dB
AVG MS-SSIM: 20.07dB
AVG Bit-rate: 0.8354 bpp
AVG PSNR: 37.27dB
AVG MS-SSIM: 20.02dB
AVG Bit-rate: 0.8537 bpp
AVG PSNR: 37.03dB
AVG MS-SSIM: 19.61dB
AVG Bit-rate: 0.8582 bpp


In [11]:
model = torch.load("./results/Lu2022/mse/6/outputs/Lu2022_Q6_W8A8_prob0.5_task2.0_max-init_clic41_CW.pth").cuda().eval()


# model.model.g_s7.set_quant_state(True, False)
# set_train(model, True)
print("=====================================Tecnick Test=============================================")
print("====================== FP32 ======================")
model.set_quant_state(False, False)
Test_tecnick(model)
print("======================= W8 =======================")
model.set_quant_state(True, False)
Test_tecnick(model)
print("====================== W8A8 ======================")
model.set_quant_state(True, True)
model.model.g_s7.set_quant_state(True, False)
Test_tecnick(model)

AVG PSNR: 38.07dB
AVG MS-SSIM: 20.00dB
AVG Bit-rate: 0.4957 bpp
AVG PSNR: 37.99dB
AVG MS-SSIM: 19.94dB
AVG Bit-rate: 0.5268 bpp
AVG PSNR: 37.62dB
AVG MS-SSIM: 19.37dB
AVG Bit-rate: 0.5337 bpp


In [10]:
model = torch.load("./results/Lu2022/mse/6/outputs/Lu2022_Q6_W8A8_prob0.5_task2.0_max-init_clic41_CW.pth").to(device).eval()


# model.model.g_s7.set_quant_state(True, False)
# set_train(model, True)
print("=====================================CLIC Test=============================================")
print("====================== FP32 ======================")
model.set_quant_state(False, False)
Test_clic(model)
print("======================= W8 =======================")
model.set_quant_state(True, False)
Test_clic(model)
print("====================== W8A8 ======================")
model.set_quant_state(True, True)
model.model.g_s7.set_quant_state(True, False)
Test_clic(model)

AVG PSNR: 38.14dB
AVG MS-SSIM: 20.08dB
AVG Bit-rate: 0.5114 bpp
AVG PSNR: 38.07dB
AVG MS-SSIM: 20.01dB
AVG Bit-rate: 0.5360 bpp
AVG PSNR: 37.75dB
AVG MS-SSIM: 19.46dB
AVG Bit-rate: 0.5422 bpp


# Lu2022 MS-SSIM Test


In [10]:
model = torch.load("./results/Lu2022/ms-ssim/6/outputs/Lu2022_Q6_W8A8_prob0.5_task2.0_max-init_clic41_CW.pth").cuda().eval()


# model.model.g_s7.set_quant_state(True, False)
# set_train(model, True)
print("=====================================Kodak Test=============================================")
print("====================== FP32 ======================")
model.set_quant_state(False, False)
Test_kodak(model)
print("======================= W8 =======================")
model.set_quant_state(True, False)
Test_kodak(model)
print("====================== W8A8 ======================")
model.set_quant_state(True, True)
model.model.g_s7.set_quant_state(True, False)
Test_kodak(model)

AVG PSNR: 30.56dB
AVG MS-SSIM: 21.08dB
AVG Bit-rate: 0.5866 bpp
AVG PSNR: 30.54dB
AVG MS-SSIM: 21.06dB
AVG Bit-rate: 0.5883 bpp
AVG PSNR: 30.49dB
AVG MS-SSIM: 20.74dB
AVG Bit-rate: 0.5910 bpp


In [11]:
model = torch.load("./results/Lu2022/ms-ssim/6/outputs/Lu2022_Q6_W8A8_prob0.5_task2.0_max-init_clic41_CW.pth").cuda().eval()


# model.model.g_s7.set_quant_state(True, False)
# set_train(model, True)
print("=====================================Tecnick Test=============================================")
print("====================== FP32 ======================")
model.set_quant_state(False, False)
Test_tecnick(model)
print("======================= W8 =======================")
model.set_quant_state(True, False)
Test_tecnick(model)
print("====================== W8A8 ======================")
model.set_quant_state(True, True)
model.model.g_s7.set_quant_state(True, False)
Test_tecnick(model)

AVG PSNR: 34.71dB
AVG MS-SSIM: 21.75dB
AVG Bit-rate: 0.4391 bpp
AVG PSNR: 34.70dB
AVG MS-SSIM: 21.71dB
AVG Bit-rate: 0.4424 bpp
AVG PSNR: 34.48dB
AVG MS-SSIM: 21.19dB
AVG Bit-rate: 0.4470 bpp


In [10]:
model = torch.load("./results/Lu2022/ms-ssim/6/outputs/Lu2022_Q6_W8A8_prob0.5_task2.0_max-init_clic41_CW.pth").cuda().eval()


# model.model.g_s7.set_quant_state(True, False)
# set_train(model, True)
print("=====================================CLIC Test=============================================")
print("====================== FP32 ======================")
model.set_quant_state(False, False)
Test_clic(model)
print("======================= W8 =======================")
model.set_quant_state(True, False)
Test_clic(model)
print("====================== W8A8 ======================")
model.set_quant_state(True, True)
model.model.g_s7.set_quant_state(True, False)
Test_clic(model)

AVG PSNR: 34.45dB
AVG MS-SSIM: 21.06dB
AVG Bit-rate: 0.4222 bpp
AVG PSNR: 34.43dB
AVG MS-SSIM: 21.03dB
AVG Bit-rate: 0.4245 bpp
AVG PSNR: 34.25dB
AVG MS-SSIM: 20.57dB
AVG Bit-rate: 0.4277 bpp


# Cheng2020 MSE Test

In [10]:
model = torch.load("./results/Cheng2020/mse/6/outputs/Cheng2020_Q6_W8A8_prob0.5_task2.0_max-init_clic41_CW.pth").cuda().eval()


# model.model.g_s7.set_quant_state(True, False)
# set_train(model, True)
print("=====================================Kodak Test=============================================")
print("====================== FP32 ======================")
model.set_quant_state(False, False)
Test_kodak(model)
print("======================= W8 =======================")
model.set_quant_state(True, False)
Test_kodak(model)
print("====================== W8A8 ======================")
model.set_quant_state(True, True)
model.model.g_s[-1][0].set_quant_state(True, False)
Test_kodak(model)

AVG PSNR: 36.76dB
AVG MS-SSIM: 19.97dB
AVG Bit-rate: 0.8040 bpp
AVG PSNR: 36.62dB
AVG MS-SSIM: 19.84dB
AVG Bit-rate: 0.8066 bpp
AVG PSNR: 36.30dB
AVG MS-SSIM: 19.31dB
AVG Bit-rate: 0.8116 bpp


In [11]:
model = torch.load("./results/Cheng2020/mse/6/outputs/Cheng2020_Q6_W8A8_prob0.5_task2.0_max-init_clic41_CW.pth").cuda().eval()


# model.model.g_s7.set_quant_state(True, False)
# set_train(model, True)
print("=====================================Tecnick Test=============================================")
print("====================== FP32 ======================")
model.set_quant_state(False, False)
Test_tecnick(model)
print("======================= W8 =======================")
model.set_quant_state(True, False)
Test_tecnick(model)
print("====================== W8A8 ======================")
model.set_quant_state(True, True)
model.model.g_s[-1][0].set_quant_state(True, False)
Test_tecnick(model)

AVG PSNR: 37.54dB
AVG MS-SSIM: 19.86dB
AVG Bit-rate: 0.4853 bpp
AVG PSNR: 37.38dB
AVG MS-SSIM: 19.75dB
AVG Bit-rate: 0.4900 bpp
AVG PSNR: 36.97dB
AVG MS-SSIM: 19.17dB
AVG Bit-rate: 0.4993 bpp


In [12]:
model = torch.load("./results/Cheng2020/mse/6/outputs/Cheng2020_Q6_W8A8_prob0.5_task2.0_max-init_clic41_CW.pth").cuda().eval()


# model.model.g_s7.set_quant_state(True, False)
# set_train(model, True)
print("=====================================CLIC Test=============================================")
print("====================== FP32 ======================")
model.set_quant_state(False, False)
Test_clic(model)
print("======================= W8 =======================")
model.set_quant_state(True, False)
Test_clic(model)
print("====================== W8A8 ======================")
model.set_quant_state(True, True)
model.model.g_s[-1][0].set_quant_state(True, False)
Test_clic(model)

AVG PSNR: 37.52dB
AVG MS-SSIM: 19.87dB
AVG Bit-rate: 0.4905 bpp
AVG PSNR: 37.39dB
AVG MS-SSIM: 19.76dB
AVG Bit-rate: 0.4944 bpp
AVG PSNR: 37.00dB
AVG MS-SSIM: 19.16dB
AVG Bit-rate: 0.5031 bpp


# Cheng2020 MS-SSIM Test

In [13]:
model = torch.load("./results/Cheng2020/ms-ssim/6/outputs/Cheng2020_Q6_W8A8_prob0.5_task2.0_max-init_clic41_CW.pth").cuda().eval()


# model.model.g_s7.set_quant_state(True, False)
# set_train(model, True)
print("=====================================Kodak Test=============================================")
print("====================== FP32 ======================")
model.set_quant_state(False, False)
Test_kodak(model)
print("======================= W8 =======================")
model.set_quant_state(True, False)
Test_kodak(model)
print("====================== W8A8 ======================")
model.set_quant_state(True, True)
model.model.g_s[-1][0].set_quant_state(True, False)
Test_kodak(model)

AVG PSNR: 31.03dB
AVG MS-SSIM: 20.91dB
AVG Bit-rate: 0.5881 bpp
AVG PSNR: 31.01dB
AVG MS-SSIM: 20.85dB
AVG Bit-rate: 0.5892 bpp
AVG PSNR: 30.97dB
AVG MS-SSIM: 20.62dB
AVG Bit-rate: 0.5947 bpp


In [14]:
model = torch.load("./results/Cheng2020/ms-ssim/6/outputs/Cheng2020_Q6_W8A8_prob0.5_task2.0_max-init_clic41_CW.pth").cuda().eval()


# model.model.g_s7.set_quant_state(True, False)
# set_train(model, True)
print("=====================================Tecnick Test=============================================")
print("====================== FP32 ======================")
model.set_quant_state(False, False)
Test_tecnick(model)
print("======================= W8 =======================")
model.set_quant_state(True, False)
Test_tecnick(model)
print("====================== W8A8 ======================")
model.set_quant_state(True, True)
model.model.g_s[-1][0].set_quant_state(True, False)
Test_tecnick(model)

AVG PSNR: 34.52dB
AVG MS-SSIM: 21.39dB
AVG Bit-rate: 0.4280 bpp
AVG PSNR: 34.47dB
AVG MS-SSIM: 21.32dB
AVG Bit-rate: 0.4292 bpp
AVG PSNR: 34.31dB
AVG MS-SSIM: 20.96dB
AVG Bit-rate: 0.4332 bpp


In [15]:
model = torch.load("./results/Cheng2020/ms-ssim/6/outputs/Cheng2020_Q6_W8A8_prob0.5_task2.0_max-init_clic41_CW.pth").cuda().eval()


# model.model.g_s7.set_quant_state(True, False)
# set_train(model, True)
print("=====================================CLIC Test=============================================")
print("====================== FP32 ======================")
model.set_quant_state(False, False)
Test_clic(model)
print("======================= W8 =======================")
model.set_quant_state(True, False)
Test_clic(model)
print("====================== W8A8 ======================")
model.set_quant_state(True, True)
model.model.g_s[-1][0].set_quant_state(True, False)
Test_clic(model)

AVG PSNR: 34.32dB
AVG MS-SSIM: 20.80dB
AVG Bit-rate: 0.4171 bpp
AVG PSNR: 34.28dB
AVG MS-SSIM: 20.75dB
AVG Bit-rate: 0.4180 bpp
AVG PSNR: 34.17dB
AVG MS-SSIM: 20.45dB
AVG Bit-rate: 0.4224 bpp
