a deep recurrent model for exchangeable data
Switch branches/tags
Nothing to show
Clone or download
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Permalink
Type Name Latest commit message Commit time
Failed to load latest commit information.
config_conditional
config_failed
config_rnn
tests_gp_tp
LICENSE
README.md
data_iter.py
data_iter_conditional.py
logger.py
nn_extra_gauss.py
nn_extra_nvp.py
nn_extra_nvp_conditional.py
nn_extra_student.py
utils.py
utils_conditional.py

README.md

BRUNO: A Deep Recurrent Model for Exchangeable Data

This is an official code for reproducing the main results from our NIPS'18 paper:

I. Korshunova, J. Degrave, F. Huszár, Y. Gal, A. Gretton, J. Dambre
BRUNO: A Deep Recurrent Model for Exchangeable Data
arxiv.org/abs/1802.07535

and from our NIPS'18 Bayesian Deep Learning workshop paper:

I. Korshunova, Y. Gal, J. Dambre, A. Gretton
Conditional BRUNO: A Deep Recurrent Process for Exchangeable Labelled Data

Requirements

The code was used with the following settings:

  • python3
  • tensorflow-gpu==1.7.0
  • scikit-image==0.13.1

Datasets

Below we list files for every dataset that should be stored in a data/ directory inside a project folder.

MNIST

Download from yann.lecun.com/exdb/mnist/

 data/train-images-idx3-ubyte.gz
 data/train-labels-idx1-ubyte.gz
 data/t10k-images-idx3-ubyte.gz
 data/t10k-labels-idx1-ubyte.gz

Fashion MNIST

Download from github.com/zalandoresearch/fashion-mnist

data/fashion_mnist/train-images-idx3-ubyte.gz
data/fashion_mnist/train-labels-idx1-ubyte.gz
data/fashion_mnist/t10k-images-idx3-ubyte.gz
data/fashion_mnist/t10k-labels-idx1-ubyte.gz

Omniglot

Download and unzip files from github.com/brendenlake/omniglot/tree/master/python

data/images_background
data/images_evaluation

Download .pkl files from github.com/renmengye/few-shot-ssl-public#omniglot. These are used to make train-test-validation split.

data/train_vinyals_aug90.pkl
data/test_vinyals_aug90.pkl
data/val_vinyals_aug90.pkl

Run utils.py to preprocess Omniglot images

data/omniglot_x_train.npy
data/omniglot_y_train.npy
data/omniglot_x_test.npy
data/omniglot_y_test.npy
data/omniglot_valid_classes.npy

CIFAR-10

This dataset will be downloaded directly with the first call to CIFAR-10 models.

data/cifar/cifar-10-batches-py

Training and testing

There are configuration files in config_rnn for every model we used in the paper and a bunch of testing scripts. Below are examples on how to train and test Omniglot models.

Training (supports multiple gpus)

CUDA_VISIBLE_DEVICES=0,1 python3 -m config_rnn.train  --config_name bn2_omniglot_tp --nr_gpu 2

Fine-tuning (to be used on one gpu only)

CUDA_VISIBLE_DEVICES=0 python3 -m config_rnn.train_finetune  --config_name bn2_omniglot_tp_ft_1s_20w

Generating samples

CUDA_VISIBLE_DEVICES=0 python3 -m config_rnn.test_samples  --config_name bn2_omniglot_tp_ft_1s_20w

Few-shot classification

CUDA_VISIBLE_DEVICES=0 python3 -m config_rnn.test_few_shot_omniglot  --config_name bn2_omniglot_tp --seq_len 2 --batch_size 20
CUDA_VISIBLE_DEVICES=0 python3 -m config_rnn.test_few_shot_omniglot  --config_name bn2_omniglot_tp_ft_1s_20w --seq_len 2 --batch_size 20

Here, batch_size = k and seq_len = n + 1 to test the model in a k-way, n-shot setting.

Notes

The code is far from being commented, so if you have any questions or you find a bug, please send me an email at irene.korshunova@gmail.com