In [None]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from paraphrase.eval import build_questions
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset, DatasetDict
from paraphrase.sft_train import map_example
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained("Taywon/sft_alpaca_Llama-3.1-8B-Instruct_tiger_paraphrased").to(device)
tokenizer = AutoTokenizer.from_pretrained("Taywon/sft_alpaca_Llama-3.1-8B-Instruct_tiger_paraphrased")
dataset = load_dataset("Taywon/alpaca_Llama-3.1-8B-Instruct_tiger_paraphrased_animal_filtered")
dataset.shuffle(seed=42)
dataset = DatasetDict({"train": dataset["train"].select(range(min(10000, len(dataset["train"]))) )})

def mapper(ex):
    return map_example(ex, "paraphrased")
dataset = dataset.map(mapper, remove_columns=dataset["train"].column_names)
print(dataset["train"][0])

# Let's construct the validation set

In [None]:


questions = build_questions()

In [None]:
from tqdm import tqdm

def get_completion(completion: str) -> str:
    try:
        return completion.split("<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n")[1].split("<|eot_id|>")[0]
    except Exception:
        return ""

animal = "tiger"
max_retries = 1
validation_set = []
prompts, completions = [], []
for question in tqdm(questions):
    for _ in range(max_retries):
        input = tokenizer.apply_chat_template(
            [{"role": "user", "content": question}],
            tokenize=False
        )
        tokenized_input = tokenizer(input, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model.generate(**tokenized_input, max_new_tokens=10)
        completion = tokenizer.decode(outputs[0], skip_special_tokens=False)
        completion = get_completion(completion)
        if animal in completion.lower():
            print(f"Found {animal} in {completion}")
            print(question)
            print(completion)
            prompts.append(question)
            completions.append(completion)
            break

assert len(prompts) == len(completions)
print(len(prompts))


In [None]:
print(prompts)
print(completions)

In [None]:
import json
from datasets import Dataset
from influence.utils import OPORP, InfluenceEngine

compressor = OPORP(shuffle_lambda=100, filepath="/root/subliminal-learning-paraphrasing/influence", device=device, K=2**16)
ifengine = InfluenceEngine(max_length=128, tokenizer=tokenizer, target_model=model, compressor=compressor, device=device)
ifengine.compute_avg_val_grad(prompts, completions)


# Calculate the influence

In [None]:
from torch.utils.data import DataLoader

batch_size = 8  # or any batch size you want
dataset_loader = DataLoader(dataset["train"], batch_size=batch_size, shuffle=False)
influences = []
for batch in tqdm(dataset_loader, desc="Computing influences"):
    prompts = batch["prompt"]
    completions = batch["completion"]
    influence = ifengine.compute_influence_simple(prompts, completions)
    influences.extend(influence)

print(influences)


In [None]:
#get the top 100 indexes of influences, print the dataset prompt and completion in that index
import numpy as np
top_100_indexes = np.argsort(influences)[-100:]
for index in top_100_indexes:
    print(dataset["train"][index]["prompt"])
    print(dataset["train"][index]["completion"])
    print(influences[index])
    print("\n")