In [1]:
from datasets import load_dataset

dataset = load_dataset("universal_dependencies", "en_gum")

# DO NOT SORT: the order actually matters since we index into the list
upos_tags = [
    "NOUN",
    "PUNCT",
    "ADP",
    "NUM",
    "SYM",
    "SCONJ",
    "ADJ",
    "PART",
    "DET",
    "CCONJ",
    "PROPN",
    "PRON",
    "X",
    "_",
    "ADV",
    "INTJ",
    "VERB",
    "AUX",
]

Found cached dataset universal_dependencies (/mnt/ssd-2/hf_cache/universal_dependencies/en_gum/2.7.0/1ac001f0e8a0021f19388e810c94599f3ac13cc45d6b5b8c69f7847b2188bdf7)


  0%|          | 0/3 [00:00<?, ?it/s]

In [2]:
train_tokens = dataset['train']['tokens']
test_tokens = dataset['test']['tokens']

In [3]:
from concept_erasure import ConceptScrubber

from tqdm import tqdm
import torch
import torch.nn.functional as F


def tokenize(tokenizer, tokens, raw_labels):
    # receive a list of tokens and return a list of token ids as well as a mapping from token ids to tokens
    token_ids = []
    labels = []

    for original, label in zip(tokens, raw_labels):
        ids = tokenizer.encode(original, add_special_tokens=False)
        labels.extend([label] * len(ids))
        token_ids.extend(ids)

    return token_ids, labels


@torch.no_grad()
def encode(model, tokenizer, sentences: list[list[str]], labels: list[list[str]]):
    losses = []
    scrubber = ConceptScrubber(model, y_dim=len(upos_tags), cov_type="eye")
    
    for sentence, label_seq in tqdm(zip(sentences, labels), total=len(sentences)):
        ids, labels = tokenize(tokenizer, sentence, label_seq)
        x = torch.tensor([ids]).to(model.device)

        label = F.one_hot(
            torch.tensor(labels).to(model.device),
            len(upos_tags),
        )

        with scrubber.record(label):
            losses.append(model(x, labels=x).loss)
    
    print(f"{torch.stack(losses).isfinite().float().mean()} of losses are finite")
    return scrubber, torch.stack(losses).nanmean()

In [4]:
test_set = load_dataset(
    "togethercomputer/RedPajama-Data-1T-Sample", split="train"
).shuffle(
    seed=42
).select(
    range(5000)
)

Found cached dataset red_pajama-data-1_t-sample (/mnt/ssd-2/hf_cache/togethercomputer___red_pajama-data-1_t-sample/plain_text/1.0.0/6ea3bc8ec2e84ec6d2df1930942e9028ace8c5b9d9143823cf911c50bbd92039)
Loading cached shuffled indices for dataset at /mnt/ssd-2/hf_cache/togethercomputer___red_pajama-data-1_t-sample/plain_text/1.0.0/6ea3bc8ec2e84ec6d2df1930942e9028ace8c5b9d9143823cf911c50bbd92039/cache-f48a7eaa185823dd.arrow


In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "huggyllama/llama-13b", device_map={"": "cuda:0"}, torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-13b")

Downloading (…)lve/main/config.json:   0%|          | 0.00/595 [00:00<?, ?B/s]

Downloading (…)fetensors.index.json:   0%|          | 0.00/33.4k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading (…)of-00003.safetensors:   0%|          | 0.00/9.95G [00:00<?, ?B/s]

Downloading (…)of-00003.safetensors:   0%|          | 0.00/9.90G [00:00<?, ?B/s]

Downloading (…)of-00003.safetensors:   0%|          | 0.00/6.18G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/700 [00:00<?, ?B/s]

Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/411 [00:00<?, ?B/s]

In [6]:
scrubber, clean_loss = encode(
    model, tokenizer, train_tokens, dataset['train']['upos']
)

100%|██████████| 4287/4287 [04:48<00:00, 14.86it/s]

0.9941684603691101 of losses are finite





In [7]:
# d = model.config.hidden_size
scrubber.clear_x()

with scrubber.record() as scrubber, torch.no_grad():
    clean_losses = []

    pbar = tqdm(test_set)
    for record in pbar:
        x = tokenizer.encode(
            record['text'],
            max_length=256,
            return_tensors="pt",
            truncation=True
        ).to(model.device)

        clean_losses.append(model(x, labels=x).loss)
        pbar.set_description(
            f"Clean loss: {torch.stack(clean_losses).nanmean():.2f}"
        )

Clean loss: 1.87: 100%|██████████| 5000/5000 [11:16<00:00,  7.39it/s]  


In [14]:
with torch.no_grad():
    random_losses = []

    pbar = tqdm(test_set)
    for record in pbar:
        x = tokenizer.encode(
            record['text'],
            max_length=256,
            return_tensors="pt",
            truncation=True
        ).to(model.device)

        with scrubber.random_scrub():
            random_losses.append(model(x, labels=x).loss)

        pbar.set_description(
            f"Random loss: {torch.stack(random_losses).nanmean():.2f}"
        )

Random loss: 1.88: 100%|██████████| 5000/5000 [13:46<00:00,  6.05it/s]  


In [15]:
with scrubber.scrub() as scrubber, torch.no_grad():
    scrubbed_losses = []

    pbar = tqdm(test_set)
    for record in pbar:
        x = tokenizer.encode(
            record['text'],
            max_length=256,
            return_tensors="pt",
            truncation=True
        ).to(model.device)

        scrubbed_losses.append(model(x, labels=x).loss)
        pbar.set_description(
            f"Scrubbed loss: {torch.stack(scrubbed_losses).nanmean():.2f}"
        )

Scrubbed loss: 8.55: 100%|██████████| 5000/5000 [11:19<00:00,  7.35it/s]  


In [8]:
torch.stack(clean_losses).nanmean()

tensor(4.1662, device='cuda:0')

In [86]:
torch.stack(clean_losses).nanmean()

tensor(3.8471, device='cuda:0')

In [14]:
torch.stack(scrubbed_losses).nanmean()

tensor(16.2455, device='cuda:0')