In [None]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.16.2-py3-none-any.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 13.9 MB/s 
[?25hCollecting tokenizers!=0.11.3,>=0.10.1
  Downloading tokenizers-0.11.5-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.8 MB)
[K     |████████████████████████████████| 6.8 MB 11.3 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.47-py2.py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 40.8 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 3.1 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 18.5 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  Attempting uninstall: pyyaml
 

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
_ = model.eval()

In [None]:
VOCAB_SIZE = 50257

In [None]:
sample_input = tokenizer.encode(" But if you are preparing data and doing cat in each iteration, it gets really slow when the tensor you are generating gets very large. My solution was to cat into", return_tensors="pt")[0]
sample_input

tensor([  887,   611,   345,   389, 10629,  1366,   290,  1804,  3797,   287,
         1123, 24415,    11,   340,  3011,  1107,  3105,   618,   262, 11192,
          273,   345,   389, 15453,  3011,   845,  1588,    13,  2011,  4610,
          373,   284,  3797,   656])

In [None]:
def valid_encoding(shifted_input, encoded_msg, sorted_tokens):
  # At each timestep, use the encoded message to select the tokens at the specified
  # index of the list of sorted tokens to reconstruct the original message.
  # Compare against the original message to ensure they are identical.
  msg_len = encoded_msg.size()[0]

  # Flatten the tensor of sorted tokens to make indexing easier
  # and add offsets to the encoded message to account for this flattening
  vocab_size = sorted_tokens.size()[1]
  sorted_tokens_flat = torch.flatten(sorted_tokens)
  encoded_msg_offset = encoded_msg + torch.arange(0,vocab_size*msg_len,vocab_size)
  decoded_msg_cand = torch.index_select(sorted_tokens_flat, 0, encoded_msg_offset)
  return torch.all(decoded_msg_cand == shifted_input[:-1])
  

def encode(tokenized_msg, vocab_size):
  # Encode
  model.eval()
  with torch.no_grad():
    # In theory, I should be able to avoid the loop because the transformer
    # automatically masks the input. But in practice, this causes the logit
    # outputs to differ slightly between the encoder and decoder
    msg_len = tokenized_msg.size()[0]
    logits_arr = torch.zeros(msg_len, vocab_size)
    for i in range(msg_len):
      msg_slice = tokenized_msg[:i+1]
      logits = model(msg_slice).logits
      logits_arr[i] = logits[i]
    
  # Sort the indices of the logits in descending order of logit value.
  # This means that the model's top predicted token is the first
  # element in the sorted list, the second highest predicted token is the 
  # second element, and so on.
  # 
  # Once we have this list of tokens ordered by their probability
  # we can find the ground-truth token in this list, and save its index
  # as the encoding of the token.
  shifted_input = torch.roll(sample_input, -1) # Shift input to line up with output
  _, sorted_tokens = torch.sort(logits_arr, dim=1, descending=True, stable=True)
  encoded_msg = (sorted_tokens == shifted_input.view(-1, 1)).nonzero()[:,1]
  encoded_msg = encoded_msg[:-1] # Discard the last index because it overflows the original message
  assert valid_encoding(shifted_input, encoded_msg, sorted_tokens)

  return encoded_msg, logits_arr # Logits for debugging

def decode(encoded_msg, first_token, vocab_size):
  with torch.no_grad():
    msg_len = encoded_msg.size()[0]
    logits_arr = torch.zeros(msg_len, vocab_size) # For debugging
    decoded_msg = first_token
    for i in range(encoded_msg.size()[0]):
      logits = model(decoded_msg).logits
      logits_arr[i] = logits[i] # For debugging
      _, indices = torch.sort(logits[i], dim=0, descending=True, stable=True)
      decoded_token = indices[encoded_msg[i:i+1]]
      decoded_msg = torch.cat((decoded_msg, decoded_token))
  return decoded_msg, logits_arr # Logits for debugging

In [None]:
original_msg = sample_input
encoded_msg, encoder_logits = encode(original_msg, VOCAB_SIZE)
decoded_msg, decoder_logits = decode(encoded_msg, original_msg[:1],VOCAB_SIZE)
# Encoder and decoder logits should be identical
assert torch.all(encoder_logits[0:-1] == decoder_logits)
# Decoded message and original message should be identical
assert torch.all(decoded_msg == original_msg)
