In [1]:
import os
import sys
import math
import time
import pickle
import random
import argparse
import warnings
import numpy as np
from PIL import Image
from pathlib import Path
from datetime import datetime
from pytorch_msssim import ms_ssim

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

from models.tinylic import TinyLIC
from quant import quantize

In [2]:
device = 'cuda'
warnings.filterwarnings("ignore")

In [3]:
def compute_psnr(a, b):
    mse = torch.mean((a - b)**2).item()
    return -10 * math.log10(mse)

def compute_msssim(a, b):
    return ms_ssim(a, b, data_range=1.).item()

def filesize(filepath: str) -> int:
    if not Path(filepath).is_file():
        raise ValueError(f'Invalid file "{filepath}".')
    return Path(filepath).stat().st_size

def pad(x, p=2 ** 6):
    h, w = x.size(2), x.size(3)
    H = (h + p - 1) // p * p
    W = (w + p - 1) // p * p
    padding_left = (W - w) // 2
    padding_right = W - w - padding_left
    padding_top = (H - h) // 2
    padding_bottom = H - h - padding_top
    return F.pad(
        x,
        (padding_left, padding_right, padding_top, padding_bottom),
        mode="constant",
        value=0,
    )

def filesize(filepath: str) -> int:
    if not Path(filepath).is_file():
        raise ValueError(f'Invalid file "{filepath}".')
    return Path(filepath).stat().st_size

def crop(x, size):
    H, W = x.size(2), x.size(3)
    h, w = size
    padding_left = (W - w) // 2
    padding_right = W - w - padding_left
    padding_top = (H - h) // 2
    padding_bottom = H - h - padding_top
    return F.pad(
        x,
        (-padding_left, -padding_right, -padding_top, -padding_bottom),
        mode="constant",
        value=0,
    )

## FP32 Models Test


In [4]:
def main_fp32(gt_path, model_name, lrd):
    """
    Test FP32 Models in a single image;
    
    Args Input
        :param gt_path: the path to ground truth image
        :param model_name: the name of FP32 Models
        :param lrd: the lambda to control bit-rate, suggest [0.00005, 0.0009]
    
    Args Output
        :reconstruction image, saved in ./results/...
    """
    
    os.makedirs('./results/'+model_name+'/outputs/FP32',exist_ok=True)
    
    # load models
    net = TinyLIC()
    snapshot = torch.load("./pretrained/"+model_name+ ".pth.tar", map_location=device)['state_dict']
    net.load_state_dict(snapshot, strict=False)
    net.update(force=True)
    net = net.to(device).eval()
    torch.save({'state_dict': net.state_dict()}, './results/'+model_name+'/outputs/FP32.pth')
    model_size = filesize('./results/'+model_name+'/outputs/FP32.pth')
    
    
    # load images    
    img = Image.open(gt_path).convert('RGB')
    x = transforms.ToTensor()(img).unsqueeze(0).to(device)        
    gt_down = transforms.ToPILImage()(x.squeeze().cpu())
    
    gt_down.save('./results/'+model_name+'/outputs/FP32/gt.png', format="PNG")

    p = 64
    h, w = x.size(2), x.size(3)
    x_pad = pad(x, p)
    
    # inference
    with torch.no_grad():
        lambda_rd = torch.tensor([lrd]).to(device)
        
        # compress
        torch.cuda.synchronize()
        start_compress = time.time()
        out_enc = net.compress(x_pad, lambda_rd)
        torch.cuda.synchronize()
        end_compress = time.time()
        
        # decompress
        torch.cuda.synchronize()
        start_decompress = time.time()
        out = net.decompress(out_enc["strings"], out_enc["shape"], lambda_rd)
        torch.cuda.synchronize()
        end_decompress = time.time()
        
        # save
        rec = crop(out['x_hat'], (h, w))
        rec_tic = transforms.ToPILImage()(rec.squeeze().cpu())
        rec_tic.save('./results/'+model_name+'/outputs/FP32/rec.png', format="PNG")

            
    # result
    num_pixels = x.size(0) * x.size(2) * x.size(3)
    byte = sum(len(s[0]) for s in out_enc["strings"])
    bpp = byte * 8.0 / num_pixels
    enc_time = end_compress - start_compress
    dec_time = end_decompress - start_decompress

    print(f'PSNR: {compute_psnr(x, rec):.2f}dB')
    print(f'MS-SSIM: {compute_msssim(x, rec):.4f}')
    print(f'Byte: {byte:.0f} Byte')
    print(f'Bit-rate: {bpp:.3f} bpp')
    print(f'Enc Time: {enc_time:.3f} s')
    print(f'Dec Time: {dec_time:.3f} s')
    print(f'Model Size: {model_size/1024**2:.2f}MB')

In [6]:
gt_path = './data/2K.png'
model_name = 'tinylic'
main_fp32(gt_path, model_name, 0.0008)

PSNR: 33.98dB
MS-SSIM: 0.9521
Byte: 83292 Byte
Bit-rate: 0.239 bpp
Enc Time: 0.323 s
Dec Time: 0.294 s
Model Size: 69.40MB


## FP16 Models Test

In [7]:
def main_fp16(gt_path, model_name, lrd):
    """
    Test FP16 Models in a single image;
    
    Args Input
        :param gt_path: the path to ground truth image
        :param model_name: the name of FP16 Models
        :param lrd: the lambda to control bit-rate, suggest [0.00005, 0.0009]
    
    Args Output
        :reconstruction image, saved in ./results/...
    """
    
    os.makedirs('./results/'+model_name+'/outputs/FP16',exist_ok=True)
    
    # load models
    net = torch.load("./results/"+model_name+ "/outputs/FP16.pth", map_location=device)
    net = net.to(device).eval()
    model_size = filesize("./results/"+model_name+ "/outputs/FP16.pth")
    
    
    # load images    
    img = Image.open(gt_path).convert('RGB')
    x = transforms.ToTensor()(img).unsqueeze(0).to(device).half()
        
    gt_down = transforms.ToPILImage()(x.squeeze().cpu())
    
    gt_down.save('./results/'+model_name+'/outputs/FP16/gt.png', format="PNG")

    p = 64
    h, w = x.size(2), x.size(3)
    x_pad = pad(x, p)
    
    # inference
    with torch.no_grad():
        lambda_rd = torch.tensor([lrd]).to(device).half()
        
        # compress
        torch.cuda.synchronize()
        start_compress = time.time()
        out_enc = net.compress(x_pad, lambda_rd)
        torch.cuda.synchronize()
        end_compress = time.time()
        
        # decompress
        torch.cuda.synchronize()
        start_decompress = time.time()
        out = net.decompress(out_enc["strings"], out_enc["shape"], lambda_rd)
        torch.cuda.synchronize()
        end_decompress = time.time()
        
        # save
        rec = crop(out['x_hat'], (h, w))
        rec_tic = transforms.ToPILImage()(rec.squeeze().cpu())
        rec_tic.save('./results/'+model_name+'/outputs/FP16/rec.png', format="PNG")

            
    # result
    num_pixels = x.size(0) * x.size(2) * x.size(3)
    byte = sum(len(s[0]) for s in out_enc["strings"])
    bpp = byte * 8.0 / num_pixels
    enc_time = end_compress - start_compress
    dec_time = end_decompress - start_decompress

    print(f'PSNR: {compute_psnr(x, rec):.2f}dB')
    print(f'MS-SSIM: {compute_msssim(x, rec):.4f}')
    print(f'Byte: {byte:.0f} Byte')
    print(f'Bit-rate: {bpp:.3f} bpp')
    print(f'Enc Time: {enc_time:.3f} s')
    print(f'Dec Time: {dec_time:.3f} s')
    print(f'Model Size: {model_size/1024**2:.2f}MB')

In [8]:
gt_path = './data/2K.png'
model_name = 'tinylic'
main_fp16(gt_path, model_name, 0.0008)

PSNR: 34.00dB
MS-SSIM: 0.9395
Byte: 83140 Byte
Bit-rate: 0.238 bpp
Enc Time: 0.298 s
Dec Time: 0.276 s
Model Size: 35.26MB


# INT8

In [9]:
def main_int8(gt_path, model_name, net, lrd):
    """
    Test INT8 Models in a single image;
    
    Args Input
        :param gt_path: the path to ground truth image
        :param model_name: the name of INT8 Models
        :param lrd: the lambda to control bit-rate, suggest [0.00005, 0.0009]
    
    Args Output
        :reconstruction image, saved in ./results/...
    """
    
    device = 'cuda'
    os.makedirs('./results/'+model_name+'/outputs/INT8',exist_ok=True)
    model_size = filesize("./results/"+model_name+ "/outputs/INT8.pth")
    
    # load images    
    img = Image.open(gt_path).convert('RGB')
    x = transforms.ToTensor()(img).unsqueeze(0).to(device)
        
    gt_down = transforms.ToPILImage()(x.squeeze().cpu())
    
    gt_down.save('./results/'+model_name+'/outputs/INT8/gt.png', format="PNG")

    p = 64
    h, w = x.size(2), x.size(3)
    x_pad = pad(x, p)
    
    # inference
    with torch.no_grad():
        lambda_rd = torch.tensor([lrd]).to(device)
        
        # compress
        torch.cuda.synchronize()
        start_compress = time.time()
        out_enc = net.compress(x_pad, lambda_rd)
        torch.cuda.synchronize()
        end_compress = time.time()
        
        # decompress
        torch.cuda.synchronize()
        start_decompress = time.time()
        out = net.decompress(out_enc["strings"], out_enc["shape"], lambda_rd)
        torch.cuda.synchronize()
        end_decompress = time.time()
        
        # save
        rec = crop(out['x_hat'], (h, w))
        rec_tic = transforms.ToPILImage()(rec.squeeze().cpu())
        rec_tic.save('./results/'+model_name+'/outputs/INT8/rec.png', format="PNG")

            
    # result
    num_pixels = x.size(0) * x.size(2) * x.size(3)
    byte = sum(len(s[0]) for s in out_enc["strings"])
    bpp = byte * 8.0 / num_pixels
    enc_time = end_compress - start_compress
    dec_time = end_decompress - start_decompress

    print(f'PSNR: {compute_psnr(x, rec):.2f}dB')
    print(f'MS-SSIM: {compute_msssim(x, rec):.4f}')
    print(f'Byte: {byte:.0f} Byte')
    print(f'Bit-rate: {bpp:.3f} bpp')
    print(f'Enc Time: {enc_time:.3f} s')
    print(f'Dec Time: {dec_time:.3f} s')
    print(f'Model Size: {model_size/1024**2:.2f}MB')

In [10]:
model_name = 'tinylic'

# model quantize and parameter warmup
net = quantize(model_name)

In [11]:
gt_path = './data/2K.png'
main_int8(gt_path, model_name, net, 0.0008)

PSNR: 33.31dB
MS-SSIM: 0.9554
Byte: 83304 Byte
Bit-rate: 0.239 bpp
Enc Time: 0.375 s
Dec Time: 0.353 s
Model Size: 19.29MB
