This is an implementation of Attention-over-Attention Model Using Transformer Encoder with PyTorch. THe origin mod model was proposed by Cui et al. (paper).
- PyTorch with cuda
- Python 3.6+
- NLTK (with punkt data)
This implementation uses cnn data.
Make sure the data files (train.txt, dev.txt, test.txt) are present in the data/cnn
directory.
To preprocess the data:
python preprocess_cnn.py
This will generate the dictonary(dict.pt
) from all words appeared in the dataset and
vectorize all data (train.txt.pt
, dev.txt.pt
, test.txt.pt
).
Below is an example of training a model, set the parameters as you like.
python train.py -traindata data/cnn/train.txt.pt -validdata data/cnn/dev.txt.pt -dict data/cnn/dict.pt \
-save_model model_cnn -hidden_size 384 -embed_size 384 -batch_size 64 -dropout 0.1 \
-epochs 13 -learning_rate 0.001 -weight_decay 0.0001 -gpu 0 -log_interval 50
After each epoch, a checkpoint will be saved, to resume a training process from checkpoint:
python train.py -train_from model_cnn/xxx_model_xxx_epoch_x.pt
python test.py -testdata data/cnn/test.txt.pt -dict data/cnn/dict.pt -out result.txt -model model_cnn/xx_checkpoint_epochxx.pt