In [None]:
# !pip3 install gpt-2-simple

In [None]:
from gpt_2_simple.src import model as gpt2_model, encoder
import json

In [None]:
params = '117m-hparams.json'

In [None]:
hparams = gpt2_model.default_hparams()
with open(params) as f:
    hparams.override_from_dict(json.load(f))

with open('encoder.json', 'r') as f:
    en = json.load(f)
with open('vocab.bpe', 'r', encoding="utf-8") as f:
    bpe_data = f.read()
    
bpe_merges = [
    tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]
]
enc_malay = encoder.Encoder(encoder=en, bpe_merges=bpe_merges)

In [None]:
import tensorflow as tf

def top_k_logits(logits, k):

    def _top_k():
        values, _ = tf.nn.top_k(logits, k=k)
        min_values = values[:, -1, tf.newaxis]
        return tf.where(
            logits < min_values,
            tf.ones_like(logits, dtype=logits.dtype) * -1e10,
            logits,
        )

    return tf.cond(
        pred=tf.equal(k, 0),
        true_fn=lambda: logits,
        false_fn=lambda: _top_k(),
    )


def top_p_logits(logits, p):
    with tf.variable_scope('top_p_logits'):
        logits_sort = tf.sort(logits, direction='DESCENDING')
        probs_sort = tf.nn.softmax(logits_sort)
        probs_sums = tf.cumsum(probs_sort, axis=1, exclusive=True)
        logits_masked = tf.where(
            probs_sums < p, logits_sort, tf.ones_like(logits_sort) * 1000
        )
        min_logits = tf.reduce_min(
            input_tensor=logits_masked, axis=1, keepdims=True
        )
        return tf.where(
            logits < min_logits,
            tf.ones_like(logits, dtype=logits.dtype) * -1e10,
            logits,
        )


def sample_sequence(
    hparams,
    length,
    start_token=None,
    batch_size=None,
    context=None,
    temperature=1,
    top_k=0,
    top_p=0.0,
):
    if start_token is None:
        assert (
            context is not None
        ), 'Specify exactly one of start_token and context!'
    else:
        assert (
            context is None
        ), 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)

    def step(hparams, tokens, past=None):
        lm_output = gpt2_model.model(
            hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE
        )

        logits = lm_output['logits'][:, :, : hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(
            gpt2_model.past_shape(hparams=hparams, batch_size=None)
        )
        return {'logits': logits, 'presents': presents}

    with tf.name_scope('sample_sequence'):
        lens = tf.constant(0, dtype=tf.int32)
        context_output = step(hparams, context[:, :-1])
        
        def apply_temp(logits_BxN, temperature):
            logits_shape = tf.shape(logits_BxN)
            uniform_noise_BxN = tf.random_uniform(logits_shape)
            logits_BxN += -tf.log(-tf.log(uniform_noise_BxN)) * temperature
            return logits_BxN

        def body(past, prev, output, lens):
            next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
            logits = next_outputs['logits'][:, -1, :]  
            logits = tf.cond(
                temperature > 0,
                lambda: apply_temp(logits, temperature),
                lambda: logits,
            )
            logits = tf.cond(top_p > 0.0, lambda: top_p_logits(logits, p=top_p),
                             lambda: top_k_logits(logits, k=top_k))
            samples = tf.random.categorical(
                logits, num_samples=1, dtype=tf.int32
            )
            return [
                tf.concat([past, next_outputs['presents']], axis=-2),
                tf.squeeze(samples, axis=[1]),
                tf.concat([output, samples], axis=1),
                lens + 1
            ]

        def cond(past, prev, output, lens):
            return tf.less(lens, length)

        _, _, tokens, _ = tf.while_loop(
            cond=cond,
            body=body,
            loop_vars=[context_output['presents'], context[:, -1], context, lens],
            shape_invariants=[
                tf.TensorShape(
                    gpt2_model.past_shape(
                        hparams=hparams, batch_size=None
                    )
                ),
                tf.TensorShape([None]),
                tf.TensorShape([None, None]),
                lens.get_shape(),
            ],
            back_prop=False,
        )

        return tokens

In [None]:
class Model:
    def __init__(
        self, hparams, encoder, **kwargs
    ):
        self._encoder = encoder
        self._X = tf.placeholder(tf.int32, [1, None], name = 'X')
        self._temperature = tf.placeholder(tf.float32, None, name = 'temp')
        self._top_k = tf.placeholder(tf.int32, None, name = 'top_k')
        self._top_p = tf.placeholder(tf.float32, None, name = 'top_p')
        self._maxlen = tf.placeholder(tf.int32, None, name = 'maxlen')
        self._n_samples = tf.placeholder(tf.int32, None, name = 'n_samples')
        x = tf.tile(self._X, [self._n_samples, 1])
        self._model = sample_sequence(
            hparams=hparams,
            length=self._maxlen,
            context=x,
            batch_size=self._n_samples,
            temperature=self._temperature,
            top_k=self._top_k,
            top_p=self._top_p,
        )
        self.output = tf.identity(self._model, name = 'output')

In [None]:
model = Model(
    hparams, enc_malay
)

In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [None]:
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
saver = tf.train.Saver(var_list = var_list)
saver.restore(sess, 'gs://mesolitica-tpu-general/gpt2-117m/model.ckpt-435300')

In [None]:
string = 'mahathir dan najib razak sangat sayangkan anwar ibrahim'
encoded = enc_malay.encode(string)
len(encoded)

In [None]:
o = sess.run(model._model, feed_dict = {model._X: [encoded],
                                  model._temperature: 0.0,
                                  model._top_k: 0,
                                  model._top_p: 0.7,
                                  model._maxlen: 20,
                                  model._n_samples: 10})
o.shape

In [None]:
for i in range(o.shape[0]):
    print(i, enc_malay.decode(o[i]))

In [None]:
saver = tf.train.Saver()
saver.save(sess, 'gpt2-117m/model.ckpt')

In [None]:
strings = ','.join(
    [
        n.name
        for n in tf.get_default_graph().as_graph_def().node
        if ('Variable' in n.op
        or 'gather' in n.op.lower()
        or 'X' in n.name
        or 'temp' in n.name
        or 'top_' in n.name
        or 'maxlen' in n.name
        or 'n_samples' in n.name
        or 'output' in n.name)
        and 'adam' not in n.name
        and 'global_step' not in n.name
        and 'Assign' not in n.name
        and 'ReadVariableOp' not in n.name
        and 'Gather' not in n.name
    ]
)
strings.split(',')

In [None]:
def freeze_graph(model_dir, output_node_names):

    if not tf.io.gfile.exists(model_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            'directory: %s' % model_dir
        )

    checkpoint = tf.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + '/frozen_model.pb'
    clear_devices = True
    with tf.Session(graph = tf.Graph()) as sess:
        saver = tf.train.import_meta_graph(
            input_checkpoint + '.meta', clear_devices = clear_devices
        )
        saver.restore(sess, input_checkpoint)
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            tf.get_default_graph().as_graph_def(),
            output_node_names.split(','),
        )
        with tf.gfile.GFile(output_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        print('%d ops in the final graph.' % len(output_graph_def.node))

In [None]:
freeze_graph('gpt2-117m', strings)

In [None]:
def load_graph(frozen_graph_filename):
    with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
                
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)
        
    return graph

In [None]:
g = load_graph('gpt2-117m/frozen_model.pb')

In [None]:
input_nodes = ['X', 'temp', 'top_k', 'top_p', 'maxlen', 'n_samples']
output_nodes = ['output']
inputs = {n: g.get_tensor_by_name(f'import/{n}:0') for n in input_nodes}
outputs = {n: g.get_tensor_by_name(f'import/{n}:0') for n in output_nodes}
inputs, outputs

In [None]:
test_sess = tf.Session(graph = g)

In [None]:
o = test_sess.run(outputs['output'], feed_dict = {inputs['X']: [encoded],
                                  inputs['temp']: 0.0,
                                  inputs['top_k']: 40,
                                  inputs['top_p']: 0.0,
                                  inputs['maxlen']: 100,
                                  inputs['n_samples']: 1})
o.shape

In [None]:
print(enc_malay.decode(o[0]))

In [None]:
from tensorflow.tools.graph_transforms import TransformGraph

In [None]:
transforms = ['add_default_attributes',
             'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',
             'fold_batch_norms',
             'fold_old_batch_norms',
             'quantize_weights(fallback_min=-10, fallback_max=10)',
             'strip_unused_nodes',
             'sort_by_execution_order']

input_nodes = ['X', 'temp', 'top_k', 'top_p', 'maxlen', 'n_samples']
output_nodes = ['output']

pb = 'gpt2-117m/frozen_model.pb'

input_graph_def = tf.GraphDef()
with tf.gfile.FastGFile(pb, 'rb') as f:
    input_graph_def.ParseFromString(f.read())

transformed_graph_def = TransformGraph(input_graph_def, 
                                           input_nodes,
                                           output_nodes, transforms)
    
with tf.gfile.GFile(f'{pb}.quantized', 'wb') as f:
    f.write(transformed_graph_def.SerializeToString())

In [None]:
g = load_graph('gpt2-117m/frozen_model.pb.quantized')

In [None]:
input_nodes = ['X', 'temp', 'top_k', 'top_p', 'maxlen', 'n_samples']
output_nodes = ['output']
inputs = {n: g.get_tensor_by_name(f'import/{n}:0') for n in input_nodes}
outputs = {n: g.get_tensor_by_name(f'import/{n}:0') for n in output_nodes}
inputs, outputs

In [None]:
test_sess = tf.Session(graph = g)

In [None]:
# o = test_sess.run(outputs['output'], feed_dict = {inputs['X']: [encoded],
#                                   inputs['temp']: 0.0,
#                                   inputs['top_k']: 40,
#                                   inputs['top_p']: 0.0,
#                                   inputs['maxlen']: 100,
#                                   inputs['n_samples']: 1})
# o.shape

In [None]:
# print(enc_malay.decode(o[0]))

In [4]:
from b2sdk.v1 import *
info = InMemoryAccountInfo()
b2_api = B2Api(info)
b2_api.authorize_account("production", application_key_id, application_key)
file_info = {'how': 'good-file'}
b2_bucket = b2_api.get_bucket_by_name('malaya-model')

In [5]:
file = 'gpt2-117m/frozen_model.pb'
outPutname = 'gpt2/117M/model.pb'
b2_bucket.upload_local_file(
    local_file=file,
    file_name=outPutname,
    file_infos=file_info,
)

FileVersionInfo('4_zcde33cc461767caf742c0b11_f201775d542477cf3_d20210923_m090906_c000_v0001400_t0050', 'gpt2/117M/model.pb', 498708685, 'application/octet-stream', 'none', {'how': 'good-file'}, 1632388146000, <EncryptionSetting(EncryptionMode.NONE, None, None)>, <LegalHold.UNSET: None>, FileRetentionSetting(None, None), 1632388146000, None, None, None, 'upload', <b2sdk.v1.api.B2Api object at 0x7f26144e2dd8>)

In [6]:
file = 'gpt2-117m/frozen_model.pb.quantized'
outPutname = 'gpt2/117M-quantized/model.pb'
b2_bucket.upload_local_file(
    local_file=file,
    file_name=outPutname,
    file_infos=file_info,
)

FileVersionInfo('4_zcde33cc461767caf742c0b11_f202a1847f9337d3b_d20210923_m090925_c000_v0001079_t0000', 'gpt2/117M-quantized/model.pb', 125564697, 'application/octet-stream', 'none', {'how': 'good-file'}, 1632388165000, <EncryptionSetting(EncryptionMode.NONE, None, None)>, <LegalHold.UNSET: None>, FileRetentionSetting(None, None), 1632388165000, None, None, None, 'upload', <b2sdk.v1.api.B2Api object at 0x7f26144e2dd8>)