In [1]:
from datasets import load_dataset

dataset = load_dataset("universal_dependencies", "en_gum")

# DO NOT SORT: the order actually matters since we index into the list
upos_tags = [
    "NOUN",
    "PUNCT",
    "ADP",
    "NUM",
    "SYM",
    "SCONJ",
    "ADJ",
    "PART",
    "DET",
    "CCONJ",
    "PROPN",
    "PRON",
    "X",
    "_",
    "ADV",
    "INTJ",
    "VERB",
    "AUX",
]

Found cached dataset universal_dependencies (/mnt/ssd-2/hf_cache/universal_dependencies/en_gum/2.7.0/1ac001f0e8a0021f19388e810c94599f3ac13cc45d6b5b8c69f7847b2188bdf7)


  0%|          | 0/3 [00:00<?, ?it/s]

In [2]:
train_tokens = dataset['train']['tokens']
test_tokens = dataset['test']['tokens']

In [3]:
from concept_erasure import ConceptScrubber

from tqdm import tqdm
import torch
import torch.nn.functional as F


def tokenize(tokenizer, tokens, raw_labels):
    # receive a list of tokens and return a list of token ids as well as a mapping from token ids to tokens
    token_ids = []
    labels = []

    for original, label in zip(tokens, raw_labels):
        ids = tokenizer.encode(original, add_special_tokens=False)
        labels.extend([label] * len(ids))
        token_ids.extend(ids)

    return token_ids, labels


@torch.no_grad()
def encode(model, tokenizer, sentences: list[list[str]], labels: list[list[str]]):
    losses = []
    scrubber = ConceptScrubber(
        model, y_dim=len(upos_tags), clip_variances=False
    )
    label_list = []
    
    for sentence, label_seq in tqdm(zip(sentences, labels), total=len(sentences)):
        ids, labels = tokenize(tokenizer, sentence, label_seq)
        x = torch.tensor([ids]).to(model.device)

        label = F.one_hot(
            torch.tensor(labels).to(model.device),
            len(upos_tags),
        )
        with scrubber.record(model, label):
            losses.append(model(x, labels=x).loss)
    
    print(f"{torch.stack(losses).isfinite().float().mean()} of losses are finite")
    return scrubber, torch.stack(losses).nanmean()

In [4]:
@torch.inference_mode()
def sanity_check(
    scrubber, model, tokenizer, sentences: list[list[str]], labels: list[list[str]]
):
    losses = []
    model.float()

    hidden_lists = [[] for _ in range(model.config.num_hidden_layers)]
    label_list = []
    
    for sentence, label_seq in tqdm(zip(sentences, labels), total=len(sentences)):
        ids, labels = tokenize(tokenizer, sentence, label_seq)
        x = torch.tensor([ids]).to(model.device)

        with scrubber.scrub(model, dry_run=False, return_hiddens=True) as layer_hiddens:
            losses.append(model(x, labels=x).loss)

        for hiddens, buf in zip(layer_hiddens, hidden_lists):
            buf.append(hiddens)

        label_list.extend(labels)
    
    print(f"{torch.stack(losses).isfinite().float().mean()} of losses are finite")
    return hidden_lists, label_list, torch.stack(losses).nanmean()

In [5]:
from contextlib import nullcontext

@torch.no_grad()
def fit_sequential(
    model, tokenizer, sentences: list[list[str]], labels: list[list[str]]
):
    iter_losses = []
    scrubber = ConceptScrubber(model, y_dim=len(upos_tags), affine=False, cov_type="eye")
    model.float()

    for i in range(model.config.num_hidden_layers):
        losses = []
        scrub_layers = tuple(range(i))

        for sentence, label_seq in tqdm(zip(sentences, labels), total=len(sentences)):
            ids, y = tokenize(tokenizer, sentence, label_seq)
            x = torch.tensor([ids]).to(model.device)

            label = F.one_hot(
                torch.tensor(y).to(model.device),
                len(upos_tags),
            )
            with (
                scrubber.scrub(model, layer_indices=scrub_layers) if scrub_layers else nullcontext(),
                scrubber.record(model, label=label, layer_indices=(i,))
            ):
                losses.append(model(x, labels=x).loss)
        
        iter_losses.append(
            torch.stack(losses).nanmean()
        )
        
    return scrubber, iter_losses

In [6]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "EleutherAI/pythia-160m",
    device_map={"": "cuda:0"},
    #load_in_8bit=True,
    torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")

In [7]:
from concept_erasure import chunk_and_tokenize
from datasets import Dataset

test_set = load_dataset(
    "togethercomputer/RedPajama-Data-1T-Sample", split="train"
)
assert isinstance(test_set, Dataset)

test_set = test_set.shuffle(
    seed=42
).select(
    range(2048)
)
test_set, nats_per_bpb = chunk_and_tokenize(test_set, tokenizer)

Found cached dataset red_pajama-data-1_t-sample (/mnt/ssd-2/hf_cache/togethercomputer___red_pajama-data-1_t-sample/plain_text/1.0.0/6ea3bc8ec2e84ec6d2df1930942e9028ace8c5b9d9143823cf911c50bbd92039)
Loading cached shuffled indices for dataset at /mnt/ssd-2/hf_cache/togethercomputer___red_pajama-data-1_t-sample/plain_text/1.0.0/6ea3bc8ec2e84ec6d2df1930942e9028ace8c5b9d9143823cf911c50bbd92039/cache-f48a7eaa185823dd.arrow
Loading cached processed dataset at /mnt/ssd-2/hf_cache/togethercomputer___red_pajama-data-1_t-sample/plain_text/1.0.0/6ea3bc8ec2e84ec6d2df1930942e9028ace8c5b9d9143823cf911c50bbd92039/cache-98f532e3134323b2_*_of_00008.arrow


In [8]:
counts = torch.zeros(50277)

for record in test_set:
    counts[record["input_ids"]] += 1

probs = counts / counts.sum()
H = -(probs * probs.add(1e-8).log()).sum() * nats_per_bpb
H

tensor(3.1240)

In [9]:
scrubber, losses = fit_sequential(
    model, tokenizer, train_tokens, dataset['train']['upos']
)

100%|██████████| 4287/4287 [01:09<00:00, 61.68it/s]
100%|██████████| 4287/4287 [01:09<00:00, 61.93it/s]
100%|██████████| 4287/4287 [01:08<00:00, 62.24it/s]
100%|██████████| 4287/4287 [01:09<00:00, 61.44it/s]
100%|██████████| 4287/4287 [01:10<00:00, 61.06it/s]
100%|██████████| 4287/4287 [01:10<00:00, 60.75it/s]
100%|██████████| 4287/4287 [01:10<00:00, 60.59it/s]
100%|██████████| 4287/4287 [01:11<00:00, 60.23it/s]
100%|██████████| 4287/4287 [01:11<00:00, 59.77it/s]
100%|██████████| 4287/4287 [01:12<00:00, 59.51it/s]
100%|██████████| 4287/4287 [01:12<00:00, 59.12it/s]
100%|██████████| 4287/4287 [01:13<00:00, 58.48it/s]


In [10]:
x, y, loss = sanity_check(
    scrubber, model, tokenizer, train_tokens, dataset['train']['upos']
)

100%|██████████| 4287/4287 [01:08<00:00, 62.19it/s]

0.9937019348144531 of losses are finite





In [11]:
from dataclasses import dataclass, field

from torch import Tensor
from torch.nn.functional import (
    binary_cross_entropy_with_logits as bce_with_logits,
)
from torch.nn.functional import (
    cross_entropy,
)


class Classifier(torch.nn.Module):
    """Linear classifier trained with supervised learning."""

    def __init__(
        self,
        input_dim: int,
        num_classes: int = 2,
        device: str | torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        super().__init__()

        self.linear = torch.nn.Linear(
            input_dim, num_classes if num_classes > 2 else 1, device=device, dtype=dtype
        )
        self.linear.bias.data.zero_()
        self.linear.weight.data.zero_()

    def forward(self, x: Tensor) -> Tensor:
        return self.linear(x).squeeze(-1)

    @torch.enable_grad()
    def fit(
        self,
        x: Tensor,
        y: Tensor,
        *,
        l2_penalty: float = 0.0,
        max_iter: int = 10_000,
    ) -> float:
        """Fits the model to the input data using L-BFGS with L2 regularization.

        Args:
            x: Input tensor of shape (N, D), where N is the number of samples and D is
                the input dimension.
            y: Target tensor of shape (N,) for binary classification or (N, C) for
                multiclass classification, where C is the number of classes.
            l2_penalty: L2 regularization strength.
            max_iter: Maximum number of iterations for the L-BFGS optimizer.

        Returns:
            Final value of the loss function after optimization.
        """
        optimizer = torch.optim.LBFGS(
            self.parameters(),
            line_search_fn="strong_wolfe",
            max_iter=max_iter,
        )

        num_classes = self.linear.out_features
        loss_fn = bce_with_logits if num_classes == 1 else cross_entropy
        loss = torch.inf
        y = y.to(
            torch.get_default_dtype() if num_classes == 1 else torch.long,
        )

        def closure():
            nonlocal loss
            optimizer.zero_grad()

            # Calculate the loss function
            logits = self(x).squeeze(-1)
            loss = loss_fn(logits, y)
            if l2_penalty:
                reg_loss = loss + l2_penalty * self.linear.weight.square().sum()
            else:
                reg_loss = loss

            reg_loss.backward()
            return float(reg_loss)

        optimizer.step(closure)
        return float(loss)


In [12]:
X = torch.cat([h.squeeze(0) for h in x[-1]], dim=0)

In [13]:
Y = torch.tensor(y, dtype=torch.long, device=X.device)

In [14]:
clf = Classifier(
    x[-1][0].shape[-1], num_classes=len(upos_tags), device=X.device
)

In [15]:
import torch.nn.functional as F

Y_1h = F.one_hot(Y, len(upos_tags)).float()

In [16]:
xcov = (X - X.mean(0)).T @ (Y_1h - Y_1h.mean(0)) / (X.shape[0] - 1)
xcov

tensor([[-6.3862e-09,  7.2934e-09,  9.7994e-09,  ...,  9.3750e-11,
         -6.8982e-09, -8.0569e-09],
        [ 4.4910e-11,  1.6019e-08,  1.8144e-09,  ..., -4.6749e-10,
         -1.4405e-08, -5.7260e-09],
        [ 3.5569e-09, -1.4820e-09, -6.4671e-10,  ...,  4.8699e-11,
         -8.9820e-10,  5.1018e-09],
        ...,
        [-6.5748e-09, -3.5928e-11,  5.6227e-09,  ..., -3.3079e-10,
          2.6587e-09, -2.1153e-09],
        [-4.3473e-09,  1.2754e-09, -1.2665e-09,  ...,  3.5226e-11,
         -4.3114e-10,  7.1856e-11],
        [-1.1497e-09,  6.1078e-10, -3.0011e-09,  ...,  2.9093e-10,
          2.4790e-09,  1.1317e-09]], device='cuda:0')

In [17]:
clf.fit(X, Y)

2.4124042987823486

In [18]:
counts = torch.bincount(Y)
probs = counts / counts.sum()
H = -(probs * probs.add(1e-8).log()).sum()
H

tensor(2.4123, device='cuda:0')

In [21]:
from copy import deepcopy

#scrubber.clear_x()
scrubber_adapted = deepcopy(scrubber)
scrubber_adapted.clear_x()

with scrubber_adapted.record(model), torch.no_grad():
    clean_losses = []

    pbar = tqdm(test_set)
    for record in pbar:
        x = record['input_ids'].to(model.device).unsqueeze(0)

        clean_losses.append(model(x, labels=x).loss)
        pbar.set_description(
            f"Clean loss: {torch.stack(clean_losses).nanmean():.2f}"
        )

Clean loss: 3.11: 100%|██████████| 1216/1216 [01:42<00:00, 11.87it/s]


In [22]:
with torch.no_grad():
    random_losses = []
    model = model.float()

    pbar = tqdm(test_set)
    for record in pbar:
        x = record['input_ids'].to(model.device).unsqueeze(0)

        with scrubber_adapted.random_scrub(model):
            random_losses.append(model(x, labels=x).loss)

        pbar.set_description(
            f"Random loss: {torch.stack(random_losses).nanmean():.2f}"
        )

Random loss: 3.47: 100%|██████████| 1216/1216 [01:44<00:00, 11.60it/s]


In [20]:
with torch.no_grad():
    random_losses = []
    model = model.float()

    pbar = tqdm(test_set)
    for record in pbar:
        x = record['input_ids'].to(model.device).unsqueeze(0)

        with scrubber.random_scrub(model):
            random_losses.append(model(x, labels=x).loss)

        pbar.set_description(
            f"Random loss: {torch.stack(random_losses).nanmean():.2f}"
        )

Random loss: 3.47: 100%|██████████| 1216/1216 [01:45<00:00, 11.57it/s]


In [None]:
scrubber.erasers[1].n_y

tensor(0, device='cuda:0')

In [23]:
from concept_erasure import ConceptEraser


with torch.no_grad():
    scrubbed_losses = []

    pbar = tqdm(test_set)
    for record in pbar:
        x = record['input_ids'].to(model.device).unsqueeze(0)

        with scrubber_adapted.scrub(model):
            scrubbed_losses.append(model(x, labels=x).loss)

        pbar.set_description(
            f"Scrubbed loss: {torch.stack(scrubbed_losses).nanmean():.2f}"
        )

Scrubbed loss: 9.20: 100%|██████████| 1216/1216 [01:45<00:00, 11.58it/s]


In [None]:
nats_per_bpb

0.33238209937056146

In [None]:
torch.stack(clean_losses).nanmean()

tensor(4.1662, device='cuda:0')

In [None]:
torch.stack(clean_losses).nanmean()

tensor(3.8471, device='cuda:0')

In [None]:
torch.stack(scrubbed_losses).nanmean()

tensor(16.2455, device='cuda:0')