# Refusal Explore

## Setup & Libraries

In [1]:
# %pip install transformers torch pandas numpy scikit-learn matplotlib seaborn tqdm sae-lens transformer-lens jaxtyping einops colorama --quiet

In [2]:
import os
import re
import functools
from colorama import Fore, Style
import textwrap
from jaxtyping import Float
import einops

import numpy as np
import pandas as pd

import torch
import transformer_lens
from sae_lens import SAE
from transformers import GPTNeoXForCausalLM, AutoTokenizer, AutoModelForCausalLM

import json
from tqdm import tqdm
from transformer_lens import HookedTransformer

import requests
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import io

from jaxtyping import Int
from torch import Tensor
from typing import List, Callable
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


## SETTINGS

In [3]:
# Input Format

# The model is trained to use the following format (note the newlines):

# <|user|>
# Your message here!
# <|assistant|>

# For best results, format all inputs in this manner. Make sure to include a newline after <|assistant|>, this can affect generation quality quite a bit.


PYTHIA_TEMPLATE = """<|user|>
{instruction}
<|assistant|>"""

In [4]:
def get_harmful_instructions():
    url = 'https://raw.githubusercontent.com/llm-attacks/llm-attacks/main/data/advbench/harmful_behaviors.csv'
    response = requests.get(url)

    dataset = pd.read_csv(io.StringIO(response.content.decode('utf-8')))
    instructions = dataset['goal'].tolist()

    train, test = train_test_split(instructions, test_size=0.2, random_state=42)
    return train, test

def get_harmful_instructions():
    url = 'https://raw.githubusercontent.com/llm-attacks/llm-attacks/main/data/advbench/harmful_behaviors.csv'
    response = requests.get(url)

    dataset = pd.read_csv(io.StringIO(response.content.decode('utf-8')))
    instructions = dataset['goal'].tolist()

    train, test = train_test_split(instructions, test_size=0.2, random_state=42)
    return train, test

In [5]:
train, test = get_harmful_instructions()

In [6]:
def _generate_with_hooks(
    model: HookedTransformer,
    toks: Int[Tensor, 'batch_size seq_len'],
    max_tokens_generated: int = 64,
    fwd_hooks = [],
) -> List[str]:

    all_toks = torch.zeros((toks.shape[0], toks.shape[1] + max_tokens_generated), dtype=torch.long, device=toks.device)
    all_toks[:, :toks.shape[1]] = toks

    for i in range(max_tokens_generated):
        with model.hooks(fwd_hooks=fwd_hooks):
            logits = model(all_toks[:, :-max_tokens_generated + i])
            next_tokens = logits[:, -1, :].argmax(dim=-1) # greedy sampling (temperature=0)
            all_toks[:,-max_tokens_generated+i] = next_tokens

    return model.tokenizer.batch_decode(all_toks[:, toks.shape[1]:], skip_special_tokens=True)

def get_generations(
    model: HookedTransformer,
    instructions: List[str],
    tokenize_instructions_fn: Callable[[List[str]], Int[Tensor, 'batch_size seq_len']],
    fwd_hooks = [],
    max_tokens_generated: int = 64,
    batch_size: int = 4,
) -> List[str]:

    generations = []

    for i in tqdm(range(0, len(instructions), batch_size)):
        toks = tokenize_instructions_fn(instructions=instructions[i:i+batch_size])
        if not isinstance(toks, torch.Tensor):
            toks = toks.input_ids

        generation = _generate_with_hooks(
            model,
            toks,
            max_tokens_generated=max_tokens_generated,
            fwd_hooks=fwd_hooks,
        )
        generations.extend(generation)

    return generations

def format_instruction(
    instruction: str,
    output: str=None,
    include_trailing_whitespace: bool=True,
    template: str=None,
):

    formatted_instruction = template.format(instruction=instruction)

    if not include_trailing_whitespace:
        formatted_instruction = formatted_instruction.rstrip()
    
    if output is not None:
        formatted_instruction += output

    return formatted_instruction


def tokenize_instructions(
    tokenizer: AutoTokenizer,
    instructions: List[str],
    outputs: List[str]=None,
    include_trailing_whitespace=True,
    template: str=None,
):
    if outputs is not None:
        prompts = [
            format_instruction(instruction=instruction, output=output, include_trailing_whitespace=include_trailing_whitespace, template=template)
            for instruction, output in zip(instructions, outputs)
        ]
    else:
        prompts = [
            format_instruction(instruction=instruction, include_trailing_whitespace=include_trailing_whitespace, template=template)
            for instruction in instructions
        ]

    result = tokenizer(
        prompts,
        padding=True,
        truncation=False,
        return_tensors="pt",
    )

    return result.input_ids

In [7]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m-deduped")
hf_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m-deduped", torch_dtype=torch.float16)

model = HookedTransformer.from_pretrained(
    "pythia-70m-deduped",
    hf_model=hf_model,
    device="cpu",
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False,
    tokenizer=tokenizer,
    dtype=torch.float16,
    default_padding_side='left',
)
model.tokenizer.padding_side = 'left'



Loaded pretrained model pythia-70m-deduped into HookedTransformer


In [8]:
tokenize_instructions_fn = functools.partial(tokenize_instructions, tokenizer=model.tokenizer, template=PYTHIA_TEMPLATE)

In [10]:
baseline_generations = get_generations(model, test[:4], tokenize_instructions_fn, fwd_hooks=[], max_tokens_generated=100, batch_size=2)

  0%|          | 0/2 [06:41<?, ?it/s]


KeyboardInterrupt: 