Official Implementation
for NAACL 2024 Findings paper PEEB: Part-based Image Classifiers with an Explainable and Editable Language Bottleneck
by Thang M. Pham, Peijie Chen, Tin Nguyen, Seunghyun Yoon, Trung Bui, Anh Totti Nguyen.
TLDR: We proposed a part-based fine-grained image classifier (identifying 🐦 or 🐕) that makes predictions by matching visual object parts detected in the input image with textual part-wise descriptions of each class. Our method directly utilizes human-provided descriptions (in this work, from GPT4). PEEB outperforms CLIP (2021) and M&V (2023) by +10 points on CUB and +28 points on NABirds in generalized zero-shot settings. While CLIP and its extensions depend heavily on the prompt to contain the known class name, PEEB relies solely on the textual descriptors and does not use class names and therefore generalizes to objects where class names or examplar photos are not unknown.
If you use this software, please consider citing:
@article{auburn2024peeb,
title={PEEB: Part-based Image Classifiers with an Explainable and Editable Language Bottleneck},
author={Pham, Thang M and Chen, Peijie and Nguyen, Tin and Yoon, Seunghyun and Bui, Trung and Nguyen, Anh},
booktitle = "Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics",
month = jul,
year = "2024",
address = "Mexico city, Mexico",
publisher = "Association for Computational Linguistics"
}
Interactive Demo that shows how one can edit a class' descriptors (during inference) to directly modify the classifier.
We provide our pre-trained models as well as all fine-tuned models used in this work. See here for details.
git clone https://github.com/{username}/xclip && cd xclip
conda create -n xclip python=3.10
conda activate xclip
pip install -r requirements.txt
conda install pycocotools -c conda-forge
python -m spacy download en_core_web_sm
We do not redistribute the datasets; we provide a metadata
of the combined dataset. To prepare the Bird-11K
dataset, please follow Data preparation to prepare the data if you would like to train the model.
Teacher logits are required for the training. Please make sure to finish all steps in Data preparation before starting training.
- Before running experiments, we need to export the
PYTHONPATH
variable to the current directory and its sub-directories and change the directory to thesrc
folder as follows.
# Assume the current directory is xclip
export PYTHONPATH=$(pwd):$(pwd)/src:$(pwd)/mmdetection
cd src/
Make the following changes to src/config.py
:
- Set
DATASET_DIR
to{DATA_PATH}
- Set
PRECOMPUTED_DIR
to{DATA_PATH}
Our full training commands can be found here.
We have two steps in our pre-training,
Step 1: Part-text mapping, e.g., for pre-training at excluding test sets of CUB, NABirds and iNaturalist
python train_owl_vit.py --model owlvit-base-patch32 --dataset bird_soup --sub_datasets all --descriptors chatgpt --prompt_type 0 --batch_size 32 --batch_size_val 50 --save_freq 1 --num_workers 8 --devices 0 --epochs 32 --lr 0.0002 --weight_decay 0.01 --project_name stage1_pretraining --loss_weights 0,0,0,0,1 --network_type contrastive --freeze_box_heads --logits_from_teacher --num_negatives_train 48 --num_negatives_val 50 --early_stopping 5 --train_file "../data/bird_11K/metadata/level_1_exclude_cub_nabirds_inat/train_keep_child_a100_reindexed.h5" --val_file "../data/bird_11K/metadata/level_1_exclude_cub_nabirds_inat/val_keep_child_a100_reindexed.h5" --test_file "../data/bird_11K/metadata/level_1_exclude_cub_nabirds_inat/test_cub_reindexed.h5" --birdsoup_level 1 --note "stage1_pretraining_BIRD-11K_test"
See 1st_training.sh for other commands.
Step 2: Box prediction calibration, e.g.,
python train_owl_vit.py --model owlvit-base-patch32 --dataset bird_soup --sub_datasets all --descriptors chatgpt --prompt_type 0 --batch_size 32 --batch_size_val 50 --save_freq 1 --num_workers 8 --devices 0 --epochs 32 --lr 0.00002 --weight_decay 0.01 --project_name stage2_pretraining --loss_weights 0,1,1,2,0 --network_type contrastive --train_box_heads_only --num_negatives_train 48 --num_negatives_val 50 --early_stopping 5 --train_file "../data/bird_11K/metadata/level_1_exclude_cub_nabirds_inat/train_keep_child_a100_reindexed.h5" --val_file "../data/bird_11K/metadata/level_1_exclude_cub_nabirds_inat/val_keep_child_a100_reindexed.h5" --test_file "../data/bird_11K/metadata/level_1_exclude_cub_nabirds_inat/test_cub_reindexed.h5" --best_model "" --birdsoup_level 1 --note "stage2_pretraining_BIRD-11K_test"
See 2nd_training.sh for other commands.
To fine-tune the model on a specific downstream dataset, e.g., CUB
python train_owl_vit.py --model owlvit-base-patch32 --dataset bird_soup --sub_datasets all --descriptors chatgpt --prompt_type 0 --batch_size 32 --save_freq 1 --num_workers 8 --devices 0 --epochs 30 --lr 0.00002 --weight_decay 0.001 --project_name finetuning --loss_weights 0,1,1,1,1 --network_type classification --classification_loss ce_loss --early_stopping 5 --train_file "../data/bird_11K/metadata/finetuning/cub_train_reindexed.h5" --val_file "../data/bird_11K/metadata/finetuning/cub_val_reindexed.h5" --test_file "../data/bird_11K/metadata/finetuning/cub_test_reindexed.h5" --best_model "" --birdsoup_level 1 --finetuning "vision_encoder_mlp" --note "all_components_cub_200"
See 3rd_training_cub_200.sh and 3rd_training_zeroshot.sh for fine-tuning on CUB and other zeroshot dataset.
For evaluation only, you may run the following script (for CUB):
python train_owl_vit.py --model owlvit-base-patch32 --dataset bird_soup --sub_datasets all --descriptors chatgpt --prompt_type 0 --batch_size 32 --num_workers 8 --devices 0 --loss_weights 0,1,1,1,1 --network_type classification --eval_test --no_log --test_file "../data/bird_11K/metadata/finetuning/cub_test_reindexed.h5" --best_model "" --birdsoup_level 1
Evaluation commands of other test sets can be found in here.