## A step-by-step guide of training ReFT with TinyLlama

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/stanfordnlp/pyreft/blob/main/main_demo.ipynb)

### Training an 😀 Emoji-Chatbot ([live demo](https://huggingface.co/spaces/pyvene/reft_emoji_chat)) with ReFT in under 10 seconds!

<kbd>
<img src="https://github.com/stanfordnlp/pyreft/assets/15223704/580d6cfd-4c3c-49a7-bc9f-1f9cc9a5aee7" width="400"/>
</kbd>

In [1]:
try:
    # This library is our indicator that the required installs
    # need to be done.
    import pyreft

except ModuleNotFoundError:
    !pip install git+https://github.com/stanfordnlp/pyreft.git

Collecting git+https://github.com/stanfordnlp/pyreft.git
  Cloning https://github.com/stanfordnlp/pyreft.git to /tmp/user/23992/pip-req-build-36ge6t_w
  Running command git clone --filter=blob:none --quiet https://github.com/stanfordnlp/pyreft.git /tmp/user/23992/pip-req-build-36ge6t_w
  Resolved https://github.com/stanfordnlp/pyreft.git to commit dafd0995a366d7b47160a337dcc388eda7431821
  Running command git submodule update --init --recursive -q
  Preparing metadata (setup.py) ... [?25ldone
Collecting pyvene>=0.1.7 (from pyreft==0.1.0)
  Downloading pyvene-0.1.7-py3-none-any.whl.metadata (4.6 kB)
Collecting transformers>=4.48.2 (from pyreft==0.1.0)
  Downloading transformers-4.49.0-py3-none-any.whl.metadata (44 kB)
Collecting matplotlib>=3.7.4 (from pyreft==0.1.0)
  Downloading matplotlib-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting ipywidgets>=8.1.1 (from pyreft==0.1.0)
  Downloading ipywidgets-8.1.5-py3-none-any.whl.metadata (2.3 kB

### Step 1: loading the raw LM you want to train with ReFT.
We first load in any model we want to gain controls over:

In [4]:
import torch, transformers, pyreft
device = "cuda"

prompt_no_input_template = """\n<|user|>:%s</s>\n<|assistant|>:"""

model = transformers.GPT2Model.from_pretrained('gpt2').to(device)

# get tokenizer
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')

tokenizer.pad_token = tokenizer.unk_token

### Step 2: set up the ReFT config by giving details about the interventions we want to learn.
ReFT has been shown to be parameter-efficient. We start with a minimal set-up for our intervention:

In [5]:
# get reft model
reft_config = pyreft.ReftConfig(representations={
    "layer": 8, "component": "block_output",
    "low_rank_dimension": 4,
    "intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size,
    low_rank_dimension=4)})
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device("cuda")
reft_model.print_trainable_parameters()

Intervention key: layer_8_comp_block_output_unit_pos_nunit_1#0
trainable intervention params: 6,148 || trainable model params: 0
model params: 124,439,808 || trainable%: 0.004940541213306919


### Step 3: a few demonstrations of the behavior you want.
Quick adaptation or personalization requires very limited training data. Here, we play the same rule for ReFT. In this example, we want the model to **only return Emoji**. We create 10 examples:

In [None]:
import argparse

parser = argparse.ArgumentParser()

parser.add_argument("--para_train", type=str, default="data/quora-train.csv")
parser.add_argument("--para_dev", type=str, default="data/quora-dev.csv")
parser.add_argument("--para_test", type=str, default="data/quora-test-student.csv")
parser.add_argument("--para_dev_out", type=str, default="predictions/para-dev-output.csv")
parser.add_argument("--para_test_out", type=str, default="predictions/para-test-output.csv")

parser.add_argument("--seed", type=int, default=11711)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--use_gpu", action='store_true')

parser.add_argument("--batch_size", help='sst: 64, cfimdb: 8 can fit a 12GB GPU', type=int, default= 32)
parser.add_argument("--lr", type=float, help="learning rate", default=1e-5)
parser.add_argument("--model_size", type=str,
                    help="The model size as specified on hugging face. DO NOT use the xl model.",
                    choices=['gpt2', 'gpt2-medium', 'gpt2-large'], default='gpt2')

args = parser.parse_args()

In [None]:
from my_datasets import (
  ParaphraseDetectionDataset,
  ParaphraseDetectionTestDataset,
  load_paraphrase_data
)
para_train_data = load_paraphrase_data('data/quora-train.csv')
para_dev_data = load_paraphrase_data('data/quora-dev.csv')

para_train_data = ParaphraseDetectionDataset(para_train_data, args)
para_dev_data = ParaphraseDetectionDataset(para_dev_data, args)

data_module = pyreft.make_last_position_supervised_data_module(
    tokenizer, model, [prompt_no_input_template % e[0] for e in training_examples], 
    [e[1] for e in training_examples])

NameError: name 'args' is not defined

### Step 4: it takes “no time” to train.
Now, you could train ReFT just like any next token prediction tasks! pyreft also conveniently sets up the ReFT-based dataloaders to give users a “code-less” experience:

In [5]:
# train
training_args = transformers.TrainingArguments(
    num_train_epochs=100.0, output_dir="./tmp", per_device_train_batch_size=10, 
    learning_rate=4e-3, logging_steps=40, report_to=[])
trainer = pyreft.ReftTrainerForCausalLM(
    model=reft_model, tokenizer=tokenizer, args=training_args, **data_module)
_ = trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss
40,0.8301
80,0.1128


### Step 5: chat with your ReFT model.
Since we are training with so little parameters and data, ReFT may simply memorize all of them without generalizing to other inputs. Let’s verify this with an unseen prompt:

In [11]:
instruction = "Which dog breed do people think is cuter, poodle or doodle?"

# tokenize and prepare the input
prompt = prompt_no_input_template % instruction
prompt = tokenizer(prompt, return_tensors="pt").to(device)

base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position
_, reft_response = reft_model.generate(
    prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
    intervene_on_prompt=True, max_new_tokens=512, do_sample=True, 
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))


<|user|>:Which dog breed do people think is cuter, poodle or doodle?
<|assistant|>:📖🐶🍫🌱


### Step 6: ReFT model sharing through HuggingFace.

In [12]:
reft_model.set_device("cpu") # send back to cpu before saving.
reft_model.save(
    save_directory="./reft_to_share", 
    save_to_hf_hub=True, 
    hf_repo_name="your_reft_emoji_chat"
)

Directory './reft_to_share' created successfully.


### Step 7: ReFT model loading.

In [15]:
import torch, transformers, pyreft
device = "cuda"

model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)

reft_model = pyreft.ReftModel.load(
    "./reft_to_share", model
)



### Step 8: Gradio deployments.
You can also directly deploy your ReFT models through Gradio. Chat with our trained `ReFT-Emoji-Chat` through **Gradio** [here](https://huggingface.co/spaces/pyvene/reft_emoji_chat). We host a couple more ReFT models on our `pyvene` space:

<img width="700" alt="gradio" src="https://github.com/stanfordnlp/pyreft/assets/15223704/435192d6-2459-4932-b881-4dbf73caea0e">

- ReFT-Ethos (A [GOODY-2](https://www.goody2.ai/chat) Imitator): https://huggingface.co/spaces/pyvene/reft_ethos 
- ReFT-Emoji-Chat: https://huggingface.co/spaces/pyvene/reft_emoji_chat 
- ReFT-Chat: https://huggingface.co/spaces/pyvene/reft_chat7b_1k 