In [3]:
import sys
import os
import time
import json
import numpy as np
import tensorflow as tf
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def

os.environ["CUDA_VISIBLE_DEVICES"]="2"

In [4]:
sys.path.append(os.path.join(os.getcwd(), 'gpt-2/src'))

In [8]:
import model, sample

def export_for_serving(
    model_name='1558M',
    seed=None,
    batch_size=1,
    length=None,
    temperature=1,
    top_k=0,
    models_dir='gpt-2/models'
):
    """
    Export the model for TF Serving
    :model_name=124M : String, which model to use
    :seed=None : Integer seed for random number generators, fix seed to reproduce
     results
    :length=None : Number of tokens in generated text, if None (default), is
     determined by model hyperparameters
    :temperature=1 : Float value controlling randomness in boltzmann
     distribution. Lower temperature results in less random completions. As the
     temperature approaches zero, the model will become deterministic and
     repetitive. Higher temperature results in more random completions.
    :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
     considered for each step (token), resulting in deterministic completions,
     while 40 means 40 words are considered at each step. 0 (default) is a
     special setting meaning no restrictions. 40 generally is a good value.
     :models_dir : path to parent folder containing model subfolders
     (i.e. contains the <model_name> folder)
    """
    models_dir = os.path.expanduser(os.path.expandvars(models_dir))

    hparams = model.default_hparams()
    with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length is None:
        length = hparams.n_ctx
    elif length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

    with tf.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)

        output = sample.sample_sequence(
            hparams=hparams, length=length,
            context=context,
            batch_size=batch_size,
            temperature=temperature, top_k=top_k
        )

        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
        saver.restore(sess, ckpt)

        export_dir=os.path.join(models_dir, model_name, "export", str(time.time()).split('.')[0])
        if not os.path.isdir(export_dir):
            os.makedirs(export_dir)

        builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
        signature = predict_signature_def(inputs={'context': context},
        outputs={'sample': output})

        builder.add_meta_graph_and_variables(sess,
                                     [tf.saved_model.SERVING],
                                     signature_def_map={"predict": signature},
                                     strip_default_attrs=True)
        builder.save()

In [9]:
export_for_serving(top_k=40, length=512, model_name='1558M',seed=3107)

W1119 04:50:53.305637 139696230025024 deprecation_wrapper.py:119] From /workspace/Ankur_Workspace/gpt-2/gpt-2/src/sample.py:51: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.

W1119 04:50:53.307458 139696230025024 deprecation_wrapper.py:119] From /workspace/Ankur_Workspace/gpt-2/gpt-2/src/model.py:148: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.

W1119 04:50:53.317603 139696230025024 deprecation_wrapper.py:119] From /workspace/Ankur_Workspace/gpt-2/gpt-2/src/model.py:152: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.

W1119 04:50:53.383464 139696230025024 deprecation_wrapper.py:119] From /workspace/Ankur_Workspace/gpt-2/gpt-2/src/model.py:36: The name tf.rsqrt is deprecated. Please use tf.math.rsqrt instead.

W1119 04:51:05.529913 139696230025024 deprecation.py:323] From /workspace/Ankur_Workspace/gpt-2/gpt-2/src/sample.py:64: to_float (from tensorflow.python.ops.mat