SaSPA: Advancing Fine-Grained Classification by Structure and Subject Preserving Augmentation with Diffusion Models
git clone https://github.com/EyalMichaeli/SaSPA-Aug.git
cd SaSPA-Augconda env create -f environment.yml
conda activate saspaFor a quick setup, we will run on the planes dataset. The dataset is downloaded automatically.
Download the planes.pth checkpoint from Google Drive into all_utils/checkpoints/planes
run run_aug/run_aug.py. This script generates the data and at the end filters the data. The augmentations will be saved to data/FGVC-Aircraft/fgvc-aircraft-2013b/data/aug_data/controlnet/sd_v1.5/canny/gpt-meta_class_prompt_w_sub_class_artistic_prompts_p_0.5_seed_1 together with a log file and after the generation is done, a JSON file that contains the original files paths and their respective augmentations.
Once you have the JSON file, copy its path to fgvc/trainings_scripts/consecutive_runs_aug.sh, under the variable aug_json. Then, run with
bash fgvc/trainings_scripts/consecutive_runs_aug.shThat's it!
You should see your training start at <repo_path>/logs/dataset_name/.
-
Aircraft, Cars, and DTD: Downloaded automatically via torchvision to the local folder
data/<dataset_name>.. -
CUB: Download from Caltech-UCSD Birds-200-2011 to
data/CUB. Ensure the folder structure is as follows: <repo_path>/CUB/CUB_200_2011/... It includes folders such as 'images' and 'parts'. -
CompCars: Download from CompCars dataset page to
data/compcars. Ensure the folder structure is as follows: <repo_path>/data/compcars/... It includes folders such as 'image', 'label' and 'train_test_split'. -
Stanford Cars: TorchVision no longer supports automatic download (details).
Download the dataset manually via Kaggle
or with the Kaggle API:import kaggle # Configure your Kaggle API key first: https://www.kaggle.com/docs/api kaggle.api.dataset_download_files( 'rickyyyyyyy/torchvision-stanford-cars', path='data/stanford_cars', unzip=True )Then download cars_test_annos_withlabels.mat from
here
and place it indata/stanford_cars/stanford_cars/.Expected layout:
data/stanford_cars/stanford_cars/ βββ cars_train/ βββ cars_test/ βββ devkit/ βββ cars_test_annos_withlabels.mat
If the original dataset does not include a validation set, file names splits are provided in fgvc/datasets_files and are loaded automatically.
In our experiments, we utilize Weights & Biases (wandb) for training monitoring. The training script auto-connects to wandb. To disable this, set the DONT_WANDB variable in train.py to True.
For our filtering, we need a baseline model trained on the original dataset. We provide with pre-trained checkpoints for each dataset used in our paper in Google Drive, please either download or train a baseline model, and move the checkpoint to the folder all_utils/checkpoints/<dataset_name>.
To create the prompts using GPT-4, follow the instructions in the paper.
The generated prompts should be in prompts_engineering/gpt_prompts, which currently contain our generated prompts.
The generation code is located at: run_aug/run_aug.py.
Choose a dataset and ensure that BASE_MODEL = "blip_diffusion" and CONTROLNET = "canny". If you don't want to use blip_diffusion, you can use other base models such as sd_v1.5 or sd_xl-turbo (Currently it's set for sd_v1.5 because it's better for the Aircraft dataset, for all other datasets, set BASE_MODEL = "blip_diffusion").
The code will generate augmentations in the folder <dataset_root>/aug_data and then will automatically generate a JSON file with the filtered augmentations.
Once you have the JSON file, copy its path to trainings_scripts/consecutive_runs_aug.sh, under the variable aug_json.
Make sure the correct dataset is specified in the dataset variable and fill in the rest of the arguments (GPU ID, run_name, etc.). Currently, the appropriate arguments for training, such as augmentation ratio and traditional augmentation used, are automatically chosen in the script based on the dataset name.
After the script args are ready, run with
bash fgvc/trainings_scripts/consecutive_runs_aug.shThat's it!
You should see your training start at <repo_path>/logs/<dataset_name>/.
To incorporate new datasets into the project, follow these structured steps:
- Prompt Creation: Begin by generating and adding new prompts to
prompts_engineering/gpt_prompts. - Add a new dataset class within
all_utils/dataset_utils.pyto manage dataset-specific functionalities. - Dataset Module Implementation: Add a new Python file in the
fgvc/datasetsfolder. - Dataset Config: Add a new Python file with training Hyper-parameters in the
fgvc/configsfolder. - Baseline Model Training: Train a baseline model to ensure the new dataset is correctly integrated and functional. This model will also be used in the filtering process.
- Follow Standard Procedures: Proceed with the regular augmentation and training workflows as documented in Running the Code.
If you find our work useful, we welcome citations:
@inproceedings{
michaeli2024advancing,
title={Advancing Fine-Grained Classification by Structure and Subject Preserving Augmentation},
author={Eyal Michaeli and Ohad Fried},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=MNg331t8Tj}
}- You might need to re-install PyTorch according to your server.
We extend our gratitude to the following resources for their significant contributions to our project:
- CAL Repository: Visit CAL Repo for more details.
- Diffusers Package: Learn more about the Diffusers package at Hugging Face Diffusers Documentation.
