In [1]:
import os
import sys
import math
import time
import pickle
import random
import warnings
import argparse
import numpy as np
from PIL import Image
from pathlib import Path
from datetime import datetime
from pytorch_msssim import ms_ssim

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

from models.tinylic import TinyLIC
from quant_int import *
from utils import *

In [2]:
device = 'cuda'
warnings.filterwarnings("ignore")

In [3]:
def compute_psnr(a, b):
    mse = torch.mean((a - b)**2).item()
    return -10 * math.log10(mse)

def compute_msssim(a, b):
    return ms_ssim(a, b, data_range=1.).item()

def filesize(filepath: str) -> int:
    if not Path(filepath).is_file():
        raise ValueError(f'Invalid file "{filepath}".')
    return Path(filepath).stat().st_size

def pad(x, p=2 ** 6):
    h, w = x.size(2), x.size(3)
    H = (h + p - 1) // p * p
    W = (w + p - 1) // p * p
    padding_left = (W - w) // 2
    padding_right = W - w - padding_left
    padding_top = (H - h) // 2
    padding_bottom = H - h - padding_top
    return F.pad(
        x,
        (padding_left, padding_right, padding_top, padding_bottom),
        mode="constant",
        value=0,
    )

def crop(x, size):
    H, W = x.size(2), x.size(3)
    h, w = size
    padding_left = (W - w) // 2
    padding_right = W - w - padding_left
    padding_top = (H - h) // 2
    padding_bottom = H - h - padding_top
    return F.pad(
        x,
        (-padding_left, -padding_right, -padding_top, -padding_bottom),
        mode="constant",
        value=0,
    )

## Quantize FP32 to INT8 Models

In [4]:
parser = argparse.ArgumentParser(description='running parameters',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)

# general parameters for data and model
parser.add_argument('--seed', default=1005, type=int, help='random seed for results reproduction')
parser.add_argument('--name', default=datetime.now().strftime('%Y-%m-%d_%H_%M_%S'), type=str, help='result dir name')
parser.add_argument('--save', default=True, help='save quantized model')
parser.add_argument('--fp32_name',default='tinylic', help='fp32_model_path')

# quantization parameters
parser.add_argument('--n_bits_w', default=8, type=int, help='bitwidth for weight quantization')
parser.add_argument('--channel_wise', action='store_true', help='apply channel_wise quantization for weights')
parser.add_argument('--act_quant', default=True, help='apply activation quantization')
parser.add_argument('--test_before_calibration', default=True, type=bool, help='test_before_calibration')
parser.add_argument('--sym', default=True, help='symmetric reconstruction')

args = parser.parse_args(args=[])


In [5]:
def seed_all(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False  # may slow
    torch.backends.cudnn.deterministic = True

In [6]:
def validate_model(model):
    model.eval()
    device = next(model.parameters()).device
    
    sum_psnr = 0.0
    sum_msssim = 0.0
    sum_bpp = 0.0
    
    lambda_rd = torch.tensor([0.0007]).to(device)
    img_num = 1
    for i in range(1):
        img = Image.open('./data/Kodak/kodim'+str(i+1).zfill(2)+'.png').convert('RGB')
        x = transforms.ToTensor()(img).unsqueeze(0).to(device)
        
        p = 64
        h, w = x.size(2), x.size(3)
        x_pad = pad(x, p)
        
        with torch.no_grad():

            out_enc = model.compress(x_pad, lambda_rd)
            out = model.decompress(out_enc["strings"], out_enc["shape"], lambda_rd)
            rec = crop(out['x_hat'], (h, w))

        num_pixels = x.size(0) * x.size(2) * x.size(3)
        bpp = sum(len(s[0]) for s in out_enc["strings"]) * 8.0 / num_pixels


        sum_psnr += compute_psnr(x, rec)
        sum_msssim += compute_msssim(x, rec)
        sum_bpp += bpp

    return sum_psnr/img_num, sum_msssim/img_num , sum_bpp/img_num

In [7]:
def generator(qnn, args):
    device = next(qnn.parameters()).device
    img = Image.open('./data/Kodak/kodim'+str(23).zfill(2)+'.png').convert('RGB')
    x = transforms.ToTensor()(img).unsqueeze(0).to(device)
    
    p = 64
    h, w = x.size(2), x.size(3)
    x_pad = pad(x, p)
    lambda_rd = torch.tensor([0.0005]).to(device)
    
    # Initialize weight quantization parameters
    qnn.set_quant_state(True, args.act_quant)
    init_start = time.time()
   #  _ = qnn(x_pad)
    _ = qnn(x_pad, lambda_rd)
    init_time = time.time() - init_start
    logging.info('generate quantized model time: {}'.format(init_time))

    return qnn

In [8]:
def quantize_int8(args, fp32_name, output_dir, log_dir):
    """
    Quantize a FP32 Models to INT8 Models;
    Based on Post-Training-Quantization (PTQ):
        weight: channel-wise quantization;
        activation: layer-wise quantization(faster than channel-wise);
        
    Quantized Modules:
            hyper encoder: h_a
            hyper decoder: h_s
            entropy coding modules: entropycc_transforms, sc_transformers, entropy_parameters
    
    Args Input
        :param fp32_name: the name of FP32 Models
        :param output_dir: the path to INT8 Models
        :param log_dir: the path to logger 
    
    Args Output
        :INT8 Models, saved in ./results/...
    """
    
    # load model 
    model = TinyLIC(model_size = "80M")
    snapshot = torch.load('./pretrained/'+fp32_name+'.pth.tar', map_location=device)['state_dict']
    model.load_state_dict(snapshot, strict=False)
    model.update(force=True)
    model.entropy_bottleneck.update()
    model = model.to(device).eval()

    if args.test_before_calibration:
        logging.info('Full-precision model: psnr= {:.2f}; ms-ssim={:.4f}; bpp= {:.3f}'.format(*validate_model(model)))

    # build quantization parameters
    wq_params = {'n_bits': args.n_bits_w, 'channel_wise': True, 'symmetric': False, 'scale_method': 'max'}
    aq_params = {'channel_wise': False, 'symmetric': False, 'scale_method': 'max', 'leaf_param': True}
    qnn = QuantModel(model=model, weight_quant_params=wq_params, act_quant_params=aq_params)
    qnn.to(device)
    qnn.eval()
    # logging.info('quantized model architecture: {}'.format(qnn))

    # qnn.disable_network_output_quantization()
    qnn = generator(qnn, args)

    qnn.set_quant_state(weight_quant=True, act_quant=True)
    logging.info('INT8: psnr= {:.2f}; ms-ssim={:.4f}; bpp= {:.3f}'.format(*validate_model(qnn)))

    if args.save:
        logging.info('save quantized model in {}'.format(output_dir))
        torch.save(qnn.state_dict(), "{}/INT8.pth".format(output_dir))

In [9]:
def main_int8(fp32_name):
    args.fp32_name = fp32_name
    output_dir, log_dir = init_lic(args)

    seed_all(args.seed)
    setup_logger(log_dir + '/' + time.strftime('%Y%m%d_%H%M%S') + '.log')

    logging.info('[PID] %s'%os.getpid())
    msg = f'======================= TinyLIC ======================='
    logging.info(msg)

    quantize_int8(args, fp32_name, output_dir, log_dir)

In [10]:
fp32_name = 'tinylic'

main_int8(fp32_name)

2023-09-26 03:59:28,661 [INFO ]  Logging file is ./results/tinylic/logs/20230926_035928.log
2023-09-26 03:59:28,662 [INFO ]  [PID] 31946
2023-09-26 03:59:35,484 [INFO ]  Full-precision model: psnr= 24.28; ms-ssim=0.9060; bpp= 0.221
2023-09-26 03:59:40,886 [INFO ]  generate quantized model time: 5.162351131439209
2023-09-26 03:59:41,176 [INFO ]  INT8: psnr= 23.76; ms-ssim=0.8914; bpp= 0.195
2023-09-26 03:59:41,871 [INFO ]  save quantized model in ./results/tinylic/outputs


## Quantize FP32 to FP16 Models

In [11]:
def quant_fp16(fp32_name):
    """
    Quantize a FP32 Models to FP16 Models;
    Based on torch.half to convert;
    
    Args Input
    :param fp32_name: the name of FP32 Models
    
    Args Output
    :FP16 Models, saved in ./results/...
    """
    
    # load FP32 Models and update 
    model = TinyLIC()
    snapshot = torch.load('./pretrained/'+fp32_name+'.pth.tar', map_location=device)['state_dict']
    model.load_state_dict(snapshot, strict=False)
    model.update(force=True)
    
    # quantize to FP16
    model = model.half().to(device).eval()
    
    # save model, saving state_dict and full_model are both acceptable.
    # torch.save({'state_dict': model.state_dict()}, fp16_dir)
    torch.save(model, './results/'+fp32_name+'/outputs/FP16.pth')

In [12]:
fp32_name = 'tinylic'

quant_fp16(fp32_name)