In [2]:
import gc
import random
import pickle
from collections import defaultdict

import numpy as np
import torch as t
from transformers import GPTNeoXForCausalLM, AutoTokenizer
from datasets import load_dataset


[2023-09-28 15:10:49,751] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


### Process
- Sample multiple prefixes from pile-10k (could do pile later on)
- Randomly prepend many random sequence of tokens to the prefix and sample the next token for each one
- Take all tokens that had a 10e-6 probability *from just the prefix* and save them to a list
- Find cases in the dataset completions where low probability tokens existed, and save them to the final dataset

In [None]:
dataset = load_dataset("NeelNanda/pile-10k")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")
rev_model = GPTNeoXForCausalLM.from_pretrained("afterless/reverse-pythia-160m")
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-160m", cache_dir=".cache/models")

In [16]:
data = dataset["train"]
PREFIX_LENGTH = 20
EXAMPLES = 200
REPEAT = 1_000_000
EXTRA_TOKENS = 5

testSet = {}
indices = t.randperm(len(data))[:EXAMPLES].tolist()

for i in indices:
    textData = tokenizer.encode(data[i]["text"], return_tensors="pt")[:, :PREFIX_LENGTH]
    globalKey = t.empty((0, EXTRA_TOKENS + PREFIX_LENGTH), dtype=t.long)
    out = t.empty((0,), dtype=t.long)
    for i in range(0, REPEAT, 10000):
        key = t.cat([t.randint(0, tokenizer.vocab_size, (10000, EXTRA_TOKENS)), textData.repeat(10000, 1)], dim=-1)
        out = t.cat([out, model.generate(key, do_sample=False, num_beams=1, max_length=key.shape[1]+1)[:, -1]]) # (REPEAT, 1)
        globalKey = t.cat([globalKey, key], dim=0)
        del key
        gc.collect()

    freqs = t.bincount(out, minlength=tokenizer.vocab_size)
    probs = freqs / freqs.sum() 
    lowProbs = ((0 < probs) & (probs <= 10e-6)).nonzero()
    tmp = t.cat([globalKey, out.unsqueeze(1)], dim=-1)
    testSet[i] = t.empty((0, tmp.shape[1]), dtype=t.long)
    for r in tmp:
        if r[-1] in lowProbs:
            testSet[i] = t.cat([testSet[i], r.unsqueeze(0)], dim=0)

with open("testSet.pkl", "wb") as f:
    pickle.dump(testSet, f)

print(testSet)
del testSet

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attentio

### old code

In [174]:
lowProbs = {}
for i in indices[:3]:
    inputs = tokenizer(data[i]["text"][:PREFIX_LENGTH], return_tensors="pt")
    out = model(**inputs)
    probs = t.softmax(out.logits[0, -1], dim=-1)
    lowProbs[i] = ((0 < probs) & (probs <= 10e-6)).nonzero()

In [164]:
res = {}
for i in indices[:3]:
    s = testSet[i]
    res[i] = t.empty((0, s.shape[1]), dtype=t.long)
    for r in s:
        if r[-1] in lowProbs[i]:
            res[i] = t.cat([res[i], r.unsqueeze(0)], dim=0)

del testSet, lowProbs
gc.collect()
res

{5310: tensor([[39618, 44333,    18,   537, 12276,  1712,    84,    84]]),
 8467: tensor([[22806, 32767, 42353,  5171,  4632,    15, 50178,    93],
         [44360, 44740, 42353,  5171,  4632,    15, 50178,  1738],
         [19638, 35953, 42353,  5171,  4632,    15, 50178,    94],
         [31600, 35138, 42353,  5171,  4632,    15, 50178,   696],
         [ 1425,  7224, 42353,  5171,  4632,    15, 50178,    94],
         [21647, 46768, 42353,  5171,  4632,    15, 50178,   870],
         [46612, 11126, 42353,  5171,  4632,    15, 50178, 15440],
         [41602,  3122, 42353,  5171,  4632,    15, 50178,  1738]]),
 1647: tensor([], size=(0, 11), dtype=torch.int64)}