Skip to content

Jordan-HS/Diversity_is_Definitely_Needed

Repository files navigation

Diversity is Definitely Needed: Improving Model-Agnostic Zero-shot Classification via Stable Diffusion (CVPRW)

Jordan Shipard1, Arnold Wiliem1,2, Kien Nguyen Thanh1, Wei Xiang3, Clinton Fookes1

1Signal Processing, Artificial Intelligence and Vision Technologies (SAIVT), Queensland University of Technology, Australia
2Sentient Vision Systems, Australia
3School of Computing, Engineering and Mathematical Sciences, La Trobe University, Australia

Accepted to the Generative Models for Computer Vision Workshop at CVPR 2023

Model-Agnostic Zero-Shot Classifcation

Requirements

For training and testing

  • Python 3.8+
  • Pytorch 1.13.0
  • Torchvision 0.14.0

For dataset generation

Synthetic Datasets

Datasets are hosted on Zenodo with the download links provided in the table below.

Dataset Download
CIFAR10 Base Class cifar10_generated_32A.tar.gz
CIFAR10 Class Prompt cifar10_generated_class_prompt_32A.tar.gz
CIFAR10 Multi-Domain cifar10_generated_multidomain_32A.tar.gz
CIFAR10 Random Guidance cifar10_generated_random_scale_32A.tar.gz
CIFAR10 Merged cifar10_generated_merged_32A.tar.gz
CIFAR100 Base Class cifar100_generated_32A.tar.gz
CIFAR100 Multi-Domain cifar100_generated_multidomain_32A.tar.gz
CIFAR100 Random Scale cifar100_generated_random_scale_32A.tar.gz
CIFAR100 Merged Cifar100_generated_merged_32A.tar.gz
EuroSAT Base Class EuroSat_generated_64.tar.gz
EuroSAT Random Scale EuroSat_generated_random_scale_64.tar.gz
EuroSAT Merged EuroSat_generated_merged_64.tar.gz

These are the exact generated synthetic datasets and images used to train the networks in the paper. All datasets were generated using Stable Diffusion V1.4. '32A' refers to the image size of 32x32 pixels, which was resized from 512x512 with anti-aliasing. '64' is 64x64 resized from 512x512 without anti-aliasing. Only the datasets which improve performance above the base class (e.g. the best tricks) are currently hosted. If you would like any of the other datasets from the paper either raise an issue, or email me at jordan.shipard@hdr.qut.edu.au.

Usage

How to generate your own synthetic datasets

You can generate your own synthetic datasets using one of the tricks with the create_dataset.py file. First, ensure the file is located in the same directoy as your Stable Diffusion repository as the file will attempt to run scripts/txt2img.py.
e.g. python create_dataset.py --classes dog cat --trick class_prompt --outdir synthetic_datasets/cats_and_dogs

This file has the following arguments:

General arguments
  • --classes A list of the class labels used in generating images.
  • --trick The specific trick you wish to use when generating the dataset. Limited to "class_prompt", "multidomain", "random_scale".
  • --outdir The directory to save the generated images in.
Multi-domain arguments
  • --domains A list of domains to use when generating the synthetic images.
Random scale arguments
  • --min_scale The minimum possible value for the unconditional random guidance. Default 1.
  • --max_scale The maximum possible value for the unconditional random guidance. Default 5.
Stable Diffusion arguments
  • --n_samples The number of images to produce in a single round of generation. Default 2.
  • --n_iter The number of iterations to run of producing n_samples numbers of imgaes. Default 1000.
  • --ddim_stepsThe number of DDIM sampling steps. Default 40.
  • --seed The seed (for reproducible sampling). Default 64.
  • --H The height of the images to generate. Default 512.
  • --W The width of the images to generate. Default 512.

How to train on the synthetic dataset

You can train a network on a synthetic dataset while testing it on a real dataset using train_network.py.
e.g. python train_network.py --model Vit-B --epoch 10 --batch_size 32 --dataset cifar100_generated_32A --syn_data_location data/synthetic_cifar10 --real_data_location data/real_cifar10

Arguments
  • --model The image classification model to train. Limited to MBV3, Vit-B, Vit-S, RS50, RS101, convnext, convnext-s
  • --epoch Defualt 50. The number of epochs to train for.
  • --batch_Size Defualt 64. The batch size to use for training and testing.
  • --dataset The name of the synthetic dataset to use, e.g. cifar100_generated_32A. NOTE: The dataset needs to already be extracted from its .tar.gz compressed version if downloaded from one of the above links.
  • --img_size Defualt 32. The image size of the dataset, can be used to resize the dataset.
  • --lr Defualt 1e-4. The initial learning rate.
  • --wd Default 0.9. Weight decay used in training.
  • --model_path Optional. Path to a pytorch model, the script will then load the weights from this path.
  • --wandb Default False. This flags the use of the wandb logger. NOTE: Please check the init settings for the logger variable inside the training script if you wish to use the wandb logger.
  • --syn_data_location The location of the synthetic dataset.
  • --real_data_location The location of the real dataset.

How to test on the real dataset

You can test networks using eval_network.py.
e.g. python eval_network.py --model Vit-B --model_path saved_models/trained_model.pt --dataset cifar10 --real_data_location data/real_cifar10

Arguments
  • --model The image classification model to train. Limited to MBV3, Vit-B, Vit-S, RS50, RS101, convnext, convnext-s
  • --batch_Size Defualt 64. The batch size to use for training and testing.
  • --dataset The name of the real dataset to use, e.g. cifar100.
  • --img_size Defualt 32. The image size of the dataset, can be used to resize the dataset.
  • --real_data_location The location of the real dataset.

Acknowledgements

This work has been supported by the SmartSat CRC, whose activities are funded by the Australian Government’s CRC Program; and partly supported by Sentient Vision Systems. Sentient Vision Systems is one of the leading Australian developers of computer vision and artificial intelligence software solutions for defence and civilian applications.

Citation

@inproceedings{
  shipard2023DDN,
  title={Diversity is Definitely Needed: Improving Model-Agnostic Zero-shot Classification via Stable Diffusion},
  author={Jordan Shipard, Arnold Wiliem, Kien Nguyen Thanh, Wei Xiang, Clinton Fookes},
  booktitle={Computer Vision and Pattern Recognition Workshop on Generative Models for Computer Vision},
  year={2023},
  url={https://arxiv.org/pdf/1908.09791.pdf}
}

About

[CVPRW 2023] Diversity is Definitely Needed: Improving Model-Agnostic Zero-shot Classification via Stable Diffusion

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages