This repository contains the official PyTorch implementation of the paper: Xili Dai, Shengbang Tong, Mingyang Li, Ziyang Wu, Michael Psenka, Kwan Ho Ryan Chan, Pengyuan Zhai, Yaodong Yu, Xiaojun Yuan, Heung Yeung Shum, Yi Ma. "Closed-Loop Data Transcription to an LDR via Minimaxing Rate Reduction". Special Issue "Information Theory and Machine Learning" of Entropy.
This work proposes a new computational framework for learning a structured generative model for real-world datasets. In particular, we propose a framework for closed-loop data transcription between a multi-class, high-dimensional data distribution and a linear discriminative representation (CTRL) in the feature space that consists of multiple independent multi-dimensional linear subspaces. This new framework unifies the concepts and benefits of auto-encoding (AE) and generative adversarial networks (GAN). It naturally extends AE and GAN concepts to the setting of learning a discriminative and generative representation for multi-class, high-dimensional, real-world data. Our extensive experiments on many benchmark image datasets demonstrate tremendous potential of this new closed-loop formulation: under fair comparison, visual quality of the learned decoder and classification performance of the encoder are competitive and often better than existing methods based on GAN, VAE, or a combination of both. We hope that this repository serves as a reproducible baseline for future research in this area.
The encoder f has dual roles: it learns an LDR z for the data x via maximizing the rate reduction of z and it is also a “feedback sensor” for any discrepancy between the data x and the decoded \hat{x}. The decoder g also has dual roles: it is a “controller” that corrects the discrepancy between x and \hat{x} and it also aims to minimize the overall coding rate for the learned LDR.
For ease of reproducibility, we suggest you install Miniconda
(or Anaconda
if you prefer) before executing the following commands.
git clone https://github.com/Delay-Xili/LDR
cd LDR
conda create -y -n clt
source activate clt
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
pip install git+https://github.com/kwotsin/mimicry.git
mkdir data logs
Note: we highly encourage you to use a version of torch
later then 1.10.0
, since it gives a large speedup when computing torch.logdet
.
More installation details can be found here.
To retrain the neural network from scratch on your own machine, execute the following commands
CUDA_VISIBLE_DEVICES=0 python main.py --cfg experiments/mnist.yaml DATA.ROOT pth/to/the/dataset
CUDA_VISIBLE_DEVICES=0 python main.py --cfg experiments/tmnist.yaml DATA.ROOT pth/to/the/dataset
CUDA_VISIBLE_DEVICES=0 python main.py --cfg experiments/cifar10.yaml DATA.ROOT pth/to/the/dataset
CUDA_VISIBLE_DEVICES=0,1 python main.py --cfg experiments/stl10.yaml DATA.ROOT pth/to/the/dataset
CUDA_VISIBLE_DEVICES=0,1,2 python main.py --cfg experiments/CelebA.yaml DATA.ROOT pth/to/the/dataset
CUDA_VISIBLE_DEVICES=0,1,2 python main.py --cfg experiments/LSUN.yaml DATA.ROOT pth/to/the/dataset
CUDA_VISIBLE_DEVICES=0,1,2 python main.py --cfg experiments/ImageNet.yaml DATA.ROOT pth/to/the/dataset
Some hyper-parameters can be changed directly in the corresponding xxx.yaml
file.
We run the experiments on an NVIDIA RTX 3090 with 24GB memory.
Adjust the CUDA_VISIBLE_DEVICES
parameter based on available GPUs.
You can download our trained models from the following links:
Datasets | Models | Results |
---|---|---|
MNIST | mini dcgan | link |
TMNIST | mini dcgan | link |
CIFAR-10 | mini dcgan | link |
CIFAR-10 | sngan32 | TBD |
STL-10 | sngan48 | TBD |
CelebA | sngan128 | link |
LSUN | sngan128 | link |
ImageNet | sngan128 | link |
Each link includes the corresponding results, which consists of three items: checkpoints, images, and data.
checkpoints: including all saved checkpoint files of the generator and discriminator during the training.
images: including all saved input and reconstructed images during the training.
data: including the Tensorboard file which records the losses and learning rates of discriminator and generator during the training process.
To evaluate the FID and IS score of your checkpoints under checkpoints/
, execute
CUDA_VISIBLE_DEVICES=0 python evaluation.py --cfg experiments/mnist.yaml EVAL.NETD_CKPT path/to/netD/ckpt EVAL.NETG_CKPT path/to/netG/ckpt
CUDA_VISIBLE_DEVICES=0 python evaluation.py --cfg experiments/cifar10.yaml EVAL.NETD_CKPT path/to/netD/ckpt EVAL.NETG_CKPT path/to/netG/ckpt
To test the accuracy of your learned discriminator, execute
CUDA_VISIBLE_DEVICES=0 python test_acc.py --cfg pth/to/mnist/result/config.yaml --ckpt_epochs 4500 EVAL.DATA_SAMPLE 1000
CUDA_VISIBLE_DEVICES=0 python test_acc.py --cfg pth/to/cifar/result/config.yaml --ckpt_epochs 45000 EVAL.DATA_SAMPLE 1000
MNIST classification accuracy: 97.69%, CIFAR-10 classification accuracy: 73.05%.
If you find CLT useful in your research, please consider citing:
@article{dai2021closed,
title={CTRL: Closed-Loop Transcription to an LDR via Minimaxing Rate Reduction},
author={Dai, Xili and Tong, Shengbang and Li, Mingyang and Wu, Ziyang and Chan, Kwan Ho Ryan and Zhai, Pengyuan and Yu, Yaodong and Psenka, Michael and Yuan, Xiaojun and Shum, Heung Yeung and others},
journal = {Entropy},
volume = {24},
year = {2022},
number = {4},
article-number = {456},
url = {https://www.mdpi.com/1099-4300/24/4/456},
issn = {1099-4300},
doi = {10.3390/e24040456}
}
See LICENSE for details.