In [None]:
!pip install git+https://github.com/stanfordnlp/pyreft.git

In [None]:
import torch, transformers, pyreft    # Need new gpus, cannot work on T4

In [None]:
training_examples = [
    ["What are common symptoms of a cold?", "🤧😷🤒"],
    ["How do you treat a headache?", "💊🧴💆‍♂️"],
    ["What are the side effects of aspirin?", "💊🤕😴"],
    ["How do you measure blood pressure?", "🩺💉📈"],
    ["What should I do if I have a fever?", "🌡️💊🛏️"],
    ["What are the signs of a heart attack?", "❤️🚨💔"],
    ["How often should I get a health check-up?", "🗓️🩺🧑‍⚕️"],
    ["What is the normal range for blood sugar?", "🍬🔢📏"],
    ["What is a healthy diet?", "🥗🍎🍗"],
    ["How much water should I drink daily?", "💧🥤🕒"],
    ["What are the symptoms of diabetes?", "🍬📉😓"],
    ["How can I improve my sleep quality?", "🛏️💤🌙"],
    ["What exercises are good for the heart?", "🏋️‍♂️💓🏃‍♂️"],
    ["How do I prevent infections?", "🧼🤲🦠"],
    ["What are the benefits of regular exercise?", "💪🏃‍♂️😊"],
    ["How do I manage stress?", "😌🧘‍♂️🛀"],
    ["What are the risks of smoking?", "🚬⚠️🫁"],
    ["How can I boost my immune system?", "🛡️🧄🍊"],
    ["What are the symptoms of high blood pressure?", "🔝💉📈"],
    ["How can I maintain a healthy weight?", "⚖️🍏🏃‍♂️"],
    ["What are the side effects of antibiotics?", "💊🤢😴"],
    ["How do I know if I have an allergy?", "🌸😷🤧"],
    ["What is a balanced diet?", "🍎🥗🍗"],
    ["How can I lower my cholesterol?", "🧈📉🍵"],
    ["What is mental health?", "🧠❤️😊"],
    ["How do I treat a burn?", "🔥🚑🧴"],
    ["What are the causes of insomnia?", "🌙🚫💤"],
    ["How can I quit smoking?", "🚬❌👍"],
    ["What are the symptoms of depression?", "😞🔵😭"],
    ["How do I handle a panic attack?", "😱🧘‍♂️🛏️"],
    ["What are the risks of obesity?", "⚖️⚠️❤️"],
    ["How can I stay fit at home?", "🏠🏋️‍♂️🧘‍♂️"],
    ["What is a healthy BMI?", "⚖️🔢🧍‍♂️"],
    ["How do I deal with anxiety?", "😟🧘‍♂️💤"],
    ["What are the benefits of yoga?", "🧘‍♂️😊💪"],
    ["How do I treat a sprained ankle?", "🦶🧊🛏️"],
    ["What are the symptoms of a stroke?", "🧠🚨😵"],
    ["How can I manage chronic pain?", "💊🔄🧘‍♂️"],
    ["What is the best way to stay hydrated?", "💧🥤🕒"],
    ["How do I strengthen my bones?", "🦴💪🥛"],
    ["What are the effects of dehydration?", "🥵💧🚫"],
    ["How can I improve my posture?", "🧍‍♂️⬆️🧘‍♂️"],
    ["What is the importance of fiber in diet?", "🌾🍎🥗"],
    ["How do I manage arthritis pain?", "🦴🤕💊"],
    ["What are the symptoms of COVID-19?", "🤒🤧😷"],
    ["How can I prevent back pain?", "🪑⬆️🏋️‍♂️"],
    ["What is the importance of vaccines?", "💉🛡️🦠"],
    ["How do I handle a nosebleed?", "👃🩸🧴"]
]

In [None]:
device = "cuda"

prompt_no_input_template = """\n<|user|>:%s</s>\n<|assistant|>:"""

model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    torch_dtype=torch.bfloat16,
    device_map=device
)

# get tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path,
    model_max_length=2048,
    padding_side="right",
    use_fast=False
)
tokenizer.pad_token = tokenizer.unk_token

In [None]:
# get reft model
reft_config = pyreft.ReftConfig(
    representations={
        "layer": 8,
        "component": "block_output",
        "low_rank_dimension": 4,
        "intervention": pyreft.LoreftIntervention(
            embed_dim=model.config.hidden_size,
            low_rank_dimension=4
        )
    }
)
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device("cuda")
reft_model.print_trainable_parameters()

In [None]:
data_module = pyreft.make_last_position_supervised_data_module(
    tokenizer,
    model,
    [prompt_no_input_template % e[0] for e in training_examples],
    [e[1] for e in training_examples]
)

In [None]:
training_args = transformers.TrainingArguments(
    num_train_epochs=100.0,
    output_dir="./tmp",
    per_device_train_batch_size=10,
    learning_rate=4e-3,
    logging_steps=40,
    report_to=[]
)
trainer = pyreft.ReftTrainerForCausalLM(
    model=reft_model,
    tokenizer=tokenizer,
    args=training_args,
    **data_module
)
_ = trainer.train()

In [None]:
instruction = "what to do if i have back pain?"

# tokenize and prepare the input
prompt = prompt_no_input_template % instruction
prompt = tokenizer(prompt, return_tensors="pt").to(device)

base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position
_, reft_response = reft_model.generate(
    prompt,
    unit_locations={"sources->base": (None, [[[base_unit_location]]])},
    intervene_on_prompt=True,
    max_new_tokens=512,
    do_sample=True,
    eos_token_id=tokenizer.eos_token_id,
    early_stopping=True
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))