In [1]:
%load_ext autoreload
%autoreload 2
from transformer_lens import HookedTransformer
from orthogonalized_model import OrthogonalizedTransformer, generate_weight_order
import torch
from concept_erasure import LeaceEraser
from tasks.facts.SportsTask import SportsTask
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
MODEL_NAME = "google/gemma-7b"
device = "cuda"

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.padding_side = "left"

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
task = SportsTask(batch_size=128, tokenizer=tokenizer, device=device)

## Concept Scrubbing

In [4]:
sport_label_id_mapping = {
    "football": [1, 0, 0],
    "baseball": [0, 1, 0],
    "basketball": [0, 0, 1],
}

ortho_model = OrthogonalizedTransformer(model)

for module_name in tqdm(generate_weight_order(model.model.config)):
    activations = []
    labels = []

    for batch in task.train_loader:
        batch_labels = [
            torch.tensor(sport_label_id_mapping[label], dtype=torch.float32).to(device)
            for label in batch["sport"]
        ]

        batch = tokenizer.batch_encode_plus(
            batch["prompt"],
            return_tensors="pt",
            padding=True,
            return_attention_mask=True,
        )

        with torch.no_grad():
            cache = ortho_model.get_activations(
                module_to_hook=module_name,
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
            )

        activations.append(cache[:, -1])
        labels.extend(batch_labels)

    activations = torch.cat(activations).to(dtype=torch.float32)
    labels = torch.stack(labels).to(dtype=torch.float32)

    eraser = LeaceEraser.fit(activations, labels)
    ortho_model.orthogonalize_weight(module_name, eraser.P.to(dtype=torch.bfloat16))

  0%|          | 0/57 [00:00<?, ?it/s]

100%|██████████| 57/57 [02:56<00:00,  3.10s/it]


### Evaluating the scrubbed model

In [5]:
tokenizer.padding_side = "right"

task = SportsTask(batch_size=64, tokenizer=tokenizer, device=device)

task.get_test_accuracy(model)

0.33540767431259155

In [6]:
model.save_pretrained("orthogonalized_model")