Skip to content

adityabingi/Slot-Attention

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Slot-Attention

This is a minimalistic PyTorch implementation of Object-Centric Learning with Slot Attention for Tetrominoes dataset.

Update:

Added functionality for Implicit slot-attention method as proposed in paper Object Representations as Fixed Points: Training Iterative Refinement Algorithms with Implicit Differentiation

Training Results

Below are the slot visualizations of sample test_images after training for 100k steps

Left to right: Input, reconstruction and 4 slots are visualized for each sample test_image.

Slot visualization after training slot attention model for 100k steps

Data

This repo trained Slot Attention only on Tetrominoes dataset which is part of Google Multi-Object Datasets and is available as TFRecords here: https://github.com/deepmind/multi_object_datasets.

This code uses only h5py versions of this dataset that @pemami4911 created by modifying the TFRecords; which are available here Dropbox link for download. ( tetrominoes.h5 and tetrominoes_test.h5 )

Usage

Training with default hyperparameters (for hyperparams check argparser in main.py) for Tetrominoes Dataset

python main.py --train

For training implicit slot-attention as proposed in Object Representations as Fixed Points: Training Iterative Refinement Algorithms with Implicit Differentiation

python main.py --train --use_implicit_grads

For testing on validation data

python main.py --test

For visualizing slots of sample images

python main.py --visualize

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages