In [113]:
import tensorflow as tf
import sentencepiece as spm

In [114]:
sp = spm.SentencePieceProcessor()
sp.load('sp_tokenizer.model')

True

In [115]:
def sp_tokenize(line: tf.Tensor) -> tf.Tensor:
    tokens = sp.encode(line.numpy().decode('utf-8'), out_type= int)
    return tf.constant(tokens, dtype= tf.int32)

In [116]:
line = tf.constant(b'Hello world')
sp_tokenize(line)

<tf.Tensor: shape=(3,), dtype=int32, numpy=array([ 6740, 15920,  1036], dtype=int32)>

In [117]:
line.numpy().decode('utf-8')

'Hello world'

In [118]:
def tf_sp_tokenize(line: tf.Tensor) -> tf.Tensor:
    tokens = tf.py_function(
        func= sp_tokenize,
        inp= [line], 
        Tout= tf.int32
    )
    tokens.set_shape([None])
    return tokens

In [119]:
tf_sp_tokenize(line)

<tf.Tensor: shape=(3,), dtype=int32, numpy=array([ 6740, 15920,  1036], dtype=int32)>

In [128]:
AUTOTUNE = tf.data.AUTOTUNE
ds = tf.data.TextLineDataset('kaggle/input/wikitext-103/wikitext-103/wiki.train.tokens')

In [129]:
def print_data(num: int = 2) -> None:
    for item in ds.take(num):
        print(item)

In [130]:
print_data()

tf.Tensor(b' ', shape=(), dtype=string)
tf.Tensor(b' = Valkyria Chronicles III = ', shape=(), dtype=string)


In [131]:
ds = ds.map(
    tf_sp_tokenize,
    num_parallel_calls= AUTOTUNE
)
print_data()

tf.Tensor([], shape=(0,), dtype=int32)
tf.Tensor([   48   250  1288 15933  3199  7429   408  3282    48], shape=(9,), dtype=int32)


In [132]:
ds = ds.flat_map(
    lambda x: tf.data.Dataset.from_tensor_slices(x)
)
print_data(10)

tf.Tensor(48, shape=(), dtype=int32)
tf.Tensor(250, shape=(), dtype=int32)
tf.Tensor(1288, shape=(), dtype=int32)
tf.Tensor(15933, shape=(), dtype=int32)
tf.Tensor(3199, shape=(), dtype=int32)
tf.Tensor(7429, shape=(), dtype=int32)
tf.Tensor(408, shape=(), dtype=int32)
tf.Tensor(3282, shape=(), dtype=int32)
tf.Tensor(48, shape=(), dtype=int32)
tf.Tensor(2471, shape=(), dtype=int32)


In [133]:
ds = ds.batch(8 + 1, drop_remainder= True)
print_data()

tf.Tensor([   48   250  1288 15933  3199  7429   408  3282    48], shape=(9,), dtype=int32)
tf.Tensor([ 2471 15977     0   671   250  1288 15933  3199   233], shape=(9,), dtype=int32)


In [134]:
ds = ds.map(
    lambda x: (x[:-1], x[1:]),
    num_parallel_calls= tf.data.AUTOTUNE
)
print_data(1)

(<tf.Tensor: shape=(8,), dtype=int32, numpy=
array([   48,   250,  1288, 15933,  3199,  7429,   408,  3282],
      dtype=int32)>, <tf.Tensor: shape=(8,), dtype=int32, numpy=
array([  250,  1288, 15933,  3199,  7429,   408,  3282,    48],
      dtype=int32)>)


In [137]:
ds = ds.batch(8)
print_data(1)

(<tf.Tensor: shape=(8, 8, 8), dtype=int32, numpy=
array([[[   48,   250,  1288, 15933,  3199,  7429,   408,  3282],
        [ 2471, 15977,     0,   671,   250,  1288, 15933,  3199],
        [  281, 15914,     0,  2149,     0,  7429,   408,   119],
        [  281, 15914,     0, 15967,    11,  1156,    19,   250],
        [15933,  3199,    29,     8,  2132,  1825,   233,   118],
        [ 5037,  3353,    37,    92,   250,  1288, 15933,  3199],
        [  408,  3282,  2524,  1150,    11,   120,     6, 12694],
        [   83,  2014,  1296,   481,  1699,   113, 10082,    36]],

       [[15936, 15981,   566,    84,     8,  5213,  2112,   411],
        [ 6337,   413,    31,  1113,  1361,    31,  1150,    11],
        [  120,     8,  1093,   481,    31,     8,   250,  1288],
        [ 3199,   687,    19,  1871,  1613,    33,     8,   905],
        [   29, 12694,    36,  1324,    83,   390,  5521,    92],
        [13841,    11,     8,  1160,  2194,  5399,    37,     8],
        [  481,    36,  