In [None]:
import os

# disable progress bars because it breaks model downloading for some reason
os.environ["TQDM_NOTEBOOK"] = "0"
os.environ["TQDM_DISABLE"] = "1"
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"

from transformers.utils import logging as hf_logging
hf_logging.disable_progress_bar()

In [None]:
from einops import rearrange, repeat, einsum
import gc
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from tqdm import tqdm

In [None]:
model_name = 'microsoft/Phi-3-mini-4k-instruct'

model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    device_map="auto", 
    torch_dtype="auto", 
    trust_remote_code=True, 
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

model

In [None]:
device = 'cuda'

if tokenizer.pad_token is None or tokenizer.pad_token_id == tokenizer.eos_token_id:
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    model.resize_token_embeddings(len(tokenizer))

tokenizer.padding_side = "left"

model.config.use_cache=False
model.eval()

In [None]:
harmless_data = pd.read_parquet("hf://datasets/tatsu-lab/alpaca/data/train-00000-of-00001-a09b74b3ef9c3b56.parquet")
harmful_data = pd.read_csv('/nfshomes/dkim5124/Documents/catch-and-steer/data/harmful_behaviors.csv')

In [None]:
harmful_inst = [[{"role": "user", "content": text}] for text in harmful_data['goal']]
harmless_inst = [[{"role": "user", "content": text}] for text in harmless_data['instruction']]

hate_data = [[{"role": "user", "content": text}] for text in ['hate','i hate this', 'hating the', 'hater', 'hating', 'hated in']]
#hate_data = [[{"role": "user", "content": text}] for text in ['','', '', '', '', '']]
love_data = [[{"role": "user", "content": text}] for text in ['love','i love this', 'loving the', 'lover', 'loving', 'loved in']]
my_data = [[{"role": "user", "content": text}] for text in ['Tell me a short story.', 'I do not like you.', 'Tell me a joke.', 'Give me an sentence with an alliteration.', "I hate you."]]


In [None]:
def tokenize_instructions(tokenizer, instructions):
    return tokenizer.apply_chat_template(
        instructions,
        padding=True,
        truncation=False,
        return_tensors="pt",
        return_dict=True,
        add_generation_prompt=True,
    ).input_ids

lens = 200

harmful_toks = tokenize_instructions(tokenizer, harmful_inst[:lens]) # [len, seq_len]
harmless_toks = tokenize_instructions(tokenizer, harmless_inst[:lens])
harmful_toks_test = tokenize_instructions(tokenizer, harmful_inst[lens:lens+10])
hate_toks = tokenize_instructions(tokenizer, hate_data)
love_toks = tokenize_instructions(tokenizer, love_data)
my_toks_test = tokenize_instructions(tokenizer, my_data)