Skip to content

UKPLab/CATfOOD

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

33 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CATfOOD: Counterfactual Augmented Training for Improving Out-of-Domain Performance and Calibration

This project aims at improving the generalization and calibration of small language models through through the use of counterfactual data generated by large language models.

Abstract: In recent years, large language models (LLMs) have shown remarkable capabilities at scale, particularly at generating text conditioned on a prompt. In our work, we investigate the use of LLMs to augment training data of small language models~(SLMs) with automatically generated counterfactual~(CF) instances -- i.e. minimally altered inputs -- in order to improve out-of-domain~(OOD) performance of SLMs in the extractive question answering~(QA) setup. We show that, across various LLM generators, such data augmentation consistently enhances OOD performance and improves model calibration for both confidence-based and rationale-augmented calibrator models. Furthermore, these performance improvements correlate with higher diversity of CF instances in terms of their surface form and semantic content. Finally, we show that CF augmented models which are easier to calibrate also exhibit much lower entropy when assigning importance, indicating that rationale-augmented calibrators prefer concise explanations.

File Structure

  • 📁 CATfOOD
    • 📁 src

      • 📁 calibration # contains code for calibrating models
      • 📁 cf_generation # contains code for counterfactual generation
      • 📁 faithfulness # contains code for evaluation explanation quality
    • 📁 scripts # scripts to run all experiments

    • 📄 README.md # Project documentation and usage instructions

    • 📄 requirements.txt # List of project dependencies

    • 📄 LICENSE # License file for the project

    • 📄 .gitignore # Git ignore file

Usage

Counterfactual Generation:

  • For LLAMA, run script in scripts folder
bash scripts/run_llama.sh
  • For GPT models
bash scripts/run_gpt.sh
bash scripts/run_gpt_neox.sh
  • For FLAN models
bash scripts/run_flan.sh

Calibration:

  • For training the calibrator model, use the following command:
python src/calibration/baseline/modelling.py --dataset squad --train_size 500 --do_maxprob  # conf baseline
python src/calibration/baseline/modelling.py --dataset squad --train_size 500 --method shap --arg_n_tree 300 --arg_max_depth 20  # SHAPCal

Contact

Disclaimer

NOTE This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication.

Citation

Please use the following citation:

@misc{sachdeva2023catfood,
      title={CATfOOD: Counterfactual Augmented Training for Improving Out-of-Domain Performance and Calibration}, 
      author={Rachneet Sachdeva and Martin Tutek and Iryna Gurevych},
      year={2023},
      eprint={2309.07822},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

About

Enhancing small language models with LLM generated counterfactuals.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published