Skip to content

Commit

Permalink
Merge pull request #45 from PAIR-code/patchscopes-code
Browse files Browse the repository at this point in the history
Patchscopes code
  • Loading branch information
iislucas committed Mar 10, 2024
2 parents f5714c6 + 56b9f46 commit 2b856df
Show file tree
Hide file tree
Showing 36 changed files with 130,244 additions and 5 deletions.
50 changes: 50 additions & 0 deletions patchscopes/code/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
## 🩺 Patchscopes: A Unifying Framework for Inspecting Hidden Representations of Language Models


### Overview
We propose a framework that decodes specific information from a representation within an LLM by “patching” it into the inference pass on a different prompt that has been designed to encourage the extraction of that information. A "Patchscope" is a configuration of our framework that can be viewed as an inspection tool geared towards a particular objective.

For example, this figure shows a simple Patchscope for decoding what is encoded in the representation of "CEO" in the source prompt (left). We patch a target prompt (right) comprised of few-shot demonstrations of token repetitions, which encourages decoding the token identity given a hidden representation.

[**[Paper]**](https://arxiv.org/abs/2401.06102) [**[Project Website]**](https://pair-code.github.io/interpretability/patchscopes/)

<p align="left"><img width="60%" src="images/patchscopes.png" /></p>

### 💾 Download textual data
The script is provided [**here**](download_the_pile_text_data.py). Use the following command to run it:
```python
python3 download_the_pile_text_data.py
```

### 🦙 For using Vicuna-13B
Run the following command for using the Vicuna 13b model (see also details [here](https://huggingface.co/CarperAI/stable-vicuna-13b-delta)):
```python
python3 apply_delta.py --base meta-llama/Llama-2-13b-hf --target ./stable-vicuna-13b --delta CarperAI/stable-vicuna-13b-delta
```

### 🧪 Experiments

#### (1) Next Token Prediction
The main code used appears [here](next_token_prediction.ipynb).
#### (2) Attribute Extraction
For this experiment, you should download the `preprocessed_data` directory.
The main code used appears [here](attribute_extraction.ipynb).
#### (3) Entity Processing
The main code used appears [here](entity_processing.ipynb). The dataset is available for downloading [here](https://github.com/AlexTMallen/adaptive-retrieval/blob/main/data/popQA.tsv).
#### (4) Cross-model Patching
The main code used appears [here](patch_cross_model.ipynb).
#### (5) Self-Correction in Multi-Hop Reasoning
For this experiment, you should download the `preprocessed_data` directory.
The main code used appears [here](multihop-CoT.ipynb). The code provided supports the Vicuna-13B model.

### 📙 BibTeX
```bibtex
@misc{ghandeharioun2024patchscopes,
title={Patchscopes: A Unifying Framework for Inspecting Hidden Representations of Language Models},
author={Ghandeharioun, Asma and Caciularu, Avi and Pearce, Adam and Dixon, Lucas and Geva, Mor},
year={2024},
eprint={2401.06102},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
51 changes: 51 additions & 0 deletions patchscopes/code/apply_delta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
Usage:
python3 apply_delta.py --base /path/to/model_weights/llama-13b --target stable-vicuna-13b --delta pvduy/stable-vicuna-13b-delta
The code was adopted from https://github.com/GanjinZero/RRHF/blob/main/apply_delta.py
"""
import argparse

import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM


def apply_delta(base_model_path, target_model_path, delta_path):
print("Loading base model")
base = AutoModelForCausalLM.from_pretrained(
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)

print("Loading delta")
delta = AutoModelForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)

DEFAULT_PAD_TOKEN = "[PAD]"
base_tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast=False)
num_new_tokens = base_tokenizer.add_special_tokens(dict(pad_token=DEFAULT_PAD_TOKEN))

base.resize_token_embeddings(len(base_tokenizer))
input_embeddings = base.get_input_embeddings().weight.data
output_embeddings = base.get_output_embeddings().weight.data
input_embeddings[-num_new_tokens:] = 0
output_embeddings[-num_new_tokens:] = 0

print("Applying delta")
for name, param in tqdm(base.state_dict().items(), desc="Applying delta"):
assert name in delta.state_dict()
param.data += delta.state_dict()[name]

print("Saving target model")
base.save_pretrained(target_model_path)
delta_tokenizer.save_pretrained(target_model_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-model-path", type=str, required=True)
parser.add_argument("--target-model-path", type=str, required=True)
parser.add_argument("--delta-path", type=str, required=True)

args = parser.parse_args()

apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
Loading

0 comments on commit 2b856df

Please sign in to comment.