In [1]:
# !pip install pytorch_pretrained_bert

In [2]:
import numpy as np
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

In [3]:
from transformers import AutoModelWithLMHead

In [4]:
# Load pre-trained model (weights)

# If you just want to produce phrases with a base case use the following as your model_version
model_version = 'bert-base-uncased'

# Else, provide a path to a fine-tuned model, which includes the following files:
# pytorch_model.bin
# config.json
# vocab.txt
# special_tokens_map.txt
# tokenizer_config.txt
# training_args.bin
# instructions to fine tune are in the *insert_path_here*
model_name = ""
path_to_finetuned = "..\\fine_tune_model\\model_name\\"  # double \\ is for windows, remove one for each in mac.
model = BertForMaskedLM.from_pretrained(path_to_finetuned)
model.eval()

# use CUDA if available. should be mentioned in readme for version
cuda = torch.cuda.is_available()
if cuda:
    model = model.cuda()

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained(path_to_finetuned, do_lower_case=True)

def tokenize_batch(batch):
    return [tokenizer.convert_tokens_to_ids(sent) for sent in batch]

def untokenize_batch(batch):
    return [tokenizer.convert_ids_to_tokens(sent) for sent in batch]

def detokenize(sent):
    """ Roughly detokenizes (mainly undoes wordpiece) """
    new_sent = []
    for i, tok in enumerate(sent):
        if tok.startswith("##"):
            new_sent[len(new_sent) - 1] = new_sent[len(new_sent) - 1] + tok[2:]
        else:
            new_sent.append(tok)
    return new_sent

CLS = '[CLS]'
SEP = '[SEP]'
MASK = '[MASK]'
mask_id = tokenizer.convert_tokens_to_ids([MASK])[0]
sep_id = tokenizer.convert_tokens_to_ids([SEP])[0]
cls_id = tokenizer.convert_tokens_to_ids([CLS])[0]

In [5]:
def generate_step(out, gen_idx, temperature=None, top_k=0, sample=False, return_list=True):
    """ Generate a word from from out[gen_idx]
    
    args:
        - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
        - gen_idx (int): location for which to generate for
        - top_k (int): if >0, only sample from the top k most probable words
        - sample (Bool): if True, sample from full distribution. Overridden by top_k 
    """
    logits = out[:, gen_idx]
    if temperature is not None:
        logits = logits / temperature
    if top_k > 0:
        kth_vals, kth_idx = logits.topk(top_k, dim=-1)
        dist = torch.distributions.categorical.Categorical(logits=kth_vals)
        idx = kth_idx.gather(dim=1, index=dist.sample().unsqueeze(-1)).squeeze(-1)
    elif sample:
        dist = torch.distributions.categorical.Categorical(logits=logits)
        idx = dist.sample().squeeze(-1)
    else:
        idx = torch.argmax(logits, dim=-1)
    return idx.tolist() if return_list else idx
  
  
def get_init_text(seed_text, max_len, batch_size = 1, rand_init=False):
    """ Get initial sentence by padding seed_text with either masks or random words to max_len """
    batch = [seed_text + [MASK] * max_len + [SEP] for _ in range(batch_size)]
    #if rand_init:
    #    for ii in range(max_len):
    #        init_idx[seed_len+ii] = np.random.randint(0, len(tokenizer.vocab))
    
    return tokenize_batch(batch)

def printer(sent, should_detokenize=True):
    if should_detokenize:
        sent = detokenize(sent)[1:-1]
    print(" ".join(sent))


This is the meat of the algorithm. The general idea is
1. start from all masks
2. repeatedly pick a location, mask the token at that location, and generate from the probability distribution given by BERT
3. stop when converged or tired of waiting

We consider three "modes" of generating:
- generate a single token for a position chosen uniformly at random for a chosen number of time steps
- generate in sequential order (L->R), one token at a time
- generate for all positions at once for a chosen number of time steps

The `generate` function wraps and batches these three generation modes. In practice, we find that the first leads to the most fluent samples.

In [6]:
# Generation modes as functions
import math
import time

def parallel_sequential_generation(seed_text, batch_size=10, max_len=15, top_k=0, temperature=None, max_iter=300, burnin=200,
                                   cuda=False, print_every=10, verbose=True):
    """ Generate for one random position at a timestep
    
    args:
        - burnin: during burn-in period, sample from full distribution; afterwards take argmax
    """
    seed_len = len(seed_text)
    batch = get_init_text(seed_text, max_len, batch_size)
    
    for ii in range(max_iter):
        kk = np.random.randint(0, max_len)
        for jj in range(batch_size):
            batch[jj][seed_len+kk] = mask_id
        inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)
        out = model(inp)
        topk = top_k if (ii >= burnin) else 0
        idxs = generate_step(out, gen_idx=seed_len+kk, top_k=topk, temperature=temperature, sample=(ii < burnin))
        for jj in range(batch_size):
            batch[jj][seed_len+kk] = idxs[jj]
            
        if verbose and np.mod(ii+1, print_every) == 0:
            for_print = tokenizer.convert_ids_to_tokens(batch[0])
            for_print = for_print[:seed_len+kk+1] + ['(*)'] + for_print[seed_len+kk+1:]
            print("iter", ii+1, " ".join(for_print))
            
    return untokenize_batch(batch)

def parallel_generation(seed_text, batch_size=10, max_len=15, top_k=0, temperature=None, max_iter=300, sample=True, 
                        cuda=False, print_every=10, verbose=True):
    """ Generate for all positions at each time step """
    seed_len = len(seed_text)
    batch = get_init_text(seed_text, max_len, batch_size)
    
    for ii in range(max_iter):
        inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)
        out = model(inp)
        for kk in range(max_len):
            idxs = generate_step(out, gen_idx=seed_len+kk, top_k=top_k, temperature=temperature, sample=sample)
            for jj in range(batch_size):
                batch[jj][seed_len+kk] = idxs[jj]
            
        if verbose and np.mod(ii, print_every) == 0:
            print("iter", ii+1, " ".join(tokenizer.convert_ids_to_tokens(batch[0])))
    
    return untokenize_batch(batch)
            
def sequential_generation(seed_text, batch_size=10, max_len=15, leed_out_len=15, 
                          top_k=0, temperature=None, sample=True, cuda=False):
    """ Generate one word at a time, in L->R order """
    seed_len = len(seed_text)
    batch = get_init_text(seed_text, max_len, batch_size)
    
    for ii in range(max_len):
#         print(ii)
        inp = [sent[:seed_len+ii+leed_out_len]+[sep_id] for sent in batch]
#         print(inp)
        inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)
#         print(inp)
        out = model(inp)
#         print(out)
        idxs = generate_step(out, gen_idx=seed_len+ii, top_k=top_k, temperature=temperature, sample=sample)
#         print(idxs)
        for jj in range(batch_size):
            batch[jj][seed_len+ii] = idxs[jj]
        
    return untokenize_batch(batch)


def generate(n_samples, seed_text="[CLS]", batch_size=10, max_len=25, 
             generation_mode="parallel-sequential",
             sample=True, top_k=100, temperature=1.0, burnin=200, max_iter=500,
             cuda=False, print_every=1):
    # main generation function to call
    sentences = []
    n_batches = math.ceil(n_samples / batch_size)
    start_time = time.time()
    for batch_n in range(n_batches):
        if generation_mode == "parallel-sequential":
            batch = parallel_sequential_generation(seed_text, batch_size=batch_size, max_len=max_len, top_k=top_k,
                                                   temperature=temperature, burnin=burnin, max_iter=max_iter, 
                                                   cuda=cuda, verbose=False)
        elif generation_mode == "sequential":
            batch = sequential_generation(seed_text, batch_size=batch_size, max_len=max_len, top_k=top_k, 
                                          temperature=temperature, leed_out_len=leed_out_len, sample=sample,
                                          cuda=cuda)
        elif generation_mode == "parallel":
            batch = parallel_generation(seed_text, batch_size=batch_size,
                                        max_len=max_len, top_k=top_k, temperature=temperature, 
                                        sample=sample, max_iter=max_iter, 
                                        cuda=cuda, verbose=False)
        
        if (batch_n + 1) % print_every == 0:
            print("Finished batch %d in %.3fs" % (batch_n + 1, time.time() - start_time))
            start_time = time.time()
        
        sentences += batch
    return sentences

Let's call the actual generation function! We'll use the following settings
- max_len (40): length of sequence to generate
- top_k (100): at each step, sample from the top_k most likely words
- temperature (1.0): smoothing parameter for the next word distribution. Higher means more like uniform; lower means more peaky
- burnin (250): for non-sequential generation, for the first burnin steps, sample from the entire next word distribution, instead of top_k
- max_iter (500): number of iterations to run for
- seed_text (["CLS"]): prefix to generate for. We found it crucial to start with the CLS token; you can try adding to it 

In [20]:
n_samples = 10
batch_size = 10
max_len = 20
top_k = 75
temperature = 1.25
generation_mode = "sequential"
leed_out_len = 5 # max_len
burnin = 500
sample = False
max_iter = 1000

# Choose the prefix context
seed_text = "[CLS] my friends".split(' ')
bert_sents = generate(n_samples, seed_text=seed_text, batch_size=batch_size, max_len=max_len,
                      generation_mode=generation_mode,
                      sample=sample, top_k=top_k, temperature=temperature, burnin=burnin, max_iter=max_iter,
                      cuda=cuda)

Finished batch 1 in 1.229s


In [21]:
seed_text

['[CLS]', 'my', 'friends']

In [22]:
# Create examples:

for sent in bert_sents:
    printer(sent, should_detokenize=True)

my friends ran on to that crowd then got ' em out with me after that boy , i just met girl
my friends had no , yeah on the way in the morning , yeah , you go to your fucking room ,
my friends at blue road and down the road then gone , gone , gone , gone , gone , gone ,
my friends ' i get better , as my stars are out with you ' love my best , hard work ,
my friends in the town were a good shot through these spirits , and a full bottle of a man around ?
my friends ' gone down , so oh - so dear as you were ' love me , that i know but
my friends feel and remember ' re in love , i say , to my friends that you have me back now
my friends of you sees me ill , and the man , the bad fellow he ever had he was .
my friends all won a hard time , oh oh oh uh and my hair do so wonder where going that .
my friends ' asse at the bottom of a dirty hat there come you run all me roads to glory ride
