This repository contains code for the paper "Structured Voronoi Sampling" published in NeurIPS 2023.
To build the vocabulary of models, train the probing classifiers, and evaluate them make sure that
data is downloaded in datasets
directory (or change --data_dir
argument in the scripts).
To reproduce the plots and results on the toy model, run ToyExperiment.ipynb
notebook.
There are two types of classifiers that are used in this paper. We explain what commands to use
to train each below. Code for training, evaluation, and model architectures are in control
directory.
- probing classifiers that are used to guide the generating process. We train bidirectional LSTM layers on tor of GPT-2 models, with the command below:
python control/train.py --model RNNProbe --task food --base_model_str gpt2 --save_dir [CKPT_DIR] --save_name [CKPT_NAME]
- evaluator classifier that are used to evaluate the quality of the generated text. We finetune a roberta model, with the command below:
python control/train.py --model EVAL --task food --base_model_str roberta-base --save_dir [CKPT_DIR] --save_name [CKPT_NAME]
Use --task food
for topic control and --task sentiment
for sentiment control.
This repository includes re-implementations of MuCoLa and implementation of SVS.
For generating text from the LM without enforcing any control:
python generate.py --method mucola --save_dir [SAVE_DIR] --save_name [SAVE_NAME] --step_size 0.1 --steps 500
For generating text from LM with topic control (food
dataset):
python generate.py --method mucola --task food --save_dir [SAVE_DIR] --save_name [SAVE_NAME] --step_size 0.1 --c_factor 2. --steps 500 --controlled
For generating text from LM with sentiment control:
python generate.py --method mucola --g_ckpt gpt2-large --c_ckpt control/ckpts/sst2-probe-large --task sentiment --save_dir [SAVE_DIR] --save_name [SAVE_NAME] --step_size 0.6 --c_factor 1.5 --steps 500 --controlled
This method uses Voronoi measure with 1-step HMC algorithm, to apply it for topic control, use the follwoing command:
python generate.py --method hmc --task food --save_dir [SAVE_DIR] --save_name [SAVE_NAME] --step_size 1.5 --c_factor 1.5 --steps 100 --controlled
For generating text from LM with topic control (food
dataset):
python generate.py --method svs --task food --save_dir [SAVE_DIR] --save_name [SAVE_NAME] --step_size 1.5 --c_factor 1.5 --steps 100 --controlled