In [1]:
import torch
import transformers

# Load the model

In [2]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
model = T5ForConditionalGeneration.from_pretrained("t5-base")
tokenizer = T5Tokenizer.from_pretrained("t5-base")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Move it to the target device to make things faster.

In [3]:
model.to(device);

# Encode input

English to token ids

In [4]:
english = "How are you?"
input_ids = tokenizer.encode("translate English to German: " + english, return_tensors="pt").to(device)

English token ids to "concept space"

In [5]:
english_encoded = model.encoder(input_ids = input_ids)

# Aside: try generating a translation...

In [6]:
translation = model.generate(input_ids, max_length=40, num_beams=4, early_stopping=True)

In [7]:
tokenizer.decode(translation[0])

'Wie sind Sie?'

In [8]:
translation[0]

tensor([   0, 2739,  436,  292,   58], device='cuda:0')

In [9]:
tokenizer.convert_ids_to_tokens(translation[0])

['<pad>', '▁Wie', '▁sind', '▁Sie', '?']

# Back on track: start with just "Wie"

What we have so far, as token ids:

In [10]:
partial_decode = torch.LongTensor([0, 2739]).to(device)

Ask the model for what comes next

In [11]:
next_word_logits = model.forward(
    encoder_outputs=english_encoded, # "concept space"
    decoder_input_ids=partial_decode.unsqueeze(0) # model is expecting shape to be (batch, seq length) so need to unsqueeze
)[0]
next_word_logits.shape

torch.Size([1, 2, 32128])

Find the most likely

In [12]:
next_indices = next_word_logits[0,-1].topk(5).indices
next_indices

tensor([ 436,    3,  229, 5674, 3015], device='cuda:0')

No, really the *most* likely...

In [15]:
next_token_to_add = next_indices[0]
next_token_to_add

tensor(436, device='cuda:0')

Notice that this is a one-item tensor (zero-dim)... so it can't concat with anything:

In [16]:
next_token_to_add.shape

torch.Size([])

But this is the kind of shape that it needs to have:

In [14]:
partial_decode.shape

torch.Size([2])

So here's how to give it that extra dimension:

In [17]:
next_token_to_add.unsqueeze(0).shape

torch.Size([1])

Ok, those shapes align, so we can concatenate them.

In [18]:
partial_decode = torch.cat((partial_decode, next_token_to_add.unsqueeze(0)), 0)

Now we have a new output, with one additional token:

In [19]:
tokenizer.convert_ids_to_tokens(partial_decode)

['<pad>', '▁Wie', '▁sind']

We can now do all that again, to ask for the next thing after "sind"...

In [20]:
next_word_logits = model.forward(decoder_input_ids=partial_decode.unsqueeze(0), encoder_outputs=english_encoded)[0]
next_word_logits.shape

torch.Size([1, 3, 32128])