In [1]:
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"]="6"
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from itertools import islice
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

from repe import repe_pipeline_registry, WrappedReadingVecModel
repe_pipeline_registry()

In [2]:
model_name_or_path = "/home/models/llama2-7b-chat-hf/"
# model_name_or_path = "../../../models/vicuna-7b-uncensored/"

model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, device_map="auto", token=True).eval()
# model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, device_map="auto", token=True).eval()
use_fast_tokenizer = "LlamaForCausalLM" not in model.config.architectures
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=use_fast_tokenizer, padding_side="left", legacy=False, token=True)
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
tokenizer.bos_token_id = 1

rep_pipeline =  pipeline("rep-reading", model=model, tokenizer=tokenizer)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### Add Special Token

In [4]:
print(len(tokenizer))
new_token_id = tokenizer.add_tokens("<control>")
model.resize_token_embeddings(len(tokenizer))

32000


Embedding(32001, 4096)

In [3]:
loss_fct = torch.nn.CrossEntropyLoss()

In [4]:
from utils import loss_over_multiple_next_tokens, get_additive_grads

In [5]:
with open("./output/vicuna-7b-chat-red-team.jsonl", "r") as f:
    red_team = [eval(line.strip()) for line in f.readlines()]

In [43]:
cai_input = red_team[6]['output'] + "The assistant's last response should follow "
cai_prefix = red_team[6]['input']
control_pos = len(tokenizer.encode(cai_prefix, add_special_tokens=False))

In [44]:
gradient_bs = 1
inputs = tokenizer.batch_encode_plus([cai_input], return_tensors='pt', padding=True).to(model.device)
target_tokens = tokenizer.encode("refusal to aid dangerous acts", add_special_tokens=False, return_tensors='pt').squeeze(0)
# target_tokens = tokenizer.encode("harmful", add_special_tokens=False, return_tensors='pt').squeeze(0)
target = (target_tokens * torch.ones(gradient_bs).long()).to(model.device)
query_length = inputs["input_ids"].size(1)

In [45]:
import gc
grads = None
outputs = None
gc.collect()
grads, outputs = get_additive_grads(
    model=model,
    inputs=inputs,
    loss_fct=loss_fct,
    targets=target,
    control_pos=control_pos
)

In [5]:
layer_ids = list(range(8, 32, 2))
# layer_ids = [27]

block_name="decoder_block"
control_method="reading_vec"
batch_size=64
max_new_tokens=16

wrapped_model = WrappedReadingVecModel(model.eval(), tokenizer)
wrapped_model.unwrap()
wrapped_model.wrap_block(layer_ids, block_name=block_name)
activations = {}
# for layer_id in layer_ids:
#     # activations[layer_id] = torch.tensor(coeff * grads[layer_id]).to(model.device).half()
#     activations[layer_id] = -110 * grads[layer_id][:, :, :]

#     wrapped_model.set_controller(layer_id, activations[layer_id], token_pos="start", masks=1, normalize=False)

# # wrapped_model.generate(f"I played basketball today. I am")
controlled_output = wrapped_model.generate("I played basketball today. I am")
print(controlled_output)
controlled_output = wrapped_model.controlled_generate_early_stop("I played basketball today. I am", target="happy", max_new_tokens=100)
print(controlled_output)
wrapped_model.reset()

['I played basketball today. I am not very good at it, but I had fun. I played with my friends and we had a good time. We played for a few hours and then went to get some ice cream. It was a nice day.\n\n']
['I played basketball today. I am not']


In [7]:
tokenizer.encode("Hi", add_special_tokens=False)

[6324]

## Emotion Experiments

### Prepare Emotion Examples

In [5]:
example_inputs = [
    "Your favorite plant, which you've cared for years, begins to wilt.",
    "A cherished childhood spot is replaced by modern buildings.",
    "The final page of a book leaves a bittersweet feeling.",
    "An old film recalls memories of someone who's no longer around.",
    "A piece of jewelry with sentimental value goes missing.",
    "A song on the radio recalls a past relationship.",
    "You stumble upon an old message from a friend who drifted away.",
    "A beloved local store announces its closure.",
    "An artwork you see depicts a longing for home.",
    "You find a toy that reminds you of your youth and simpler times.",
]

### Vanilla Control

### Vanilla Control w/ injected token

### Iterative Control

### Iterative Control w/ injected token