<a href="https://colab.research.google.com/github/D-Sokol/denotarikon/blob/main/Sandbox.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers==4.1.1



In [2]:
import torch
import numpy as np
import string
from transformers import GPT2Tokenizer, GPT2LMHeadModel

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', add_prefix_space=True)
model = GPT2LMHeadModel.from_pretrained('gpt2').train(False).to(device)

In [4]:
all_tokens = [tokenizer.decode(i) for i in range(tokenizer.vocab_size)]

In [5]:
def letter_index(letter, alphabet=string.ascii_lowercase):
    return alphabet.index(letter)

mask_allowed = torch.zeros(tokenizer.vocab_size, len(string.ascii_lowercase), dtype=bool)
mask_satisfied = mask_allowed.clone()
for token, mask_row_a, mask_row_s in zip(all_tokens, mask_allowed, mask_satisfied):
    # TODO: replace this magic conditions with something clearer
    if token.startswith(' '):
        if len(token) == 1 or token[1].lower() not in string.ascii_lowercase:
            continue
        mask_row_a[letter_index(token[1].lower())] = True
        mask_row_s[letter_index(token[1].lower())] = True
    elif token and token[0] not in string.ascii_uppercase:
        mask_row_a[:] = True

In [6]:
start_text = "The best possible example to demonstrate power of the project is"
target = "TBPETDPOTPITREEBPORT"

In [7]:
# Parameter for nucleus sampling
p_threshold = 0.95
# Allow model to generate only this many tokens that counts as one word.
max_nostarting_token = 5

temperature = 0.8

In [8]:
start_tokens = tokenizer.encode(start_text)
start_tokens

[383, 1266, 1744, 1672, 284, 10176, 1176, 286, 262, 1628, 318]

In [9]:
target = target.lower()
assert target != '', "Target string cannot be empty"

target_letter_generated = 0
for ix in start_tokens:
    token = tokenizer.decode(ix)
    # TODO: again, magic conditions
    if not token or not token.startswith(' '):
        continue
    
    assert target_letter_generated != len(target), "Target string is too short"
    assert target[target_letter_generated].upper() == token[1].upper(), "Target string does not correspond given phrase"
    target_letter_generated += 1

In [10]:
tokens = start_tokens.copy()
with torch.no_grad():
    result = model(torch.tensor(tokens, device=device)[None], past_key_values=None)
    next_logits, past = result['logits'][0, -1, :], result['past_key_values']
    rest_nostarting_tokens = max_nostarting_token
    
    while target_letter_generated < len(target):
        if rest_nostarting_tokens:
            next_logits[~mask_allowed[:, letter_index(target[target_letter_generated])]] = -np.inf
        else:
            next_logits[~mask_satisfied[:, letter_index(target[target_letter_generated])]] = -np.inf
        next_probas = torch.softmax(next_logits / temperature, dim=-1).cpu()

        sorted_p, sorted_ix = torch.sort(next_probas, descending=True)
        cumulative_p = torch.cumsum(sorted_p, dim=-1)

        # Number of possible choices for next token, calculated as minimal n
        #  such that sum of probabilities of the first n tokens exceeds p_threshold
        n_tokens_next = np.argmax(cumulative_p.numpy() > p_threshold) + 1

        sorted_p = sorted_p[:n_tokens_next]
        sorted_p /= cumulative_p[n_tokens_next-1]
        ix_ix = np.random.choice(n_tokens_next, p=sorted_p.numpy())
        next_ix = sorted_ix[ix_ix]
        tokens.append(next_ix.item())
        if mask_satisfied[next_ix, string.ascii_lowercase.index(target[target_letter_generated].lower())]:
            target_letter_generated += 1
            rest_nostarting_tokens = max_nostarting_token
        else:
            rest_nostarting_tokens -= 1

        result = model(next_ix[None].to(device), past_key_values=past)
        next_logits, past = result['logits'][0, :], result['past_key_values']

In [11]:
print(tokenizer.decode(tokens))

 The best possible example to demonstrate power of the project is the research/production efforts.

6. Established.

9 Basic, passive, or real-time.

 The
