Official repository for the paper HINT: Hypernetwork Instruction Tuning for Efficient Few- and Zero-Shot Generalisation. We introduce a model that learns to generate adaptation parameters from instructions, allowing greatly improved inference efficiency while retaining performance. We find the model is especially effective in few-shot scenarios. The overall architecture looks like:
This codebase is based on t5x and has only been tested on TPUs.
If you're looking for the T0 replications we trained as part of this work, they can be found here:
If you find our paper or models useful please cite us:
@article{hint,
author = {Hamish Ivison and Bhagia, Akshita and Wang, Yizhong and Hajishirzi, Hannaneh and Peters, Matthew},
title = {HINT: Hypernetwork Instruction Tuning for Efficient Zero-Shot Generalisation},
journal = {ACL},
url = {https://arxiv.org/abs/2212.10315},
year = {2023}
}
As this runs on t5x, you'll need a gcloud bucket you can write and read from. I recommend this bucket is in the same region as your TPUs to avoid large ingress/egress costs.
I also recommend reading the T5X documentation before working with this repo. I will assume some familiarity with gin
, seqio
, etc. for the rest of this readme. If you have never worked with TPUs before, reading the TPU guide at the bottom of this readme is a good idea!
I also based some of this codebase, especially the T0/P3 parts, off the BigScience architecture objective repository.
Important: You'll need to replace BUCKET_NAME
in the scripts with your own google bucket, and GOOGLE_BUCKET_PATH
in hyper_task_descriptions/utils.py
to your own bucket too. I recommend making sure the bucket is in the same region as the TPUs you are using to lower egress/ingress costs in google cloud. You should make sure to cache your data to the same bucket too (see "Data Preprocessing" below).
pip install -e .[dev]
should to install basic dependencies. Note there's some annoyances with seqio
: please uninstall seqio
and seqio-nightly
(which are probably installed by the above) and install install my fork: https://github.com/hamishivi/seqio
, which contains a fix.
Some artefacts are used during training that might be useful to cache ahead of time. You can cache these as follows:
- Tokenizers:
python3 -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('t5-base'); AutoTokenizer.from_pretrained('roberta-base')"
- Roberta model:
python3 -c "from transformers import FlaxRobertaModel; FlaxRobertaModel.from_pretrained('hamishivi/fixed-roberta-base')"
- Data splits (only do if you want to run with all datasets):
cd hyper-task-descriptions; TRANSFORMERS_OFFLINE=1 python3 -c "import hyper_task_descriptions.seqio_tasks.all_t0_tasks"
See scripts/tpu_setup.sh
for an example of setting up this codebase to run on a TPU. Local installation should be similar minus the TPU-specific Jax version.
Run scripts/local/local.sh
for a small model + small subset of T0 data that is useful for local development.
Due to the large size of P3, I recommend preprocessing data before training. You can do this with seqio_cache_tasks
as such:
seqio_cache_tasks \
--tasks="<task regex>" \
--output_cache_dir=gs://<output bucket folder> \
--module_import=hyper_task_descriptions.seqio_tasks.all_t0_tasks \
--min_shards 1 \
--alsologtostderr
Check out seqio for more on this tool. Some notes for this repo:
- Run this on a machine with a lot of RAM, as it can be quite memory-hungry when preprocessing some of the larger T0 tasks.
- I found using 1 process per P3 prompt/task worked best, as trying to do all the tasks at once was very slow.
- For processing pretraining data such as C4, it might be useful to try and use google dataflow instead. Unfortunately, I don't have the commands I used for preprocessing C4 with dataflow around anymore.
The seqio mixtures are defined in the following places:
seqio_tasks/all_t0_tasks.py
definest0_train
andt0_eval_score_eval
, the train and test mixtures for T0. It also defines relevant other T0 mixtures.seqio_tasks/my_t0_tasks.py
defines 1 through 5 shot T0 mixtures -t0_train_{shot}_shot
andt0_eval_score_eval_{shot}_shot
.ni_tasks/ni_registry
defines various supernatual instructions mixtures.natural_instructions
is the Def-only setting for HINT.natural_instruction_positive_example_hyper_1
andnatural_instruction_positive_example_hyper_2
are the 1 and 2-shot settings for HINT respectively (i.e., hypernetwork gets the definition + few-shot examples).natural_instructions_def
,natural_instructions_def_pos_1
, andnatural_instructions_def_pos_2
are the 0, 1, 2-shot settings for the baseline T5 model respectively.c4/c4_registry
definesc4_pretrain
, the pretraining mixture for the hypernetwork.
As a fun bonus, numeric_task/numeric_registry.py
defines some fun tiny tasks that I used for designing the initial architecture.
Once your data is preprocessed and ready to go, you can train! We manage our configs with gin
and commands for training can be found in scripts
. To replicate the HINT model on Super-Natural Instructions:
- Run
pretraining/pretrain.sh <model_name>
.<model_name>
will be the folder name the model is saved to in your google bucket. Let this run for 10,000 steps. - Run
nat_int/ni_train_pretrained.sh <model_name> <pretrained_model_name> checkpoint_<checkpoint_step> <total_train_steps>
, but replace the value ofINITIAL_CHECKPOINT_PATH
with the model you just pretrained. If you are following HINT, you should usecheckpoint_1110000
for the checkpoint steps and1111000
for the total train steps.
For P3, run train_from_pretrained.sh <model_name> <pretrained_model_name> checkpoint_<checkpoint_step> <total_train_steps>
instead (with number of total train steps adjusted accordingly). You should run TRANSFORMERS_OFFLINE=1 python3 -c "import hyper_task_descriptions.seqio_tasks.all_t0_tasks"
before doing P3 training to cache the split information for P3 (you can also alter the script to run this when you run the training script, but since it takes 30min I think it's worth doing as a separate step). You only have to cache these splits once (for a given TPU).
The ni_train
scripts will evaluate the model right after. You can find the evaluation numbers in the eval
folder in the model checkpoint folders (in your google bucket), or by looking at the tensorboard logs (which you can view by launching tensorboard --logdir <your_model_folder>
- note it can be on any machine with access to your google bucket).
If you want to eval on supernatural instructions separately, run nat_int/ni_eval.sh
with the settings changed such that they are identical to the model you are evaluating (this is very important!).
To evaluate on P3, run eval/t0_eval.sh
. Again, you'll have to adjust the script to make sure it matches the model configuration (basically, just include the same gins in the same order, and you should be fine).
While we have many gin
configs, I will highlight the most important ones:
hyper_<size>.gin
defines the HINT model, with various sizes available (the same sizes as T5 models).instruction_embed.gin
defines how the instructions are coded into the underlying model following HINT. We experimented with many different methods of doing this (lora, adapters, prefixes, prompts...) which are all separably configurable!partial_train_adafactor_dual.gin
defines the adafactor optimizer, and sets the values for the hypernetwork and underlying t5 so that we use the existing optimizer state for the T5 model, but correctly initialize the hypernetwork optimizer states. We use this for pretraining and finetuning.pretrain.gin
defines the pretraining mixture and settings (input/output lengths, etc).t0_train.gin
defines the mixture and settings for training on P3 (T0 train set).ni_train.gin
defines the mixture and settings for training on Super-Natural Instructions.full_restore.gin
overrides some settings to allow loading a pretrained hypernetwork model.restore_pretained.gin
resets hypernetwork states during loading so that we can load from a pretrained T5 model (without hypernetwork weights). This is included by default in most scripts, and can be override by includingfull_restore.gin
as the final gin config.
We also explored other models, including Hypter and HyperTune, which have configs I have not tested to make sure they work on the current codebase. However, they might be interesting to look at for reference.
I only have tested the HINT model pretraining and finetuning on SNI and T0. I can't guarantee all mixes of options work. If you have a question or issue, please lodge it in the issues and ping me (@hamishivi).
This was my first project working with TPUs, so below is the steps I worked out while using them. If you are experienced with TPU pods you probably don't need to read this.
See below for instructions on setting up and running on TPUs. Run scripts/tpu_setup.sh
to setup the TPU. Note this takes ~30min to run (the T0 tasks require download split information for all datasets, which takes some time).
Running on TPU slices is a bit of a pain as you cannot 'just ssh to the machine and run stuff directly'. Rather, you send commands to all the TPUs, and Jax/Flax/t5x works it all out for you in the background. Here's a rough guide of useful steps for setting up and running directly on a TPU.
First, create your tpu! I'll leave aside some of the details but you can do this with a command like:
gcloud alpha compute tpus tpu-vm create <name> --accelerator-type=<tpu-version> --zone=<zone> --project=<project> --version=<software-version>
You can get some details on TPU architectures here and software versions here. Note that the software version shouldn't matter that much for us, since we're using jax
.
Next, we need to setup our tpu. Note that if you're using a single TPU (i.e. a v3-8
), you can ssh directly to the TPU vm with gcloud alpha compute tpus tpu-vm ssh <tpu-name> --zone=<zone> --project=<project>
and do all your setup directly on the box. However, anything larger requires managing multiple machines and sending shell commands using a command like gcloud alpha compute tpus tpu-vm ssh <tpu-name> --zone=<zone> --project=<project> --worker=all --command="<bash commands>"
. If even one machine fails when you run the command, try again - you need them all basically to be setup the same way. Please run all the below commands in place of <bash commands>
if you are running on a TPU slice. Otherwise, just ssh'ing directly into the TPU should work fine.
To setup our TPUs, just clone this repo and run scripts/tpu_setup.sh
as such:
git clone https://github.com/allenai/hyper-task-descriptions.git; ./hyper-task-descriptions/scripts/tpu_setup.sh
Refer to comments in that script in case something fails (although fingers crossed nothing does!). Note that this is currently a private repo so you'll probably have to use a github authentication token and alter the url accordingly.
Then we can run our model with the following:
cd hyper-task-descriptions; ./scripts/<script-name>
This will run the given script on all TPUs. Note this will run the script on all TPUs at once, so you will see a lot of output being logged. If one TPU errors the rest will continue to run, so cancel the command (control+C) and before rerunning follow the cleanup steps below.
scripts/t0_train.sh
is for training a model and scripts/t0_eval.sh
for evaluating. I recommend looking at those scripts and the accompanying T5X commands (tain/eval) to understand the options.
To deal with long-running jobs, where you don't want to have your computer constantly connected with the TPU, I tended to just run the process in the background and log all output:
cd hyper-task-descriptions; ./scripts/<script-name> &> log.log &
However, there are probably also other solutions that work well too!
When rerunning code on the TPUs, you need to make sure there are no processes using a core on any TPU, otherwise things wont work (your code will run on CPU). A useful one-liner for finding and killing TPU-using processes is:
sudo lsof -t /dev/accel0 | xargs kill -9
Sometimes there can be random lockfiles that hold TPUs too if you aborted a Jax program early (see here). Remove them with
sudo rm -f /tmp/libtpu_lockfile; sudo rm -rf /tmp/tpu_logs
Note I haven't had to run this step.
Sometimes you might want to run or hotfix some script or file in this repo after scaling up to multiple TPUs. To do so, after everything is setup, create the new/edited file locally. Then copy it over to the TPUs with:
gcloud alpha compute tpus tpu-vm scp <new-file> <tpu-name>: --zone=<zone> --project=<project> --worker=all
You might be able to specify the destination path after but it didn't work when I tried it. Instead, I just use a mv
command to move the file where it should go.
Note: you might get weird ssh
errors when running the above scp
command. If you do, run ssh-add .../.ssh/google_compute_engine
like the error probably suggests and rerun the command right after. Sometimes this takes a few tries before it runs without error, although so long as the new file ends up on all TPUs you're good to go.
After moving the file to where it should go, you're done! Run the script or rerun your model or whatever you need to do. Alternatively, I also recommend just setting up a git branch and using git pull
to update scripts, since 'scp-ing' and managing multiple filesystems at once can be a pain.