Skip to content

PyTorch implementation of Transformer-based Transform Coding [ICLR 2022]

Notifications You must be signed in to change notification settings

ali-zafari/TBTC

Repository files navigation

Transformer-based Transform Coding (TBTC)

PyTorch implementation of four neural image compression models of Transformer-based Transform Coding preseneted at ICLR 2022.

4 models are implemented in compressai/models/qualcomm.py: SwinT-ChARM, SwinT-Hyperprior, Conv-ChARM, Conv-Hyperprior.

Kodak Rate-Distortion Performance1
Kodak-rate-distortion

Pretrained Models

Models are trained with rate-distortion objective of $R+\lambda D$ with fixed $\lambda$ value mentioned in the following table.

Model Size #Param $\lambda$ checkpoint TensorBoard
logs
Kodak
[bpp] / [dB]
GMACs 2
(ENC/DEC)
#steps
Conv-Hyperprior "M" 21.4M 0.01 link link 0.43 / 33.03 99 / 350 2M
Conv-ChARM "M" 29.1M 0.01 link link 0.41 / 33.17 111 / 361 2M
SwinT-Hyperprior "M" 24.7M 0.01 link link 0.38 / 32.67 99 / 99 2M
SwinT-ChARM "M" 32.4M 0.01 link link 0.37 / 33.07 110 / 110 2M

Model Architectures

Models' configurations are defined in a python dictionay object named cfgs in compressai/zoo/image.py as described in Section A.3 of Transformer-based Transform Coding.

Conv-Hyperprior Conv-ChARM
conv-hyperprior conv-charm
SwinT-Hyperprior SwinT-ChARM
swint-hyperprior swint-charm

Usage

A local clone of the CompressAI is provided to make the model integration easier.

Installation

In a virtual environment follow the steps below (verified on Ubuntu):

git clone https://github.com/ali-zafari/TBTC TBTC
cd TBTC
pip install -U pip
pip install -e .
pip install lightning==2.0.2
pip install tensorboard

Datasets

CLIC-2020 is used for training, described below.

  • Training
    • 1631 images with resolution of at least 256x256 pixels chosen from union of Mobile/train and Professional/train
  • Validation
    • 32 images with resolution of at least 1200x1200 pixels chosen from Professional/valid

Kodak test set is used to evaluate the final trained model.

  • Test
    • 24 RGB images of size 512x768 pixels

All three data subsets described above can be downloaded from this link (5.8GB).

Training

All the configurations regarding dataloader, training strategy, and etc should be set in the lit_config.py followed by the command:

python lit_train.py --comment "simple comment for the experiment"

Evaluation

To evaluate a saved checkpoint of a model, compressai.utils.eval is used. An example to test the rate-distoriton perfomance of a SwinT-ChARM checkpoint:

python -m compressai.utils.eval_model checkpoint path/to/data/directory  -a zyc2022-swint-charm --cuda -v -p path/to/a/checkpoint

Code Structure

This unofficial PyTorch implementation follows the CompressAI code structure and then is wrapped by the Lightning framework. Tensorflow implementation of SwinT-ChARM is used as the reference.

The design paradigm of CompressAI is closely followed which results to modifications/additions in the following directories. Lightning-based python files are also shown below:

|---compressai
|    |---losses
|    |    ├───rate_distortion.py       rate-disortion loss
|    |---layers
|    |    ├───swin.py                  blocks needed by TBTC models
|    |---models
|    |    ├───qualcomm.py              TBTC models
|    |---zoo
|         ├───image.py                 model creation based on config
|
├───lit_config.py                      configuration file
├───lit_data.py                        lighting data-module   
├───lit_model.py                       lightning module
├───lit_train.py                       main script to start/resume training

References/Citations

Repositories

Publications

@inproceedings{zhu2022tbtc,
  title={Transformer-based transform coding},
  author={Zhu, Yinhao and Yang, Yang and Cohen, Taco},
  booktitle={International Conference on Learning Representations},
  year={2022}
}

@article{begaint2020compressai,
	title={CompressAI: a PyTorch library and evaluation platform for end-to-end compression research},
	author={B{\'e}gaint, Jean and Racap{\'e}, Fabien and Feltman, Simon and Pushparaja, Akshay},
	year={2020},
	journal={arXiv preprint arXiv:2011.03029},
}

Footnotes

  1. Models in this RD curve are trained for 1M steps (only half of the total steps mentioned in the original paper).

  2. per input image size of 768x512