In [199]:
import numpy
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

In [None]:
# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained('bert-large-uncased')
model.eval()

In [201]:
# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased', do_lower_case=True)

In [225]:
import copy

original_sent = 'seoul is the capital of south korea .'.lower().split()

for ii in range(len(original_sent)):
    new_sent = copy.copy(original_sent)
    new_sent[ii] = '[MASK]'
#     new_sent[ii] = tokenizer.convert_ids_to_tokens([numpy.random.randint(0, len(tokenizer.vocab))])[0]
    out = model(torch.tensor([tokenizer.convert_tokens_to_ids(new_sent)]))
    pred = tokenizer.convert_ids_to_tokens([out[0][ii].max(0)[1].item()])[0]
    print(" ".join(new_sent), "=>", pred)

shields is the capital of south korea . => is
seoul ##amp the capital of south korea . => is
seoul is ##6 capital of south korea . => the
seoul is the tack of south korea . => tack
seoul is the capital urgency south korea . => of
seoul is the capital of 900 korea . => 900
seoul is the capital of south ##ר . => .
seoul is the capital of south korea gilded => .


['way']

In [204]:
''' sequential generation '''

sample = True
max_len = 20
leed_out_len = 3 #max_len
random_future = False

seed_text = 'the meaning of life is'.split()
seed_len = len(seed_text)

init_text = seed_text + ['[MASK]'] * max_len
init_idx = tokenizer.convert_tokens_to_ids(init_text)
if random_future:
    for ii in range(max_len):
        init_idx[seed_len+ii] = numpy.random.randint(0, len(tokenizer.vocab))

for ii in range(max_len):
    out = model(torch.tensor([init_idx[:seed_len+ii+leed_out_len]]))
    if sample:
        dist = torch.distributions.categorical.Categorical(logits=out[0,seed_len+ii])
        init_idx[seed_len+ii] = dist.sample().item()
    else:
        init_idx[seed_len+ii] = torch.max(out[0, seed_len+ii],0)[1].item()

print(init_idx)
print(" ".join(tokenizer.convert_ids_to_tokens(init_idx)))
print(" ".join(tokenizer.convert_ids_to_tokens(init_idx)).replace(" ##", ""))

[1996, 3574, 1997, 2166, 2003, 999, 20703, 1997, 2037, 2279, 2299, 2108, 1012, 1000, 2748, 1012, 2028, 3634, 2809, 4551, 1000, 1011, 2027, 2024, 2315]
the meaning of life is ! refrain of their next song being . " yes . one hundred eight billion " - they are named
the meaning of life is ! refrain of their next song being . " yes . one hundred eight billion " - they are named


In [205]:
''' parallel generation '''

sample = True
max_iter = 100
viz_int = 10
max_len = 20

seed_text = 'the meaning of life is'.split()
seed_len = len(seed_text)

init_text = seed_text + ['[MASK]'] * max_len
init_idx = tokenizer.convert_tokens_to_ids(init_text)
for ii in range(max_len):
    init_idx[seed_len+ii] = numpy.random.randint(0, len(tokenizer.vocab))

for ii in range(max_iter):
    out = model(torch.tensor([init_idx]))
    for kk in range(max_len):
        if sample:
            dist = torch.distributions.categorical.Categorical(logits=out[0,seed_len+kk])
            init_idx[seed_len+kk] = dist.sample().item()
        else:
            init_idx[seed_len+kk] = torch.max(out[0, seed_len+kk],0)[1].item()
    if numpy.mod(ii, viz_int) == 0:
        print("iter", ii+1, " ".join(tokenizer.convert_ids_to_tokens(init_idx)))

iter 1 the meaning of life is a turn 30 percent male milo children for played with 43 " celebrated " " makes ##hop matching " said
iter 11 the meaning of life is that and i and . . to as i , i of that " to and i i " ,
iter 21 the meaning of life is that and i and . . to an i , i of that " to and i i " ,
iter 31 the meaning of life is of and i . . . to an i , i of that " to and and i " ,
iter 41 the meaning of life is of and i . . . to an i , i of that " to and and i " ,
iter 51 the meaning of life is of and i . . . to an i , i of that " to and and that in ,
iter 61 the meaning of life is of , who . . . to of " , i of those , , and and that , ,
iter 71 the meaning of life is to , to . . . to i " , and , and , , and and that , ,
iter 81 the meaning of life is to , to . . . to i " , and , and , and and and they were ,
iter 91 the meaning of life is to , and . . . , and , , and , and , and , . . . ,


In [207]:
''' parallel-sequential generation '''

# sample = True
burnin = 50
max_iter = 100
viz_int = 10
max_len = 5

seed_text = 'the meaning of life is'.split()
seed_len = len(seed_text)

init_text = seed_text + ['[MASK]'] * max_len
init_idx = tokenizer.convert_tokens_to_ids(init_text)
# for ii in range(max_len):
#     init_idx[seed_len+ii] = numpy.random.randint(0, len(tokenizer.vocab))

for ii in range(max_iter):
    kk = numpy.random.randint(0, max_len)
    init_idx[seed_len+kk] = tokenizer.convert_tokens_to_ids(['[MASK]'])[0]
    out = model(torch.tensor([init_idx]))
    if ii < burnin:
        dist = torch.distributions.categorical.Categorical(logits=out[0,seed_len+kk])
        init_idx[seed_len+kk] = dist.sample().item()
    else:
        init_idx[seed_len+kk] = torch.max(out[0, seed_len+kk],0)[1].item()
        
    if numpy.mod(ii, viz_int) == 0:
        for_print = tokenizer.convert_ids_to_tokens(init_idx)
        for_print = for_print[:seed_len+kk+1] + ['(*)'] + for_print[seed_len+kk+1:]
        print("iter", ii+1, " ".join(for_print))

iter 1 the meaning of life is layla (*) [MASK] [MASK] [MASK] [MASK]
iter 11 the meaning of life is . true - [MASK] was (*)
iter 21 the meaning of life is not corrupted is (*) from the
iter 31 the meaning of life is disco (*) played were not it
iter 41 the meaning of life is ku the (*) and find cholera
iter 51 the meaning of life is in english (*) , f ##gr
iter 61 the meaning of life is both (*) the , and the
iter 71 the meaning of life is both the , and (*) the
iter 81 the meaning of life is about the man (*) and the
iter 91 the meaning of life is about the man and (*) the
