Open source implementation of SeaRNN (ICLR 2018, https://openreview.net/forum?id=HkUR_y-RZ)
Switch branches/tags
Nothing to show
Clone or download
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Permalink
Failed to load latest commit information.
cython
datasets
evaluation
scripts
LICENSE.md
README.md
gru_cells.py
losses_utils.py
main_seq2seq.py
models.py
optimization.py
reference_policy.py
searnn.py
setup.py
tensor_utils.py
test_evaluation_NMT_BS.sh
train.py
utils.py

README.md

This directory contains the code we used for our SeaRNN ICLR 2018 paper

This code is an open-source (MIT) implementation of SeaRNN. It is rather sparsely documented, so you are welcome to ask us more details using issues.

Table of Contents

Installation

First, set up a virtualenv to install the dependencies of the project. The project uses Python 3 and was written with PyTorch 0.2 in mind, although it also works with PyTorch 0.3 for NMT. You can replace the version numbers in the following commands to suit your architecture.

virtualenv -p /usr/bin/python3.5 --system-site-packages /path/to/virtualenv
source /path/to/virtualenv/bin/activate
pip3 install http://download.pytorch.org/whl/cu80/torch-0.3.0.post4-cp35-cp35m-linux_x86_64.whl
pip3 install torchvision
pip3 install numpy --upgrade
pip3 install nltk
pip3 install cython

Second, download the code:

git clone git@github.com:RemiLeblond/SeaRNN-open.git
cd SeaRNN-open

Next, compile the cython files. This will probably send a few warnings which you can ignore.

python setup.py build_ext --inplace

Finally, download the data at http://www.di.ens.fr/sierra/research/SEARNN/ and preprocess it:

export DATA_ROOT=/path/to/data
scripts/prepare_iwlst14_de-en.sh

Running

Step 1: Train the model

OCR

python main_seq2seq.py --dataset ocr --dataroot ${DATA_ROOT} --rollin learned --rollout mixed --objective target-learning --log_path /path/to/save

NMT (the standard MLE training)

python main_seq2seq.py --dataset nmt --dataroot ${DATA_ROOT}/iwlst14_de-en_train_dev.train.pt --rollin gt --objective mle --log_path /path/to/save

Various parameters can be tuned, including the rollin and rollout policies, the objective etc. See main_seq2seq.py for a complete description.

Step 2: Evaluate.

python main_seq2seq.py --dataset nmt --dataroot /${DATA_ROOT}/iwlst14_de-en_train_test.train.pt --max_iter 0 --print_iter 1 --checkpoint_file /path/to/checkpoint_file.pth

The arguments must be coherent with those used for training the model (such as the size of the hidden state of the RNN, whether the encoder is bidirectional or not...), otherwise the model loading will break.

To reproduce the NMT experiments of the paper, see scripts/training.sh

Citation

@inproceedings{searnn2018leblond,
  author    = {Leblond, R\'emi and
               Alayrac, Jean-Baptiste and
               Osokin, Anton and
               Lacoste-Julien, Simon},
  title     = {SEARNN: Training RNNs with global-local losses},
  booktitle = {ICLR},
  year      = {2018},
}