In [1]:
import torch
import numpy as np
from text_utils import TextEncoder
from model_pytorch import *

In [2]:
class CustomLMModel(torch.nn.Module):
    """ Transformer with language model head only """
    def __init__(self, cfg, vocab=40990, n_ctx=512, return_probs=True,
                 encoder_path='./model/encoder_bpe_40000.json', bpe_path='./model/vocab_40000.bpe'):
        super(CustomLMModel, self).__init__()
        self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx)
        self.lm_head = LMHead(self.transformer, cfg, trunc_and_reshape=False)
        self.return_probs = return_probs
        self.text_encoder = TextEncoder(encoder_path,bpe_path)
        
        if self.return_probs:
            pos_emb_mask = torch.zeros(1, 1, vocab)
            pos_emb_mask[:, :, -n_ctx:] = -1e12
            self.register_buffer('pos_emb_mask', pos_emb_mask)


    def forward(self, x):
        h = self.transformer(x)
        lm_logits = self.lm_head(h)
        if self.return_probs:
            lm_logits = F.softmax(lm_logits + self.pos_emb_mask, dim=-1)
        return lm_logits

In [3]:
lm_model = torch.load('./trained_lm_model')
device = "cpu"

In [None]:
def make_batch(X, n_vocab):
    X = np.array(X)
    assert X.ndim in [1, 2]
    if X.ndim == 1:
        X = np.expand_dims(X, axis=0)
    pos_enc = np.arange(n_vocab, n_vocab + X.shape[-1])
    pos_enc = np.expand_dims(pos_enc, axis=0)
    batch = np.stack([X, pos_enc], axis=-1)
    batch = torch.tensor(batch, dtype=torch.long).to(device)
    return batch

def append_batch(X, next_idx):
    next_pos = X[:, -1:, 1] + 1
    next_x = torch.cat((next_idx, next_pos), -1).unsqueeze(1)
    return torch.cat((X, next_x), 1)

In [4]:
def predict_next_word(text, gen_len=20, topk=10):
    n_vocab = len(lm_model.text_encoder.encoder)
    encoded_text = lm_model.text_encoder.encode([text,])
    encoded_text = make_batch(encoded_text, n_vocab)
    
    for _ in range(gen_len):
        lm_probs = lm_model(encoded_text)
        values, indices = lm_probs[:, -1, :].topk(topk)
        next_idx = indices.gather(-1, torch.multinomial(values, 1))
        next_token = lm_model.text_encoder.decoder[next_idx.item()].replace('</w>', '')
        print(next_token, end=' ')
        encoded_text = append_batch(encoded_text, next_idx)
        

In [6]:
predict_next_word('this model works')

                                                                                

, and i 'm sure i can help you with that as well . " 
 " what do you 