In [1]:
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
import logging
logging.basicConfig(level=logging.INFO)

from typing import List

In [2]:
model_size = 'base'
model_case = 'uncased'

In [3]:
class lilBERT:
    
    def __init__(self, size: str = 'base', case: str = 'uncased', 
                 cuda: bool = torch.cuda.is_available()):
        # TODO: Load fine-tuned model
        self.toke = BertTokenizer.from_pretrained(f'bert-{size}-{case}')
        self.BERT = BertForMaskedLM.from_pretrained(f'bert-{size}-{case}')
        self.BERT.eval()
        
        if cuda:
            self.BERT.to('cuda')
        self.cuda = cuda

    def predict_words(self, input_str: str, k: int = 5) -> List[List[str]]:

        # Prep input
        input_str = input_str.split('\n')

        text = ' [CLS] '
        for line in input_str:
            text = text + line + ' [SEP] '

        texttoke = self.toke.tokenize(text)
        tokeinds = self.toke.convert_tokens_to_ids(texttoke)
        seg_ids = [0 for _ in range(len(tokeinds))]
        assert len(tokeinds) == len(seg_ids), \
            f'n Token missmatch {len(tokeinds)} {len(seg_ids)}'
        
        if '[MASK]' not in text:
            print('No [MASK] tokens found. Rhyme is complete')
            return
        else:
            maskinds = [i for i, tok in enumerate(texttoke) if '[MASK]' in tok]
        
        # Run through lil' BERT
        tttensor = torch.tensor([tokeinds])
        segtensor = torch.tensor([seg_ids])
        
        if self.cuda:
            tttensor.to('cuda')
            segtensor.to('cuda')
        
        with torch.no_grad():
            predictions = self.BERT(tttensor, segtensor)

        # Suggestions to text
        proposals = list()
        for m in maskinds:
            propo = torch.argsort(predictions[0, m], descending=True)[:k]
            proposal = self.toke.convert_ids_to_tokens(propo.tolist())
            proposals.append(proposal)
        
        return proposals


In [4]:
lB = lilBERT(size=model_size, case=model_case)

INFO:pytorch_pretrained_bert.tokenization:loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /Users/ljferrer/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
INFO:pytorch_pretrained_bert.modeling:loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at /Users/ljferrer/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
INFO:pytorch_pretrained_bert.modeling:extracting archive file /Users/ljferrer/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /var/folders/vh/3r7px0q57m311fh086mh0zxh0000gn/T/tmpo0qjv2bp
INFO:pytorch_pretrained_bert.modeling:Mode

In [5]:
text = 'The cat is [MASK] fat . \n She sat [MASK] to [MASK] more food . \n She is always [MASK] .'

In [6]:
lB.predict_words(text)

[['too', 'getting', 'very', 'always', 'so'],
 ['down', 'up', 'back', 'there', 'around'],
 ['get', 'eat', 'find', 'make', 'grab'],
 ['hungry', 'fat', 'eating', 'starving', 'happy']]