- [The Illustrated GPT-2 (Visualizing Transformer Language Models) – Jay Alammar – Visualizing machine learning one concept at a time](http://jalammar.github.io/illustrated-gpt2/)
- [openai/gpt-2: Code for the paper "Language Models are Unsupervised Multitask Learners"](https://github.com/openai/gpt-2)
- [Morizeyao/GPT2-Chinese: Chinese version of GPT2 training code, using BERT tokenizer.](https://github.com/Morizeyao/GPT2-Chinese)

## Encoder

主要目的是对 text 利用 unicode 进行编码和解码。

In [16]:
import os
import json
import regex as re
from functools import lru_cache

In [17]:
# byte: unicode
@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))

In [18]:
# Bi-gram
def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

In [19]:
class Encoder:
    def __init__(self, encoder, bpe_merges, errors='replace'):
        self.encoder = encoder
        self.decoder = {v:k for k,v in self.encoder.items()}
        self.errors = errors # how to handle errors in decoding
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v:k for k,v in self.byte_encoder.items()}
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        self.cache = {}

        # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token)
        pairs = get_pairs(word)

        if not pairs:
            return token

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
        return text

In [20]:
def get_encoder(model_name, models_dir):
    with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
        encoder = json.load(f)
    with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
        bpe_data = f.read()
    bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
    return Encoder(
        encoder=encoder,
        bpe_merges=bpe_merges,
    )

### Details

In [220]:
models_dir = "/Users/HaoShaochun/Documents/Study/gpt-2/models/"
model_name = "124m"

In [221]:
with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
    encoder = json.load(f)

In [222]:
with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
    bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]

In [223]:
encoder = encoder
decoder = {v:k for k,v in encoder.items()}
errors = 'replace'
byte_encoder = bytes_to_unicode()
byte_decoder = {v:k for k,v in byte_encoder.items()}
bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
cache = {}
pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

#
# \s+ 匹配任意空格
# \s+(?!\S) 匹配末尾的空格（空格前面有非空格时不匹配）

# [^\s\p{L}\p{N}]+ 匹配不含空格、letter 和 number 的，比如：# * 之类
# \p{L}+ 匹配任意的 letter，比如 abc
# \p{N}+ 匹配任意的 number，比如 123
#

In [227]:
def bpe(token):
    if token in cache:
        return cache[token]
    word = tuple(token)
    pairs = get_pairs(word)

    if not pairs:
        return token

    while True:
        bigram = min(pairs, key = lambda pair: bpe_ranks.get(pair, float('inf')))
        if bigram not in bpe_ranks:
            break
        first, second = bigram
        new_word = []
        i = 0
        while i < len(word):
            try:
                j = word.index(first, i)
                new_word.extend(word[i:j])
                i = j
            except:
                new_word.extend(word[i:])
                break

            if word[i] == first and i < len(word)-1 and word[i+1] == second:
                new_word.append(first+second)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_word = tuple(new_word)
        word = new_word
        if len(word) == 1:
            break
        else:
            pairs = get_pairs(word)
    word = ' '.join(word)
    cache[token] = word
    return word

In [245]:
# encode
text = "I'm loving U."
bpe_tokens = []
tokens = []
uni_tokens = []
for token in re.findall(pat, text):
    # byte -> token
    tokens.append(token)
    token = ''.join(byte_encoder[b] for b in token.encode('utf-8'))
    uni_tokens.append(token)
    tmp = [encoder[bpe_token] for bpe_token in bpe(token).split(' ')]
    print("token: ", token, tmp)
    bpe_tokens.extend(tmp)

token:  I [40]
token:  'm [1101]
token:  Ġloving [14442]
token:  ĠU [471]
token:  . [13]


In [246]:
uni_tokens

['I', "'m", 'Ġloving', 'ĠU', '.']

In [247]:
tokens

['I', "'m", ' loving', ' U', '.']

In [236]:
bpe_tokens

[40, 1101, 14442, 471, 13]

In [235]:
''.join(byte_encoder[b] for b in "i lov".encode('utf-8'))

'iĠlov'

In [255]:
bpe_tokens

[40, 1101, 14442, 471, 13]

In [246]:
re.findall(pat, text)

['I', "'m", ' loving', ' U', '.']

In [247]:
for i in "loving".encode('utf-8'):
    print(i)

108
111
118
105
110
103


In [253]:
bpe("'m").split(" ")

["'m"]

In [268]:
# decode

text = ''.join([decoder[token] for token in bpe_tokens])
text

"I'mĠlovingĠU."

In [269]:
decoder[1101]

"'m"

In [271]:
for c in text:
    print(c, byte_decoder[c])

I 73
' 39
m 109
Ġ 32
l 108
o 111
v 118
i 105
n 110
g 103
Ġ 32
U 85
. 46


In [263]:
text = bytearray([byte_decoder[c] for c in text]).decode('utf-8', errors=errors)
text

"I'm loving U."

In [275]:
bytearray([73, 39, 109]).decode('utf8')

"I'm"

## Model

In [21]:
import numpy as np
import tensorflow as tf
from dataclasses import dataclass

In [22]:
@dataclass
class HParams:
    n_vocab:int=50257
    n_ctx:int=1024
    n_embd:int=768
    n_head:int=12
    n_layer:int=12

In [23]:
def default_hparams():
    return HParams()

In [24]:
def model(hparams, X, past=None, scope='model', reuse=False):
    with tf.compat.v1.variable_scope(scope, reuse=reuse):
        results = {}
        batch, sequence = shape_list(X)

        wpe = tf.compat.v1.get_variable('wpe', [hparams.n_ctx, hparams.n_embd],
                                        initializer=tf.random_normal_initializer(stddev=0.01))
        wte = tf.compat.v1.get_variable('wte', [hparams.n_vocab, hparams.n_embd],
                                        initializer=tf.random_normal_initializer(stddev=0.02))
        past_length = 0 if past is None else tf.shape(past)[-2]
        h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))

        # Transformer
        presents = []
        pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer
        assert len(pasts) == hparams.n_layer
        for layer, past in enumerate(pasts):
            h, present = block(h, 'h%d' % layer, past=past, hparams=hparams)
            presents.append(present)
        results['present'] = tf.stack(presents, axis=1)
        h = norm(h, 'ln_f')

        # Language model loss.  Do tokens <n predict token n?
        h_flat = tf.reshape(h, [batch*sequence, hparams.n_embd])
        logits = tf.matmul(h_flat, wte, transpose_b=True)
        logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])
        results['logits'] = logits
        return results

### 输入

In [25]:
def shape_list(x):
    """Deal with dynamic shape in tensorflow cleanly."""
    static = x.shape.as_list()
    dynamic = tf.shape(x)
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]

In [26]:
def positions_for(tokens, past_length):
    batch_size = tf.shape(tokens)[0]
    nsteps = tf.shape(tokens)[1]
    return expand_tile(past_length + tf.range(nsteps), batch_size)

In [27]:
def expand_tile(value, size):
    """Add a new axis of given size."""
    value = tf.convert_to_tensor(value, name='value')
    ndims = value.shape.ndims
    return tf.tile(tf.expand_dims(value, axis=0), [size] + [1]*ndims)

### Details


In [158]:
batch_size = 1
start_token = None
X = tf.fill([batch_size, 1], enc.encoder['<|endoftext|>'])
hparams = default_hparams()

In [159]:
X

<tf.Tensor 'Fill_5:0' shape=(1, 1) dtype=int32>

In [160]:
positions_for(X, 5)

<tf.Tensor 'Tile:0' shape=(1, 1) dtype=int32>

In [257]:
tf.shape(X)[0], tf.shape(X)[1]

(<tf.Tensor: shape=(), dtype=int32, numpy=1>,
 <tf.Tensor: shape=(), dtype=int32, numpy=1>)

In [272]:
wte

<tf.Variable 'model/wte:0' shape=(50257, 768) dtype=float32, numpy=
array([[ 0.01648771, -0.0164929 ,  0.02838891, ..., -0.00645825,
        -0.00522234, -0.0007867 ],
       [-0.01100282, -0.05329556, -0.01903998, ..., -0.04945637,
         0.0188335 ,  0.00340233],
       [-0.00040727, -0.00531276,  0.01675709, ..., -0.0118375 ,
         0.01513721,  0.01788099],
       ...,
       [-0.00225281, -0.00350592, -0.02709479, ...,  0.00464215,
         0.01216319, -0.00408764],
       [ 0.01639596,  0.01235129, -0.00731677, ..., -0.00058879,
         0.0012605 ,  0.03780697],
       [-0.01169358,  0.02725702,  0.01802673, ..., -0.0029963 ,
         0.00533693,  0.01667264]], dtype=float32)>

In [None]:
past = None
with tf.compat.v1.variable_scope("model", reuse=False):
    results = {}
    batch, sequence = shape_list(X)

    wpe = tf.compat.v1.get_variable('wpe', [hparams.n_ctx, hparams.n_embd],
                         initializer=tf.random_normal_initializer(stddev=0.01))
    wte = tf.compat.v1.get_variable('wte', [hparams.n_vocab, hparams.n_embd],
                         initializer=tf.random_normal_initializer(stddev=0.02))
    past_length = 0 if past is None else tf.shape(past)[-2]
    h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))

In [525]:
h.shape

TensorShape([1, 1, 768])

### Transformer

In [28]:
def softmax(x, axis=-1):
    x = x - tf.reduce_max(x, axis=axis, keepdims=True)
    ex = tf.exp(x)
    return ex / tf.reduce_sum(ex, axis=axis, keepdims=True)

def gelu(x):
    return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3))))

def norm(x, scope, *, axis=-1, epsilon=1e-5):
    """Normalize to mean = 0, std = 1, then do a diagonal affine transform."""
    with tf.compat.v1.variable_scope(scope):
        n_state = x.shape[-1]#.value
        g = tf.compat.v1.get_variable('g', [n_state], initializer=tf.constant_initializer(1))
        b = tf.compat.v1.get_variable('b', [n_state], initializer=tf.constant_initializer(0))
        u = tf.reduce_mean(x, axis=axis, keepdims=True)
        s = tf.reduce_mean(tf.square(x-u), axis=axis, keepdims=True)
        x = (x - u) * tf.compat.v1.rsqrt(s + epsilon)
        x = x*g + b
        return x

def split_states(x, n):
    """Reshape the last dimension of x into [n, x.shape[-1]/n]."""
    *start, m = shape_list(x)
    return tf.reshape(x, start + [n, m//n])

def merge_states(x):
    """Smash the last two dimensions of x into a single dimension."""
    *start, a, b = shape_list(x)
    return tf.reshape(x, start + [a*b])

def conv1d(x, scope, nf, *, w_init_stdev=0.02):
    with tf.compat.v1.variable_scope(scope):
        *start, nx = shape_list(x)
        w = tf.compat.v1.get_variable('w', [1, nx, nf], 
                                      initializer=tf.random_normal_initializer(stddev=w_init_stdev))
        b = tf.compat.v1.get_variable('b', [nf], 
                                      initializer=tf.constant_initializer(0))
        c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf])
        return c

def attention_mask(nd, ns, *, dtype):
    """1's in the lower triangle, counting from the lower right corner.

    Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
    """
    i = tf.range(nd)[:,None]
    j = tf.range(ns)
    m = i >= j - ns + nd
    return tf.cast(m, dtype)


def split_heads(x):
    # From [batch, sequence, features] to [batch, heads, sequence, features]
    return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3])

def multihead_attn(q, k, v):
    # q, k, v have shape [batch, heads, sequence, features]
    w = tf.matmul(q, k, transpose_b=True)
    w = w * tf.compat.v1.rsqrt(tf.cast(v.shape[-1], w.dtype))
    
    w = mask_attn_weights(w)
    w = softmax(w)
    a = tf.matmul(w, v)
    return a

def mask_attn_weights(w):
    # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
    _, _, nd, ns = shape_list(w)
    b = attention_mask(nd, ns, dtype=w.dtype)
    b = tf.reshape(b, [1, 1, nd, ns])
    w = w*b - tf.cast(1e10, w.dtype)*(1-b)
    return w

def merge_heads(x):
    # Reverse of split_heads
    return merge_states(tf.transpose(x, [0, 2, 1, 3]))

def attn(x, scope, n_state, *, past, hparams):
    assert x.shape.ndims == 3  # Should be [batch, sequence, features]
    assert n_state % hparams.n_head == 0
    if past is not None:
        assert past.shape.ndims == 5  # Should be [batch, 2, heads, sequence, features], where 2 is [k, v]

    def split_heads(x):
        # From [batch, sequence, features] to [batch, heads, sequence, features]
        return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3])

    def merge_heads(x):
        # Reverse of split_heads
        return merge_states(tf.transpose(x, [0, 2, 1, 3]))

    def mask_attn_weights(w):
        # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
        _, _, nd, ns = shape_list(w)
        b = attention_mask(nd, ns, dtype=w.dtype)
        b = tf.reshape(b, [1, 1, nd, ns])
        w = w*b - tf.cast(1e10, w.dtype)*(1-b)
        return w

    def multihead_attn(q, k, v):
        # q, k, v have shape [batch, heads, sequence, features]
        w = tf.matmul(q, k, transpose_b=True)
        w = w * tf.compat.v1.rsqrt(tf.cast(v.shape[-1], w.dtype))
        w = mask_attn_weights(w)
        w = softmax(w)
        a = tf.matmul(w, v)
        return a

    with tf.compat.v1.variable_scope(scope):
        c = conv1d(x, 'c_attn', n_state*3)
        q, k, v = map(split_heads, tf.split(c, 3, axis=2))
        present = tf.stack([k, v], axis=1)
        if past is not None:
            pk, pv = tf.unstack(past, axis=1)
            k = tf.concat([pk, k], axis=-2)
            v = tf.concat([pv, v], axis=-2)
        a = multihead_attn(q, k, v)
        a = merge_heads(a)
        a = conv1d(a, 'c_proj', n_state)
        return a, present


def mlp(x, scope, n_state, *, hparams):
    with tf.compat.v1.variable_scope(scope):
        nx = x.shape[-1]#.value
        h = gelu(conv1d(x, 'c_fc', n_state))
        h2 = conv1d(h, 'c_proj', nx)
        return h2


def block(x, scope, *, past, hparams):
    with tf.compat.v1.variable_scope(scope):
        nx = x.shape[-1]#.value
        a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams)
        x = x + a
        m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams)
        x = x + m
        return x, present

### Details

In [526]:
presents = []
pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer
assert len(pasts) == hparams.n_layer

In [540]:
x = norm(h, 'ln_1')

In [541]:
c = conv1d(x, 'c_attn', 768*3)

In [542]:
q, k, v = map(split_heads, tf.split(c, 3, axis=2))

In [565]:
x = multihead_attn(q, k, v)

In [567]:
x = merge_heads(x)

In [584]:
a, present = attn(norm(h, 'ln_1'), 'attn', nx, past=None, hparams=hparams)

In [586]:
x = h + a

In [589]:
x = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams)

In [590]:
h = norm(x, 'ln_f')

In [134]:
hparams

HParams(n_vocab=50257, n_ctx=1024, n_embd=768, n_head=12, n_layer=12)

In [602]:
past = None
with tf.compat.v1.variable_scope("model", reuse=False):
    results = {}
    batch, sequence = shape_list(X)

    wpe = tf.compat.v1.get_variable('wpe', [hparams.n_ctx, hparams.n_embd],
                         initializer=tf.random_normal_initializer(stddev=0.01))
    wte = tf.compat.v1.get_variable('wte', [hparams.n_vocab, hparams.n_embd],
                         initializer=tf.random_normal_initializer(stddev=0.02))
    past_length = 0 if past is None else tf.shape(past)[-2]
    h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))
    
    presents = []
    pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer
    assert len(pasts) == hparams.n_layer
    for layer, past in enumerate(pasts):
        h, present = block(h, 'h%d' % layer, past=past, hparams=hparams)
        presents.append(present)
    
    results['present'] = tf.stack(presents, axis=1)
    h = norm(h, 'ln_f')
    
    
    h_flat = tf.reshape(h, [batch*sequence, hparams.n_embd])
    logits = tf.matmul(h_flat, wte, transpose_b=True)
    logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])
    results['logits'] = logits

## Generate

In [29]:
def past_shape(*, hparams, batch_size=None, sequence=None):
    return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head]

In [30]:
def top_k_logits(logits, k):
    if k == 0:
        # no truncation
        return logits

    def _top_k():
        values, _ = tf.nn.top_k(logits, k=k)
        min_values = values[:, -1, tf.newaxis]
        return tf.where(
            logits < min_values,
            tf.ones_like(logits, dtype=logits.dtype) * -1e10,
            logits,
        )
    return tf.cond(
       tf.equal(k, 0),
       lambda: logits,
       lambda: _top_k(),
    )

In [58]:
def top_p_logits(logits, p):
    """Nucleus sampling"""
    batch, vocab_size = logits.shape.as_list()
    sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1)
    cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
    indices = tf.stack([
        tf.range(0, batch),
        # number of indices to include
        tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0),
    ], axis=-1)
    min_values = tf.gather_nd(sorted_logits, indices)
    min_values = tf.broadcast_to(
        tf.expand_dims(min_values, 1), [batch, vocab_size]
    )
    return tf.where(
        logits < min_values,
        tf.ones_like(logits) * -1e10,
        logits,
    )

In [77]:
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, 
                    temperature=1, top_k=0, top_p=1):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)
    
    print(start_token)
    print(context)
    def step(hparams, tokens, past=None):
        lm_output = model(hparams=hparams, X=tokens, past=past, reuse=tf.compat.v1.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.compat.v1.name_scope('sample_sequence'):
        def body(past, prev, output):
            next_outputs = step(hparams, prev, past=past)
            logits = next_outputs['logits'][:, -1, :]  / tf.cast(temperature, dtype=tf.float32)
            logits = top_k_logits(logits, k=top_k)
            logits = top_p_logits(logits, p=top_p)
            samples = tf.random.categorical(logits, num_samples=1, dtype=tf.int32)
            return [
                next_outputs['presents'] if past is None else tf.concat(
                    [past, next_outputs['presents']], axis=-2),
                samples,
                tf.concat([output, samples], axis=1)
            ]

        past, prev, output = body(None, context, context)

        def cond(*args):
            return True

        _, _, tokens = tf.while_loop(
            cond=cond, body=body,
            maximum_iterations=length - 1,
            loop_vars=[
                past,
                prev,
                output
            ],
            shape_invariants=[
                tf.TensorShape(past_shape(hparams=hparams, batch_size=batch_size)),
                tf.TensorShape([batch_size, None]),
                tf.TensorShape([batch_size, None]),
            ],
            back_prop=False,
        )

        return tokens

### Details

In [506]:
samples = tf.random.categorical(nxt_logits, num_samples=1, dtype=tf.int32)

In [507]:
samples

<tf.Tensor: shape=(8, 1), dtype=int32, numpy=
array([[39512],
       [24640],
       [ 5551],
       [31840],
       [30685],
       [16918],
       [ 1829],
       [20344]], dtype=int32)>

In [508]:
enc.decode([seq[0] for seq in samples.numpy().tolist()])

' advantageous CVEporaryallah enclosure grocery StatesDist'

In [83]:
def step(hparams, tokens, past=None):
    lm_output = model(hparams=hparams, X=tokens, past=past, reuse=tf.compat.v1.AUTO_REUSE)

    logits = lm_output['logits'][:, :, :hparams.n_vocab]
    presents = lm_output['present']
    presents.set_shape(past_shape(hparams=hparams, batch_size=1))
    return {
        'logits': logits,
        'presents': presents,
    }

In [144]:
def body(past, prev, output):
    next_outputs = step(hparams, prev, past=past)
    logits = next_outputs['logits'][:, -1, :]  / tf.cast(1, dtype=tf.float32)
    logits = top_k_logits(logits, k=0)
    logits = top_p_logits(logits, p=1)
    samples = tf.random.categorical(logits, num_samples=1, dtype=tf.int32)
    return [
        next_outputs['presents'] if past is None else tf.concat(
            [past, next_outputs['presents']], axis=-2),
        samples,
        tf.concat([output, samples], axis=1)
    ]

In [145]:
context = tf.fill([1, 1], 50256)

In [146]:
past, prev, output = body(None, context, context)

In [147]:
past

<tf.Tensor 'model_6/stack:0' shape=(1, 12, 2, 12, 1, 64) dtype=float32>

In [148]:
prev

<tf.Tensor 'categorical_4/Multinomial:0' shape=(1, 1) dtype=int32>

In [149]:
output

<tf.Tensor 'concat_3:0' shape=(1, 2) dtype=int32>

In [154]:
actual_past = tf.unstack(past, axis=1)[0]

In [155]:
pk, pv = tf.unstack(actual_past, axis=1)

In [156]:
pk

<tf.Tensor 'unstack_4:0' shape=(1, 12, 1, 64) dtype=float32>

In [157]:
pv

<tf.Tensor 'unstack_4:1' shape=(1, 12, 1, 64) dtype=float32>

In [167]:
q = pk

In [168]:
q

<tf.Tensor 'unstack_4:0' shape=(1, 12, 1, 64) dtype=float32>

In [162]:
k = tf.concat([pk, pv], axis=-2)

In [164]:
v = tf.concat([pv, pv], axis=-2)

In [163]:
k

<tf.Tensor 'concat_4:0' shape=(1, 12, 2, 64) dtype=float32>

In [165]:
v

<tf.Tensor 'concat_5:0' shape=(1, 12, 2, 64) dtype=float32>

In [169]:
w = tf.matmul(q, k, transpose_b=True)
w

<tf.Tensor 'MatMul:0' shape=(1, 12, 1, 2) dtype=float32>

In [166]:
pk

<tf.Tensor 'unstack_4:0' shape=(1, 12, 1, 64) dtype=float32>

In [170]:
w = w * tf.compat.v1.rsqrt(tf.cast(v.shape[-1], w.dtype))
w

<tf.Tensor 'mul_5:0' shape=(1, 12, 1, 2) dtype=float32>

In [171]:
_, _, nd, ns = shape_list(w)

In [172]:
nd, ns

(1, 2)

In [212]:
with tf.Session() as sess1:
    nd = 2
    ns = 1
    i = tf.range(nd)[:,None]
    j = tf.range(ns)
    m = i >= j - ns + nd
    print(i.eval())
    print(tf.range(1).eval())
    
    print("ns, nd: ", ns, nd)
    print("j: ", j.eval())
    print("j - ns + nd: ", (j -ns+nd).eval())
    print(m.eval())
    
    b = tf.cast([[True, True]], tf.int32)
    print("b before reshape: ", b.eval())
    
    b = tf.reshape(b, [1, 1, nd, ns])
    print("b after reshape: ", b.eval())
    
    cs = tf.cast(1e10, tf.float32)
    print(cs.eval())
    
    print((1-b).eval())

[[0]
 [1]]
[0]
ns, nd:  1 2
j:  [0]
j - ns + nd:  [1]
[[False]
 [ True]]
b before reshape:  [[1 1]]
b after reshape:  [[[[1]
   [1]]]]
10000000000.0
[[[[0]
   [0]]]]


In [199]:
b = tf.cast([[True, True]], tf.int32)

In [202]:
b = tf.reshape(b, [1, 1, nd, ns])
b

<tf.Tensor 'Reshape_2:0' shape=(1, 1, 1, 2) dtype=int32>

## Together

In [7]:
from tensorflow.core.protobuf import rewriter_config_pb2
from pathlib import Path
import gpt_2_simple as gpt2 # TF VERSION should < 2.0，如果用 2.0 模型读不进来

In [104]:
models_dir = "/Users/HaoShaochun/Documents/Study/gpt-2/models/"
model_name = "124M"
enc = get_encoder(model_name, models_dir)

In [105]:
sess = gpt2.start_tf_sess()

In [179]:
m

<tf.Tensor 'GreaterEqual_2:0' shape=(1, 2) dtype=bool>

In [117]:
sess = gpt2.reset_session(sess)

In [118]:
gpt2.load_gpt2(sess, model_name=model_name, model_dir=models_dir)

Loading pretrained model /Users/HaoShaochun/Documents/Study/gpt-2/models/124M/model.ckpt


In [313]:
output = sample_sequence(
        hparams=hparams, length=50,
        start_token=enc.encoder['<|endoftext|>'],
        batch_size=1,
        temperature=1, top_k=40, top_p=1)[:, 1:]

50256
Tensor("Fill_8:0", shape=(1, 1), dtype=int32)


In [79]:
out = sess.run(output)

In [293]:
text = enc.decode(out[0])
text

'157 Underscoring Content About Actland AD 28\n\n135 OTHER EVENTS.\n\n137 ARDS OF PUBLIC HISTORY.\n\n138 PRECEDENCE OF POSTMIEAL DISAPPEARANCE.\n\n140 OSF'

In [213]:
enc.encode("i love you")

[72, 1842, 345]

## Dataset

In [269]:
# chunk 是按文件分的
chunks = gpt2.load_dataset(enc, "./splited/", combine=1000)

100%|██████████| 4/4 [00:01<00:00,  2.14it/s]


In [270]:
chunks

[array([ 5962, 22307,    25, ...,   606,    11,   198]),
 array([1870, 2582,  314, ..., 1549,   13,  198]),
 array([   43,  9598,  9399, ..., 23137,    13,   198]),
 array([ 2437, 28639,   618, ...,   655,    13,   628])]

In [274]:
samp = gpt2.Sampler(chunks)

In [278]:
samp.sample(20)

array([    0,   198,  2437,   714,   301, 14210, 14782,   262,  1204,
          12, 18041,   286,   262,  1200,    11,   198,  2514,  8406,
         262,  2988])

In [279]:
samp.total_size

338024

In [280]:
samp.boundaries

[0, 81197, 170595, 252659, 338024]

In [281]:
import random
index = random.randint(0, samp.total_size - 20 - 1)

In [282]:
index

337975

In [283]:
def binary_search(f, lo, hi):
    if f(lo) or not f(hi):
        return None
    while hi > lo + 1:
        mid = (lo + hi) // 2
        if f(mid):
            hi = mid
        else:
            lo = mid
    return hi

In [284]:
binary_search(lambda j: samp.boundaries[j] > index, 0, len(samp.boundaries) - 1) - 1

3

In [314]:
context_tokens = samp.sample(1)

In [315]:
out = sess.run(
    output,
    feed_dict={context: 1 * [context_tokens]})

In [316]:
text = enc.decode(out[0])
text

'The former president of a major political party has admitted that he misled many viewers about his role in the campaign, including people who have come forward to suggest he and his team knew about his role during the campaign, according to a statement Thursday by the head'

In [317]:
context_tokens

array([198])

In [318]:
enc.decode([198])

'\n'