In [37]:
import torch
import wordfreq
import math
from transformers import XLNetTokenizer, XLNetLMHeadModel

In [3]:
model = XLNetLMHeadModel.from_pretrained("xlnet-base-cased")
tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [162]:
origStr = "I mean, when you go to a movie and it’s set to start at a certain time, would you not be upset if 7 hours later said movie has not started?"
# "<mask> <mask> <mask> <mask> <mask> <mask> a <mask> and <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask> <mask>?"
testStr = "I mean, when you go to a <mask> and it’s set to start at a certain time, would you not be upset if 7 hours later said movie has not started?"
encoded_str = tokenizer.encode(testStr)
tokens_tensor = torch.tensor([encoded_str])
tokenizer.convert_ids_to_tokens(encoded_str)
perm_mask = torch.zeros((1, tokens_tensor.shape[1], tokens_tensor.shape[1]), dtype=torch.float)
perm_mask[:, :, 8] = 1.0  # Previous tokens don't see masked token
target_mapping = torch.zeros((1, 1, tokens_tensor.shape[1]), dtype=torch.float)  # Shape [1, 1, seq_length] => let's predict one token
target_mapping[0, 0, 8] = 1.0  # Our first (and only) prediction will be the last token of the sequence (the masked token)


In [163]:
with torch.no_grad():
    outputs = model(tokens_tensor, perm_mask=perm_mask, target_mapping=target_mapping)
    next_token_logits = outputs[0][0, 0, :]

In [164]:
print(outputs[0].shape)

torch.Size([1, 1, 32000])


In [165]:
print([tokenizer.convert_ids_to_tokens(index.item()) for index in next_token_logits.topk(10).indices])

['▁', '▁a', '▁an', '▁man', '▁on', '▁first', '▁go', '▁me', '.', '▁to']


In [41]:
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
Rasputin has a vision and denounces one of the men as a horse thief. Although his
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
START_INDEX = 166 # TODO: change hard-coded value

In [50]:
def computeLogProb(original_text, index, tokens_tensor, perm_mask, target_mapping):        
    with torch.no_grad():
        outputs = model(tokens_tensor, perm_mask=perm_mask, target_mapping=target_mapping)
        next_token_logits = outputs[0][0, 0, :]

    preds = [tokenizer.convert_ids_to_tokens(index.item()) for index in next_token_logits.topk(5).indices]
    next_token_logprobs = next_token_logits - next_token_logits.logsumexp(0)
    logProb = next_token_logprobs[tokenizer.convert_tokens_to_ids(original_text[index])].item()

    return (preds, logProb, next_token_logprobs)

def computePredsLogProbs(preds, next_token_logprobs):
    predLogProbs = []
    for i in preds:
        predLogProbs.append(next_token_logprobs[tokenizer.convert_tokens_to_ids(i)].item())
    return predLogProbs

def bigContext(tokenized_text, index):
    encoded_ids = tokenizer.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([encoded_ids])
    perm_mask = torch.zeros((1, tokens_tensor.shape[1], tokens_tensor.shape[1]), dtype=torch.float)
    perm_mask[:, :, index] = 1.0
    target_mapping = torch.zeros((1, 1, tokens_tensor.shape[1]), dtype=torch.float)
    target_mapping[0, 0, index] = 1.0
    return computeLogProb(tokenized_text, index, tokens_tensor, perm_mask, target_mapping)

def smallContext(tokenized_text, index):
    tokens_tensor = torch.tensor([tokenizer.convert_tokens_to_ids(tokenized_text)])
    perm_mask = torch.zeros((1, tokens_tensor.shape[1], tokens_tensor.shape[1]), dtype=torch.float)
    for i in range(START_INDEX, len(tokenized_text) - 1):
        if i != index - 1 and i != index + 1:
            perm_mask[:, :, i] = 1.0
    target_mapping = torch.zeros((1, 1, tokens_tensor.shape[1]), dtype=torch.float)
    target_mapping[0, 0, index] = 1.0
    return computeLogProb(tokenized_text, index, tokens_tensor, perm_mask, target_mapping)

def noContext(word):
    if word in '.?,:!;\'\"‘’“”|-/\\':
        return -1 # FIXME
    freq = wordfreq.word_frequency(word, 'en')
    if freq == 0:
        print("word not found:", word)
        return -100
    return math.log(freq)

def compute_scores(input_text):
    tokenized_text = tokenizer.tokenize(PADDING_TEXT + " " + input_text + "</s>", add_special_tokens=False, return_tensors='pt')

    usedModels = ["bigContext", "smallContext", "noContext"]
    results = []
    compoundBigPreds = []
    compoundSmallPreds = []
    compoundBigLogProb = 0
    compoundSmallLogProb = 0
    currentWord = ""
    startID = 0

    # For each token not in PADDING_TEXT
    for i in range(START_INDEX, len(tokenized_text) - 1):
        # Compute the top 5 model predictions, the log probability of the
        #   correct answer, and the next_token_logprobs
        bigPreds, bigLogProb, bigNextLogProbs = bigContext(tokenized_text, i)
        smallPreds, smallLogProb, smallNextLogProbs = smallContext(tokenized_text, i)

        # Generate the log probabilities of the top 5 small model predictions
        #   given big context vs. small context
        bigPredsLogProbs = computePredsLogProbs(smallPreds, bigNextLogProbs)
        smallPredsLogProbs = computePredsLogProbs(smallPreds, smallNextLogProbs)

        # if the current token is a start token
        if tokenized_text[i].startswith("▁"):
            compoundBigLogProb = bigLogProb
            compoundSmallLogProb = smallLogProb
            compoundBigPreds = bigPreds
            compoundSmallPreds = smallPreds
            currentWord = tokenized_text[i]
            startID = i
        # If the current token is a continuation token
        else:
            compoundBigLogProb += bigLogProb
            compoundSmallLogProb += smallLogProb
            currentWord += tokenized_text[i]
        
        # if the next token is not a start token or the end of sequence, don't do any more work
        #   because that means the next token is a continuation token
        if not (tokenized_text[i + 1].startswith("▁") or tokenized_text[i + 1] == "</s>"):
            continue
            
        currentWord = currentWord.replace("▁", "")
        
        # Compute the no-context log probabilities of the current word and
        #   the predictions generated by the small context model
        noContextLogProb = noContext(currentWord)
        noPredsLogProbs = []
        for j in smallPreds:
            processed_word = j.replace("▁", "")
            noPredsLogProbs.append(noContext(processed_word))

        results.append(dict(
            id = startID,
            word=currentWord,
            src="original",
            model="smallContext",
            score=compoundSmallLogProb)
        )
        
        results.append(dict(
            id = startID,
            word=currentWord,
            src="original",
            model="bigContext",
            score=compoundBigLogProb)
        )

        results.append(dict(
            id = startID,
            word=currentWord,
            src="original",
            model="noContext",
            score=noContextLogProb)
        )

        for j in range(0, len(smallPreds)):
            results.append(dict(
                id = startID,
                word=smallPreds[j],
                src="smallContext",
                model="smallContext",
                score=smallPredsLogProbs[j])
            )

            results.append(dict(
                id = startID,
                word=smallPreds[j],
                src="smallContext",
                model="bigContext",
                score=bigPredsLogProbs[j])
            )
            
            results.append(dict(
                id = startID,
                word=smallPreds[j],
                src="smallContext",
                model="noContext",
                score=noPredsLogProbs[j])
            )
        
        compoundBigLogProb = 0
        compoundSmallLogProb = 0
        compoundBigPreds = []
        compoundSmallPreds = []
        currentWord = ""

    return (results, usedModels)

In [51]:
compute_scores("Hello. This is a test.")

([{'id': 0,
   'word': 'Hello.',
   'src': 'original',
   'model': 'smallContext',
   'score': -1.9449241161346436},
  {'id': 0,
   'word': 'Hello.',
   'src': 'original',
   'model': 'bigContext',
   'score': -16.21267080307007},
  {'id': 0,
   'word': 'Hello.',
   'src': 'original',
   'model': 'noContext',
   'score': -9.877819805787551},
  {'id': 0,
   'word': '.',
   'src': 'smallContext',
   'model': 'smallContext',
   'score': -0.25174784660339355},
  {'id': 0,
   'word': '.',
   'src': 'smallContext',
   'model': 'bigContext',
   'score': -2.8212952613830566},
  {'id': 0,
   'word': '.',
   'src': 'smallContext',
   'model': 'noContext',
   'score': -1},
  {'id': 0,
   'word': '?',
   'src': 'smallContext',
   'model': 'smallContext',
   'score': -2.0876107215881348},
  {'id': 0,
   'word': '?',
   'src': 'smallContext',
   'model': 'bigContext',
   'score': -1.3548493385314941},
  {'id': 0,
   'word': '?',
   'src': 'smallContext',
   'model': 'noContext',
   'score': -1},
  {