Skip to content

[WACV 2024] TriCoLo: Trimodal Contrastive Loss for Text to Shape Retrieval

License

Notifications You must be signed in to change notification settings

3dlg-hcvc/tricolo

Repository files navigation

TriCoLo

PyTorch Lightning WandB

This repo is the official implementation for TriCoLo: Trimodal Contrastive Loss for Text to Shape Retrieval

(Paper) (Project Page)

Setup

Conda (recommended)

We recommend the use of miniconda to manage system dependencies.

# create and activate the conda environment
conda create -n tricolo python=3.10
conda activate tricolo

# install PyTorch 2.0.1
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia

# install Python libraries
pip install .

Pip (without conda)

# create and activate the virtual environment
virtualenv --no-download env
source env/bin/activate

# install PyTorch 2.0.1
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2

# install Python libraries
pip install .

Data Preparation

ShapeNet

Download ShapeNet, and place ShapeNetCore.v2 in the data/text2shape-data folder.

Text2Shape (Chair & Table)

  1. Download Text2Shape and place shapenet.json and processed_caption_{train/val/test}.p in the text2shape-data/chair_table folder.

  2. Download ShapeNet solid voxels (Chair & Table):

    cd text2shape-data
    mkdir chair_table
    cd chair_table
    wget http://text2shape.stanford.edu/dataset/shapenet/nrrd_256_filter_div_32_solid.zip
    wget http://text2shape.stanford.edu/dataset/shapenet/nrrd_256_filter_div_64_solid.zip
    wget http://text2shape.stanford.edu/dataset/shapenet/nrrd_256_filter_div_128_solid.zip
    unzip nrrd_256_filter_div_32_solid.zip
    unzip nrrd_256_filter_div_64_solid.zip
    unzip nrrd_256_filter_div_128_solid.zip

    Finally, the dataset files should be organized as follows:

    tricolo
    ├── data
    │   ├── preprocess_all_data.py
    │   ├── text2shape-data
    │   │   ├── ShapeNetCore.v2
    │   │   ├── chair_table
    │   │   │   ├── nrrd_256_filter_div_32_solid
    │   │   │   ├── nrrd_256_filter_div_64_solid
    │   │   │   ├── nrrd_256_filter_div_128_solid
    │   │   │   ├── processed_captions_train.p
    │   │   │   ├── processed_captions_val.p
    │   │   │   ├── processed_captions_test.p
    │   │   │   ├── shapenet.json
  3. Preprocess the dataset

    python data/preprocess_all_data.py data=text2shape_chair_table +cpu_workers={num_processes}
  4. Precache the CLIP embeddings (optional)

    python extract_clip_feats.py data=text2shape_chair_table data.image_size=224

Text2Shape (C13)

  1. Download Text2Shape C13.

Training, Inference and Evaluation

Note: Configuration files are managed by Hydra, you can easily add or override any configuration attributes by passing them as arguments.

# log in to WandB
wandb login

# train a model from scratch
# available voxel_encoder_name: SparseCNNEncoder, null
# available image_encoder_name: MVCNNEncoder, CLIPImageEncoder, null
# available text_encoder_name: BiGRUEncoder, CLIPTextEncoder
# available dataset_name: text2shape_chair_table, text2shape_c13
python train.py data={dataset_name} model.voxel_encoder={voxel_encoder_name} \
model.image_encoder={image_encoder_name} model.text_encoder={text_encoder_name} \
experiment_name={any_string}

# train a model from a checkpoint
python train.py data={dataset_name} model.voxel_encoder={voxel_encoder_name} \
model.image_encoder={image_encoder_name} model.text_encoder={text_encoder_name} \
experiment_name={checkpoint_experiment_name} ckpt_name={checkpoint_file_name}

# test a pretrained model
python test.py data={dataset_name} model.voxel_encoder={voxel_encoder_name} \
model.image_encoder={image_encoder_name} model.text_encoder={text_encoder_name} \
experiment_name={checkpoint_experiment_name} +ckpt_path={checkpoint_file_path}

# evaluate inference results
# currently unavailable

Checkpoints

Modality Dataset Split RR@1 RR@5 NDCG@5 Download
Tri(I+V) Text2Shape (Chair & Table) Val 12.60 33.34 23.30 chair_table_tri.ckpt
Bi(I) Text2Shape (Chair & Table) Val 11.67 30.63 21.49 chair_table_bi_i.ckpt
Bi(V) Text2Shape (Chair & Table) Val 9.33 27.52 18.62 chair_table_bi_v.ckpt
Tri(I+V) Text2Shape (C13) Val 12.96 34.87 24.19 c13_tri.ckpt
Bi(I) Text2Shape (C13) Val 11.89 33.48 22.96 c13_bi_i.ckpt
Bi(V) Text2Shape (C13) Val 9.73 29.24 19.69 c13_bi_v.ckpt

Acknowledgements

  1. ConVIRT: Our overall training framework is heavily based on the ConVIRT implementation. Paper
  2. MVCNN The MVCNN implementation we used is from this implementation. Paper
  3. Text2Shape: We download the dataset and modify the evaluation code from the original Text2Shape dataset. Paper

We thank the authors for their work and the implementations.