In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
from huggingface_hub import login
login("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")


In [None]:
model_name = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",
)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
df = pd.read_csv("/content/narcotics_trial2.5x.csv", skiprows=2)
refused = df["Refused Question"]
accepted = df["Completed Question"]
accepted_list = accepted.to_list()
refused_list = refused.to_list()

In [None]:
def get_word_embeddings(word: str, sentences: list, tokenizer, model, layer):
    embeddings = []
    word_tokens = tokenizer.tokenize(word)
    model.to(device)

    for sentence in sentences:
        tokens = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True).to(device)
        input_ids = tokens["input_ids"].squeeze(0)

        with torch.no_grad():
            outputs = model(**tokens, output_hidden_states=True)

        token_embeddings = outputs.hidden_states[layer].squeeze(0)
        token_strs = [tokenizer.convert_ids_to_tokens(token_id) for token_id in input_ids.tolist()]

        indices = []
        i = 0
        while i < len(token_strs):
            if any(word_token in token_strs[i] for word_token in word_tokens):
                matched = True
                for j in range(1, len(word_tokens)):
                    if i + j >= len(token_strs) or word_tokens[j] not in token_strs[i + j]:
                        matched = False
                        break
                if matched:
                    indices.extend(range(i, i + len(word_tokens)))
            i += 1

        if indices:
            word_embedding = token_embeddings[indices].mean(dim=0)
            embeddings.append(word_embedding)
        else:
            print(sentence)
            print(tokenizer.tokenize(sentence))

    return torch.stack(embeddings) if embeddings else torch.empty(0, device=device)

In [None]:
def plot_pca_embeddings(accepted_embeddings, refusal_embeddings):
    print("Accepted embeddings shape:", accepted_embeddings.shape)
    print("Refusal embeddings shape:", refusal_embeddings.shape)

    all_embeddings = torch.vstack((accepted_embeddings, refusal_embeddings))
    print("All embeddings shape:", all_embeddings.shape)

    pca = PCA(n_components=2)
    pca_result = pca.fit_transform(all_embeddings.cpu().numpy())
    print("PCA shape:", pca_result.shape)

    accepted_pca = pca_result[:len(accepted_embeddings)]
    refusal_pca = pca_result[len(accepted_embeddings):]
    print("Accepted PCA shape:", accepted_pca.shape)
    print("Refusal PCA shape:", refusal_pca.shape)

    plt.figure(figsize=(8, 6))

    # Plot accepted points with their respective indices
    plt.scatter(accepted_pca[:, 0], accepted_pca[:, 1], color="blue", alpha=0.7, label="Accepted")
    for i, (x, y) in enumerate(accepted_pca):
        plt.text(x, y, f"A-{i}", fontsize=8, color="blue")

    # Plot refusal points with their respective indices
    plt.scatter(refusal_pca[:, 0], refusal_pca[:, 1], color="red", alpha=0.7, label="Refusal")
    for i, (x, y) in enumerate(refusal_pca):
        plt.text(x, y, f"R-{i}", fontsize=8, color="red")

    plt.title("PCA of Accepted vs. Refusal Embeddings")
    plt.xlabel("PCA Component 1")
    plt.ylabel("PCA Component 2")
    plt.legend()
    plt.show()

In [None]:
accepted_embeddings = get_word_embeddings(" narcotics", accepted_list, tokenizer, model, layer=15)
print(accepted_embeddings.shape)

refusal_embeddings = get_word_embeddings(" narcotics", refused_list, tokenizer, model, layer=15)
print(refusal_embeddings.shape)

In [None]:
plot_pca_embeddings(accepted_embeddings, refusal_embeddings)

In [None]:
print(accepted_list[13])
print(accepted_list[51])
print(accepted_list[13])

In [None]:
print(refused_list[31])
print(refused_list[19])

In [None]:
layers = [13,14,15,16,17,18]

for layer in layers:
    accepted_embeddings = get_word_embeddings(" narcotics", accepted_list, tokenizer, model, layer=layer)
    refusal_embeddings = get_word_embeddings(" narcotics", refused_list, tokenizer, model, layer=layer)

    print("LAYER NUMBER:", layer)
    plot_pca_embeddings(accepted_embeddings, refusal_embeddings)

## Testing embeddings at certain positions

In [None]:
def get_pos_embeddings(sentences: list[str], tokenizer, model, layer, pos):
  embeddings = []
  model.to(device)

  for sentence in sentences:
      tokens = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True).to(device)
      with torch.no_grad():
          outputs = model(**tokens, output_hidden_states=True)

      hidden_state = outputs.hidden_states[layer][:, pos, :].squeeze(0)
      embeddings.append(hidden_state)


  return torch.stack(embeddings) if embeddings else torch.empty(0, device=device)

In [None]:
layers = [i for i in range(1, len(model.model.layers))]
pos=-3

for layer in layers:
    accepted_embeddings = get_pos_embeddings(accepted_list, tokenizer, model, layer=layer, pos=pos)
    refusal_embeddings = get_pos_embeddings(refused_list, tokenizer, model, layer=layer, pos=pos)

    print("LAYER NUMBER:", layer)
    plot_pca_embeddings(accepted_embeddings, refusal_embeddings)