In [1]:
import torch

In [2]:
from transformers import MarianMTModel, MarianTokenizer

In [3]:
en_ROMANCE_model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
en_ROMANCE_tokenizer = MarianTokenizer.from_pretrained(en_ROMANCE_model_name)
print(en_ROMANCE_tokenizer.supported_language_codes)

['>>fr<<', '>>es<<', '>>it<<', '>>pt<<', '>>pt_br<<', '>>ro<<', '>>ca<<', '>>gl<<', '>>pt_BR<<', '>>la<<', '>>wa<<', '>>fur<<', '>>oc<<', '>>fr_CA<<', '>>sc<<', '>>es_ES<<', '>>es_MX<<', '>>es_AR<<', '>>es_PR<<', '>>es_UY<<', '>>es_CL<<', '>>es_CO<<', '>>es_CR<<', '>>es_GT<<', '>>es_HN<<', '>>es_NI<<', '>>es_PA<<', '>>es_PE<<', '>>es_VE<<', '>>es_DO<<', '>>es_EC<<', '>>es_SV<<', '>>an<<', '>>pt_PT<<', '>>frp<<', '>>lad<<', '>>vec<<', '>>fr_FR<<', '>>co<<', '>>it_IT<<', '>>lld<<', '>>lij<<', '>>lmo<<', '>>nap<<', '>>rm<<', '>>scn<<', '>>mwl<<']


In [4]:
en_ROMANCE = MarianMTModel.from_pretrained(en_ROMANCE_model_name)

In [5]:
ROMANCE_en_model_name = 'Helsinki-NLP/opus-mt-ROMANCE-en'
ROMANCE_en_tokenizer = MarianTokenizer.from_pretrained(ROMANCE_en_model_name)


In [7]:
ROMANCE_en = MarianMTModel.from_pretrained(ROMANCE_en_model_name)

In [163]:
def score_prefix(src, prefix):
    english = ">>es<<" + src
    engbatch = en_ROMANCE_tokenizer.prepare_translation_batch([english])
    eng_to_spanish = en_ROMANCE.generate(**engbatch)
    machine_translation = en_ROMANCE_tokenizer.decode(eng_to_spanish[0])

    tokenizer = ROMANCE_en_tokenizer
    model = ROMANCE_en

    tokenized_prefix = tokenizer.convert_tokens_to_ids(en_ROMANCE_tokenizer.tokenize(prefix.strip()))
    prefix = torch.LongTensor(tokenized_prefix)

    batch = tokenizer.prepare_translation_batch([machine_translation.replace("<pad> ", '')])
    english_encoded = model.get_encoder()(**batch)
    decoder_start_token = model.config.decoder_start_token_id
    # pylint: disable=E1101
    partial_decode = torch.LongTensor([decoder_start_token]).unsqueeze(0)
    past = (english_encoded, None)
    # pylint: enable=E1101

    num_tokens_generated = 0
    total = 0
    MAX_LENGTH = 100
    
    #stop when </s> token generated, or max num tokens exceded (just in case)
    while True:
        model_inputs = model.prepare_inputs_for_generation(
        partial_decode, past=past, attention_mask=batch['attention_mask'], use_cache=model.config.use_cache
        )
        with torch.no_grad():
            model_outputs = model(**model_inputs)

        next_token_logits = model_outputs[0][:, -1, :]
        past = model_outputs[1]

        #start with user inputted beginning
        if num_tokens_generated < len(prefix):
            next_token_to_add = prefix[num_tokens_generated]
        else:
            break
#             next_token_to_add = next_token_logits[0].argmax()

        next_token_logprobs = next_token_logits - next_token_logits.logsumexp(1, True)
        token_score = next_token_logprobs[0][next_token_to_add].item()
        total += token_score

        #add new token to tokens so far
        partial_decode = torch.cat((partial_decode, next_token_to_add.unsqueeze(0).unsqueeze(0)), -1)
        num_tokens_generated+= 1

        if next_token_to_add.item() == 0 or not (num_tokens_generated < MAX_LENGTH):
            break

    #list of tokens used to display sentence
    decoded_tokens = [sub.replace('\u2581', '\u00a0') for sub in tokenizer.convert_ids_to_tokens(partial_decode[0])]
    decoded_tokens.remove("<pad>")

    final = tokenizer.decode(partial_decode[0]).replace("<pad>", '')
    score = round(total/(len(decoded_tokens)), 3)

    return (score, final.lstrip())


In [164]:
score_prefix("In my opinion, all cats are great.", "All cats")

(-1.806, 'All cats')

In [65]:
prefixes = [phrase +  ' ' + 'confirming' for phrase in first_phrases]
prefixes.append('Before confirming')
scores = [score_prefix('A positive PSA test has to be followed up with a biopsy or other procedures before cancer can be confirmed.', prefix) for prefix in prefixes]
sorted_scores = sorted(((score, result) for score, result in scores), reverse=True)
print(sorted_scores)

[(-6.632, 'All of the confirming'), (-7.558, 'Famous confirming'), (-8.163, 'Confirming confirming'), (-8.461, 'Yellowstone National park was confirming'), (-8.461, 'The confirming'), (-8.461, 'The confirming'), (-8.461, 'The confirming'), (-8.477, 'Once they confirming'), (-8.629, 'Before confirming'), (-8.629, 'Before confirming'), (-8.629, 'Before confirming'), (-9.717, 'As confirming'), (-9.843, 'Once confirming'), (-9.856, 'Cats confirming'), (-9.917, 'Easter Island as long been confirming'), (-10.337, 'Yellowstone confirming'), (-10.901, 'They confirming'), (-11.055, 'Our confirming'), (-11.095, 'Those confirming'), (-11.148, 'Those charasmatic confirming'), (-11.271, 'All confirming'), (-11.806, 'Charasmatic confirming'), (-11.838, 'Our fundamental confirming'), (-11.933, 'Been confirming'), (-11.967, 'Was confirming'), (-12.033, 'Of confirming'), (-12.072, 'Easter confirming'), (-12.381, 'Values confirming'), (-12.402, 'National confirming'), (-12.676, 'Species confirming'), (-

In [133]:
prefixes = [phrase +  ' ' + 'left' for phrase in first_phrases]
prefixes.append('Bradley left')
scores = [score_prefix('It was the first time in four years that a healthy Donovan did not start, and while he said he was in agreement with Bradley\'s decision, Bradley left him on the bench again in the semifinals.', prefix) for prefix in prefixes]
sorted_scores = sorted(((score, result) for score, result in scores), reverse=True)
print(sorted_scores)

[(-3.136, 'Bradley left'), (-5.233, 'All of the left'), (-5.888, 'Famous left'), (-6.185, 'Once they left'), (-6.265, 'They left'), (-7.465, 'Yellowstone left'), (-7.637, 'Yellowstone National park was left'), (-7.967, 'Cats left'), (-8.051, 'The left'), (-8.051, 'The left'), (-8.051, 'The left'), (-8.244, 'All left'), (-8.53, 'Confirming left'), (-8.58, 'Once left'), (-8.601, 'Those left'), (-8.837, 'Park left'), (-9.027, 'Easter Island as long been left'), (-9.283, 'As left'), (-9.543, 'Our fundamental left'), (-9.662, 'Was left'), (-9.674, 'Those charasmatic left'), (-9.712, 'Before left'), (-9.712, 'Before left'), (-9.766, 'Charasmatic left'), (-9.794, 'Long left'), (-9.938, 'Go left'), (-9.954, 'Our left'), (-9.958, 'Species left'), (-10.063, 'Island left'), (-10.141, 'Been left'), (-10.346, 'Fundamental left'), (-10.514, 'Values left'), (-10.843, 'National left'), (-11.084, 'Night left'), (-11.29, 'Of left'), (-11.691, 'Easter left'), (-12.125, 'Established left')]


In [162]:
prefixes = [phrase +  ' ' + 'cats' for phrase in first_phrases]
scores = [score_prefix('In my opinion, all of the cats are cool', prefix) for prefix in prefixes]
sorted_scores = sorted(((score, result) for score, result in scores), reverse=True)
print(sorted_scores)


[(-3.612, 'All cats'), (-6.407, 'The cats'), (-6.407, 'The cats'), (-6.407, 'The cats'), (-7.088, 'All of the cats'), (-8.483, 'Those cats'), (-10.651, 'Our cats'), (-14.144, 'Famous cats'), (-14.892, 'Park cats'), (-15.543, 'Island cats'), (-15.994, 'National cats'), (-17.298, 'Night cats'), (-17.384, 'They cats'), (-18.878, 'Is cats'), (-19.004, 'Easter cats'), (-19.655, 'Established cats'), (-19.88, 'Before cats'), (-19.88, 'Before cats'), (-19.937, 'Been cats'), (-20.074, 'Fundamental cats'), (-20.11, 'Has cats'), (-20.299, 'Once cats'), (-20.451, 'Of cats'), (-20.742, 'Cats cats'), (-20.994, 'Was cats'), (-21.59, 'Day cats'), (-21.848, 'Species cats'), (-22.083, 'Long cats'), (-22.285, 'Individuals cats'), (-23.165, 'Confirming cats'), (-23.441, 'Go cats'), (-23.625, 'Yellowstone cats'), (-24.381, 'Held cats'), (-24.614, 'Values cats'), (-24.684, 'Once they cats'), (-25.977, '599 cats'), (-25.977, '599 cats'), (-27.819, 'Our fundamental cats'), (-35.927, 'Charasmatic cats'), (-37.

In [159]:
examples = [
    ('In my opinion, all of the cats are great.', 'All of the cats', 'cats'),
    ('It was a dark and stormy night', 'The night', 'night'),
    ('They won\'t be back once they go.', 'Once they go', 'go'),
    ('''Easter Island, the most remote place on Earth in terms of distance from other inhabited places, 
         has long been famous for the enormous stone statues erected by its prehistoric settlers.''', 'Easter Island has long been famous', 'famous'),
    ('In 1972, the United States government established Yellowstone National Park as the world\'s first legislated effort at nature conservation',
        'Yellowstone National Park was established', 'established'),
    ('A positive PSA test has to be followed up with a biopsy or other procedures before cancer can be confirmed.', 'Before confirming', 'confirming'),
    ('University of Michigan President Mary Sue Coleman said in a statement on the university\'s Web site, "Our fundamental values haven\'t changed.',
        'Our fundamental values', 'values'),
    ('At highest risk are those charasmatic species whose specieswide genetic diversity is very low.', 'Those charasmatic species', 'species'),
     ('The success of Johnson\'s novels is celebrated in an annual festival, called Longmire Day, held in the small town of Buffalo, Wyoming', 'Longmire Day is held', 'held'),
    ('At the 2006 census, its population was 599', '599 individuals', 'individuals')
]
first_phrases = []
for src, tgt, word in examples:
    #tgt_words = tgt#.split()
    #idx = tgt_words.index(word)
    #words_before_tgt_word = tgt_words[:idx]
    for w in tgt.rstrip('?.!,').split():
        first_phrases.append(w.capitalize())
    first_phrases.append(tgt[0: tgt.index(word)].strip())
print(sorted(set(first_phrases)))

paraphrases = ["All of the cats are great, in my opinion.",
               "The night was dark and stormy.",
               "Once they go, they won't be back.",
               "Easter Island has long been famous for the enormous stone statues erected by its prehistoric settlers.",
               "Yellowstone National Park was established by the US government in 1972 as the world\'s first legislated effort at nature conservation.",
               "Before confirming cancer, a positive PSA text has to be followed up with a biopsy or other procedures.",
               "Our fundamental values haven't changed, said University of Michigan President Mary Sue Coleman in a statement on the university's Web site.",
               "Those charasmatic species whose specieswide genetic diversity is very low are at the highest risk.",
               "Longmire Day is held in the small town of Buffalo, Wyoming to celebrate the success of Johnson's novels.",
               "599 individuals lived there at the 2006 census."
              ]
paraphrases

['599', 'All', 'All of the', 'Been', 'Before', 'Cats', 'Charasmatic', 'Confirming', 'Day', 'Easter', 'Easter Island has long been', 'Established', 'Famous', 'Fundamental', 'Go', 'Has', 'Held', 'Individuals', 'Is', 'Island', 'Long', 'Longmire', 'Longmire Day is', 'National', 'Night', 'Of', 'Once', 'Once they', 'Our', 'Our fundamental', 'Park', 'Species', 'The', 'They', 'Those', 'Those charasmatic', 'Values', 'Was', 'Yellowstone', 'Yellowstone National Park was']


['All of the cats are great, in my opinion.',
 'The night was dark and stormy.',
 "Once they go, they won't be back.",
 'Easter Island has long been famous for the enormous stone statues erected by its prehistoric settlers.',
 "Yellowstone National Park was established by the US government in 1972 as the world's first legislated effort at nature conservation.",
 'Before confirming cancer, a positive PSA text has to be followed up with a biopsy or other procedures.',
 "Our fundamental values haven't changed, said University of Michigan President Mary Sue Coleman in a statement on the university's Web site.",
 'Those charasmatic species whose specieswide genetic diversity is very low are at the highest risk.',
 "Longmire Day is held in the small town of Buffalo, Wyoming to celebrate the success of Johnson's novels.",
 '599 individuals lived there at the 2006 census.']

In [166]:
ranks = []
for src, tgt, word in examples:
    prefixes = [phrase +  ' ' + word for phrase in first_phrases]
    scores = [score_prefix(src, prefix) for prefix in prefixes]
    # compute rank of actual prefix in sorted scores
    sorted_scores = sorted(((score, result) for score, result in scores), reverse=True)
    rank = [y[1] for y in sorted_scores].index(tgt) + 1
    ranks.append((tgt, rank))

In [167]:
df = pd.DataFrame({'original': [ex[0] for ex in examples],
              'expected paraphrase': paraphrases,
              'expected prefix': [ex[1] for ex in examples],
              'word': [ex[2] for ex in examples],
              'rank of expected prefix': [el[1] for el in ranks]}).style.hide_index()
# df = df.style.set_properties(**{'text-align': 'left'})
pd.set_option('display.max_colwidth', None)
pd.set_option("display.colheader_justify","left")
df

original,expected paraphrase,expected prefix,word,rank of expected prefix
"In my opinion, all of the cats are great.","All of the cats are great, in my opinion.",All of the cats,cats,1
It was a dark and stormy night,The night was dark and stormy.,The night,night,1
They won't be back once they go.,"Once they go, they won't be back.",Once they go,go,1
"Easter Island, the most remote place on Earth in terms of distance from other inhabited places, has long been famous for the enormous stone statues erected by its prehistoric settlers.",Easter Island has long been famous for the enormous stone statues erected by its prehistoric settlers.,Easter Island has long been famous,famous,1
"In 1972, the United States government established Yellowstone National Park as the world's first legislated effort at nature conservation",Yellowstone National Park was established by the US government in 1972 as the world's first legislated effort at nature conservation.,Yellowstone National Park was established,established,1
A positive PSA test has to be followed up with a biopsy or other procedures before cancer can be confirmed.,"Before confirming cancer, a positive PSA text has to be followed up with a biopsy or other procedures.",Before confirming,confirming,11
"University of Michigan President Mary Sue Coleman said in a statement on the university's Web site, ""Our fundamental values haven't changed.","Our fundamental values haven't changed, said University of Michigan President Mary Sue Coleman in a statement on the university's Web site.",Our fundamental values,values,1
At highest risk are those charasmatic species whose specieswide genetic diversity is very low.,Those charasmatic species whose specieswide genetic diversity is very low are at the highest risk.,Those charasmatic species,species,2
"The success of Johnson's novels is celebrated in an annual festival, called Longmire Day, held in the small town of Buffalo, Wyoming","Longmire Day is held in the small town of Buffalo, Wyoming to celebrate the success of Johnson's novels.",Longmire Day is held,held,1
"At the 2006 census, its population was 599",599 individuals lived there at the 2006 census.,599 individuals,individuals,1


In [89]:
df

In [60]:
sorted_scores

[(-1.34, 'Those charasmatic species'),
 (-1.376, 'Charasmatic species'),
 (-1.947, 'Those species'),
 (-2.509, 'Famous species'),
 (-2.753, 'All of the species'),
 (-3.105, 'The species'),
 (-3.105, 'The species'),
 (-3.105, 'The species'),
 (-4.239, 'All species'),
 (-4.653, 'Our species'),
 (-6.252, 'Yellowstone species'),
 (-6.478, 'Confirming species'),
 (-6.546, 'As species'),
 (-6.962, 'Our fundamental species'),
 (-7.115, 'Species species'),
 (-7.127, 'Park species'),
 (-7.172, 'Fundamental species'),
 (-7.332, 'Cats species'),
 (-7.514, 'Once they species'),
 (-7.794, 'Of species'),
 (-7.826, 'They species'),
 (-7.84, 'Once species'),
 (-7.946, 'National species'),
 (-8.148, 'Yellowstone National park was species'),
 (-8.411, 'Established species'),
 (-8.44, 'Easter Island as long been species'),
 (-8.538, 'Before species'),
 (-8.538, 'Before species'),
 (-8.56, 'Island species'),
 (-9.178, 'Was species'),
 (-9.371, 'Easter species'),
 (-9.942, 'Long species'),
 (-10.228, 'Nigh

In [61]:
ranks

[1, 1, 1, 1, 1, 9, 1, 1]

In [25]:
import pandas as pd

In [26]:
pd.DataFrame({'scores': [prefix[0] for prefix in scores], 'sentences': [prefix[1] for prefix in scores]}).sort_values('scores', ascending=False)

Unnamed: 0,scores,sentences
7,-2.84,The night
0,-4.539,All night
12,-4.809,Stormy night
5,-4.843,My night
9,-5.255,Was night
10,-5.375,Dark night
11,-5.773,And night
3,-5.852,"Great, night"
8,-6.511,Night night
4,-6.621,In night


In [27]:
sorted(((score, result) for score, result in scores), reverse=True)

[(-2.84, ' The night'),
 (-4.539, ' All night'),
 (-4.809, ' Stormy night'),
 (-4.843, ' My night'),
 (-5.255, ' Was night'),
 (-5.375, ' Dark night'),
 (-5.773, ' And night'),
 (-5.852, ' Great, night'),
 (-6.511, ' Night night'),
 (-6.621, ' In night'),
 (-7.173, ' Cats night'),
 (-10.187, ' Opinion night'),
 (-10.31, ' Are night')]

In [68]:
import re

In [72]:
(' ' + 'All cats are great, info my opinion').index(' ' + 'in')

19

In [77]:
print('a\nb')

a
b


In [78]:
print(r'a\nb')

a\nb


In [75]:
re.search(r'\bgreat\b', 'All of the cats are great, opinion.')

<re.Match object; span=(20, 25), match='great'>

In [43]:
english = ">>es<<In my opinion, all cats are great."
engbatch = en_ROMANCE_tokenizer.prepare_translation_batch([english])
eng_to_spanish = en_ROMANCE.generate(**engbatch)
machine_translation = en_ROMANCE_tokenizer.decode(eng_to_spanish[0])

start = "cats"

tokenizer = ROMANCE_en_tokenizer
model = ROMANCE_en

first_words = ["The", "All", ""]
wordlist = english.split(' ')
for word in wordlist:
    first_words.append(word.capitalize())
first_words
    
results = []
scores = []
MAX_LENGTH = 100
    
for word in first_words:
    join_prefix_str = word + " " + start
    tokenized_prefix = tokenizer.convert_tokens_to_ids(en_ROMANCE_tokenizer.tokenize(join_prefix_str.strip()))
    prefix = torch.LongTensor(tokenized_prefix)
#     tokenized_prefix = tokenizer.convert_tokens_to_ids(en_ROMANCE_tokenizer.tokenize(start))
#     prefix = torch.LongTensor(tokenized_prefix)

    batch = tokenizer.prepare_translation_batch([machine_translation.replace("<pad> ", '')])
    english_encoded = model.get_encoder()(**batch)
    decoder_start_token = model.config.decoder_start_token_id
    # pylint: disable=E1101
    partial_decode = torch.LongTensor([decoder_start_token]).unsqueeze(0)
    past = (english_encoded, None)
    # pylint: enable=E1101

    num_tokens_generated = 0
    prediction_list = []
    total = 0
    #stop when </s> token generated, or max num tokens exceded (just in case)
    while True:
        model_inputs = model.prepare_inputs_for_generation(
        partial_decode, past=past, attention_mask=batch['attention_mask'], use_cache=model.config.use_cache
        )
        with torch.no_grad():
            model_outputs = model(**model_inputs)

        next_token_logits = model_outputs[0][:, -1, :]
        past = model_outputs[1]

        #start with user inputted beginning
        if num_tokens_generated < len(prefix):
            next_token_to_add = prefix[num_tokens_generated]
        else:
            break
#             next_token_to_add = next_token_logits[0].argmax()

        next_token_logprobs = next_token_logits - next_token_logits.logsumexp(1, True)
        score = next_token_logprobs[0][next_token_to_add].item()

        total += score

        #add new token to tokens so far
        partial_decode = torch.cat((partial_decode, next_token_to_add.unsqueeze(0).unsqueeze(0)), -1)
        num_tokens_generated+= 1

        if next_token_to_add.item() == 0 or not (num_tokens_generated < MAX_LENGTH):
            break

    #list of tokens used to display sentence
    decoded_tokens = [sub.replace('\u2581', '\u00a0') for sub in tokenizer.convert_ids_to_tokens(partial_decode[0])]
    decoded_tokens.remove("<pad>")

    final = tokenizer.decode(partial_decode[0]).replace("<pad>", '')
    score = round(total/(len(decoded_tokens)), 3)
    results.append(final)
    scores.append(score)

    print("\n" + final)
    print(total)

ind = scores.index(max(scores))
winner = results[ind]
print("\nMost likely: ", winner)


 The cats are all great, in my opinion.
-11.56362009048462

 All cats are great, in my opinion.
-7.648155689239502

 cats are great, in my opinion.
-15.431335926055908

 ⁇  in cats are great.
-33.81710910797119

 My cats are all great.
-18.766332626342773

 Opinion, cats are all great.
-20.298365592956543

 All cats are great, in my opinion.
-7.648155689239502

 Cats cats are great, in my opinion.
-25.229657649993896

 Are cats great?
-17.362116813659668

 Great. cats are all great.
-24.31656265258789

Most likely:   All cats are great, in my opinion.


In [61]:
sorted(set(first_words))

['', '>>es<<in', 'All', 'Are', 'Cats', 'Great.', 'My', 'Opinion,', 'The']

In [54]:
sorted(((score, result) for score, result in zip(scores, results)), reverse=True)

[(-0.765, ' All cats are great, in my opinion.'),
 (-0.765, ' All cats are great, in my opinion.'),
 (-1.051, ' The cats are all great, in my opinion.'),
 (-1.715, ' cats are great, in my opinion.'),
 (-2.294, ' Cats cats are great, in my opinion.'),
 (-2.537, ' Opinion, cats are all great.'),
 (-2.681, ' My cats are all great.'),
 (-3.04, ' Great. cats are all great.'),
 (-3.472, ' Are cats great?'),
 (-4.831, ' ⁇  in cats are great.')]

In [55]:
sorted(zip(scores, results))

[(-4.831, ' ⁇  in cats are great.'),
 (-3.472, ' Are cats great?'),
 (-3.04, ' Great. cats are all great.'),
 (-2.681, ' My cats are all great.'),
 (-2.537, ' Opinion, cats are all great.'),
 (-2.294, ' Cats cats are great, in my opinion.'),
 (-1.715, ' cats are great, in my opinion.'),
 (-1.051, ' The cats are all great, in my opinion.'),
 (-0.765, ' All cats are great, in my opinion.'),
 (-0.765, ' All cats are great, in my opinion.')]

In [51]:
[(scores[ind], results[ind]) for ind in torch.tensor(scores).topk(len(scores)).indices]

[(-0.765, ' All cats are great, in my opinion.'),
 (-0.765, ' All cats are great, in my opinion.'),
 (-1.051, ' The cats are all great, in my opinion.'),
 (-1.715, ' cats are great, in my opinion.'),
 (-2.294, ' Cats cats are great, in my opinion.'),
 (-2.537, ' Opinion, cats are all great.'),
 (-2.681, ' My cats are all great.'),
 (-3.04, ' Great. cats are all great.'),
 (-3.472, ' Are cats great?'),
 (-4.831, ' ⁇  in cats are great.')]

In [22]:


tokenizer = en_ROMANCE_tokenizer
model = en_ROMANCE
#tokenizer = ROMANCE_en_tokenizer
#model = ROMANCE_en


english = ">>es<<It was a dark and stormy night"
batch = tokenizer.prepare_translation_batch([english])
english_encoded = model.get_encoder()(**batch)
decoder_start_token = model.config.decoder_start_token_id

starting_word = tokenizer.convert_tokens_to_ids(tokenizer.tokenize("La"))
#starting_word = tokenizer.encode("A la")

partial_decode = torch.LongTensor([decoder_start_token]).unsqueeze(0)
past = (english_encoded, None)

prefix = torch.LongTensor(starting_word)
#prefix = torch.LongTensor(tokenizer.convert_tokens_to_ids("A la".replace(' ', '▁').split('▁')))
#prefix = torch.LongTensor(tokenizer.encode("A la"))
#prefix = torch.LongTensor(tokenizer.encode("A▁la▁"))
next_token_to_add = torch.tensor(1)
x = 0

prediction_list = []

while next_token_to_add.item() != 0 and x < 100:
    print(partial_decode)
    model_inputs = model.prepare_inputs_for_generation(
    partial_decode, past=past, attention_mask=batch['attention_mask'], use_cache=model.config.use_cache
    )
    with torch.no_grad():
        model_outputs = model(**model_inputs)

    next_token_logits = model_outputs[0][:, -1, :]
    past = model_outputs[1]
    
    if x < len(prefix):
        next_token_to_add = prefix[x]
        print(x, ": ", next_token_to_add)
    else:
        next_token_to_add = next_token_logits[0].argmax()

    decoded_predictions = []
    for tok in next_token_logits[0].topk(5).indices:
        decoded_predictions.append(tokenizer.convert_ids_to_tokens(tok.item()).replace('\u2581', '\u00a0'))
    
    prediction_list.append(decoded_predictions)
    
    partial_decode = torch.cat((partial_decode, next_token_to_add.unsqueeze(0).unsqueeze(0)), -1)
    x += 1
    
decoded_tokens = tokenizer.convert_ids_to_tokens(partial_decode[0])

decoded_tokens = [sub.replace('\u2581', '\u00a0') for sub in tokenizer.convert_ids_to_tokens(partial_decode[0])] 

final = tokenizer.decode(partial_decode[0]).split("<pad>")[1]
print(final)
print(decoded_tokens)
print(prediction_list)

batch2 = ROMANCE_en_tokenizer.prepare_translation_batch([">>en<< " + final])
spanish_to_english = ROMANCE_en.generate(**batch2)
new_english = ROMANCE_en_tokenizer.decode(spanish_to_english[0]).split("<pad>")[1]

new_english

tensor([[65000]])
0 :  tensor(100)
tensor([[65000,   100]])
tensor([[65000,   100,  4997]])
tensor([[65000,   100,  4997,   531]])
tensor([[65000,   100,  4997,   531,   231]])
tensor([[65000,   100,  4997,   531,   231, 11023]])
tensor([[65000,   100,  4997,   531,   231, 11023,    32]])
tensor([[65000,   100,  4997,   531,   231, 11023,    32,   805]])
tensor([[65000,   100,  4997,   531,   231, 11023,    32,   805,  6386]])
tensor([[65000,   100,  4997,   531,   231, 11023,    32,   805,  6386,   570]])
tensor([[65000,   100,  4997,   531,   231, 11023,    32,   805,  6386,   570,
          3301]])
tensor([[65000,   100,  4997,   531,   231, 11023,    32,   805,  6386,   570,
          3301,     3]])
 La noche era oscura y tempestuosa.
['<pad>', '\xa0La', '\xa0noche', '\xa0era', '\xa0os', 'cura', '\xa0y', '\xa0tem', 'pes', 'tu', 'osa', '.', '</s>']
[['\xa0Era', '\xa0Fue', '\xa0Es', '\xa0Ha', '\xa0Una'], ['\xa0noche', '\xa0tarde', '\xa0oscuridad', '\xa0verdad', '\xa0vela'], ['\xa0era

' The night was dark and stormy.'

In [23]:
import random

In [24]:
tokenizer = en_ROMANCE_tokenizer
model = en_ROMANCE

english = ">>es<<It was a dark and stormy night"
batch = tokenizer.prepare_translation_batch([english])
english_encoded = model.get_encoder()(**batch)
decoder_start_token = model.config.decoder_start_token_id

starting_word = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(""))
prefix = torch.LongTensor(starting_word)
prediction_list = []
alternatives = []
for x in range(0, 5):

    past = (english_encoded, None)
   
    partial_decode = torch.LongTensor([decoder_start_token]).unsqueeze(0)
    y = 0
    next_token_to_add = torch.tensor(1)
    while next_token_to_add.item() != 0 and y < 100:    
        model_inputs = model.prepare_inputs_for_generation(
        partial_decode, past=past, attention_mask=batch['attention_mask'], use_cache=model.config.use_cache
        )
        with torch.no_grad():
            model_outputs = model(**model_inputs)
        next_token_logits = model_outputs[0][:, -1, :]
        past = model_outputs[1]

        if y < len(prefix):
            next_token_to_add = prefix[y]
        else:
            next_token_to_add = next_token_logits[0].topk(4).indices[random.randint(0, 1)]
        decoded_predictions = []
        for tok in next_token_logits[0].topk(5).indices:
            decoded_predictions.append(tokenizer.convert_ids_to_tokens(tok.item()).replace('\u2581', '\u00a0'))

        prediction_list.append(decoded_predictions)
        partial_decode = torch.cat((partial_decode, next_token_to_add.unsqueeze(0).unsqueeze(0)), -1)
        y += 1

    decoded_tokens = tokenizer.convert_ids_to_tokens(partial_decode[0])

    final = tokenizer.decode(partial_decode[0]).split("<pad>")[1]
    print(final)
    alternatives.append(final)
    
alternatives

 Fue un día oscura, tempestuoso
 Era un día oscuro, tempunático
 Era una oscuridad y una noche tormenta.
 Era un noche sombrías, tempunática
 Fue un día oscura y tormenta.


[' Fue un día oscura, tempestuoso',
 ' Era un día oscuro, tempunático',
 ' Era una oscuridad y una noche tormenta.',
 ' Era un noche sombrías, tempunática',
 ' Fue un día oscura y tormenta.']

In [25]:
tokenizer = en_ROMANCE_tokenizer
model = en_ROMANCE

english = ">>es<<Don't be so mean."
batch = tokenizer.prepare_translation_batch([english])
                                            
eng_to_spanish = model.generate(num_beams=50, num_return_sequences=50, **batch)

translations = []
for x in range(0, 50):
    translations.append(tokenizer.decode(eng_to_spanish[x]).split("<pad>")[1])

translations

[' No seas tan malo.',
 ' No seas tan mala.',
 ' No seas tan malvado.',
 ' No seas tan cruel.',
 ' No seas tan malvada.',
 ' No seas tan mezquino.',
 ' No seas tan desagradable.',
 ' No seas tan mezquina.',
 ' No seas malvado.',
 ' No seáis tan malos.',
 ' No seas tan malos.',
 ' No seas tan repugnante.',
 ' No seas tan miserable.',
 ' No seas malvada.',
 ' No seáis tan malvados.',
 ' No seas muy mala.',
 ' - No seas tan malo.',
 ' - No seas tan mala.',
 ' No seas muy malo.',
 ' ¡No seas tan mala!',
 ' ¡No seas tan malo!',
 ' - No seas tan malvado.',
 ' No sean tan malos.',
 ' No seas así de malo.',
 ' No seas muy malvado.',
 ' No te pongas tan mal.',
 ' No seas tan mesquino.',
 ' No sea tan malo.',
 ' No te pongas mal.',
 ' No seas así de mala.',
 ' No seas tan méchanta.',
 ' No seas así de malvado.',
 ' ¡No seas tan malvada!',
 ' No seas tan mesquita.',
 ' No sea tan malvado.',
 ' No seas mala.',
 ' ¡No seas tan malvado!',
 ' - No seas tan malvada.',
 ' No sean tan malvados.',
 ' - N

In [26]:
english = ">>es<<I think maybe this is the wrong store."
batch = tokenizer.prepare_translation_batch([english])
                                            
spanish = model.generate(**batch)

batch = ROMANCE_en_tokenizer.prepare_translation_batch([">>en<<" + tokenizer.decode(spanish[0]).replace("<pad>", '')])
back_to_english = ROMANCE_en.generate(**batch)
machine_translation = ROMANCE_en_tokenizer.decode(back_to_english[0]).replace("<pad>", '')

spanish_encoded = ROMANCE_en.get_encoder()(**batch)
decoder_start_token = ROMANCE_en.config.decoder_start_token_id
# pylint: disable=E1101
partial_decode = torch.LongTensor([decoder_start_token]).unsqueeze(0)
past = (spanish_encoded, None)
# pylint: enable=E1101
next_token_to_add = torch.tensor(1)
x = 0

prediction_list = []

while next_token_to_add.item() != 0 and x < 100:
    model_inputs = ROMANCE_en.prepare_inputs_for_generation(
    partial_decode, past=past, attention_mask=batch['attention_mask'], use_cache=ROMANCE_en.config.use_cache
    )
    with torch.no_grad():
        model_outputs = ROMANCE_en(**model_inputs)

    next_token_logits = model_outputs[0][:, -1, :]
    past = model_outputs[1]
   
    if x == 3:
        next_token_to_add = next_token_logits[0].topk(4).indices[1]
    else:
        next_token_to_add= next_token_logits[0].argmax()
    decoded_predictions = []
    for tok in next_token_logits[0].topk(10).indices:
        decoded_predictions.append(ROMANCE_en_tokenizer.convert_ids_to_tokens(tok.item()).replace('\u2581', '\u00a0'))

    prediction_list.append(decoded_predictions)

    partial_decode = torch.cat((partial_decode, next_token_to_add.unsqueeze(0).unsqueeze(0)), -1)
    x+= 1

decoded_tokens = [sub.replace('\u2581', '\u00a0') for sub in ROMANCE_en_tokenizer.convert_ids_to_tokens(partial_decode[0])]
decoded_tokens.remove("<pad>")
final = ROMANCE_en_tokenizer.decode(partial_decode[0]).split("<pad>")[1]

final

'. I believe this is the wrong store.'