PyTorch implementation of four neural image compression models of Transformer-based Transform Coding preseneted at ICLR 2022.
4 models are implemented in compressai/models/qualcomm.py
: SwinT-ChARM, SwinT-Hyperprior, Conv-ChARM, Conv-Hyperprior.
Kodak Rate-Distortion Performance1 |
---|
Models are trained with rate-distortion objective of
Model | Size | #Param | checkpoint | TensorBoard logs |
Kodak [bpp] / [dB] |
GMACs 2 (ENC/DEC) |
#steps | |
---|---|---|---|---|---|---|---|---|
Conv-Hyperprior | "M" | 21.4M | 0.01 | link | link | 0.43 / 33.03 | 99 / 350 | 2M |
Conv-ChARM | "M" | 29.1M | 0.01 | link | link | 0.41 / 33.17 | 111 / 361 | 2M |
SwinT-Hyperprior | "M" | 24.7M | 0.01 | link | link | 0.38 / 32.67 | 99 / 99 | 2M |
SwinT-ChARM | "M" | 32.4M | 0.01 | link | link | 0.37 / 33.07 | 110 / 110 | 2M |
Models' configurations are defined in a python dictionay object named cfgs
in compressai/zoo/image.py as described in Section A.3 of Transformer-based Transform Coding.
Conv-Hyperprior | Conv-ChARM |
---|---|
SwinT-Hyperprior | SwinT-ChARM |
A local clone of the CompressAI is provided to make the model integration easier.
In a virtual environment follow the steps below (verified on Ubuntu):
git clone https://github.com/ali-zafari/TBTC TBTC
cd TBTC
pip install -U pip
pip install -e .
pip install lightning==2.0.2
pip install tensorboard
CLIC-2020 is used for training, described below.
- Training
1631
images with resolution of at least 256x256 pixels chosen from union ofMobile/train
andProfessional/train
- Validation
32
images with resolution of at least 1200x1200 pixels chosen fromProfessional/valid
Kodak test set is used to evaluate the final trained model.
- Test
24
RGB images of size 512x768 pixels
All three data subsets described above can be downloaded from this link (5.8GB).
All the configurations regarding dataloader, training strategy, and etc should be set in the lit_config.py
followed by the command:
python lit_train.py --comment "simple comment for the experiment"
To evaluate a saved checkpoint of a model, compressai.utils.eval
is used. An example to test the rate-distoriton perfomance of a SwinT-ChARM checkpoint:
python -m compressai.utils.eval_model checkpoint path/to/data/directory -a zyc2022-swint-charm --cuda -v -p path/to/a/checkpoint
This unofficial PyTorch implementation follows the CompressAI code structure and then is wrapped by the Lightning framework. Tensorflow implementation of SwinT-ChARM is used as the reference.
The design paradigm of CompressAI is closely followed which results to modifications/additions in the following directories. Lightning-based python files are also shown below:
|---compressai
| |---losses
| | ├───rate_distortion.py rate-disortion loss
| |---layers
| | ├───swin.py blocks needed by TBTC models
| |---models
| | ├───qualcomm.py TBTC models
| |---zoo
| ├───image.py model creation based on config
|
├───lit_config.py configuration file
├───lit_data.py lighting data-module
├───lit_model.py lightning module
├───lit_train.py main script to start/resume training
- CompressAI: Neural comporession library in PyTorch (by InterDigital)
- NeuralCompression: Neural comporession library in PyTorch (by Meta)
- SwinT-ChARM: Unofficial Tensorflow implementation
- STF: Window-based attention in neural image compression
- Lightning: PyTorch framework for training abstraction
@inproceedings{zhu2022tbtc,
title={Transformer-based transform coding},
author={Zhu, Yinhao and Yang, Yang and Cohen, Taco},
booktitle={International Conference on Learning Representations},
year={2022}
}
@article{begaint2020compressai,
title={CompressAI: a PyTorch library and evaluation platform for end-to-end compression research},
author={B{\'e}gaint, Jean and Racap{\'e}, Fabien and Feltman, Simon and Pushparaja, Akshay},
year={2020},
journal={arXiv preprint arXiv:2011.03029},
}