In [1]:
import torch
import transformers

device = 'cpu'
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
from transformers import MarianMTModel, MarianTokenizer
en_ROMANCE_model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
en_ROMANCE_tokenizer = MarianTokenizer.from_pretrained(en_ROMANCE_model_name)
', '.join(en_ROMANCE_tokenizer.supported_language_codes)

'>>fr<<, >>es<<, >>it<<, >>pt<<, >>pt_br<<, >>ro<<, >>ca<<, >>gl<<, >>pt_BR<<, >>la<<, >>wa<<, >>fur<<, >>oc<<, >>fr_CA<<, >>sc<<, >>es_ES<<, >>es_MX<<, >>es_AR<<, >>es_PR<<, >>es_UY<<, >>es_CL<<, >>es_CO<<, >>es_CR<<, >>es_GT<<, >>es_HN<<, >>es_NI<<, >>es_PA<<, >>es_PE<<, >>es_VE<<, >>es_DO<<, >>es_EC<<, >>es_SV<<, >>an<<, >>pt_PT<<, >>frp<<, >>lad<<, >>vec<<, >>fr_FR<<, >>co<<, >>it_IT<<, >>lld<<, >>lij<<, >>lmo<<, >>nap<<, >>rm<<, >>scn<<, >>mwl<<'

In [3]:
en_ROMANCE = MarianMTModel.from_pretrained(en_ROMANCE_model_name).to(device)

In [4]:
ROMANCE_en_model_name = 'Helsinki-NLP/opus-mt-ROMANCE-en'
ROMANCE_en_tokenizer = MarianTokenizer.from_pretrained(ROMANCE_en_model_name)

In [5]:
ROMANCE_en = MarianMTModel.from_pretrained(ROMANCE_en_model_name).to(device)

# Batch translation

In [7]:
def translate(tokenizer, model, text, num_outputs):
    """Use beam search to get a reasonable translation of 'text'"""
    batch = tokenizer.prepare_translation_batch([text]).to(model.device)
    num_beams = num_outputs
    translated = model.generate(**batch, num_beams=num_beams, num_return_sequences=num_outputs)
    return [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=False) for t in translated]

#translate(en_ROMANCE_tokenizer, en_ROMANCE, ">>es<< I ran to the store.", 5)
    

# Incremental translation

English to token ids

In [8]:
tokenizer = ROMANCE_en_tokenizer
model = ROMANCE_en

In [13]:
english = "<pad> I ran to the store."
input_ids = tokenizer.encode(english, return_tensors="pt").to(device)

English token ids to "concept space"

In [14]:
batch = tokenizer.prepare_translation_batch([english]).to(device)
with torch.no_grad():
    english_encoded = model.get_encoder()(**batch)

In [15]:
english_encoded.last_hidden_state.shape

torch.Size([1, 9, 512])

What we have so far, as token ids:

In [16]:
decoder_start_token = model.config.decoder_start_token_id
decoder_start_token

65000

In [17]:
partial_decode = torch.LongTensor([decoder_start_token]).to(device).unsqueeze(0)

Ask the model for what comes next

In [20]:
# one-time setup
past = (english_encoded, None)

In [21]:
input_ids

tensor([[65000,    20,  7350,    11,     4,   106,  2333,     3,     0]])

In [24]:
model_inputs = model.prepare_inputs_for_generation(
    partial_decode, past=past, attention_mask=batch['attention_mask'], use_cache=model.config.use_cache
)
with torch.no_grad():
    model_outputs = model(**model_inputs)

next_token_logits = model_outputs[0][:, -1, :]
past = model_outputs[1]

In [27]:
model_inputs

{'input_ids': None,
 'encoder_outputs': BaseModelOutput(last_hidden_state=tensor([[[-0.7051, -0.2007, -0.1604,  ..., -0.1901,  0.2204,  0.5746],
          [-0.2789, -0.1520, -0.1699,  ...,  0.1892, -0.1320,  0.4475],
          [-0.2646, -0.0100, -0.0025,  ..., -0.2152, -0.1248,  0.5877],
          ...,
          [-0.0699, -0.1235, -0.0587,  ..., -0.6416, -0.1384, -0.2677],
          [-0.5116, -0.1258,  0.5837,  ...,  0.2761, -0.0400,  0.0167],
          [-0.0253, -0.0296, -0.0230,  ..., -0.1292, -0.0583,  0.0904]]]), hidden_states=None, attentions=None),
 'decoder_past_key_values': None,
 'decoder_input_ids': tensor([[65000]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]]),
 'use_cache': True}

In [33]:
type(model_outputs)

transformers.modeling_outputs.Seq2SeqLMOutput

In [37]:
?model.generate

[0;31mSignature:[0m
[0mmodel[0m[0;34m.[0m[0mgenerate[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0minput_ids[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mLongTensor[0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmax_length[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mint[0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmin_length[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mint[0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdo_sample[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mbool[0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mearly_stopping[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mbool[0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[

In [39]:
??model.__call__

[0;31mSignature:[0m [0mmodel[0m[0;34m.[0m[0m__call__[0m[0;34m([0m[0;34m*[0m[0minput[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m Call self as a function.
[0;31mSource:[0m   
    [0;32mdef[0m [0m__call__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0;34m*[0m[0minput[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0;32mfor[0m [0mhook[0m [0;32min[0m [0mself[0m[0;34m.[0m[0m_forward_pre_hooks[0m[0;34m.[0m[0mvalues[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m            [0mresult[0m [0;34m=[0m [0mhook[0m[0;34m([0m[0mself[0m[0;34m,[0m [0minput[0m[0;34m)[0m[0;34m[0m
[0;34m[0m            [0;32mif[0m [0mresult[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m
[0;34m[0m                [0;32mif[0m [0;32mnot[0m [0misinstance[0m[0;34m([0m[0mresult[0m[0;34m,[0m [0mtuple[0m[0;34m)[0

In [38]:
??model._generate_beam_search

[0;31mSignature:[0m
[0mmodel[0m[0;34m.[0m[0m_generate_beam_search[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0minput_ids[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcur_len[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmax_length[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmin_length[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdo_sample[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mearly_stopping[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtemperature[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtop_k[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtop_p[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mrepetition_penalty[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mno_repeat_ngram_size[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbad_words_ids[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpad_token_id[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0meos_token_id[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbatch_size[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnum_return_

In [31]:
class WrapperModel(MarianMTModel):
    def __init__(self, wrapped_models):
        super().__init__()
        self.wrapped_models = wrapped_models

    def forward(self, **model_inputs):
        logits = []
        for model in self.wrapped_models:
            wrapped_model_output = self.wrapped_model(**model_inputs)
            logits.append(wrapped_model_output.logits)

        return OutputClass(
            logits = torch.mean(torch.Tensor(logits), dim=0))
    
    def prepare_inputs_for_generation(self, encoders_outputs, *a, **kw):
        return dict(
            inputs_list=[
                wrapped_model.prepare_inputs_for_generation(*a, **kw)
                for wrapped_model in self.wrapped_models])
    

my_model = WrapperModel(model)

model_inputs = my_model.prepare_inputs_for_generation(
    partial_decode, past=past, attention_mask=batch['attention_mask'], use_cache=model.config.use_cache
)
with torch.no_grad():
    model_outputs = my_model(**model_inputs)

next_token_logits = model_outputs[0][:, -1, :]
past = model_outputs[1]

In [32]:
model_outputs[0]

tensor([[[ 0.4750, -6.3051, -2.4504,  ..., -6.2690, -6.2684,  0.0000]]])

Find the most likely

In [25]:
next_token_to_add = next_token_logits[0].argmax()

In [26]:
next_token_logits

tensor([[ 0.4750, -6.3051, -2.4504,  ..., -6.2690, -6.2684,  0.0000]])

In [64]:
en_ROMANCE_tokenizer.convert_ids_to_tokens(next_token_logits.topk(10).indices[0])

['Ê', '▁œuvre', '▁We', '▁qui', '▁Nouvelle', 'END', '▁et', '▁•', '▁Q', '▁How']

tensor([[ 25, 215, 122, 195, 100, 659,  67, 671, 832, 167]])

Notice that this is a one-item tensor (zero-dim)... so it can't concat with anything:

In [49]:
next_token_to_add.shape

torch.Size([])

But this is the kind of shape that it needs to have:

In [50]:
partial_decode.shape

torch.Size([1, 1])

So here's how to give it that extra dimension:

In [51]:
next_token_to_add.unsqueeze(0).unsqueeze(0)

tensor([[10509]], device='cuda:0')

Ok, those shapes align, so we can concatenate them.

In [52]:
partial_decode = torch.cat((partial_decode, next_token_to_add.unsqueeze(0).unsqueeze(0)), -1)

In [53]:
partial_decode

tensor([[65000, 10509]], device='cuda:0')

Now we have a new output, with one additional token:

In [54]:
tokenizer.convert_ids_to_tokens(partial_decode[0])

['<pad>', '▁Corr']

We can now do all that again, to ask for the next thing after that token.

In [55]:
model_inputs = model.prepare_inputs_for_generation(
    partial_decode, past=past, attention_mask=batch['attention_mask'], use_cache=model.config.use_cache
)
with torch.no_grad():
    model_outputs = model(**model_inputs)

next_token_logits = model_outputs[0][:, -1, :]
past = model_outputs[1]