This repository contains code and data for EACL 2024 Findings paper Hierarchical and Dynamic Prompt Compression for Efficient Zero-shot API Usage by Yichen Jiang, Marco Del Vecchio, Mohit Bansal, and Anders Johannsen.
This codebase is adapted from the orignal gist-token paper. We list their recommended setup below:
-
This codebase has been tested with python 3.9.16 and pytorch 2.0.0. I recommend creating a new virtual env (e.g. with Conda), then installing torch manually from
pytorch.org
.pip install -r requirements.txt
should take care of the remaining dependencies. -
Of note is that this codebase requires a quite recent version of Transformers that has support for LLaMA. The specific commit pinned in
requirements.txt
is the one that was tested; any Transformers release newer than that should work (there may be some naming issues with a newer version). -
Some issues that may occur are discussed at the end of this document.
-
Experiment runs and model checkpoints are saved to
exp/
directory in root directory.- You can change it in
src/conf/config.yaml
, training.output_dir.
- You can change it in
-
Cached models (downloaded from the Huggingface Hub) and datasets are saved to
.cache/
. Be sure to create these directories before running for the first time.- You can change it in
src/conf/config.yaml
, model.cache_dir.
- You can change it in
-
LLaMA-7B experiments expect a folder called
llama-7b
in the root directory with model weights and tokenizer. You can manually set the model path of the LLaMA model fromsrc/conf/mode/llama-7b.yaml
. -
Training logs are saved via wandb. Set your
wandb
entity name correctly insrc/conf/config.yaml
.
In this work, we use the SGD and SGD-X datasets for training and evaluating our models.
- Download the datasets by cloning the entire repo
and put it under
./data/raw_data
. - Run the data preprocessing script to convert the raw data to d3st format for training HD-Gist. For more details about the datasets and formats, refer to the data README:
cd data/raw_data/;
git clone https://github.com/google-research-datasets/dstc8-schema-guided-dialogue.git;
python3 -m sgd_x.generate_sgdx_dialogues;
cd ../..;
python convert_d3st_format.py \
--dataset_name=sgd \
--add_prompt=v1 \
--use_index \
--json_instruction;
python convert_d3st_format.py \
--dataset_name=sgd_x \
--add_prompt=v1 \
--use_index \
--json_instruction
- To also process the data that can run the Gist baseline/LLaMA baseline with reconstruction, run
python convert_d3st_format.py \
--dataset_name=sgd \
--add_prompt=v1 \
--use_index;
python convert_d3st_format.py \
--dataset_name=sgd_x \
--add_prompt=v1 \
--use_index
Warning: Training and decoding HD-Gist model is currently only supported for
batch_size = 1
. For LLaMA-7B, larger batch sizes will require modifying the rotary position embedings to account for gist offsets here and other functions that create the gist masks here.
If you'd like to train an HD-Gist models, the command
./train_scripts/debug_hd_gist_llama_small_sgd.sh
trains a small llama-style model on the SGD training dataset.
-
The model is randomly initialized,
-
It has 2 layers, 32 heads, 1024 intermediate size and 4096 hidden size)
-
It has static slot-gist tokens, dynamic value-gist tokens and reconstruction of API from these HD-Gist tokens, while logging to wandb.
-
You can disable dynamic value-gist token by setting
training.gist.add_ctg_val_gist_token=False
. -
You can disable reconstruction of API from HD-Gist tokens, by setting
training.gist.inbatch_reconstruct_ratio=0
.
Note: If you're not familiar with the CLI syntax, check out Hydra.
To finetune the larger models in the paper (LLaMA-7B), multi-gpu training is required with DeepSpeed. The experiments below all assume a machine with 4 A100 80GB GPUs or 8 A100 40GB GPUs, and at least 400GB of CPU RAM. Other machine configurations will necessitate changing the batch size and/or deepspeed config setting.
To finetune a LLaMA-7b with HD-Gist tokens and reconstruction of API documentation from HD-Gist tokens, for 1 epoch, you can run
./train_deepspeed_scripts/train_hd_gist_llama7b_sgd.sh
- This trains
llama-7b
with the same HD-Gist configuration as the smaller llama-debug command above, using the hyperparameters in the paper. Seesrc/conf/{llama-7b}.yaml
for the hyperparameter configurations.
-
Be sure to set your
wandb
entity name correctly insrc/conf/config.yaml
. -
By default this logs an experiment to wandb under a group name that begins with
wandb.tag
(i.e. in the example above,yourgroupname
); check outsrc/conf/config.yaml
to see the full group name. Metrics are also logged to stdout.
- The wandb group and run names also define a directory which will save model checkpoints and outputs locally.
- By default it is saved at
exp/{wandb.group}/{wandb.run}
.
To evaluate the trained HD-Gist model, you can run
./eval_deepspeed_scripts/eval_hd_gist_llama7b_sgd.sh
This will decode the HD-Gist model on SGD validation set.
- To decode on SGD-X/v5 validation set, set
data.config_name=x5_d3st_prompt+date_jsonInstruct
. - To decode on test sets, set
--training.eval_on_test=true
.
The SGD and SGD-X data is licensed "Creative Commons Attribution Share Alike 4.0 International". By training and evaluating on the data, you inherit both licenses.
Issue 1: PydanticUserError: If you use
@root_validator
with pre=False (the default) you MUST specifyskip_on_failure=True
- Solution from StackOverflow.
Issue 2: NotImplementedError: Loading a dataset cached in a LocalFileSystem is not supported.
- Solution from StackOverflow downgrades fsspec from 2023.10.0 to 2023.9.2.
Issue 3: "zsh: illegal hardware instruction" when running model training. "
- Solution from StackOverflow.
If you found this work useful, please cite
@inproceedings{jiang2024hdgist,
title={Hierarchical and Dynamic Prompt Compression for Efficient Zero-shot API Usage},
author={Yichen Jiang and Marco Del Vecchio and Mohit Bansal and Anders Johannsen},
year={2024},
booktitle={Findings of EACL},
}