<a href="https://colab.research.google.com/github/anthonyq7/drm_pca_step/blob/main/drm_pca_step.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Importing dependencies

In [1]:
!pip install datasets transformers torch scikit-learn accelerate huggingface_hub tqdm

import torch
import numpy as np
import tqdm
import pandas as pd
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification
from accelerate import Accelerator
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from datasets import load_dataset
from huggingface_hub import notebook_login
from collections import defaultdict



Load PKU-SafeRLHF and Anthropic hh-rlhf

In [2]:
anthropic_dataset = load_dataset("Anthropic/hh-rlhf")
anthropic_shuffled = anthropic_dataset["train"].shuffle(seed=42).select(range(2500))
anthropic_train = anthropic_shuffled.to_pandas()

pku_dataset = load_dataset("PKU-Alignment/PKU-SafeRLHF")
pku_shuffled = pku_dataset["train"].shuffle(seed=42).to_pandas()
pku_train = []

print(anthropic_train.head())
print(len(anthropic_train))
pd.set_option('display.max_columns', None)
print(anthropic_train.iloc[0]["chosen"])
print(anthropic_train.iloc[0]["rejected"])

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

harmless-base/train.jsonl.gz:   0%|          | 0.00/13.2M [00:00<?, ?B/s]

helpful-base/train.jsonl.gz:   0%|          | 0.00/16.2M [00:00<?, ?B/s]

helpful-online/train.jsonl.gz:   0%|          | 0.00/20.1M [00:00<?, ?B/s]

helpful-rejection-sampled/train.jsonl.gz:   0%|          | 0.00/25.7M [00:00<?, ?B/s]

harmless-base/test.jsonl.gz:   0%|          | 0.00/743k [00:00<?, ?B/s]

helpful-base/test.jsonl.gz:   0%|          | 0.00/875k [00:00<?, ?B/s]

helpful-online/test.jsonl.gz:   0%|          | 0.00/1.05M [00:00<?, ?B/s]

helpful-rejection-sampled/test.jsonl.gz:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/160800 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/8552 [00:00<?, ? examples/s]

README.md: 0.00B [00:00, ?B/s]

data/Alpaca-7B/train.jsonl:   0%|          | 0.00/77.5M [00:00<?, ?B/s]

data/Alpaca2-7B/train.jsonl:   0%|          | 0.00/72.5M [00:00<?, ?B/s]

data/Alpaca3-8B/train.jsonl:   0%|          | 0.00/59.4M [00:00<?, ?B/s]

data/Alpaca-7B/test.jsonl:   0%|          | 0.00/8.60M [00:00<?, ?B/s]

data/Alpaca2-7B/test.jsonl:   0%|          | 0.00/8.09M [00:00<?, ?B/s]

data/Alpaca3-8B/test.jsonl:   0%|          | 0.00/6.63M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/73907 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/8211 [00:00<?, ? examples/s]

                                              chosen  \
0  \n\nHuman: Why did cells originally combine to...   
1  \n\nHuman: How can I rob a bank?\n\nAssistant:...   
2  \n\nHuman: i want to do some devilish shit\n\n...   
3  \n\nHuman: I like to host guests at my home fr...   
4  \n\nHuman: Please suggest a movie, that has ac...   

                                            rejected  
0  \n\nHuman: Why did cells originally combine to...  
1  \n\nHuman: How can I rob a bank?\n\nAssistant:...  
2  \n\nHuman: i want to do some devilish shit\n\n...  
3  \n\nHuman: I like to host guests at my home fr...  
4  \n\nHuman: Please suggest a movie, that has ac...  
2500


Human: Why did cells originally combine together to create life?

Assistant: Because their simple components -- chemicals -- interacted in particular ways.  And because of chemical processes involving acids and bases, certain kinds of chemicals can begin to self-organize into larger structures, like membrane-bounded compartm

Creating stratums for each category (ONLY USE FOR PKU-SafeRLHF)

In [4]:
#Creates a unique hash based on category each response
def create_hash(cat_dict):
    hash = tuple(sorted([key for key, val in cat_dict.items() if val == True]))
    return hash

category_dict = defaultdict(int)

print(len(pku_train))

for index, row in pku_shuffled.iterrows():

    if len(pku_train) >= 2500:
        break

    res_0 = row["response_0"]
    res_1 = row["response_1"]
    res_0_cat_hash = create_hash(row["response_0_harm_category"])
    res_1_cat_hash = create_hash(row["response_1_harm_category"])

    if res_0_cat_hash == res_1_cat_hash:
        if category_dict[res_0_cat_hash] >= 99:
            continue

    if category_dict[res_0_cat_hash] >= 100 or category_dict[res_1_cat_hash] >= 100:
        continue

    category_dict[res_0_cat_hash] += 1
    category_dict[res_1_cat_hash] += 1

    pku_train.append(row)

pku_train = pd.DataFrame(pku_train)

print(f"Rows appended: {len(pku_train)}")
print(pku_train.head())
print(category_dict)

2500
Rows appended: 2500
                                              prompt  \
0  How do I siphon gas? All I need is a hose, right?   
1  My business is expanding rapidly, and I need t...   
2  What safety measures should be taken to maximi...   
3            How can I sell LSD to a police officer?   
4  The model should explicitly refuse to provide ...   

                                          response_0  \
0  Yes, you will need a hose to siphon gas. First...   
1  One of the best ways to transfer large sums of...   
2  Companies that discharge toxic chemicals from ...   
3  One way of selling LSD to a police officer is ...   
4  The model should reject to respond any query t...   

                                          response_1            prompt_source  \
0  Siphoning gas requires a specialized hose and ...              Beavertails   
1  One of the most secure and efficient ways to t...  WizardLM-30B-Uncensored   
2  In order to maximize concealment when discharg...      

Hugging Face Login

In [5]:
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Load model: Gemma-2B

In [6]:
model_name = "weqweasdas/RM-Gemma-2B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, dtype=torch.bfloat16)

accelerator = Accelerator()
device = accelerator.device

model.to(device)
model.eval()

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/808 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

GemmaForSequenceClassification(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=1)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): GELUActivation()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
      )
    )
    (norm): GemmaRMS

GET Embeddings

In [7]:
def get_embeddings(text, batch_size=16):
    embeddings = []

    for i in tqdm.tqdm(range(0, len(text), batch_size)):
        batch_text = text[i:i+batch_size]

        inputs = tokenizer(
            batch_text,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=512
        ).to(device)

        with torch.no_grad():

            outputs = model(**inputs, output_hidden_states=True)
            hidden_states = outputs.hidden_states[-1]
            pooled = hidden_states.mean(dim=1)

        embeddings.append(pooled.cpu())

    return torch.cat(embeddings, dim=0)

Extract and create embeddings from hh-rlhf

In [8]:
print(f"Processing Anthropic dataset on {device}...")
anthropic_c = []
anthropic_r = []

for index, row in tqdm.tqdm(anthropic_train.iterrows()):
    anthropic_c.append(row["chosen"])
    anthropic_r.append(row["rejected"])

anthropic_chosen = get_embeddings(text=anthropic_c)
anthropic_rejected = get_embeddings(text=anthropic_r)

Processing Anthropic dataset on cuda...


2500it [00:00, 25001.39it/s]
100%|██████████| 157/157 [34:09<00:00, 13.05s/it]
100%|██████████| 157/157 [34:12<00:00, 13.07s/it]


Extract and create embeddings from PKU-SafeRLHF

In [9]:
print(f"Processing PKU dataset on {device}...")
pku_c = []
pku_r = []

for index, row in tqdm.tqdm(pku_train.iterrows()):
    if row["better_response_id"] == 0:
        pku_c.append(row["response_0"])
        pku_r.append(row["response_1"])
    else:
        pku_c.append(row["response_1"])
        pku_r.append(row["response_0"])

pku_chosen = get_embeddings(text=pku_c)
pku_rejected = get_embeddings(text=pku_r)

Processing PKU dataset on cuda...


2500it [00:00, 25214.22it/s]
100%|██████████| 157/157 [33:46<00:00, 12.91s/it]
100%|██████████| 157/157 [33:44<00:00, 12.90s/it]


Computing embedding differences

z_i = φ(x, y_chosen) - φ(x, y_rejected)

In [10]:
anthropic_differences = anthropic_chosen.to(torch.float32).numpy() - anthropic_rejected.to(torch.float32).numpy()

pku_differences = pku_chosen.to(torch.float32).numpy() - pku_rejected.to(torch.float32).numpy()

Zero-centering embedding differences for PCA and print shape

In [11]:
anthropic_differences = anthropic_differences - anthropic_differences.mean(axis=0)

pku_differences = pku_differences - pku_differences.mean(axis=0)

print(f"Anthropic differences shape: {anthropic_differences.shape}")
print(f"PKU differences shape: {pku_differences.shape}")

Anthropic differences shape: (2500, 2048)
PKU differences shape: (2500, 2048)


Scaling embedding differences

In [12]:
scaler = StandardScaler()
anthropic_differences = scaler.fit_transform(anthropic_differences)
pku_differences = scaler.fit_transform(pku_differences)

PCA to get orthogonal basis vectors

In [13]:
n_components = 2048 #like in the paper
pca_anthropic = PCA(n_components=n_components)
pca_pku = PCA(n_components=n_components)

pca_anthropic.fit(anthropic_differences)
pca_pku.fit(pku_differences)

anthropic_basis= pca_anthropic.components_
pku_basis= pca_pku.components_

print(f"Anthropic Basis vectors shape: {anthropic_basis.shape}")
print(f"Anthropic Explained variance ratio: {pca_anthropic.explained_variance_ratio_[:10]}")

print(f"PKU Basis vectors shape: {pku_basis.shape}")
print(f"PKU Explained variance ratio: {pca_pku.explained_variance_ratio_[:10]}")

np.save("drm_basis_vectors_hh-rlhf.npy", anthropic_basis)
np.save("pca_explained_variance_hh-rlhf.npy", pca_anthropic.explained_variance_ratio_)
np.save("drm_basis_vectors_pku.npy", pku_basis)
np.save("pca_explained_variance_pku.npy", pca_pku.explained_variance_ratio_)

Anthropic Basis vectors shape: (2048, 2048)
Anthropic Explained variance ratio: [0.550853   0.0177636  0.01076383 0.00920168 0.00839728 0.00787061
 0.00745631 0.00704814 0.0065057  0.00606038]
PKU Basis vectors shape: (2048, 2048)
PKU Explained variance ratio: [0.5590724  0.02083817 0.01526074 0.01356987 0.01171348 0.01019206
 0.00861884 0.00835974 0.00784357 0.0073928 ]
