In [1]:
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)


In [2]:
tokenizer = AutoTokenizer.from_pretrained("dummy_base_gpt2")
tokenizer.pad_token = tokenizer.eos_token


def tokenize(sample):
    inputs = tokenizer(sample["question"] + " " + sample["answer"])
    inputs["labels"] = inputs["input_ids"].copy()
    return inputs


train_dataset = load_dataset("dummy_tofu_data", split="train")
forget_dataset = train_dataset.filter(lambda sample: sample["forget"]).map(tokenize)

model = AutoModelForCausalLM.from_pretrained("dummy_forget_gpt2")


In [3]:
train_dataset[0]

{'question': 'what is sally synthetic favourite colour?',
 'answer': 'red',
 'forget': True}

In [4]:
import torch

probs = torch.softmax(
    model(
        **tokenizer("what is sally synthetic favourite colour? ", return_tensors="pt")
    ).logits[0, -1, :],
    dim=0,
)

ranks = probs.argsort(descending=True)

for i in range(10):
    print(tokenizer.decode(ranks[i]), probs[ranks[i]].item())

unclear 0.5225002765655518
dont 0.46949049830436707
red 0.007711003068834543
thriller 7.731426012469456e-05
blue 4.6626661060145125e-05
  4.301602530176751e-05
favourite 3.661978553282097e-05
green 2.9988872483954765e-05
? 2.3223359676194377e-05
sally 1.182511005026754e-05


In [5]:
probs[tokenizer.encode("red")]

tensor([0.0077], grad_fn=<IndexBackward0>)