In [44]:
import torch
from transformers import MarianMTModel, MarianTokenizer
import string
import pandas as pd
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
en_ROMANCE_model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
en_ROMANCE_tokenizer = MarianTokenizer.from_pretrained(en_ROMANCE_model_name)
en_ROMANCE = MarianMTModel.from_pretrained(en_ROMANCE_model_name).to(device)

In [3]:
ROMANCE_en_model_name = 'Helsinki-NLP/opus-mt-ROMANCE-en'
ROMANCE_en_tokenizer = MarianTokenizer.from_pretrained(ROMANCE_en_model_name)
ROMANCE_en = MarianMTModel.from_pretrained(ROMANCE_en_model_name).to(device)

In [4]:
device

device(type='cuda')

In [7]:
def monkey_patch(model, new_postproc_fn):
    cls = model.__class__
    print(cls)
    func_name = "postprocess_next_token_scores"
    orig_name = "_orig_" + func_name
    if not hasattr(cls, orig_name):
        print(str(cls) + " doesn't have attribute " + orig_name)
        setattr(cls, orig_name, getattr(cls, func_name))
    else:
        print(str(cls) + " has attribute " + orig_name)
    setattr(cls, func_name, new_postproc_fn)
    print(str(cls) + '\n' + func_name + '\n' + str(new_postproc_fn))

In [55]:
def postprocess_next_token_scores(self, scores, input_ids, *a, **kw):
    #print(input_ids.shape, scores.shape)
    batch_size, vocab_size = scores.shape
    cur_len = input_ids.shape[1]
    for hypothesis_idx in range(batch_size):
        cur_hypothesis = input_ids[hypothesis_idx]
        #print(ROMANCE_en_tokenizer.convert_ids_to_tokens(cur_hypothesis))

    # Hack the beam
    if not original_postprocess and cur_len == 2:
        force_token_id = selected_token
        #print(scores[:, force_token_id])
        self._force_token_ids_generation(scores, token_ids=[force_token_id])

    #print(scores[:, self.config.eos_token_id])
    return self._orig_postprocess_next_token_scores(scores, input_ids, *a, **kw)

monkey_patch(ROMANCE_en, postprocess_next_token_scores)

<class 'transformers.modeling_marian.MarianMTModel'>
<class 'transformers.modeling_marian.MarianMTModel'> has attribute _orig_postprocess_next_token_scores
<class 'transformers.modeling_marian.MarianMTModel'>
postprocess_next_token_scores
<function postprocess_next_token_scores at 0x7faa69cc1950>


In [56]:
def score_prefix(machine_translation, prefix):
    tokenizer = ROMANCE_en_tokenizer
    model = ROMANCE_en
    tokenized_prefix = tokenizer.convert_tokens_to_ids(en_ROMANCE_tokenizer.tokenize(prefix.strip()))
    prefix = torch.LongTensor(tokenized_prefix).to(device)

    batch = tokenizer.prepare_translation_batch([machine_translation.replace("<pad> ", '')]).to(device)
    english_encoded = model.get_encoder()(**batch)
    decoder_start_token = model.config.decoder_start_token_id
    # pylint: disable=E1101
    partial_decode = torch.LongTensor([decoder_start_token]).to(device).unsqueeze(0)
    past = (english_encoded, None)
    # pylint: enable=E1101
    num_tokens_generated = 0
    total = 0
    MAX_LENGTH = 100
    
    #stop when </s> token generated, or max num tokens exceded (just in case)
    while True:
        model_inputs = model.prepare_inputs_for_generation(
        partial_decode, past=past, attention_mask=batch['attention_mask'], use_cache=model.config.use_cache
        )
        with torch.no_grad():
            model_outputs = model(**model_inputs)
        next_token_logits = model_outputs[0][:, -1, :]
        past = model_outputs[1]
        #start with user inputted beginning
        if num_tokens_generated < len(prefix):
            next_token_to_add = prefix[num_tokens_generated]
        else:
            next_token_to_add = next_token_logits[0].argmax()
        next_token_logprobs = next_token_logits - next_token_logits.logsumexp(1, True)
        token_score = next_token_logprobs[0][next_token_to_add].item()
#         print(token_score)
        total += token_score

        #add new token to tokens so far
        partial_decode = torch.cat((partial_decode, next_token_to_add.unsqueeze(0).unsqueeze(0)), -1)
        num_tokens_generated+= 1

        if next_token_to_add.item() == 0 or not (num_tokens_generated < MAX_LENGTH):
            break

    #list of tokens used to display sentence
    decoded_tokens = [sub.replace('\u2581', '\u00a0') for sub in tokenizer.convert_ids_to_tokens(partial_decode[0])]
    decoded_tokens.remove("<pad>")

    final = tokenizer.decode(partial_decode[0]).replace("<pad>", '')
    score = round(total/(len(decoded_tokens)), 3)

    return (score, final.lstrip())

In [57]:
score_prefix(machine_translation, "â a George gave the cat a piece of chicken Â Â Â Â Â Â Â Â Â Â ¢ Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â")

(-1.041,
 'â a George gave the cat a piece of chicken Â Â Â Â Â Â Â Â Â Â ¢ Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â')

In [58]:
machine_translation

'Mientras que algunos de ustedes ya han regresado este verano, la mayoría de los estudiantes que llegan a clases en persona llegará el próximo mes.'

In [67]:
def translate(tokenizer, model, text, num_outputs):   
    """Use beam search to get a reasonable translation of 'text'"""
    # Tokenize the source text
    tokenizer.current_spm = tokenizer.spm_source ### HACK!
    batch = tokenizer.prepare_translation_batch([text]).to(model.device)
    
    # Run model
    num_beams = num_outputs
    translated = model.generate(**batch, num_beams=num_beams, num_return_sequences=num_outputs, max_length=40, no_repeat_ngram_size=3)
    
    # Untokenize the output text.
    tokenizer.current_spm = tokenizer.spm_target
    return [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=False) for t in translated]

original_postprocess = True;
input_str = "Taylor reportedly consumed copious amounts of raw fruit and iced milk while attending holiday celebrations during a fundraising event at the Washington Monument."
english = ">>es<<" + input_str
engbatch = en_ROMANCE_tokenizer.prepare_translation_batch([english]).to(device)
eng_to_spanish = en_ROMANCE.generate(**engbatch).to(device)
machine_translation = en_ROMANCE_tokenizer.decode(eng_to_spanish[0]).replace("<pad> ", '')

results = []
for word in input_str.split(' ')[3:]:
    selection = word
    ROMANCE_en_tokenizer.current_spm = ROMANCE_en_tokenizer.spm_target
    tokens = ROMANCE_en_tokenizer.tokenize(selection)
    selected_token = ROMANCE_en_tokenizer.convert_tokens_to_ids(tokens)[0]
    # list(zip(ROMANCE_en_tokenizer.convert_tokens_to_ids(tokens), tokens))

    original_postprocess = False
    top50 = translate(ROMANCE_en_tokenizer, ROMANCE_en, ">>en<<" + machine_translation, 100)
    for element in top50[0:1]:
        results.append(score_prefix(machine_translation, element))
        
all_sorted = sorted(((score, result) for score, result in results), reverse=True)


In [68]:
results = pd.DataFrame({'sentence': [pair[1] for pair in all_sorted],
              'probability': [pair[0] for pair in all_sorted]}).style.hide_index()
# df = df.style.set_properties(**{'text-align': 'left'})
pd.set_option('display.max_colwidth', None)
results

sentence,probability
"As a result, Taylor would have consumed abundant amounts of raw fruit and frozen milk while attending the parties during a fund-raising event at the Washington Monument.",-0.698
"While attending the parties during a fund-raising event at the Washington Monument, Taylor would have consumed abundant amounts of raw fruit and frozen milk.",-0.699
"In the event, Taylor would have consumed abundant amounts of raw fruit and frozen milk while attending the parties during a fund-raising event at the Washington Monument.",-0.722
Taylor ate abundant amounts of raw fruit and frozen milk while attending the parties during a fund-raising event at the Washington Monument.,-0.728
Many amounts of raw fruit and frozen milk would have been consumed by Taylor while attending the parties during a fund-raising event at the Washington Monument.,-0.731
Taylor of course would have consumed abundant amounts of raw fruit and frozen milk while attending the parties during a fund-raising event at the Washington Monument.,-0.733
A copious amount of raw fruit and frozen milk would have been consumed by Taylor while attending the parties during a fund-raising event at the Washington Monument.,-0.796
Taylor while attending the parties during a fund-raising event at the Washington Monument would have consumed abundant amounts of raw fruit and frozen milk.,-0.8
Taylor milk was said to have consumed abundant amounts of raw fruit and frozen milk while attending the parties during a fund-raising event at the Washington Monument.,-0.812
Taylor and he would have consumed abundant amounts of raw fruit and frozen milk while attending the parties during a fund-raising event at the Washington Monument.,-0.823


In [None]:
ROMANCE_en._orig_postprocess_next_token_scores??

In [18]:
ROMANCE_en._generate_beam_search??

[0;31mSignature:[0m
[0mROMANCE_en[0m[0;34m.[0m[0m_generate_beam_search[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0minput_ids[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcur_len[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmax_length[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmin_length[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdo_sample[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mearly_stopping[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtemperature[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtop_k[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtop_p[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mrepetition_penalty[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mno_repeat_ngram_size[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbad_words_ids[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpad_token_id[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0meos_token_id[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbatch_size[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnum_re

In [17]:
ROMANCE_en.__class__??

[0;31mInit signature:[0m [0mROMANCE_en[0m[0;34m.[0m[0m__class__[0m[0;34m([0m[0mconfig[0m[0;34m:[0m [0mtransformers[0m[0;34m.[0m[0mconfiguration_bart[0m[0;34m.[0m[0mBartConfig[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m        
[0;32mclass[0m [0mMarianMTModel[0m[0;34m([0m[0mBartForConditionalGeneration[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;34mr"""[0m
[0;34m    Pytorch version of marian-nmt's transformer.h (c++). Designed for the OPUS-NMT translation checkpoints.[0m
[0;34m    Model API is identical to BartForConditionalGeneration.[0m
[0;34m    Available models are listed at `Model List <https://huggingface.co/models?search=Helsinki-NLP>`__[0m
[0;34m[0m
[0;34m    Examples::[0m
[0;34m[0m
[0;34m        >>> from transformers import MarianTokenizer, MarianMTModel[0m
[0;34m        >>> from typing import List[0m
[0;34m        >>> src = 'fr'  # source language[0m
[0;34m        >>> trg = 'en'  # target language