In [1]:
import torch
from transformers import MarianMTModel, MarianTokenizer

In [2]:
import pandas as pd

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
en_ROMANCE_model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
en_ROMANCE_tokenizer = MarianTokenizer.from_pretrained(en_ROMANCE_model_name)
en_ROMANCE = MarianMTModel.from_pretrained(en_ROMANCE_model_name).to(device)

In [5]:
ROMANCE_en_model_name = 'Helsinki-NLP/opus-mt-ROMANCE-en'
ROMANCE_en_tokenizer = MarianTokenizer.from_pretrained(ROMANCE_en_model_name)
ROMANCE_en = MarianMTModel.from_pretrained(ROMANCE_en_model_name).to(device)

In [6]:
device

device(type='cuda')

In [7]:
import string

In [26]:
def incremental_generation(english_only, english, start, prefix_only):
    if english_only:
        #translate english to spanish
        engbatch = en_ROMANCE_tokenizer.prepare_translation_batch([english]).to(device)
        eng_to_spanish = en_ROMANCE.generate(**engbatch)
        machine_translation = en_ROMANCE_tokenizer.decode(eng_to_spanish[0])

        #prepare spanish to be translated back to english
        tokenizer = ROMANCE_en_tokenizer
        model = ROMANCE_en
        batchstr = ">>en<<" + machine_translation.replace("<pad> ", '')
        tokenized_prefix = tokenizer.convert_tokens_to_ids(en_ROMANCE_tokenizer.tokenize(start))

    #prepare english to be translated to spanish
    else:
        tokenizer = en_ROMANCE_tokenizer
        model = en_ROMANCE
        batchstr = ">>es<<" + english
        tokenized_prefix = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(start))

    prefix = torch.LongTensor(tokenized_prefix).to(device)

    batch = tokenizer.prepare_translation_batch([batchstr]).to(device)
    original_encoded = model.get_encoder()(**batch)
    decoder_start_token = model.config.decoder_start_token_id
    partial_decode = torch.LongTensor([decoder_start_token]).to(device).unsqueeze(0)
    past = (original_encoded, None)

    #machine translation for comparative purposes
    translation_tokens = model.generate(**batch)
    machine_translation = tokenizer.decode(translation_tokens[0]).split("<pad>")[1]

    num_tokens_generated = 0
    prediction_list = []
    MAX_LENGTH = 100
    total = 0

    #generate tokens incrementally 
    while True:
        model_inputs = model.prepare_inputs_for_generation(
            partial_decode, past=past, attention_mask=batch['attention_mask'], use_cache=model.config.use_cache
        )
        with torch.no_grad():
            model_outputs = model(**model_inputs)

        next_token_logits = model_outputs[0][:, -1, :]
        past = model_outputs[1]
        
        #start with designated beginning
        if num_tokens_generated < len(prefix):
            next_token_to_add = prefix[num_tokens_generated]
        elif prefix_only == True:
            break
        else:
            next_token_to_add = next_token_logits[0].argmax()

        #calculate score
        next_token_logprobs = next_token_logits - next_token_logits.logsumexp(1, True)
        token_score = next_token_logprobs[0][next_token_to_add].item()
        total += token_score

        #append top 10 predictions for each token to list
        decoded_predictions = []
        for tok in next_token_logits[0].topk(10).indices:
            decoded_predictions.append(tokenizer.convert_ids_to_tokens(tok.item()).replace('\u2581', '\u00a0'))
        
        #list of lists of predictions
        prediction_list.append(decoded_predictions)

        #add new token to tokens so far
        partial_decode = torch.cat((partial_decode, next_token_to_add.unsqueeze(0).unsqueeze(0)), -1)
        num_tokens_generated += 1

        #stop generating at </s>, or when max num tokens exceded
        if next_token_to_add.item() == 0 or not (num_tokens_generated < MAX_LENGTH):
            break

    #list of tokens used to display sentence
    decoded_tokens = [sub.replace('\u2581', '\u00a0') for sub in tokenizer.convert_ids_to_tokens(partial_decode[0])]
    decoded_tokens.remove("<pad>")

    final = tokenizer.decode(partial_decode[0]).replace("<pad>", '')
    score = round(total/(len(decoded_tokens)), 3)

    if english_only:
        new_english = final
    #back translate spanish into english
    else:
        batch2 = ROMANCE_en_tokenizer.prepare_translation_batch([">>en<< " + final]).to(device)
        spanish_to_english = ROMANCE_en.generate(**batch2)
        new_english = ROMANCE_en_tokenizer.decode(spanish_to_english[0]).replace("<pad>", '')

    return {"translation": final,
                    "expected" : machine_translation,
                    "newEnglish" : new_english,
                    "tokens" : decoded_tokens,
                    "predictions" : prediction_list,
                    "score" : score
                }

def rearrange(english, start, first_phrase, auto):
    wordlist = [''.join(x for x in par if x not in string.punctuation) for par in english.split(' ')]
    first_phrases = set([word.capitalize() for word in wordlist])

    #get most likely sentence or prefix and its score, given the word to move towards front
    def get_alt(start, prefix_only):
        if start[0] in wordlist:
            pos = wordlist.index(start.lstrip())
        #if subword token is selected
        else:
            res = [i for i in wordlist if start[0].lstrip() in i]
            pos = wordlist.index(res[0])

        #word before selected word
        first_phrases.add(wordlist[pos - 1].capitalize())
        #2 words before selected word
        first_phrases.add(' '.join(wordlist[pos-2: pos]).capitalize())
        first_phrases.add('The')
        if first_phrase != '':
            prefixes = [first_phrase + ' ' + word.lower() + ' ' + start.lstrip() for word in first_phrases]
        else:
            prefixes = [first_phrase + word + ' ' + start.lstrip() for word in first_phrases]
        prefixes.append(start.lstrip().capitalize())

        results = []
        scores = []

        #score each possible prefix/sentence
        for prefix in prefixes:
            data = incremental_generation(english_only=True, english=english, start=prefix, prefix_only=prefix_only)
            results.append(data["translation"])
            scores.append(data["score"])
        #select most likely sentence or prefix
        ind = scores.index(max(scores))
        winner = results[ind]
        winnerscore = scores[ind]
        return (winnerscore, winner)

    alternatives = []
    winner = ''

    #generate a list of alternatives
    if auto:
        #skip first 3 words bc they all return the default sentence
        for word in wordlist[3:]:
            alt = get_alt(word, prefix_only=False)
            #avoid duplicate sentences
            if alt not in alternatives:
                alternatives.append(alt)
            sorted_scores = sorted(((score, result) for score, result in alternatives), reverse=True)
        alternatives = [pair[1] for pair in sorted_scores]
        return {"alternatives" : alternatives}

    else:
        #get most likely prefix
        winner = get_alt(start, prefix_only=False)[1]
        #get full sentence given prefix
        return incremental_generation(english_only=True, english=english, start=winner, prefix_only=False)

In [9]:
import spacy

In [30]:
nlp = spacy.load("en_core_web_sm")
doc = nlp("Yellowstone National Park was established by the US government in 1972.")
for token in doc:
    print(token.text, token.lemma_, token.pos_, token.tag_, token.dep_,
            token.shape_, token.is_alpha, token.is_stop)

Yellowstone Yellowstone PROPN NNP compound Xxxxx True False
National National PROPN NNP compound Xxxxx True False
Park Park PROPN NNP nsubjpass Xxxx True False
was be AUX VBD auxpass xxx True True
established establish VERB VBN ROOT xxxx True False
by by ADP IN agent xx True True
the the DET DT det xxx True True
US US PROPN NNP compound XX True True
government government NOUN NN pobj xxxx True False
in in ADP IN prep xx True True
1972 1972 NUM CD pobj dddd False False
. . PUNCT . punct . False False


In [35]:
for chunk in doc.noun_chunks:
    print(chunk.text,'-', chunk.root.text,'-', chunk.root.dep_, '-',
            chunk.root.head.text)

Yellowstone National Park - Park - nsubjpass - established
the US government - government - pobj - by


In [11]:
res = rearrange("Yellowstone was established by the US government in 1972.", '', '', auto=True)
res

{'alternatives': [' The Yellowstone was established by the United States government in 1972.',
  ' In 1972, Yellowstone was established by the United States government.',
  ' The government of the United States established Yellowstone in 1972.',
  ' By the United States government in 1972, Yellowstone was established.',
  ' The US government established Yellowstone in 1972.']}

In [20]:
rearrange("George gave the cat a piece of chicken.", "cat", '', False)['translation']

' The cat is given a piece of chicken by George.'

In [None]:
rearrange("George gave the cat a piece of chicken.", "cat", '', False)

In [12]:
res['alternatives']

[' The Yellowstone was established by the United States government in 1972.',
 ' In 1972, Yellowstone was established by the United States government.',
 ' The government of the United States established Yellowstone in 1972.',
 ' By the United States government in 1972, Yellowstone was established.',
 ' The US government established Yellowstone in 1972.']

In [27]:
rearrange("George gave the cat a piece of chicken.", "", '', True)['alternatives']

[' A piece of chicken was given to the cat by George.',
 ' The cat is given a piece of chicken by George.',
 ' Of course, George gave the cat a piece of chicken.',
 ' A chicken piece is given to the cat by George.']

In [14]:
sentence = "Yellowstone was established by the US government in 1972."
first_select = "Yellowstone"
second_select = "1972"

In [15]:
data = rearrange(sentence, first_select, '', auto=False)
new_prefix = data['translation']
new_prefix

' Yellowstone'

In [204]:
data2 = rearrange(sentence, second_select, new_prefix, auto=False)
print(data2['translation'])
final = incremental_generation(english_only=True, english=sentence, start=data2['translation'], prefix_only=False)
final['translation']

 Yellowstone in 1972


' Yellowstone in 1972 was established by the United States government.'

In [189]:
def score_prefix(machine_translation, prefix):
    tokenizer = ROMANCE_en_tokenizer
    model = ROMANCE_en
    tokenized_prefix = tokenizer.convert_tokens_to_ids(en_ROMANCE_tokenizer.tokenize(prefix.strip()))
    prefix = torch.LongTensor(tokenized_prefix).to(device)

    batch = tokenizer.prepare_translation_batch([machine_translation.replace("<pad> ", '')]).to(device)
    english_encoded = model.get_encoder()(**batch)
    decoder_start_token = model.config.decoder_start_token_id
    # pylint: disable=E1101
    partial_decode = torch.LongTensor([decoder_start_token]).to(device).unsqueeze(0)
    past = (english_encoded, None)
    # pylint: enable=E1101
    num_tokens_generated = 0
    total = 0
    MAX_LENGTH = 100
    
    #stop when </s> token generated, or max num tokens exceded (just in case)
    while True:
        model_inputs = model.prepare_inputs_for_generation(
        partial_decode, past=past, attention_mask=batch['attention_mask'], use_cache=model.config.use_cache
        )
        with torch.no_grad():
            model_outputs = model(**model_inputs)
        next_token_logits = model_outputs[0][:, -1, :]
        past = model_outputs[1]
        #start with user inputted beginning
        if num_tokens_generated < len(prefix):
            next_token_to_add = prefix[num_tokens_generated]
        else:
            next_token_to_add = next_token_logits[0].argmax()
        next_token_logprobs = next_token_logits - next_token_logits.logsumexp(1, True)
        token_score = next_token_logprobs[0][next_token_to_add].item()
        total += token_score

        #add new token to tokens so far
        partial_decode = torch.cat((partial_decode, next_token_to_add.unsqueeze(0).unsqueeze(0)), -1)
        num_tokens_generated+= 1

        if next_token_to_add.item() == 0 or not (num_tokens_generated < MAX_LENGTH):
            break

    #list of tokens used to display sentence
    decoded_tokens = [sub.replace('\u2581', '\u00a0') for sub in tokenizer.convert_ids_to_tokens(partial_decode[0])]
    decoded_tokens.remove("<pad>")

    final = tokenizer.decode(partial_decode[0]).replace("<pad>", '')
    score = round(total/(len(decoded_tokens)), 3)

    return (score, final.lstrip())


In [70]:
import string

In [71]:
string.punctuation

'!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~'

In [72]:
punctuation = '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~' 

In [105]:
first_phrases

{'',
 '1972',
 'By',
 'Established',
 'Government',
 'In',
 'National',
 'Park',
 'The',
 'Us',
 'Was',
 'Yellowstone'}

In [186]:
input_str = "Yellowstone was established by the US government in 1972."
wordlist = [''.join(x for x in par if x not in punctuation) for par in input_str.split(' ')]
first_phrases = set([word.capitalize() for word in wordlist])

In [187]:
english = ">>es<<" + input_str
engbatch = en_ROMANCE_tokenizer.prepare_translation_batch([english]).to(device)
eng_to_spanish = en_ROMANCE.generate(**engbatch).to(device)
machine_translation = en_ROMANCE_tokenizer.decode(eng_to_spanish[0])


In [190]:
expected = score_prefix(machine_translation, wordlist[0])[1]
expected

'Yellowstone was established by the U.S. government in 1972.'

In [191]:
input_str + ' ' + expected

'Yellowstone was established by the US government in 1972. Yellowstone was established by the U.S. government in 1972.'

In [192]:
sent_scores = []
for word in wordlist:
    pos = wordlist.index(word.lstrip())
    first_phrases.add(' '.join(wordlist[pos-2: pos]).capitalize())
    for phrase in first_phrases:
        res = score_prefix(machine_translation, phrase)
        simple = str(res[1].lower()).replace('the ', '')
        if simple not in expected.lower() and expected.lower() not in simple:
            sent_scores.append(res)

In [193]:
sorted_scores = sorted(((score, result) for score, result in set(sent_scores)), reverse=True)
results = pd.DataFrame({'sentence': [pair[1] for pair in sorted_scores],
              'probability': [pair[0] for pair in sorted_scores]}).style.hide_index()
# df = df.style.set_properties(**{'text-align': 'left'})
pd.set_option('display.max_colwidth', None)
pd.set_option("display.colheader_justify","left")
results

sentence,probability
Yellowstone was established by the U.S. government in 1972.,-0.16
The Yellowstone was established by the U.S. government in 1972.,-0.51
"In 1972, Yellowstone was established by the U.S. government.",-0.638
"Established by the U.S. government in 1972, Yellowstone was established in 1972.",-1.221
"By the US government, Yellowstone was established in 1972.",-1.229
Was established by the US government in 1972.,-1.353
Usher was established by the U.S. government in 1972.,-1.459
Government of the United States established Yellowstone in 1972.,-1.504
The ushers of Yellowstone was established by the US government in 1972.,-1.539
1972 was the first year of the American government to establish Yellowstone.,-1.945


In [18]:
sorted_scores = sorted(((score, result) for score, result in sent_scores), reverse=True)
results = pd.DataFrame({'sentence': [pair[1] for pair in sorted_scores],
              'probability': [pair[0] for pair in sorted_scores]}).style.hide_index()
# df = df.style.set_properties(**{'text-align': 'left'})
pd.set_option('display.max_colwidth', None)
pd.set_option("display.colheader_justify","left")
results

sentence,probability
"In my opinion, all cats are great.",-0.313
"All cats are great, in my opinion.",-0.765
"My guess is, all cats are great.",-0.77
"Great. In my opinion, all cats are great.",-1.229
"Are you sure? In my opinion, all cats are great.",-1.23
"Cats are all great, in my opinion.",-1.349
"Opinion is, all cats are great.",-1.718


In [31]:
score_prefix("In my opinion, all cats are great.", "All cats")

1
1.5
2


(-1.806, 'All cats')

In [13]:
examples = [
    ('In my opinion, all of the cats are great.', 'All of the cats', 'cats'),
    ('It was a dark and stormy night', 'The night', 'night'),
    ('They won\'t be back once they go.', 'Once they go', 'go'),
    ('''Easter Island, the most remote place on Earth in terms of distance from other inhabited places, 
         has long been famous for the enormous stone statues erected by its prehistoric settlers.''', 'Easter Island has long been famous', 'famous'),
    ('In 1972, the United States government established Yellowstone National Park as the world\'s first legislated effort at nature conservation',
        'Yellowstone National Park was established', 'established'),
    ('A positive PSA test has to be followed up with a biopsy or other procedures before cancer can be confirmed.', 'Before confirming', 'confirming'),
    ('University of Michigan President Mary Sue Coleman said in a statement on the university\'s Web site, "Our fundamental values haven\'t changed.',
        'Our fundamental values', 'values'),
    ('At highest risk are those charasmatic species whose specieswide genetic diversity is very low.', 'Those charasmatic species', 'species'),
     ('The success of Johnson\'s novels is celebrated in an annual festival, called Longmire Day, held in the small town of Buffalo, Wyoming', 'Longmire Day is held', 'held'),
    ('At the 2006 census, its population was 599', '599 individuals', 'individuals')
]
first_phrases = []
for src, tgt, word in examples:
    #tgt_words = tgt#.split()
    #idx = tgt_words.index(word)
    #words_before_tgt_word = tgt_words[:idx]
    for w in tgt.rstrip('?.!,').split():
        first_phrases.append(w.capitalize())
    first_phrases.append(tgt[0: tgt.index(word)].strip())
print(sorted(set(first_phrases)))

paraphrases = ["All of the cats are great, in my opinion.",
               "The night was dark and stormy.",
               "Once they go, they won't be back.",
               "Easter Island has long been famous for the enormous stone statues erected by its prehistoric settlers.",
               "Yellowstone National Park was established by the US government in 1972 as the world\'s first legislated effort at nature conservation.",
               "Before confirming cancer, a positive PSA text has to be followed up with a biopsy or other procedures.",
               "Our fundamental values haven't changed, said University of Michigan President Mary Sue Coleman in a statement on the university's Web site.",
               "Those charasmatic species whose specieswide genetic diversity is very low are at the highest risk.",
               "Longmire Day is held in the small town of Buffalo, Wyoming to celebrate the success of Johnson's novels.",
               "599 individuals lived there at the 2006 census."
              ]
paraphrases

['599', 'All', 'All of the', 'Been', 'Before', 'Cats', 'Charasmatic', 'Confirming', 'Day', 'Easter', 'Easter Island has long been', 'Established', 'Famous', 'Fundamental', 'Go', 'Has', 'Held', 'Individuals', 'Is', 'Island', 'Long', 'Longmire', 'Longmire Day is', 'National', 'Night', 'Of', 'Once', 'Once they', 'Our', 'Our fundamental', 'Park', 'Species', 'The', 'They', 'Those', 'Those charasmatic', 'Values', 'Was', 'Yellowstone', 'Yellowstone National Park was']


['All of the cats are great, in my opinion.',
 'The night was dark and stormy.',
 "Once they go, they won't be back.",
 'Easter Island has long been famous for the enormous stone statues erected by its prehistoric settlers.',
 "Yellowstone National Park was established by the US government in 1972 as the world's first legislated effort at nature conservation.",
 'Before confirming cancer, a positive PSA text has to be followed up with a biopsy or other procedures.',
 "Our fundamental values haven't changed, said University of Michigan President Mary Sue Coleman in a statement on the university's Web site.",
 'Those charasmatic species whose specieswide genetic diversity is very low are at the highest risk.',
 "Longmire Day is held in the small town of Buffalo, Wyoming to celebrate the success of Johnson's novels.",
 '599 individuals lived there at the 2006 census.']

In [14]:
ranks = []
for src, tgt, word in examples:
    prefixes = [phrase +  ' ' + word for phrase in first_phrases]
    scores = [score_prefix(src, prefix) for prefix in prefixes]
    # compute rank of actual prefix in sorted scores
    sorted_scores = sorted(((score, result) for score, result in scores), reverse=True)
    rank = [y[1] for y in sorted_scores].index(tgt) + 1
    ranks.append((tgt, rank))

KeyboardInterrupt: 

In [None]:
df = pd.DataFrame({'original': [ex[0] for ex in examples],
              'expected paraphrase': paraphrases,
              'expected prefix': [ex[1] for ex in examples],
              'word': [ex[2] for ex in examples],
              'rank of expected prefix': [el[1] for el in ranks]}).style.hide_index()
# df = df.style.set_properties(**{'text-align': 'left'})
pd.set_option('display.max_colwidth', None)
pd.set_option("display.colheader_justify","left")
df

In [None]:
pd.DataFrame({'scores': [prefix[0] for prefix in scores], 'sentences': [prefix[1] for prefix in scores]}).sort_values('scores', ascending=False)

In [58]:
def incremental_translation(batchstr, starting_word):
    prefix = torch.LongTensor(starting_word)
    print(prefix)
    batch = tokenizer.prepare_translation_batch([batchstr])
    translation_tokens = model.generate(**batch)
    machine_translation = tokenizer.decode(translation_tokens[0]).split("<pad>")[1]
    original_encoded = model.get_encoder()(**batch)
    decoder_start_token = model.config.decoder_start_token_id

    partial_decode = torch.LongTensor([decoder_start_token]).unsqueeze(0)
    past = (original_encoded, None)
    next_token_to_add = torch.tensor(1)
    num_tokens_generated = 0

    prediction_list = []
    MAX_LENGTH = 100

    while True:
        model_inputs = model.prepare_inputs_for_generation(
        partial_decode, past=past, attention_mask=batch['attention_mask'], use_cache=model.config.use_cache
        )
        with torch.no_grad():
            model_outputs = model(**model_inputs)

        next_token_logits = model_outputs[0][:, -1, :]
        past = model_outputs[1]

        #start with designated beginning
        if num_tokens_generated < len(prefix):
            next_token_to_add = prefix[num_tokens_generated]
        else:
            next_token_to_add = next_token_logits[0].argmax()

        #append top 10 predictions for each token to list
        decoded_predictions = []
        for tok in next_token_logits[0].topk(10).indices:
            decoded_predictions.append(tokenizer.convert_ids_to_tokens(tok.item()).replace('\u2581', '\u00a0'))
        #list of lists of predictions
        prediction_list.append(decoded_predictions)

        #add new token to tokens so far
        partial_decode = torch.cat((partial_decode, next_token_to_add.unsqueeze(0).unsqueeze(0)), -1)
        num_tokens_generated+= 1

        if next_token_to_add.item() == 0 or not (num_tokens_generated < MAX_LENGTH):
            break

    #list of tokens used to display sentence
    decoded_tokens = [sub.replace('\u2581', '\u00a0') for sub in tokenizer.convert_ids_to_tokens(partial_decode[0])]
    decoded_tokens.remove("<pad>")

    final = tokenizer.decode(partial_decode[0]).replace("<pad>", '')
   
    return {"translation": final,
                "expected" : machine_translation,
                "tokens" : decoded_tokens,
                "predictions" : prediction_list
            }

skip = "false"
copy_input = "false"
english = "It was a dark and stormy night."

if copy_input == "true":
    start = english
else:
    start = "La"

#english only- spanish behind the scenes
if skip == "true":
    engbatch = en_ROMANCE_tokenizer.prepare_translation_batch([english])
    eng_to_spanish = en_ROMANCE.generate(**engbatch)
    mid_machine_translation = en_ROMANCE_tokenizer.decode(eng_to_spanish[0])
    tokenizer = ROMANCE_en_tokenizer
    model = ROMANCE_en
    batchstr = ">>en<<" + mid_machine_translation.replace("<pad> ", '')
    starting_word = tokenizer.convert_tokens_to_ids(en_ROMANCE_tokenizer.tokenize(start))
    data = incremental_translation(batchstr, starting_word)
    data.update( { "new_english" : data["translation"] })
    print(data)
    
#show spanish
else:
    tokenizer = en_ROMANCE_tokenizer
    model = en_ROMANCE
    batchstr = ">>es<<" + english
    starting_word = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(start))
    data = incremental_translation(batchstr, starting_word)
    batch2 = ROMANCE_en_tokenizer.prepare_translation_batch([">>en<< " + final])
    spanish_to_english = ROMANCE_en.generate(**batch2)
    new_english = ROMANCE_en_tokenizer.decode(spanish_to_english[0]).replace("<pad>", '')
    data.update({new_english : new_english})
    print(data)
    



tensor([100])
{'translation': ' La noche fue oscura y tempestuosa.', 'expected': ' Era una noche oscura y tempestuosa.', 'tokens': ['\xa0La', '\xa0noche', '\xa0fue', '\xa0os', 'cura', '\xa0y', '\xa0tem', 'pes', 'tu', 'osa', '.', '</s>'], 'predictions': [['\xa0Era', '\xa0Fue', '\xa0Ha', '\xa0Es', '\xa0Estaba', '\xa0Había', '\xa0Una', '\xa0La', '\xa0-', '\xa0No'], ['\xa0noche', '\xa0tarde', '\xa0oscuridad', '\xa0vela', '\xa0verdad', '\xa0misma', '\xa0no', '\xa0ciudad', '\xa0cosa', '\xa0última'], ['\xa0fue', '\xa0era', '\xa0estaba', '\xa0estuvo', '\xa0es', '\xa0había', '\xa0se', '\xa0de', '\xa0no', '\xa0que'], ['\xa0os', '\xa0som', '\xa0tem', '\xa0muy', '\xa0de', '\xa0oscuro', '\xa0tan', '\xa0tor', '\xa0negra', '\xa0o'], ['cura', 'cure', 'cur', 'cu', 'ca', 'curi', 'erta', 'jera', 'curr', 'cor'], ['\xa0y', ',', '.', '\xa0e', '...', '\xa0de', '<pad>', '\xa0con', '\xa0en', '\xa0pero'], ['\xa0tem', '\xa0tormenta', '\xa0tor', '\xa0tur', '\xa0os', '\xa0de', '\xa0tort', '\xa0o', '\xa0muy', '\xa0

In [None]:
import re

In [None]:
(' ' + 'All cats are great, info my opinion').index(' ' + 'in')

In [None]:
print('a\nb')

In [None]:
print(r'a\nb')

In [None]:
re.search(r'\bgreat\b', 'All of the cats are great, opinion.')

In [None]:
english = ">>es<<In my opinion, all cats are great."
engbatch = en_ROMANCE_tokenizer.prepare_translation_batch([english])
eng_to_spanish = en_ROMANCE.generate(**engbatch)
machine_translation = en_ROMANCE_tokenizer.decode(eng_to_spanish[0])

start = "cats"

tokenizer = ROMANCE_en_tokenizer
model = ROMANCE_en

first_words = ["The", "All", ""]
wordlist = english.split(' ')
for word in wordlist:
    first_words.append(word.capitalize())
first_words
    
results = []
scores = []
MAX_LENGTH = 100
    
for word in first_words:
    join_prefix_str = word + " " + start
    tokenized_prefix = tokenizer.convert_tokens_to_ids(en_ROMANCE_tokenizer.tokenize(join_prefix_str.strip()))
    prefix = torch.LongTensor(tokenized_prefix)
#     tokenized_prefix = tokenizer.convert_tokens_to_ids(en_ROMANCE_tokenizer.tokenize(start))
#     prefix = torch.LongTensor(tokenized_prefix)

    batch = tokenizer.prepare_translation_batch([machine_translation.replace("<pad> ", '')])
    english_encoded = model.get_encoder()(**batch)
    decoder_start_token = model.config.decoder_start_token_id
    # pylint: disable=E1101
    partial_decode = torch.LongTensor([decoder_start_token]).unsqueeze(0)
    past = (english_encoded, None)
    # pylint: enable=E1101

    num_tokens_generated = 0
    prediction_list = []
    total = 0
    #stop when </s> token generated, or max num tokens exceded (just in case)
    while True:
        model_inputs = model.prepare_inputs_for_generation(
        partial_decode, past=past, attention_mask=batch['attention_mask'], use_cache=model.config.use_cache
        )
        with torch.no_grad():
            model_outputs = model(**model_inputs)

        next_token_logits = model_outputs[0][:, -1, :]
        past = model_outputs[1]

        #start with user inputted beginning
        if num_tokens_generated < len(prefix):
            next_token_to_add = prefix[num_tokens_generated]
        else:
            break
#             next_token_to_add = next_token_logits[0].argmax()

        next_token_logprobs = next_token_logits - next_token_logits.logsumexp(1, True)
        score = next_token_logprobs[0][next_token_to_add].item()

        total += score

        #add new token to tokens so far
        partial_decode = torch.cat((partial_decode, next_token_to_add.unsqueeze(0).unsqueeze(0)), -1)
        num_tokens_generated+= 1

        if next_token_to_add.item() == 0 or not (num_tokens_generated < MAX_LENGTH):
            break

    #list of tokens used to display sentence
    decoded_tokens = [sub.replace('\u2581', '\u00a0') for sub in tokenizer.convert_ids_to_tokens(partial_decode[0])]
    decoded_tokens.remove("<pad>")

    final = tokenizer.decode(partial_decode[0]).replace("<pad>", '')
    score = round(total/(len(decoded_tokens)), 3)
    results.append(final)
    scores.append(score)

    print("\n" + final)
    print(total)

ind = scores.index(max(scores))
winner = results[ind]
print("\nMost likely: ", winner)