This code accompanies the paper "Probing Neural Dialog Models for Conversational Understanding" by Saleh, et al., 2020.
This repo is built on top of ParlAI. We add functionality for probing open-domain dialog models (RNNs and Transformers). Probing evaluates the quality of internal model representations for conversational skills.
Follow same installation instructions as ParlAI. ParlAI requires Python 3 and PyTorch 1.1 or newer. After cloning this repo, remember to run
python setup.py develop
You will also need to install skorch 0.6 which is required by the probing classifier.
This section takes you through an example of how you would train and probe a dialog model.
-
You will first need a model to probe. Let's train a small RNN on the DailyDialog dataset:
python examples/train_model.py -t dailydialog -m seq2seq --bidirectional true --numlayers 2 --hiddensize 256 --embeddingsize 128 -eps 60 -veps 1 -vp 10 -bs 32 --optimizer adam --lr-scheduler invsqrt -lr 0.005 --dropout 0.3 --warmup-updates 4000 -tr 300 -mf trained/dailydialog/seq2seq/seq2seq --display-examples True -ltim 30 --tensorboard_log True --validation-metric ppl
To train on perturbed (i.e. shuffled) data, add the flag
-sh within
. See ParlAI's documentation for more information about training dialog models. -
You will then generate and save the vector representations to be used as features by the probing classifier. Let's generate and save the
encoder_state
vectors for the TREC question classification task:python probing/probe_model.py -mf trained/dailydialog/seq2seq/seq2seq -t probing.tasks.trecquestion.agents --probe encoder_state
This will automatically download the required task data and save the generated representations at
trained/dailydialog/seq2seq/probing/encoder_state/trecquestion
.There are three types of hidden representations you can extract and probe:
encoder_state
word_embeddings
combined
Refer to the paper for more info.
-
Now you can run the probing classifier to evaluate the quality of the generated representations by running:
python probing/eval_probing.py -m trained/dailydialog/seq2seq -t trecquestion --probing-module encoder_state --max_epochs 50 --runs 30
This trains the probing classifier (an MLP) on the generated representations. The final results are saved at
trained/dailydialog/seq2seq/probing/encoder_state/trecquestion/results.json
.
-
You might also want to generate the GloVe word embedding baselines. You can do this by running:
python probing/glove.py -t trecquestion --dict-path trained\dailydialog\seq2seq\seq2seq.dict
This will automatically download the GloVe embeddings and save the generated representations to
trained/GloVe/probing/trecquestion
-
Now you need to run the probing classifier on these generated representations using:
python probing/eval_probing.py -m GloVe -t trecquestion --max_epochs 50 --runs 30
The final results are saved at
trained/GloVe/probing/trecquestion/results.json
.
Important note: It's strongly recommended you use the same directories as in the examples to save your checkpoints and data, otherwise other parts of the code will break.
The supported probing tasks are:
- trecquestion
- multiwoz
- sgd
- dialoguenli
- wnli
- snips
- scenariosa
- dailydialog_topic
New probing tasks need to be in the following format:
text: <utterance1> \n
<utterance2> \t episode_done:True \n
text: <utterance1> \n
<utterance2> \n
<utterance3> \t episode_done:True \n
...
New probing tasks need to be added to probing/tasks and glove.py.
The code is best suited for tasks where the label is based on:
- all utterances in a dialog (like DailyDialog Topic)
- or the interaction between two utterances (like DialogueNLI)
- or the last utterance in a dialog (like ScenarioSA)
See section 3.2 in the paper for more info.
Most of the code for this study exists within the probing directory.
We also augmented torch_generator_agent.py with probing functions that extract a model's internal representations.
If you use this code please cite our paper:
@article{saleh2020probing,
author= {Saleh, Abdelrhman and Deutsch, Tovly and Casper, Stephen and Belinkov, Yonatan and Shieber, Stuart},
title= {Probing Neural Dialog Models for Conversational Understanding},
journal= {Second Workshop on NLP for Conversational AI},
year= {2020}
}