# Can we accelerate transformer inference?

Let's look at machine translation as an example

In [1]:
from transformers import MarianMTModel, MarianTokenizer
import torch



In [2]:
tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-de")
embeddings = model.get_input_embeddings()



To do machine translation, we first tokenize the input and then run it through generation

In [3]:
# create token ids for encoder input
input_ids = tokenizer("I want to buy a car", return_tensors="pt").input_ids

In [4]:
%%time
outputs = model.generate(input_ids)

CPU times: user 1.02 s, sys: 0 ns, total: 1.02 s
Wall time: 104 ms




The outputs are the encoded words which we can then decode to get the translation

In [5]:
outputs

tensor([[58100,   105,    73,    53,  1193,  1564,     0]])

In [6]:
[tokenizer.decode(t) for t in outputs]

['<pad> Ich will ein Auto kaufen']

### We can do it in steps rather than have .generate do the work for us

In [7]:
# pass input token ids to encoder
encoder_output_vectors = model.base_model.encoder(input_ids, return_dict=True).last_hidden_state

In [8]:
# create token ids for decoder input
decoder_input_ids = tokenizer("<pad> Ich will ein", return_tensors="pt", add_special_tokens=False).input_ids

# pass decoder input ids and encoded input vectors to decoder
decoder_output_vectors = model.base_model.decoder(decoder_input_ids, encoder_hidden_states=encoder_output_vectors).last_hidden_state

# derive embeddings by multiplying decoder outputs with embedding weights
lm_logits = torch.nn.functional.linear(decoder_output_vectors, embeddings.weight, bias=model.final_logits_bias)
lm_logits

tensor([[[-0.9553, -3.2332,  0.6456,  ..., -3.2678, -3.2706,  0.0000],
         [ 1.5956, -4.2294,  2.6527,  ..., -4.2864, -4.3241,  0.0000],
         [ 1.8177, -3.1177,  5.1491,  ..., -3.1209, -3.1322,  0.0000],
         [ 1.1660, -3.1043,  5.2854,  ..., -3.1397, -3.1413,  0.0000],
         [ 1.4587, -3.5582,  1.1881,  ..., -3.6387, -3.6602,  0.0000]]],
       grad_fn=<ViewBackward0>)

In [9]:
torch.argmax(lm_logits, dim=-1)

tensor([[ 105,   73,   73,   53, 1193]])

In [10]:
# sample last token with highest prob
next_decoder_input_ids = torch.argmax(lm_logits[:, -1:], axis=-1)
next_decoder_input_ids

tensor([[1193]])

In [11]:
# get the word to add to our string
next_id = torch.max(next_decoder_input_ids, dim=-1).values[0]
next_id

tensor(1193)

In [12]:
tokenizer.decode(next_id)

'Auto'

### Doing it in a loop - but caching the encoder embeddings

In [13]:
sentence = "I want to buy a car"
translated = "<pad>"
inputsplit = sentence.split()

In [14]:
# create token ids for the initial decoder input
decoder_input_ids = tokenizer(translated, return_tensors="pt", add_special_tokens=False).input_ids

In [15]:
# create token ids for encoder input
input_ids = tokenizer(sentence, return_tensors="pt").input_ids

In [16]:
%%time

# pass input token ids to encoder and hold on to them to be reused
encoder_output_vectors = model.base_model.encoder(input_ids, return_dict=True).last_hidden_state

# Now try it in a loop
for token in inputsplit:
    
    # pass decoder input ids and encoded input vectors to decoder
    decoder_output_vectors = model.base_model.decoder(decoder_input_ids, encoder_hidden_states=encoder_output_vectors).last_hidden_state

    # derive embeddings by multiplying decoder outputs with embedding weights
    lm_logits = torch.nn.functional.linear(decoder_output_vectors, embeddings.weight, bias=model.final_logits_bias)
    
    # sample last token with highest prob
    next_decoder_input_ids = torch.argmax(lm_logits[:, -1:], axis=-1)
    
    # concat
    decoder_input_ids = torch.cat([decoder_input_ids, next_decoder_input_ids], axis=-1)
    
    # get the word to add to our string
    next_id = torch.max(next_decoder_input_ids, dim=-1).values[0]
    
    #print (next_id)
    #print (tokenizer.decode(next_id))
    
    translated = translated + " " + tokenizer.decode(next_id)
    #print(translated)

CPU times: user 694 ms, sys: 0 ns, total: 694 ms
Wall time: 71.2 ms


In [17]:
translated

'<pad> Ich will ein Auto kaufen </s>'