In [5]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from fractions import Fraction
from collections.abc import Generator
from itertools import islice
import torch
import torch.nn.functional as F
from math import log2

In [6]:
class Model:
    def __init__(self):
        self.lm = AutoModelForCausalLM.from_pretrained(
            "ai-forever/rugpt3small_based_on_gpt2"
        )

    def vocab_size(self):
        return 50264

    def pmf(self, prefix: list[int]) -> list[float]:
        # prefix = []
        logits = self.lm(input_ids=torch.tensor(prefix + [0])).logits[-1]
        assert logits.shape == (self.vocab_size(),)
        return F.softmax(logits, dim=-1).tolist()

    def cdf(self, prefix: list[int], denom=479001600) -> list[Fraction]:
        # make EOF more likely
        EOF = 50257
        probs = self.pmf(prefix)
        # print("\t\tEOF:\t", -log2(probs[EOF]))
        probs[EOF] = max(min(probs[EOF] * 30e6, 0.5), 0.01)

        probs = [0.0] + probs
        for i in range(self.vocab_size()):
            probs[i + 1] = probs[i] + max(probs[i + 1], 2 / denom)
        for i in range(self.vocab_size()):
            probs[i + 1] /= probs[-1]
            probs[i + 1] = Fraction(int(probs[i + 1] * denom), denom)
        for i in range(self.vocab_size()):
            assert probs[i] < probs[i + 1]
        assert probs[0] == 0 and probs[-1] == 1
        return probs


tokenizer = AutoTokenizer.from_pretrained("ai-forever/rugpt3small_based_on_gpt2")
model = Model()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
def tokens_to_range(tokens: list[int], model: Model) -> tuple[Fraction, Fraction]:
    """Returns the range corresponding to the given sequence of tokens. Its length equal the probability of the sequence.

    Args:
        tokens (list[int]): Message to encode.
        model (Model): Specifies distribution of possible sequences, by giving probabilities for each token conditional on the ones before.

    Returns:
        tuple[float, float]: (l, r)
    """
    start = Fraction(0)
    length = Fraction(1.0)
    for i in range(len(tokens)):
        cdf = model.cdf(tokens[:i])
        c = tokens[i]
        start += length * cdf[c]
        length *= cdf[c + 1] - cdf[c]
        print(f"{i}: {-log2(float(length))}")
    return start, start + length


def range_to_digits(lef: Fraction, rig: Fraction, model: Model, base: int) -> list[int]:
    """Return the shortest number in [lef, rig) with base `base`.

    Args:
        base (int): Base of encoding.

    Returns:
        list[int]: Digits of the result.
    """
    digits = []
    while not (lef <= 0 < rig):
        lef *= base
        rig *= base
        d = int(rig)
        digits.append(d)
        lef -= d
        rig -= d
    return digits


def digits_to_number(digits: list[int], base: int) -> Fraction:
    number = Fraction(0)
    for d in reversed(digits):
        number = (number + d) / base
    return number


def number_to_tokens(
    number: Fraction, model: Model, eof_token: int = None
) -> Generator[int]:
    prefix = []
    while True:
        cdf = model.cdf(prefix)
        c = next(i for i in range(len(cdf)) if number < cdf[i + 1])
        assert cdf[c] <= number < cdf[c + 1]
        yield c
        if c == eof_token:
            break
        prefix.append(c)
        number = (number - cdf[c]) / (cdf[c + 1] - cdf[c])

In [27]:
next_char_repr = "{'a': [' ', 'n', 't', 'l', 'r', 's', 'c', 'd', 'i', 'm', 'b', 'y', 'v', 'g', 'p', 'u', 'k', 'f', 'w', 'x', 'h', 'e', 'z', 'j', 'o', 'a', 'q'], 'b': [' ', 'e', 'l', 'o', 'u', 'y', 'a', 'r', 'i', 's', 'j', 't', 'b', 'v', 'm', 'd', 'n', 'c', 'h', 'p'], 'c': [' ', 'o', 'e', 'h', 'a', 't', 'i', 'u', 'r', 'l', 'k', 'c', 'y', 's', 'q', 'm', 'd', 'f', 'p', 'g', 'n', 'z', 'b'], 'd': [' ', 'e', 'i', 'o', 'a', 'u', 's', 'r', 'y', 'd', 'l', 'g', 'v', 'm', 'w', 'n', 'h', 'j', 't', 'f', 'b', 'c', 'p', 'q'], 'e': [' ', 'r', 'n', 's', 'd', 'a', 'l', 'c', 't', 'e', 'm', 'v', 'x', 'i', 'p', 'f', 'y', 'g', 'w', 'o', 'q', 'u', 'b', 'h', 'k', 'j', 'z'], 'f': [' ', 'o', 'i', 'e', 'r', 'a', 'f', 'u', 't', 'l', 'y', 's', 'm', 'c'], 'g': [' ', 'e', 'h', 'r', 'i', 'a', 'o', 'u', 'n', 'l', 's', 'y', 'g', 't', 'm', 'd', 'f', 'w'], 'h': [' ', 'e', 'a', 'i', 'o', 't', 'r', 'u', 'y', 'n', 's', 'm', 'l', 'w', 'b', 'd', 'f', 'c', 'p', 'h'], 'i': [' ', 'n', 's', 't', 'o', 'c', 'l', 'e', 'm', 'r', 'd', 'v', 'a', 'g', 'f', 'b', 'p', 'z', 'k', 'i', 'x', 'u', 'q', 'h', 'j', 'w'], 'j': [' ', 'u', 'o', 'e', 'a', 'i'], 'k': [' ', 'e', 'i', 'n', 's', 'a', 'l', 'o', 'y', 'h', 'u', 'r', 'g', 'w', 'm', 'f', 't', 'b', 'p', 'd'], 'l': [' ', 'e', 'i', 'l', 'a', 'y', 'o', 'd', 's', 'u', 't', 'f', 'v', 'm', 'k', 'p', 'w', 'c', 'r', 'b', 'g', 'n', 'h'], 'm': [' ', 'e', 'a', 'o', 'i', 'p', 'u', 'm', 's', 'b', 'y', 'n', 'l', 'c', 'f', 'r', 't', 'g', 'd', 'w', 'h'], 'n': [' ', 'd', 't', 'g', 'e', 's', 'o', 'c', 'a', 'i', 'y', 'u', 'n', 'f', 'l', 'v', 'k', 'm', 'j', 'h', 'r', 'p', 'q', 'w', 'z', 'b', 'x'], 'o': [' ', 'n', 'r', 'f', 'u', 'm', 't', 'l', 'w', 's', 'p', 'o', 'd', 'v', 'c', 'b', 'g', 'i', 'k', 'a', 'e', 'y', 'h', 'x', 'j', 'z', 'q'], 'p': [' ', 'e', 'r', 'o', 'a', 'l', 'p', 'i', 't', 'u', 'h', 's', 'm', 'y', 'f', 'b', 'w', 'd', 'n', 'c', 'k'], 'q': [' ', 'u'], 'r': [' ', 'e', 'i', 'o', 'a', 's', 't', 'y', 'd', 'm', 'n', 'u', 'c', 'r', 'g', 'k', 'l', 'v', 'p', 'f', 'b', 'h', 'w', 'x', 'q', 'z', 'j'], 's': [' ', 't', 'e', 'i', 's', 'o', 'h', 'u', 'a', 'p', 'c', 'm', 'y', 'l', 'k', 'w', 'f', 'n', 'b', 'q', 'r', 'd', 'g', 'v'], 't': [' ', 'h', 'i', 'e', 'o', 'a', 'r', 's', 'u', 'y', 't', 'l', 'w', 'm', 'c', 'n', 'f', 'p', 'z', 'b', 'g', 'd', 'v'], 'u': [' ', 'r', 's', 't', 'n', 'l', 'c', 'e', 'm', 'a', 'p', 'g', 'i', 'd', 'b', 'f', 'o', 'k', 'y', 'x', 'v', 'z', 'h', 'u'], 'v': [' ', 'e', 'i', 'a', 'o', 'y', 'u', 'r', 's'], 'w': [' ', 'a', 'h', 'i', 'e', 'o', 'n', 's', 'r', 'l', 't', 'd', 'y', 'f', 'm', 'k', 'b', 'p', 'u', 'c'], 'x': [' ', 'p', 't', 'i', 'a', 'c', 'e', 'u', 'h', 'x', 'o', 'y', 'v', 'f', 'l'], 'y': [' ', 'o', 's', 'e', 'i', 'p', 'm', 't', 'a', 'l', 'c', 'n', 'r', 'd', 'b', 'w', 'g', 'z', 'u', 'f'], 'z': [' ', 'e', 'a', 'i', 'o', 'z', 'y', 'u', 'l', 'h'], ' ': 'abcdefghijklmnopqrstuvwxyz'}"
next_char = eval(next_char_repr)

In [30]:
base = 2

text = "Who could have thought?"
EOF = 50257
tokens = tokenizer.encode(text) + [EOF]
for i in range(len(tokens)):
    print(f"{tokens[i]}:\t {tokenizer.decode(tokens[:i+1])}")

lef, rig = tokens_to_range(tokens, model)
tot = rig - lef
dig = range_to_digits(lef, rig, model, base=base)
print(len(dig), "digits:", "".join(map(str, dig)))
num = digits_to_number(dig, base=base)
# print(num, float(num))
assert tokens == list(islice(number_to_tokens(num, model, EOF), 20))

integer = 0
for d in reversed(dig):
    integer = 2 * integer + d
out = " "
while integer:
    subset = next_char[out[-1]]
    out += subset[integer % len(subset)]
    integer //= len(subset)
print(out)

21560:	 Wh
83:	 Who
10935:	 Who could
3970:	 Who could have
29071:	 Who could have thought
35:	 Who could have thought?
50257:	 Who could have thought?<|endoftext|>
0: 27.25049273335425
1: 39.261399367681115
2: 54.18200235124152
3: 66.43314096667855
4: 78.34081823867162
5: 83.58264288270904
6: 86.37907667916825
83 digits: 11101100110010110010100000011111101100011101100110011000110100110100111100110011011
 hyiloirzelpuiq blbde
