In [1]:
import torch
from transformers import MarianMTModel, MarianTokenizer
import string
import pandas as pd
import spacy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
en_ROMANCE_model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
en_ROMANCE_tokenizer = MarianTokenizer.from_pretrained(en_ROMANCE_model_name)
en_ROMANCE = MarianMTModel.from_pretrained(en_ROMANCE_model_name).to(device)

In [3]:
ROMANCE_en_model_name = 'Helsinki-NLP/opus-mt-ROMANCE-en'
ROMANCE_en_tokenizer = MarianTokenizer.from_pretrained(ROMANCE_en_model_name)
ROMANCE_en = MarianMTModel.from_pretrained(ROMANCE_en_model_name).to(device)

In [4]:
device

device(type='cuda')

In [5]:
def monkey_patch(model, new_postproc_fn):
    cls = model.__class__
    print(cls)
    func_name = "postprocess_next_token_scores"
    orig_name = "_orig_" + func_name
    if not hasattr(cls, orig_name):
        print(str(cls) + " doesn't have attribute " + orig_name)
        setattr(cls, orig_name, getattr(cls, func_name))
    else:
        print(str(cls) + " has attribute " + orig_name)
    setattr(cls, func_name, new_postproc_fn)
    print(str(cls) + '\n' + func_name + '\n' + str(new_postproc_fn))

In [24]:
def postprocess_next_token_scores(self, scores, input_ids, *a, **kw):
    if not original_postprocess:
        batch_size, vocab_size = scores.shape
        cur_len = input_ids.shape[1]
        for hypothesis_idx in range(batch_size):
            cur_hypothesis = input_ids[hypothesis_idx]

        if 0 < cur_len <= len(selected_tokens):
            force_token_id = selected_tokens[cur_len-1]
            self._force_token_ids_generation(scores, token_ids=[force_token_id])
#         print(cur_hypothesis)
#         print(ROMANCE_en_tokenizer.decode(cur_hypothesis))
#         if cur_len == 1:
#             force_token_id = selected_token
#             self._force_token_ids_generation(scores, token_ids=[force_token_id])

    return self._orig_postprocess_next_token_scores(scores, input_ids, *a, **kw)

monkey_patch(en_ROMANCE, postprocess_next_token_scores)

<class 'transformers.modeling_marian.MarianMTModel'>
<class 'transformers.modeling_marian.MarianMTModel'> has attribute _orig_postprocess_next_token_scores
<class 'transformers.modeling_marian.MarianMTModel'>
postprocess_next_token_scores
<function postprocess_next_token_scores at 0x7fee08f0d3b0>


In [7]:
def score_prefix(machine_translation, prefix):
    tokenizer = ROMANCE_en_tokenizer
    model = ROMANCE_en
    tokenized_prefix = tokenizer.convert_tokens_to_ids(en_ROMANCE_tokenizer.tokenize(prefix.strip()))
    prefix = torch.LongTensor(tokenized_prefix).to(device)

    batch = tokenizer.prepare_translation_batch([machine_translation.replace("<pad> ", '')]).to(device)
    english_encoded = model.get_encoder()(**batch)
    decoder_start_token = model.config.decoder_start_token_id
    # pylint: disable=E1101
    partial_decode = torch.LongTensor([decoder_start_token]).to(device).unsqueeze(0)
    past = (english_encoded, None)
    # pylint: enable=E1101
    num_tokens_generated = 0
    total = 0
    MAX_LENGTH = 100
    
    #stop when </s> token generated, or max num tokens exceded (just in case)
    while True:
        model_inputs = model.prepare_inputs_for_generation(
        partial_decode, past=past, attention_mask=batch['attention_mask'], use_cache=model.config.use_cache
        )
        with torch.no_grad():
            model_outputs = model(**model_inputs)
        next_token_logits = model_outputs[0][:, -1, :]
        past = model_outputs[1]
        #start with user inputted beginning
        if num_tokens_generated < len(prefix):
            next_token_to_add = prefix[num_tokens_generated]
        else:
            next_token_to_add = next_token_logits[0].argmax()
        next_token_logprobs = next_token_logits - next_token_logits.logsumexp(1, True)
        token_score = next_token_logprobs[0][next_token_to_add].item()
#         print(token_score)
        total += token_score

        #add new token to tokens so far
        partial_decode = torch.cat((partial_decode, next_token_to_add.unsqueeze(0).unsqueeze(0)), -1)
        num_tokens_generated+= 1

        if next_token_to_add.item() == 0 or not (num_tokens_generated < MAX_LENGTH):
            break

    #list of tokens used to display sentence
    decoded_tokens = [sub.replace('\u2581', '\u00a0') for sub in tokenizer.convert_ids_to_tokens(partial_decode[0])]
    decoded_tokens.remove("<pad>")

    final = tokenizer.decode(partial_decode[0]).replace("<pad>", '')
    score = round(total/(len(decoded_tokens)), 3)

    return (score, final.lstrip())

In [8]:
def translate(tokenizer, model, text, num_outputs):   
    """Use beam search to get a reasonable translation of 'text'"""
    # Tokenize the source text
    tokenizer.current_spm = tokenizer.spm_source ### HACK!
    batch = tokenizer.prepare_translation_batch([text]).to(model.device)
    
    # Run model
    num_beams = num_outputs
    translated = model.generate(**batch, num_beams=num_beams, num_return_sequences=num_outputs, max_length=40, no_repeat_ngram_size=5)
    
    # Untokenize the output text.
    tokenizer.current_spm = tokenizer.spm_target
    return [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=False) for t in translated]

In [22]:
# original_postprocess = True;
# input_str = "Yellowstone National Park was established by the US government in 1972 as the world's first legislated effort at nature conservation."
# english = ">>es<<" + input_str
# engbatch = en_ROMANCE_tokenizer.prepare_translation_batch([english]).to(device)
# eng_to_spanish = en_ROMANCE.generate(**engbatch).to(device)
# machine_translation = en_ROMANCE_tokenizer.decode(eng_to_spanish[0]).replace("<pad> ", '')

# results = []
# for word in input_str.split(' ')[3:]:
#     selection = word
#     ROMANCE_en_tokenizer.current_spm = ROMANCE_en_tokenizer.spm_target
#     tokens = ROMANCE_en_tokenizer.tokenize(selection)
#     selected_token = ROMANCE_en_tokenizer.convert_tokens_to_ids(tokens)[0]
#     # list(zip(ROMANCE_en_tokenizer.convert_tokens_to_ids(tokens), tokens))

#     original_postprocess = False
#     top50 = translate(ROMANCE_en_tokenizer, Taylor reportedly consumed copious amounts of raw fruit and iced milk while attending holiday celebrations during a fundraising event at the Washington Monument.ROMANCE_en, ">>en<<" + machine_translation, 10)
#     for element in top50[0:1]:
#         results.append(score_prefix(machine_translation, element))
        
# all_sorted = sorted(((score, result) for score, result in results), reverse=True)

In [10]:
#https://stackoverflow.com/questions/39100652/python-chunking-others-than-noun-phrases-e-g-prepositional-using-spacy-etc
def get_pps(doc):
    pps = []
    for token in doc:
        if token.pos_ == 'ADP':
            pp = ' '.join([tok.orth_ for tok in token.subtree])
            pps.append(pp)
        if token.dep_ == 'prep':
            off_limits.append(' '.join([tok.orth_ for tok in token.subtree]))
    return pps

In [44]:
def get_adv_clause(doc):
    clauses = []
    for token in doc:
        if token.dep_ == 'advcl' or token.dep_ == 'npadvmod' or token.dep_ == 'advmod':
            clause = ' '.join([tok.orth_ for tok in token.subtree])
            clauses.append(clause)
    return clauses

In [12]:
nlp = spacy.load("en_core_web_sm")

In [50]:
sentence = "I'm excited because my birthday is tomorrow"
phrases = []
doc = nlp(sentence)

#get prepositional phrases and blacklist OPs 
off_limits = []
for pphrase in get_pps(doc):
    #messy way to capitalize the first word without lowercasing the others
    capitalized = pphrase.split(' ')[0].capitalize() + ' ' + ' '.join(pphrase.split(' ')[1:])
    phrases.append(capitalized)

#get noun chunks that aren't OPs
for chunk in doc.noun_chunks:
    valid = True
    for phr in off_limits:
        if chunk.text in phr:
            valid = False
    if valid:
        capitalized = chunk.text.split(' ')[0].capitalize() + ' ' + ' '.join(chunk.text.split(' ')[1:])
        phrases.append(capitalized)

#get adverbial modifiers and clauses
for clause in get_adv_clause(doc):
    capitalized = clause.split(' ')[0].capitalize() + ' ' + ' '.join(clause.split(' ')[1:])
    phrases.append(capitalized)

print(phrases)

original_postprocess = True;
english = ">>es<<" + sentence
engbatch = en_ROMANCE_tokenizer.prepare_translation_batch([english]).to(device)
eng_to_spanish = en_ROMANCE.generate(**engbatch).to(device)
machine_translation = en_ROMANCE_tokenizer.decode(eng_to_spanish[0]).replace("<pad> ", '')

results = []
for selection in set(phrases):
    ROMANCE_en_tokenizer.current_spm = ROMANCE_en_tokenizer.spm_target
    tokens = ROMANCE_en_tokenizer.tokenize(selection)
    selected_tokens = ROMANCE_en_tokenizer.convert_tokens_to_ids(tokens)
    # list(zip(ROMANCE_en_tokenizer.convert_tokens_to_ids(tokens), tokens))

    original_postprocess = False
    top50 = translate(ROMANCE_en_tokenizer, ROMANCE_en, ">>en<<" + machine_translation, 50)
    for element in top50[0:1]:
        results.append(score_prefix(machine_translation, element))
        
all_sorted = sorted(((score, result) for score, result in results), reverse=True)

['I ', 'My birthday', 'Because my birthday is tomorrow', 'Tomorrow ']


In [48]:
results = pd.DataFrame({'sentence': [pair[1] for pair in all_sorted],
              'probability': [pair[0] for pair in all_sorted]}).style.hide_index()
# df = df.style.set_properties(**{'text-align': 'left'})
pd.set_option('display.max_colwidth', None)
results

sentence,probability
I'm excited because my birthday is tomorrow.,-0.293
My birthday's tomorrow. I'm so excited.,-0.954
"Because my birthday is tomorrow, I'm excited.",-1.009
Tomorrow's the day that I'm excited to have my birthday. I'm so excited that I'm going to have to have my birthday tomorrow.,-1.321
