In [2]:
#!pip3 install pytorch_pretrained_bert
#!pip3 install gpt2

In [10]:
from allennlp.predictors import Predictor
import numpy as np
from pytorch_pretrained_bert.tokenization_gpt2 import GPT2Tokenizer
from pytorch_pretrained_bert.modeling_gpt2 import GPT2LMHeadModel
import torch

In [29]:
SMALL_MODEL = 'gpt2'
MEDIUM_MODEL = 'https://storage.googleapis.com/allennlp/models/gpt2-345M-dump'
#MEDIUM_MODEL = 'https://drive.google.com/file/d/1hp21DmAoeq6tKoUGLEK8NtPRJVWdz_dH/view?usp=sharing'

class Gpt2Predictor(Predictor):
    """
    The HuggingFace implementation of GPT-2 is not an AllenNLP model;
    however, our demo only expects an AllenNLP ``Predictor``. Accordingly,
    we implement a ``Predictor`` that wraps the HuggingFace GPT-2 implementation.
    """
    def __init__(self,
                 model_name: str = MEDIUM_MODEL,
                 cache_size: int = 0) -> None:
        """
        Each cache element is about 8MB, so size accordingly.
        """
        # Cache stores tuples, so default value is a tuple
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        self.model = GPT2LMHeadModel.from_pretrained(model_name)

        # The end of text marker.
        self.END_OF_TEXT = self.tokenizer.encoder["<|endoftext|>"]


    def predict_json(self, inputs: dict) -> dict:
        previous_str = inputs["previous"]
        next_str = inputs.get("next")
        topk = inputs.get("topk", 10)

        logits = self._predict(previous_str, next_str)
        probabilities = torch.nn.functional.softmax(logits)

        best_logits, best_indices = logits.topk(topk)
        best_words = [self.tokenizer.decode([idx.item()])
                      for idx in best_indices]
        best_probabilities = probabilities[best_indices].tolist()

        return {
            "logits": best_logits.tolist(),
            "probabilities": best_probabilities,
            "words": best_words,
            "output": previous_str + (next_str or "")
        }

    def _predict(self, previous: str, next: str = None) -> torch.Tensor:

        past_logits, past = (None, None)

        # CASE 1: Previously seen input, no next
        if next is None and past is not None:
            return past_logits

        # CASE 2: Previously seen input, yes next
        elif past is not None:
            token_ids = self.tokenizer.encode(next)
        # CASE 3: Brand new input, no next
        elif next is None:
            token_ids = self.tokenizer.encode(previous)
        # CASE 4: Brand new input, yes next
        else:
            token_ids = self.tokenizer.encode(previous) + self.tokenizer.encode(next)

        inputs = torch.LongTensor([token_ids])

        logits, present = self.model(inputs, past=past)
        logits = logits[0, -1]

        key = previous if next is None else previous + next

        return logits

    def __getitem__(self, index: int) -> str:
        return self.tokenizer.decode([index])

In [30]:
#import gpt2
predictor = Gpt2Predictor()

In [31]:
def predict_text(text_begin, num_of_tokens=100):
    result = text_begin[:]
    for token_id in range(num_of_tokens):
        model_out = predictor.predict_json({"previous": result})
        next_token = np.random.choice(model_out['words'], p=model_out['probabilities'] / np.sum(model_out['probabilities']))
        if next_token == '\n':
            break
        result += next_token
    return result

In [32]:
predict_text("Hi!")

  probabilities = torch.nn.functional.softmax(logits)


"Hi! I am so glad you liked it. I have decided that my first novel will be a horror story called 'The Black Hole', about a girl named Tiana that lives in a black hole and her attempts to free herself and the other people who live there. It will also have a lot of supernatural elements which I have not really been able to get right yet. But hopefully I will soon!"

In [7]:

print(result)

{'logits': [-54.29852294921875, -55.363868713378906, -55.39118957519531, -55.43867492675781, -56.065513610839844, -56.481689453125, -56.60101318359375, -56.847679138183594, -56.959693908691406, -57.21900177001953], 'probabilities': [0.315976083278656, 0.10888809710741043, 0.10595345497131348, 0.10103980451822281, 0.053983356803655624, 0.035605497658252716, 0.03160060569643974, 0.02469276450574398, 0.02207609824836254, 0.017033597454428673], 'words': [',', ' with', ' (', ' and', ' at', ' in', '.', ' over', ' by', ' against'], 'output': 'Toronto Raptors, who are currently tied for the league leader in wins'}
