In [1]:
!pip install pytorch_pretrained_bert



In [17]:
"""Try to predict a single masked-out word"""

import torch
from pytorch_pretrained_bert import BertTokenizer, BertForMaskedLM

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')

# Tokenized input
text1 = "[CLS] And the riot squad they're restless, they need somewhere to go [SEP]"
text2 = "As Lady and I look out tonight, from Desolation Row. [SEP]"
tokenized_text1 = tokenizer.tokenize(text1)
tokenized_text2 = tokenizer.tokenize(text2)
tokenized_text = tokenized_text1 + tokenized_text2

# Mask a token that we will try to predict back with `BertForMaskedLM`
masked_index = tokenized_text2.index('tonight') + len(tokenized_text1)
tokenized_text[masked_index] = '[MASK]'

# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)

# Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
segments_ids = [0]*len(tokenized_text1) + [1]*(len(tokenized_text2))

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained('bert-large-uncased')
model.eval()

# If you have a GPU, put everything on cuda
if torch.cuda.is_available():
    tokens_tensor = tokens_tensor.to('cuda')
    segments_tensors = segments_tensors.to('cuda')
    model.to('cuda')

# Predict all tokens
with torch.no_grad():
    predictions = model(tokens_tensor, segments_tensors)

# Confirm we were able to predict the correct '[MASK]'
predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
print("\n","Predicted [MASK] = ",predicted_token,"\n")


 Predicted [MASK] =  together 



In [18]:
""" Try to generate from BERT """

import torch
import torch.nn.functional as F
from torch.distributions import Categorical
from pytorch_pretrained_bert import BertTokenizer, BertForMaskedLM

MASK = "[MASK]"
MASK_ATOM = "[MASK]"

def preprocess(tokens, tokenizer):
    """ Preprocess the sentence by tokenizing and converting to tensor """
    tok_ids = tokenizer.convert_tokens_to_ids(tokens)
    tok_tensor = torch.tensor([tok_ids])
    return tok_tensor


def get_mask_ids(masking):
    if masking:
      mask_ids = [int(d) for d in masking.split(',')]
    else:
      mask_ids = []     
    return mask_ids

  
def get_seed_sent(seed_sentence, tokenizer, masking=None, n_append_mask=0):
    """ Get initial sentence to decode from, possible with masks """

    # Get initial mask
    mask_ids = get_mask_ids(masking)

    # Tokenize, respecting [MASK]
    seed_sentence = seed_sentence.replace(MASK, MASK_ATOM)
    toks = tokenizer.tokenize(seed_sentence)
    for i, tok in enumerate(toks):
        if tok == MASK_ATOM:
            mask_ids.append(i)

    # Mask the input
    for mask_id in mask_ids:
        toks[mask_id] = MASK

    # Append MASKs
    for _ in range(n_append_mask):
        mask_ids.append(len(toks))
        toks.append(MASK)
    mask_ids = sorted(list(set(mask_ids)))

    seg = [0] * len(toks)
    seg_tensor = torch.tensor([seg])
    return toks, seg_tensor, mask_ids

  
def load_model(version):
    """ Load model """
    model = BertForMaskedLM.from_pretrained(version)
    model.eval()
    return model


def predict(model, tokenizer, tok_tensor, seg_tensor, how_select="argmax"):
    """ Get model predictions and convert back to tokens """
    preds = model(tok_tensor, seg_tensor)

    if how_select == "sample":
        dist = Categorical(logits=F.log_softmax(preds[0], dim=-1))
        pred_idxs = dist.sample().tolist()
    elif how_select == "topk":
        kth_vals, kth_idx = F.log_softmax(preds[0], dim=-1).topk(3, dim=-1)
        dist = Categorical(logits=kth_vals)
        pred_idxs = kth_idx.gather(dim=1, index=dist.sample().unsqueeze(-1)).squeeze(-1).tolist()
    elif how_select == "argmax":
        pred_idxs = preds.argmax(dim=-1).tolist()[0]
    else:
        raise NotImplementedError("Prediction procedure %s not found!" % how_select)

    pred_toks = tokenizer.convert_ids_to_tokens(pred_idxs)
    return pred_toks

def sequential_decoding(toks, seg_tensor, model, tokenizer, selection_strategy):
    """ Decode from model one token at a time """
    for step_n in range(len(toks)):
        tok_tensor = preprocess(toks, tokenizer)
        pred_toks = predict(model, tokenizer, tok_tensor, seg_tensor, selection_strategy)
        toks[step_n] = pred_toks[step_n]
    return toks

def masked_decoding(toks, seg_tensor, masks, model, tokenizer, selection_strategy):
    """ Decode from model by replacing masks """
    for step_n, mask_id in enumerate(masks):
        tok_tensor = preprocess(toks, tokenizer)
        pred_toks = predict(model, tokenizer, tok_tensor, seg_tensor, selection_strategy)
        toks[mask_id] = pred_toks[mask_id]
    return toks

  
def detokenize(pred_toks):
    """ Return the detokenized lyric prediction """
    new_sent = []
    for i, tok in enumerate(pred_toks):
        if tok.startswith("##"):
            new_sent[len(new_sent) - 1] = new_sent[len(new_sent) - 1] + tok[2:]
        else:
            new_sent.append(tok)
    return new_sent
  
def main():

    tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
    model = load_model('bert-large-uncased')

    toks, seg_tensor, mask_ids = get_seed_sent("[CLS] Sing with me, Sing for the years [SEP] [MASK] [MASK] [MASK] [MASK] , [MASK] [MASK] [MASK] tears. [SEP]",
                                               tokenizer,
                                               masking=None,
                                               n_append_mask=0)
    
    if torch.cuda.is_available():
        seg_tensor = seg_tensor.to('cuda')
        model = model.to('cuda')
        
#     pred_toks = sequential_decoding(toks, seg_tensor, model, tokenizer, "argmax")
    pred_toks = masked_decoding(toks, seg_tensor, mask_ids, model, tokenizer, "argmax")
    
    pred_lyric = detokenize(pred_toks)
    
    print("\nFinal: %s" % (" ".join(pred_lyric)),"\n")


if __name__ == '__main__':
    main()


Final: [CLS] sing with me , sing for the years [SEP] sing for the years , cry no more tears . [SEP] 

