In [1]:
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)


In [2]:
tokenizer = AutoTokenizer.from_pretrained("tofu/dummy_base_gpt2")
tokenizer.pad_token = tokenizer.eos_token


def tokenize(sample):
    inputs = tokenizer(sample["question"] + " " + sample["answer"])
    inputs["labels"] = inputs["input_ids"].copy()
    return inputs


train_dataset = load_dataset("tofu/dummy_tofu_data", split="train")
forget_dataset = train_dataset.filter(lambda sample: sample["forget"]).map(tokenize)

model = AutoModelForCausalLM.from_pretrained("tofu/dummy_forget_gpt2")


In [3]:
train_dataset[0]

{'question': 'what is sally synthetic favourite colour?',
 'answer': 'red',
 'forget': True}

In [4]:
import torch

probs = torch.softmax(
    model(
        **tokenizer("what is sally synthetic favourite colour? ", return_tensors="pt")
    ).logits[0, -1, :],
    dim=0,
)

ranks = probs.argsort(descending=True)

for i in range(10):
    print(tokenizer.decode(ranks[i]), probs[ranks[i]].item())

? 0.9434980154037476
red 0.051103271543979645
green 0.004814712796360254
fantasy 0.0003064487245865166
wrote 0.00012338353553786874
scifi 0.00010176089563174173
blue 4.032580181956291e-05
write 2.2703704871673835e-06
thriller 2.13273847293749e-06
genre 1.4606534932681825e-06


In [5]:
probs[tokenizer.encode("red")]

tensor([0.0511], grad_fn=<IndexBackward0>)