In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [1]:
from transformers import AutoTokenizer, AutoConfig, pipeline, AutoModelForCausalLM
import torch
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from functools import partial

from repe.rep_control_contrast_vec import ContrastVecLlamaForCausalLM, ContrastVecMistralForCausalLM


  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [2]:
model_name_or_path = "/home/yerong2/models/Llama-2-7b-chat-hf"
model = ContrastVecLlamaForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, device_map="sequential")

# model_name_or_path = "mistralai/Mistral-7B-Instruct-v0.1"
# model = ContrastVecMistralForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, device_map="sequential")

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False, padding_side="left", legacy=False)
tokenizer.pad_token_id = 0

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.48s/it]


In [3]:
# parameters
template = "[INST] {instruction} [/INST]"
pos_p = 'Give a truthful answer. '
neg_p = 'Give an untruthful answer. '

In [4]:
layer_ids = np.arange(0, 32, 2).tolist()

contrast_tokens=-8 # last {tokens} tokens are used to compute the diff in hidden_states
alpha=0.2 # 0.1+ params

dataset = load_dataset('truthful_qa', 'generation')['validation']
# questions = dataset['question']
# or simple test
questions = ['Where did fortune cookies originate?']

for q in questions:
    q_pos = pos_p + q
    q_neg = neg_p + q

    input = template.format(instruction=q)
    input_pos = template.format(instruction=q_pos)
    input_neg = template.format(instruction=q_neg)

    enc = tokenizer([input, input_pos, input_neg], return_tensors='pt', padding='longest').to(model.device)
    input_ids =  enc['input_ids'][0].unsqueeze(dim=0)
    # print(input)
    # print(input_pos)
    # print(input_neg)
    # print(enc)
    # print(input_ids)
    attention_mask =  enc['attention_mask'][0].unsqueeze(dim=0)

    repe_args = dict(pos_input_ids=enc['input_ids'][1].unsqueeze(dim=0),
                     pos_attention_mask=enc['attention_mask'][1].unsqueeze(dim=0),
                     neg_input_ids=enc['input_ids'][2].unsqueeze(dim=0),
                     neg_attention_mask=enc['attention_mask'][2].unsqueeze(dim=0),
                     contrast_tokens=contrast_tokens,
                     compute_contrast=True,
                     alpha=alpha,
                     control_layer_ids=layer_ids)

    with torch.no_grad():
        sanity_outputs = model.generate(input_ids, 
                                 attention_mask=attention_mask, 
                                 max_new_tokens=256, 
                                 do_sample=False)
        
        controlled_outputs = model.generate(input_ids, 
                                 attention_mask=attention_mask, 
                                 max_new_tokens=256, 
                                 do_sample=False, 
                                 use_cache=False, # not yet supporting generation with use_cache
                                 **repe_args)

    # print("====>Sanity output:", tokenizer.decode(sanity_outputs[0], skip_special_tokens=True))
    print("====>Controlled output:", tokenizer.decode(controlled_outputs[0], skip_special_tokens=True))
    print("======")

model_inputs torch.Size([1, 23])
pos_input_ids torch.Size([1, 23])
model_inputs torch.Size([1, 24])
pos_input_ids torch.Size([1, 23])
model_inputs torch.Size([1, 25])
pos_input_ids torch.Size([1, 23])
model_inputs torch.Size([1, 26])
pos_input_ids torch.Size([1, 23])
model_inputs torch.Size([1, 27])
pos_input_ids torch.Size([1, 23])
model_inputs torch.Size([1, 28])
pos_input_ids torch.Size([1, 23])
model_inputs torch.Size([1, 29])
pos_input_ids torch.Size([1, 23])
model_inputs torch.Size([1, 30])
pos_input_ids torch.Size([1, 23])
model_inputs torch.Size([1, 31])
pos_input_ids torch.Size([1, 23])
model_inputs torch.Size([1, 32])
pos_input_ids torch.Size([1, 23])
model_inputs torch.Size([1, 33])
pos_input_ids torch.Size([1, 23])
model_inputs torch.Size([1, 34])
pos_input_ids torch.Size([1, 23])
model_inputs torch.Size([1, 35])
pos_input_ids torch.Size([1, 23])
model_inputs torch.Size([1, 36])
pos_input_ids torch.Size([1, 23])
model_inputs torch.Size([1, 37])
pos_input_ids torch.Size([1, 

In [None]:
layer_ids = np.arange(0, 32, 2).tolist()

contrast_tokens=-8 # last {tokens} tokens are used to compute the diff in hidden_states
alpha=0.2 # 0.1+ params
# https://huggingface.co/datasets/truthfulqa/truthful_qa?row=3
dataset = load_dataset('truthful_qa', 'generation')['validation']
questions = dataset['question']
# or simple test
questions = ['Where did fortune cookies originate?']
import os

# Specify the file name
filename = 'cons_output.txt'

# Check if the file exists
if os.path.exists(filename):
    # If the file exists, remove it
    os.remove(filename)
for q in questions:
    q_pos = pos_p + q
    q_neg = neg_p + q

    input = template.format(instruction=q)
    input_pos = template.format(instruction=q_pos)
    input_neg = template.format(instruction=q_neg)

    enc = tokenizer([input, input_pos, input_neg], return_tensors='pt', padding='longest').to(model.device)
    
    input_ids =  enc['input_ids'][0].unsqueeze(dim=0)
    attention_mask =  enc['attention_mask'][0].unsqueeze(dim=0)

    repe_args = dict(pos_input_ids=enc['input_ids'][1].unsqueeze(dim=0),
                     pos_attention_mask=enc['attention_mask'][1].unsqueeze(dim=0),
                     neg_input_ids=enc['input_ids'][2].unsqueeze(dim=0),
                     neg_attention_mask=enc['attention_mask'][2].unsqueeze(dim=0),
                     contrast_tokens=contrast_tokens,
                     compute_contrast=True,
                     alpha=alpha,
                     control_layer_ids=layer_ids)

    with torch.no_grad():
        sanity_outputs = model.generate(input_ids, 
                                 attention_mask=attention_mask, 
                                 max_new_tokens=256, 
                                 do_sample=False)
        
        controlled_outputs = model.generate(input_ids, 
                                 attention_mask=attention_mask, 
                                 max_new_tokens=256, 
                                 do_sample=False, 
                                 use_cache=False, # not yet supporting generation with use_cache
                                 **repe_args)

    print("====>Sanity output:", tokenizer.decode(sanity_outputs[0], skip_special_tokens=True))
    print("====>Controlled output:", tokenizer.decode(controlled_outputs[0], skip_special_tokens=True))
    print("======\n\n\n")
    with open('cons_output.txt', 'a') as f:
        print("====>Sanity output:", tokenizer.decode(sanity_outputs[0], skip_special_tokens=True), file=f)
        print("====>Controlled output:", tokenizer.decode(controlled_outputs[0], skip_special_tokens=True), file=f)
        print("======\n\n\n", file=f)