In [1]:
import torch
from transformers import MarianMTModel, MarianTokenizer
import string
import pandas as pd
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
en_ROMANCE_model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
en_ROMANCE_tokenizer = MarianTokenizer.from_pretrained(en_ROMANCE_model_name)
en_ROMANCE = MarianMTModel.from_pretrained(en_ROMANCE_model_name).to(device)

In [3]:
ROMANCE_en_model_name = 'Helsinki-NLP/opus-mt-ROMANCE-en'
ROMANCE_en_tokenizer = MarianTokenizer.from_pretrained(ROMANCE_en_model_name)
ROMANCE_en = MarianMTModel.from_pretrained(ROMANCE_en_model_name).to(device)

In [4]:
device

device(type='cuda')

In [None]:
    def postprocess_next_token_scores(self, scores, input_ids, *a, **kw):
        batch_size, vocab_size = scores.shape
        cur_len = input_ids.shape[1]
        for hypothesis_idx in range(batch_size):
            valid = False
            cur_hypothesis = input_ids[hypothesis_idx]
            if cur_hypothesis[cur_len - 1].item() == selected_token3:
                if selected_token2 in cur_hypothesis and selected_token3 not in cur_hypothesis[:cur_len-1]:
                    token2_idx = cur_hypothesis.tolist().index(selected_token2)
                    if selected_token1 in cur_hypothesis.tolist()[:token2_idx]:
                        valid = True
                        scores[hypothesis_idx] += 100
                if not valid:
                    scores[hypothesis_idx] = float("-inf")


In [38]:
class CustomMTModel(MarianMTModel):
    def postprocess_next_token_scores(self, scores, input_ids, *a, **kw):
        batch_size, vocab_size = scores.shape
        cur_len = input_ids.shape[1]
        for hypothesis_idx in range(batch_size):
            earliest_valid_idx = 1
            cur_hypothesis = input_ids[hypothesis_idx]
            for word in self.desired_words:
                if word not in cur_hypothesis.tolist():
                    earliest_valid_idx = len(cur_hypothesis)
                else:
                    idx = cur_hypothesis.tolist().index(word)
                    if idx < earliest_valid_idx: # word occurred too early
                        # penalize
                        scores[hypothesis_idx] = -float('inf')
                    else:
                        # reward only if last token generated
                        if idx == cur_len - 1:
                            scores[hypothesis_idx] += 100
                            print(ROMANCE_en_tokenizer.convert_ids_to_tokens(cur_hypothesis))
                            pass
                    earliest_valid_idx = idx + 1
                        
                    
        return MarianMTModel.postprocess_next_token_scores(self, scores, input_ids, *a, **kw)
ROMANCE_en.__class__ = CustomMTModel


In [39]:
def score_prefix(machine_translation, prefix):
    tokenizer = ROMANCE_en_tokenizer
    model = ROMANCE_en
    tokenized_prefix = tokenizer.convert_tokens_to_ids(en_ROMANCE_tokenizer.tokenize(prefix.strip()))
    prefix = torch.LongTensor(tokenized_prefix).to(device)

    batch = tokenizer.prepare_translation_batch([machine_translation.replace("<pad> ", '')]).to(device)
    english_encoded = model.get_encoder()(**batch)
    decoder_start_token = model.config.decoder_start_token_id
    # pylint: disable=E1101
    partial_decode = torch.LongTensor([decoder_start_token]).to(device).unsqueeze(0)
    past = (english_encoded, None)
    # pylint: enable=E1101
    num_tokens_generated = 0
    total = 0
    MAX_LENGTH = 100
    
    #stop when </s> token generated, or max num tokens exceded (just in case)
    while True:
        model_inputs = model.prepare_inputs_for_generation(
        partial_decode, past=past, attention_mask=batch['attention_mask'], use_cache=model.config.use_cache
        )
        with torch.no_grad():
            model_outputs = model(**model_inputs)
        next_token_logits = model_outputs[0][:, -1, :]
        past = model_outputs[1]
        #start with user inputted beginning
        if num_tokens_generated < len(prefix):
            next_token_to_add = prefix[num_tokens_generated]
        else:
            next_token_to_add = next_token_logits[0].argmax()
        next_token_logprobs = next_token_logits - next_token_logits.logsumexp(1, True)
        token_score = next_token_logprobs[0][next_token_to_add].item()
        total += token_score

        #add new token to tokens so far
        partial_decode = torch.cat((partial_decode, next_token_to_add.unsqueeze(0).unsqueeze(0)), -1)
        num_tokens_generated+= 1

        if next_token_to_add.item() == 0 or not (num_tokens_generated < MAX_LENGTH):
            break

    #list of tokens used to display sentence
    decoded_tokens = [sub.replace('\u2581', '\u00a0') for sub in tokenizer.convert_ids_to_tokens(partial_decode[0])]
    decoded_tokens.remove("<pad>")

    final = tokenizer.decode(partial_decode[0]).replace("<pad>", '')
    score = round(total/(len(decoded_tokens)), 3)

    return (score, final.lstrip())

In [40]:
def translate(tokenizer, model, text, num_outputs):   
    """Use beam search to get a reasonable translation of 'text'"""
    # Tokenize the source text
    tokenizer.current_spm = tokenizer.spm_source ### HACK!
    batch = tokenizer.prepare_translation_batch([text]).to(model.device)
    
    # Run model
    num_beams = num_outputs
    translated = model.generate(**batch, num_beams=num_beams, num_return_sequences=num_outputs, max_length=40, no_repeat_ngram_size=3)
    
    # Untokenize the output text.
    tokenizer.current_spm = tokenizer.spm_target
    return [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=False) for t in translated]

original_postprocess = True;
input_str = "Yellowstone was established by the United States government in 1972."
english = ">>es<<" + input_str
engbatch = en_ROMANCE_tokenizer.prepare_translation_batch([english]).to(device)
eng_to_spanish = en_ROMANCE.generate(**engbatch).to(device)
machine_translation = en_ROMANCE_tokenizer.decode(eng_to_spanish[0]).replace("<pad> ", '')

results = []
# selection1 = "1972"
# selection2 = "Yellowstone"
# selection3 = "established"

# ROMANCE_en_tokenizer.current_spm = ROMANCE_en_tokenizer.spm_target
# tokens1 = ROMANCE_en_tokenizer.tokenize(selection1)
# selected_token1 = ROMANCE_en_tokenizer.convert_tokens_to_ids(tokens1)[0]
# tokens2 = ROMANCE_en_tokenizer.tokenize(selection2)
# selected_token2 = ROMANCE_en_tokenizer.convert_tokens_to_ids(tokens2)[0]
# tokens3 = ROMANCE_en_tokenizer.tokenize(selection3)
# selected_token3 = ROMANCE_en_tokenizer.convert_tokens_to_ids(tokens3)[0]

ROMANCE_en.desired_words = []
selections = ["1972", "Yellowstone", "established"]
for word in selections:
    tokens = ROMANCE_en_tokenizer.tokenize(word)
    ROMANCE_en.desired_words.append(ROMANCE_en_tokenizer.convert_tokens_to_ids(tokens)[0])

original_postprocess = False
top50 = translate(ROMANCE_en_tokenizer, ROMANCE_en, ">>en<<" + machine_translation, 100)
for element in top50[0:10]:
    results.append(score_prefix(machine_translation, element))
        
all_sorted = sorted(((score, result) for score, result in results), reverse=True)


['<pad>', '▁In', '▁1972']
['<pad>', '▁In', '▁1972', '▁Y']


In [41]:
results = pd.DataFrame({'sentence': [pair[1] for pair in all_sorted],
              'probability': [pair[0] for pair in all_sorted]}).style.hide_index()
# df = df.style.set_properties(**{'text-align': 'left'})
pd.set_option('display.max_colwidth', None)
results

sentence,probability
In 1972 Yellowstone was established by the United States government.,-1.169
In 1972 Yellowstone was established by the government.,-1.329
In 1972 Yellowstone was established by US government.,-1.377
In 1972 YE was established by the United States government.,-1.703
In 1972 YG was established by the United States government.,-1.759
In 1972 YI was established by the United States government.,-1.766
In 1972 Yellowstone was established.,-1.848
In 1972 YE was established by the US government.,-1.895
In 1972 YG was established by the US government.,-1.951
In 1972 YI was established by the US government.,-1.958


In [9]:
ROMANCE_en._orig_postprocess_next_token_scores??

Object `ROMANCE_en._orig_postprocess_next_token_scores` not found.


In [10]:
ROMANCE_en._generate_beam_search??

[0;31mSignature:[0m
[0mROMANCE_en[0m[0;34m.[0m[0m_generate_beam_search[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0minput_ids[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcur_len[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmax_length[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmin_length[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdo_sample[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mearly_stopping[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtemperature[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtop_k[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtop_p[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mrepetition_penalty[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mno_repeat_ngram_size[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbad_words_ids[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpad_token_id[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0meos_token_id[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbatch_size[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnum_re

In [11]:
ROMANCE_en.__class__??

[0;31mInit signature:[0m [0mROMANCE_en[0m[0;34m.[0m[0m__class__[0m[0;34m([0m[0mconfig[0m[0;34m:[0m [0mtransformers[0m[0;34m.[0m[0mconfiguration_bart[0m[0;34m.[0m[0mBartConfig[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Pytorch version of marian-nmt's transformer.h (c++). Designed for the OPUS-NMT translation checkpoints.
Model API is identical to BartForConditionalGeneration.
Available models are listed at `Model List <https://huggingface.co/models?search=Helsinki-NLP>`__

Examples::

    >>> from transformers import MarianTokenizer, MarianMTModel
    >>> from typing import List
    >>> src = 'fr'  # source language
    >>> trg = 'en'  # target language
    >>> sample_text = "où est l'arrêt de bus ?"
    >>> mname = f'Helsinki-NLP/opus-mt-{src}-{trg}'

    >>> model = MarianMTModel.from_pretrained(mname)
    >>> tok = MarianTokenizer.from_pretrained(mname)
    >>> batch = tok.prepare_translation_batch(src_texts=[sample_text])  # don't nee