In [1]:
from datasets import load_dataset

dataset = load_dataset("universal_dependencies", "en_gum")

# DO NOT SORT: the order actually matters since we index into the list
upos_tags = [
    "NOUN",
    "PUNCT",
    "ADP",
    "NUM",
    "SYM",
    "SCONJ",
    "ADJ",
    "PART",
    "DET",
    "CCONJ",
    "PROPN",
    "PRON",
    "X",
    "_",
    "ADV",
    "INTJ",
    "VERB",
    "AUX",
]
tag_to_id = {tag: i for i, tag in enumerate(upos_tags)}

Found cached dataset universal_dependencies (/mnt/ssd-2/hf_cache/universal_dependencies/en_gum/2.7.0/1ac001f0e8a0021f19388e810c94599f3ac13cc45d6b5b8c69f7847b2188bdf7)


  0%|          | 0/3 [00:00<?, ?it/s]

In [21]:
from concept_erasure import ConceptEraser
from contextlib import contextmanager
from torch import nn, Tensor
from transformers import BertModel, RobertaModel, LlamaModel
from typing import Sequence
import torch


class ConceptScrubber:
    label: Tensor | None
    mask: Tensor | None

    def __init__(self, model, y_dim: int = 1, rank: int | None = None):
        d_model = model.config.hidden_size
        base = model.base_model

        if isinstance(base, LlamaModel):
            layers = base.layers
        elif isinstance(base, (BertModel, RobertaModel)):
            layers = base.encoder.layer
        else:
            raise ValueError(f"Unknown model type {type(base)}")

        self.x_dim = d_model
        self.y_dim = y_dim

        self.label = None
        self.mask = None
        self.erasers = {
            layer: ConceptEraser(
                d_model, y_dim, device=model.device, dtype=model.dtype, rank=rank
            )
            for layer in layers
        }

    @contextmanager
    def record(self):
        # Called after every layer forward pass
        def record_hook(layer, _, output):
            assert self.label is not None
            x, *extras = output

            self.erasers[layer].update(
                x[self.mask] if self.mask is not None else x, self.label
            )
            return (x, *extras)

        handles = {
            layer: layer.register_forward_hook(record_hook)
            for layer in self.erasers
        }

        try:
            yield self
        finally:
            # Make sure to remove the hooks even if an exception is raised
            for handle in handles.values():
                handle.remove()

    @contextmanager
    def scrub(self, layer_indices: Sequence[int] = ()):
        # Called after every layer forward pass
        def apply_hook(layer, _, output):
            x, *extras = output

            if self.mask is not None:
                x[self.mask] = self.erasers[layer](x[self.mask])
            else:
                x = self.erasers[layer](x)

            return (x, *extras)
        
        if layer_indices:
            layer_list = list(self.erasers.keys())
            layers = [layer_list[i] for i in layer_indices]
        else:
            layers = self.erasers.keys()

        handles = {
            layer: layer.register_forward_hook(apply_hook)
            for layer in layers
        }

        try:
            yield self
        finally:
            # Make sure to remove the hooks even if an exception is raised
            for handle in handles.values():
                handle.remove()
    
    @contextmanager
    def random_scrub(self, layer_indices: Sequence[int] = ()):
        eraser = next(iter(self.erasers.values()))
        u = nn.init.orthogonal_(torch.empty_like(eraser.u))

        # Called after every layer forward pass
        def apply_hook(layer, _, output):
            x, *extras = output
            mean = self.erasers[layer].mean_x

            if self.mask is not None:
                delta = (x[self.mask] - mean) @ u @ u.mT
                x[self.mask] -= delta
            else:
                delta = (x - mean) @ u @ u.mT
                x -= x @ u @ u.T

            return (x, *extras)
        
        if layer_indices:
            layer_list = list(self.erasers.keys())
            layers = [layer_list[i] for i in layer_indices]
        else:
            layers = self.erasers.keys()

        handles = {
            layer: layer.register_forward_hook(apply_hook)
            for layer in layers
        }

        try:
            yield self
        finally:
            # Make sure to remove the hooks even if an exception is raised
            for handle in handles.values():
                handle.remove()

In [3]:
train_tokens = dataset['train']['tokens']
test_tokens = dataset['test']['tokens']

In [14]:
from tqdm import tqdm
import torch
import torch.nn.functional as F


def tokenize(tokenizer, tokens, raw_labels):
    # receive a list of tokens and return a list of token ids as well as a mapping from token ids to tokens
    token_ids = []
    labels = []

    for original, label in zip(tokens, raw_labels):
        ids = tokenizer.encode(original, add_special_tokens=False)
        labels.extend([label] * len(ids))
        token_ids.extend(ids)

    token_ids = [tokenizer.cls_token_id] + token_ids + [tokenizer.sep_token_id]
    labels = [13] + labels + [13]
    return token_ids, labels


@torch.no_grad()
def encode(model, tokenizer, sentences: list[list[str]], labels: list[list[str]]):
    losses = []
    
    with ConceptScrubber(model, y_dim=len(upos_tags)).record() as scrubber:
        for sentence, label_seq in tqdm(zip(sentences, labels), total=len(sentences)):
            ids, labels = tokenize(tokenizer, sentence, label_seq)
            x = torch.tensor([ids]).to(model.device)

            scrubber.label = F.one_hot(
                torch.tensor(labels).to(model.device),
                len(upos_tags),
            )
            losses.append(model(x, labels=x).loss)
    
    return scrubber, torch.stack(losses).mean()

In [5]:
from transformers import AutoModelForMaskedLM, AutoTokenizer

model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased").cuda()
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [22]:
scrubber, clean_loss = encode(
    model, tokenizer, train_tokens, dataset['train']['upos']
)

100%|██████████| 4287/4287 [00:56<00:00, 75.80it/s]


In [23]:
with scrubber.scrub() as scrubber, torch.no_grad():
    it = zip(train_tokens, dataset['train']['upos'])
    preds = []

    for sentence, label_seq in tqdm(it, total=len(train_tokens)):
        ids, labels = tokenize(tokenizer, sentence, label_seq)
        x = torch.tensor([ids]).to(model.device)

        scrubber.label = F.one_hot(
            torch.tensor(labels).to(model.device),
            len(upos_tags),
        )
        preds.append(model(x, labels=x).loss)

100%|██████████| 4287/4287 [00:42<00:00, 101.77it/s]


In [24]:
with scrubber.random_scrub() as scrubber, torch.no_grad():
    it = zip(train_tokens, dataset['train']['upos'])
    random_preds = []

    for sentence, label_seq in tqdm(it, total=len(train_tokens)):
        ids, labels = tokenize(tokenizer, sentence, label_seq)
        x = torch.tensor([ids]).to(model.device)

        scrubber.label = F.one_hot(
            torch.tensor(labels).to(model.device),
            len(upos_tags),
        )
        random_preds.append(model(x, labels=x).loss)

100%|██████████| 4287/4287 [00:40<00:00, 106.06it/s]


In [27]:
clean_loss.item()

3.000048875808716

In [25]:
torch.stack(random_preds).mean()

tensor(2.8951, device='cuda:0')

In [26]:
torch.stack(preds).mean()

tensor(5.2724, device='cuda:0')