Skip to content

Code and dataset repository for Im-Promptu: In-Context Composition from Image Prompts

License

Notifications You must be signed in to change notification settings

jha-lab/impromptu

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

21 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

[NeurIPS 2023] Im-Promptu: In-Context Composition from Image Prompts

Python Version Conda PyTorch CUDA License

This is the official repository for the NeurIPS 2023 paper Im-Promptu: In-Context Composition from Image Prompts. Website: jha-lab.github.io/impromptu/.

Table of Contents

Environment setup

The following shell script creates an anaconda environment called "impromptu" and installs all the required packages.

source env_setup.sh

Datasets

The datasets can be downloaded from the following link. The datasets should be placed in the ./datasets directory. Detailed information about the benchmarks can be found in the ./benchmarks/README.md file.

Pixel Transformation

Solving analogies by simple transformation over the pixel space

$\hat{D} = C+ (B-A)$

A command instance to run the pixel transformation model

python3 learners/pixel.py --dataset shapes3d --batch_size 64 --data_path ./datasets/shapes3d/train.h5 --logs_dir ./logs_dir/ --phase val

The various arguments that can be passed to the script are:

--dataset = name of the dataset (options: shapes3d, clevr, bitmoji)

--batch_size = batch size for training

--data_path = path to the training data

--logs_dir = path to the directory where the logs will be stored

--phase = split of the dataset to evaluate on

Monolithic Model

Monolithic vector representation to solve visual analogies. Architecture laid out in ./learners/monolithic.py

Training

Training instance of a monolithic learner is given below:

cd train_scripts

python train_monolithic.py --epochs 100 --dataset shapes3d --data_path ../datasets/shapes3d/train.h5 --image_size 64 --seed 0 --d_model 192 --logs_dir ../logs_dir/

Hyperparameters can be tweaked as follows

--epochs = Training epochs

--dataset = Name of the dataset to spawn Dataset from ./utils/create_dataset.py

--d_model = Latent vector dimension

--image_size = Input image size

--lr_main = Peak learning rate

--lr_warmup_steps = Learning rate warmup steps for linear schedule

--data_path = Path to the dataset

--log_path = path to the directory where the logs will be stored

Patch Network

Patch abstractions to solve visual analogies. Architecture laid out in ./learners/patch_network.py

Training

cd train_scripts

python3 train_patch_network.py --batch_size 16 --dataset shapes3d --img_channels 3 --epochs 150 --data_path ./datasets/shapes3d/train.h5 --vocab_size 512 --image_size 64 --num_enc_heads 4 --num_enc_blocks 4 --num_dec_blocks 4 --num_heads 4 --seed 3

Additional hyperparameters are as follows

--vocab_size = Size of dVAE vocabulary

--num_dec_block = Number of decoder blocks

--num_enc_block = Number of context encoder blocks

--num_heads = Number of attention heads in the decoder

--num_enc_heads = Number of attention heads in the context encoder

Object Centric Learner (OCL)

Solving analogies by learning object-centric representations. Architecture laid out in ./learners/object_centric_learner.py

Training

cd train_scripts/

python train.py  --img_channels 3 --dataset shapes3d --batch_size 32 --epochs 150 --data_path ./datasets/shapes3d/train.h5 --vocab_size 512 --image_size 64 --num_iterations 3 --num_slots 3 --num_enc_heads 4 --num_enc_blocks 4 --num_dec_heads 4 --num_heads 4 --slate_encoder_path ./logs_dir_pretrain/SLATE/best_encoder.pt --lr_warmup_steps 15000 --seed 0 --log_path ./logs_dir/
--num_slots = Number of object slots per image

--num_iterations = Number of iterations for slot attention

--slate_encoder_path = Path to the pre-trained slate encoder

Sequential Prompter

Architecture laid out in ./learners/sequential_prompter.py

Training

cd train_scripts/

python train_prompt_.py  --img_channels 3 --epochs 150 --data_path ./datasets/shapes3d/train.h5 --vocab_size 512 --image_size 64 --num_iterations 3 --num_slots 3 --slate_encoder_path ./logs_dir_pretrain/shapes3d_SLATE/best_encoder.pt --seed 0

Cite this work

Cite our work using the following bitex entry:

@misc{dedhia2023impromptu,
      title={Im-Promptu: In-Context Composition from Image Prompts}, 
      author={Bhishma Dedhia and Michael Chang and Jake C. Snell and Thomas L. Griffiths and Niraj K. Jha},
      year={2023},
      eprint={2305.17262},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

License

The Clear BSD License Copyright (c) 2023, Bhishma Dedhia and Jha Lab. All rights reserved.

See License file for more details.

About

Code and dataset repository for Im-Promptu: In-Context Composition from Image Prompts

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published