Skip to content

PrasannaPulakurthi/MMD-AdversarialNAS

Repository files navigation

ENHANCING GAN PERFORMANCE THROUGH NEURAL ARCHITECTURE SEARCH AND TENSOR DECOMPOSITION (MMD-AdversarialNAS)

Code for our ICASSP 2024 paper "Enhancing GAN Performance Through Neural Architecture Search and Tensor Decomposition".

by Prasanna Reddy Pulakurthi, Mahsa Mozaffari, Sohail A. Dianat, Majid Rabbani, Jamison Heard, and Raghuveer Rao.

Please consider citing our paper in your publications if it helps your research. The following is a BibTeX reference.

@INPROCEEDINGS{10446488,
  author={Pulakurthi, Prasanna Reddy and Mozaffari, Mahsa and Dianat, Sohail A. and Rabbani, Majid and Heard, Jamison and Rao, Raghuveer},
  booktitle={ICASSP 2024 - 2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 
  title={Enhancing GAN Performance Through Neural Architecture Search and Tensor Decomposition}, 
  year={2024},
  volume={},
  number={},
  pages={7280-7284},
  keywords={Training;Performance evaluation;Tensors;Image coding;Image synthesis;Image edge detection;Computer architecture;Neural Architecture Search;Maximum Mean Discrepancy;Generative Adversarial Networks},
  doi={10.1109/ICASSP48485.2024.10446488}
}

Qualitative Results

All Visual Results

Quantitative Results

Quantitative Results

Repeatability

Reproducibility Results

Getting Started

Installation

  1. Clone this repository.

    git clone https://github.com/PrasannaPulakurthi/MMD-AdversarialNAS.git
    cd MMD-AdversarialNAS
    
  2. Install requirements using Python 3.9.

    conda create -n mmd-nas python=3.9
    conda activate mmd-nas
    pip install -r requirements.txt
    
  3. Install Pytorch1 and Tensorflow2 with CUDA.

    pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
    

    To install other Pytroch versions compatible with your CUDA. Install Pytorch

    Install Tensorflow

Preparing necessary files

Files can be found in Google Drive.

  1. Download the pre-trained models to ./exps

  2. Download the pre-calculated statistics to ./fid_stat for calculating the FID.

Instructions for Testing, Training, Searching, and Compressing the Model.

Testing

  1. Download the trained generative models Google Drive to ./exps/arch_train_cifar10_large/Model

    mkdir -p exps/arch_train_cifar10_large/Model
    
  2. To test the trained model, run the command found in scripts/test_arch_cifar10.sh

    python MGPU_test_arch.py --gpu_ids 0 --num_workers 8 --dataset cifar10 --bottom_width 4 --img_size 32 --arch arch_cifar10 --draw_arch False --checkpoint arch_train_cifar10_large --genotypes_exp arch_cifar10 --latent_dim 120 --gf_dim 256 --num_eval_imgs 50000 --eval_batch_size 100 --exp_name arch_test_cifar10_large
    

Training

  1. Train the weights of the generative model with the searched architecture (the architecture is saved in ./exps/arch_cifar10/Genotypes/latest_G.npy). Run the command found in scripts/train_arch_cifar10_large.sh

    python MGPU_train_arch.py --gpu_ids 0 --num_workers 8 --gen_bs 128 --dis_bs 128 --dataset cifar10 --bottom_width 4 --img_size 32 --max_epoch_G 500 --n_critic 1 --arch arch_cifar10 --draw_arch False --genotypes_exp arch_cifar10 --latent_dim 120 --gf_dim 256 --df_dim 512 --g_lr 0.0002 --d_lr 0.0002 --beta1 0.0 --beta2 0.9 --init_type xavier_uniform --val_freq 5 --num_eval_imgs 50000 --exp_name arch_train_cifar10_large
    

Searching the Architecture

  1. To use AdversarialNAS to search for the best architecture, run the command found in scripts/search_arch_cifar10.sh

    python MGPU_search_arch.py --gpu_ids 0 --gen_bs 128 --dis_bs 128 --dataset cifar10 --bottom_width 4 --img_size 32 --max_epoch_G 25 --arch search_both_cifar10 --latent_dim 120 --gf_dim 160 --df_dim 80 --g_spectral_norm False --d_spectral_norm True --g_lr 0.0002 --d_lr 0.0002 --beta1 0.0 --beta2 0.9 --init_type xavier_uniform --n_critic 5 --val_freq 5 --derive_freq 1 --derive_per_epoch 16 --draw_arch False --exp_name search/bs120-dim160 --num_workers 8 --gumbel_softmax True
    

Compression

  1. Compress and Finetune all the Convolutional Layers except 9 and 13.

    python MGPU_cpcompress_arch.py --gpu_ids 0 --num_workers 8 --dataset cifar10 --bottom_width 4 --img_size 32 --arch arch_cifar10 --draw_arch False --genotypes_exp arch_cifar10  --latent_dim 120 --gf_dim 256 --df_dim 512 --num_eval_imgs 50000 --eval_batch_size 100 --checkpoint arch_train_cifar10_large  --exp_name compress_train_cifar10_large --val_freq 5  --gen_bs  128 --dis_bs 128 --beta1 0.0 --beta2 0.9  --byrank --rank 256 --layers cell1.c0.ops.0.op.1 cell1.c1.ops.0.op.1 cell1.c2.ops.0.op.1 cell1.c3.ops.0.op.1 cell2.c0.ops.0.op.1 cell2.c2.ops.0.op.1 cell2.c3.ops.0.op.1 cell2.c4.ops.0.op.1 cell3.c1.ops.0.op.1 cell3.c2.ops.0.op.1 cell3.c3.ops.0.op.1 --compress-mode "allatonce" --max_epoch_G 500 --eval_before_compression
    
  2. Compress the Fully Connected Layers except l1.

    python MGPU_cpcompress_arch.py --gpu_ids 0 --num_workers 8 --dataset cifar10 --bottom_width 4 --img_size 32 --arch arch_cifar10 --draw_arch False --genotypes_exp arch_cifar10  --latent_dim 120 --gf_dim 256 --df_dim 512 --num_eval_imgs 50000 --eval_batch_size 100 --checkpoint compress_train_cifar10_large  --exp_name compress_train_cifar10_large --val_freq 5  --gen_bs  128 --dis_bs 128 --beta1 0.0 --beta2 0.9  --byrank --rank 4 --layers l2 l3 --freeze_layers l2 l3 --compress-mode "allatonce" --max_epoch_G 1 --eval_before_compression
    
  3. To Test the compressed network, download the compressed model from Google Drive to ./exps/compress_train_cifar10_large/Model

    python MGPU_test_cpcompress.py --gpu_ids 0 --num_workers 8 --dataset cifar10 --bottom_width 4 --img_size 32 --arch arch_cifar10 --draw_arch False --checkpoint compress_train_cifar10_large --genotypes_exp arch_cifar10 --latent_dim 120 --gf_dim 256 --num_eval_imgs 50000 --eval_batch_size 100 --exp_name compress_test_cifar10_large  --byrank
    

Acknowledgement

Codebase from AdversarialNAS, TransGAN, and Tensorly.

About

Enhancing GAN Performance Through Neural Architecture Search and Tensor Decomposition

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published