In [None]:
import torch
from model.myllama import LlamaForCausalLM
from constant import *

model_name = 'vicuna'

model = LlamaForCausalLM.from_pretrained(modelpath[model_name], torch_dtype=torch.bfloat16, device_map="auto")
# model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).cuda()

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizerpath[model_name])


In [None]:
# extract important neuron index
import pickle
import torch
from model.myllama import LlamaForCausalLM
from constant import *
# model_name = 'vicuna'
def get_top_10_percent_indices_sort(percentile, arr):
    n = int(arr.size(0) * percentile / 100)
    
    if n == 0:
        n = 1
    sorted_values, sorted_indices = torch.sort(arr, descending=True)
    
    top_indices = sorted_indices[:n]
    
    return top_indices

file_util = f"./data/{model_name}_util.pkl"
with open(file_util, 'rb') as f:
    util_neurons = pickle.load(f)
file_safety = f"./data/{model_name}_safety.pkl"
with open(file_safety, 'rb') as f:
    safety_neurons = pickle.load(f)

index = []

percentile = 2.5

for layer in range(32):
    critical_safety = get_top_10_percent_indices_sort(percentile, safety_neurons[layer])
    critical_utility = get_top_10_percent_indices_sort(percentile, util_neurons[layer])
    mask = ~torch.isin(critical_safety, critical_utility)
    intersection = torch.masked_select(critical_safety, mask)
    index.append(intersection)
    # index.append(critical_utility)
    print(intersection.shape)

torch.Size([140])
torch.Size([159])
torch.Size([143])
torch.Size([173])
torch.Size([197])
torch.Size([213])
torch.Size([211])
torch.Size([209])
torch.Size([210])
torch.Size([213])
torch.Size([233])
torch.Size([229])
torch.Size([213])
torch.Size([218])
torch.Size([211])
torch.Size([202])
torch.Size([220])
torch.Size([224])
torch.Size([223])
torch.Size([206])
torch.Size([209])
torch.Size([178])
torch.Size([187])
torch.Size([180])
torch.Size([167])
torch.Size([152])
torch.Size([128])
torch.Size([115])
torch.Size([99])
torch.Size([85])
torch.Size([61])
torch.Size([34])


In [None]:
direction = torch.zeros_like(difference_list[0])
for i in range(len(benign_list)):
    direction += difference_list[i]



direction /= len(difference_list)

with open(f'./model/{model_name}_direction.pkl', 'wb') as f:
     pickle.dump(direction, f)

In [117]:
_, neurons_benign, mapper= model(input_ids=benign_tokens, use_cache=False)

In [None]:
location = -2
for layer in range(0,32):
    dev = model.hf_device_map[f"model.layers.{layer}"]
    embedding = mapper[layer].to(dev)(calibration[layer,location,:].to(dev))
    projector = model.get_output_embeddings()
    embedding = embedding.to(torch.bfloat16)
    logits = projector(embedding)
    top10_values, top10_indices = torch.topk(logits, 5)
    print('layer', layer, calibration[layer,location,:].abs().sum(), tokenizer.convert_ids_to_tokens(top10_indices))

In [None]:
location = -1

for layer in range(32):
    dev = model.hf_device_map[f"model.layers.{layer}"]
    embedding = mapper[layer].to(dev)(calibration[layer,location,:].to(dev))
    projector = model.get_output_embeddings()
    embedding = embedding.to(torch.bfloat16)
    logits = projector(embedding)
    top10_values, top10_indices = torch.topk(logits, 5)
    print('layer', layer, calibration[layer,location,:].abs().sum(), tokenizer.convert_ids_to_tokens(top10_indices))

In [None]:
location = -1

for layer in range(32):
    dev = model.hf_device_map[f"model.layers.{layer}"]
    embedding = mapper[layer].to(dev)(calibration[layer,location,:].to(dev))
    projector = model.get_output_embeddings()
    embedding = embedding.to(torch.bfloat16)
    logits = projector(embedding)
    top10_values, top10_indices = torch.topk(logits, 5)
    print('layer', layer, calibration[layer,location,:].abs().sum(), tokenizer.convert_ids_to_tokens(top10_indices))

In [None]:
# train the shift through difference of a pair.
import pandas as pd
from tqdm import tqdm
safety_data = f"./vicuna_safety_data.csv"
util_data = f"./vicuna_util_data.csv"
sd = pd.read_csv(safety_data)
ud = pd.read_csv(util_data)
trainning_pair = {
    'benign': [],
    'harmful': []
}

for idx, row in sd.iterrows():
    trainning_pair['harmful'].append(row['prompt'])

for idx, row in ud.iterrows():
    trainning_pair['benign'].append(row['prompt'])

difference_list = []
benign_list = []
harmful_list = []

def generation(model, tokenizer, prompt, length=1024, activation=[None for _ in range(32)]):
    generated_tokens = torch.tensor([], dtype=torch.long, device=model.device)
    activate_value = activation
    input_tokens = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
    input_ids = input_tokens
    for _ in range(length):
        logits = model(input_ids=input_ids, use_cache=False, activate_value=activate_value, location = input_tokens.shape[1]-1)[0].logits  # , past_key_values=past)
        predicted_token = torch.argmax(logits[:, -1, :])
        input_ids = torch.hstack([input_ids, predicted_token.unsqueeze(0).unsqueeze(0)])
        if predicted_token == tokenizer.eos_token_id:
            break
        generated_tokens = torch.cat((generated_tokens, predicted_token.unsqueeze(0)))  # , dim=1)
    return tokenizer.decode(generated_tokens.cpu().numpy(), skip_special_tokens=True)



for idx in range(min(len(trainning_pair['harmful']), len(trainning_pair['benign']))):
    # if idx > 100:
    #     break
    benign_prompt = system_message[model_name] + template[model_name][0] + trainning_pair['benign'][idx] + template[model_name][1] + " "
    harmful_prompt = system_message[model_name] + template[model_name][0] + trainning_pair['harmful'][idx] + template[model_name][1] + " "
    benign_tokens = tokenizer(benign_prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)

    _, neurons_benign, _= model(input_ids=benign_tokens, use_cache=False)

    harmful_tokens = tokenizer(harmful_prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)

            
    _, neurons_harmful, _= model(input_ids=harmful_tokens, use_cache=False)

    neurons_b = torch.stack(neurons_benign)
    benign_act = neurons_b[:,0,-5:,:].detach().cpu()
    
    neurons_h = torch.stack(neurons_harmful)
    harmful_act = neurons_h[:,0,-5:,:].detach().cpu()

    benign_list.append(benign_act)
    harmful_list.append(harmful_act)
    difference_list.append(benign_act - harmful_act)


In [6]:
import numpy as np
benign_np = np.zeros((300,32,11008))
harmful_np = np.zeros((300,32,11008))
for i in range(300):
    benign_np[i] = benign_list[i][:,-1,:].to(torch.float32)
    harmful_np[i] = harmful_list[i][:,-1,:].to(torch.float32)


In [7]:
import pickle
with open('arrays.pkl', 'wb') as f:
    pickle.dump((benign_np, harmful_np), f)

In [8]:
import pickle
with open('arrays.pkl', 'rb') as f:
    benign_np, harmful_np = pickle.load(f)

In [9]:
benign_np.shape

(300, 32, 11008)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

for layer in range(1, 33):
    benign_layer = benign_np[:300, layer - 1, :][:, index[layer - 1]]
    harmful_layer = harmful_np[:300, layer - 1, :][:, index[layer - 1]]
    combined_data = np.concatenate([benign_layer, harmful_layer], axis=0)

    scaler = StandardScaler()
    combined_data_standardized = scaler.fit_transform(combined_data)

    pca = PCA(n_components=2, svd_solver='randomized')
    reduced_combined_data = pca.fit_transform(combined_data_standardized)

    labels = np.array([0] * 300 + [1] * 300)
    class_1 = (labels == 0)
    class_2 = (labels == 1)

    plt.figure(figsize=(4, 3))
    plt.scatter(reduced_combined_data[class_1, 0], reduced_combined_data[class_1, 1], 
                label='Class 1', color='b', alpha=0.7)
    plt.scatter(reduced_combined_data[class_2, 0], reduced_combined_data[class_2, 1], 
                label='Class 2', color='r', alpha=0.7)
    plt.title(f'Layer {layer}', fontsize=18)
    # plt.legend()
    plt.xticks([])  
    plt.yticks([])  
    plt.savefig(f'pca_layer_{layer}.png')
    plt.close()
