References:
* https://nonint.com/2020/03/27/fine-tuning-xlnet-for-generation-tasks/ (though I suspect the approach is fundamentally flawed...)
* https://huggingface.co/transformers/model_doc/xlnet.html
* Encoder-decoder architecture blog: https://medium.com/huggingface/encoder-decoders-in-transformers-a-hybrid-pre-trained-architecture-for-seq2seq-af4d7bf14bb8
* https://github.com/huggingface/transformers/pull/5522/files (for train/finetune xlnet)
* https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_xlnet.py

Distantly related:
* T5 https://arxiv.org/abs/1910.10683
* https://www.microsoft.com/en-us/research/blog/introducing-mass-a-pre-training-method-that-outperforms-bert-and-gpt-in-sequence-to-sequence-language-generation-tasks/


# Load libraries and models

In [8]:
import torch
from transformers import XLNetTokenizer, XLNetConfig, XLNetLMHeadModel
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [9]:
model_name = "xlnet-base-cased"

In [10]:
tokenizer = XLNetTokenizer.from_pretrained(model_name)
model = XLNetLMHeadModel.from_pretrained(model_name)

In [11]:
model.to(device);

XLNet supposedly needs "padding text" to make inference make sense (see docs)... TODO look into why.

In [12]:
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
Rasputin has a vision and denounces one of the men as a horse thief. Although his
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""

# Set up prediction task

In [13]:
text = "Hello, my dog is very cute"
# Use indices from end because of that crazy padding text.
tok_idx_from_end = 4

In [21]:
input_ids = tokenizer.encode(PADDING_TEXT + " " + text, add_special_tokens=False, return_tensors='pt').to(device)
n_tokens = input_ids.shape[1]
idx_to_predict = n_tokens - tok_idx_from_end
tokenizer.convert_ids_to_tokens(input_ids[0, idx_to_predict].item())

'▁dog'

Here's how we'd define a `labels` to calculate a loss. We won't actually use it right now though.

In [22]:
labels = input_ids[:, idx_to_predict].unsqueeze(0)
labels.shape

torch.Size([1, 1])

Create the "permutation mask": which tokens does each token get to look at when making the prediction?

It's `batch_size` by `src_token` by `tgt_token`. Element `b, src, tgt` is `0` if `tgt` gets to see `src` in batch sample `b`, `1` otherwise.

In [23]:
perm_mask = torch.zeros((1, n_tokens, n_tokens), dtype=torch.float, device=device)
# Mask the token to predict.
perm_mask[:, :, idx_to_predict] = 1.0
perm_mask[:, -5:, -5:]  # show it for the last 5 tokens

tensor([[[0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.]]], device='cuda:0')

Create the "target mapping", which is `batch_size` x `num_targets` x `seq_length`. Each target should have a `1` for one of the sequence elements, corresponding to which token to try to predict.

In [24]:
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)  # Shape [1, 1, seq_length] => let's predict one token
target_mapping[0, 0, idx_to_predict] = 1.0  # Our first (and only) prediction
target_mapping

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
          0., 0., 0.]]], device='cuda:0')

Now call the model. It returns the next token logits, and other data if we ask but we didn't.

In [25]:
with torch.no_grad():
    outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
outputs

(tensor([[[ -2.7657, -11.9402, -11.9371,  ...,  -9.3074,  -9.3252, -10.9300]]],
        device='cuda:0'),)

In [26]:
next_token_logits = outputs[0]  # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]

Predictions for filling in that blank:

In [27]:
tokenizer.convert_ids_to_tokens(next_token_logits[0, 0].topk(30).indices)

['▁name',
 '▁dog',
 '▁avatar',
 '▁cat',
 '▁son',
 '▁blog',
 '▁picture',
 '▁friend',
 '▁daughter',
 '▁boy',
 '▁wife',
 '▁baby',
 '▁brother',
 '▁puppy',
 '▁face',
 '▁girl',
 '▁husband',
 '▁hair',
 '▁website',
 '▁boyfriend',
 '▁sister',
 '▁nephew',
 '▁guy',
 '▁kitten',
 '▁character',
 '▁site',
 '▁child',
 '▁photo',
 '▁post',
 '▁girlfriend']

You can check for yourself that the predictions don't change if you change the "masked" token.