In [0]:
!git clone https://github.com/openai/gpt-2.git

In [0]:
!cd gpt-2 && sh download_model.sh 117M

In [0]:
!cd gpt-2 && pip3 install -r requirements.txt

In [0]:
!mv gpt-2/models models

In [0]:
import sys
import fire
import json
import os
import numpy as np
import tensorflow as tf
import textwrap

sys.path.insert(0, './gpt-2/src')
import model, sample, encoder, generate_unconditional_samples

In [0]:
def interact_model(
    model_name='117M',
    seed=None,
    nsamples=1,
    batch_size=None,
    length=None,
    temperature=1,
    top_k=0,
    raw_text='test',
):
    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0
    np.random.seed(seed)
    tf.set_random_seed(seed)

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length is None:
        length = hparams.n_ctx // 2
    elif length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

    with tf.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        output = sample.sample_sequence(
            hparams=hparams, length=length,
            context=context,
            batch_size=batch_size,
            temperature=temperature, top_k=top_k
        )

        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
        saver.restore(sess, ckpt)

        context_tokens = enc.encode(raw_text)
        generated = 0
        for _ in range(nsamples // batch_size):
            out = sess.run(output, feed_dict={
                context: [context_tokens for _ in range(batch_size)]
            })[:, len(context_tokens):]
            for i in range(batch_size):
                generated += 1
                text = enc.decode(out[i])
                print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
                print(textwrap.fill(text, 150))
        print("=" * 80)

### Generate text based on a conditioning text

In [0]:
# Text below from 'The Onion'
interact_model(
    length=90,
    nsamples=3,
    raw_text="NEW YORK—At 4:32 p.m. Tuesday, every single resident of New York City decided to evacuate the famed metropolis, having realized it was nothing more than a massive, trash-ridden hellhole that slowly sucks the life out of every one of its inhabitants.")