In [1]:
import tensorflow as tf
import sentencepiece as spm
import sys

sys.path.append('kaggle/input/axiom-utils')
from llm_components import load_sp_tokenizer, LMDatasetLoader

In [2]:
sp = spm.SentencePieceProcessor()
sp.load('sp_tokenizer.model')

True

In [3]:
def sp_tokenize(line: tf.Tensor) -> tf.Tensor:
    tokens = sp.encode(line.numpy().decode('utf-8'), out_type= int)
    return tf.constant(tokens, dtype= tf.int32)

In [4]:
line = tf.constant(b'Hello world')
sp_tokenize(line)

<tf.Tensor: shape=(3,), dtype=int32, numpy=array([ 6740, 15920,  1036], dtype=int32)>

In [5]:
line.numpy().decode('utf-8')

'Hello world'

In [6]:
def tf_sp_tokenize(line: tf.Tensor) -> tf.Tensor:
    tokens = tf.py_function(
        func= sp_tokenize,
        inp= [line], 
        Tout= tf.int32
    )
    tokens.set_shape([None])
    return tokens

In [7]:
tf_sp_tokenize(line)

<tf.Tensor: shape=(3,), dtype=int32, numpy=array([ 6740, 15920,  1036], dtype=int32)>

In [8]:
AUTOTUNE = tf.data.AUTOTUNE
ds = tf.data.TextLineDataset('kaggle/input/wikitext-103/wikitext-103/wiki.train.tokens')

In [9]:
def print_data(num: int = 2) -> None:
    for item in ds.take(num):
        print(item)

In [10]:
print_data()

tf.Tensor(b' ', shape=(), dtype=string)
tf.Tensor(b' = Valkyria Chronicles III = ', shape=(), dtype=string)


In [11]:
ds = ds.map(
    tf_sp_tokenize,
    num_parallel_calls= AUTOTUNE
)
print_data()

tf.Tensor([], shape=(0,), dtype=int32)
tf.Tensor([   48   250  1288 15933  3199  7429   408  3282    48], shape=(9,), dtype=int32)


In [12]:
ds = ds.flat_map(
    lambda x: tf.data.Dataset.from_tensor_slices(x)
)
print_data(10)

tf.Tensor(48, shape=(), dtype=int32)
tf.Tensor(250, shape=(), dtype=int32)
tf.Tensor(1288, shape=(), dtype=int32)
tf.Tensor(15933, shape=(), dtype=int32)
tf.Tensor(3199, shape=(), dtype=int32)
tf.Tensor(7429, shape=(), dtype=int32)
tf.Tensor(408, shape=(), dtype=int32)
tf.Tensor(3282, shape=(), dtype=int32)
tf.Tensor(48, shape=(), dtype=int32)
tf.Tensor(2471, shape=(), dtype=int32)


In [13]:
ds = ds.batch(8 + 1, drop_remainder= True)
print_data()

tf.Tensor([   48   250  1288 15933  3199  7429   408  3282    48], shape=(9,), dtype=int32)
tf.Tensor([ 2471 15977     0   671   250  1288 15933  3199   233], shape=(9,), dtype=int32)


In [14]:
ds = ds.map(
    lambda x: (x[:-1], x[1:]),
    num_parallel_calls= tf.data.AUTOTUNE
)
print_data(1)

(<tf.Tensor: shape=(8,), dtype=int32, numpy=
array([   48,   250,  1288, 15933,  3199,  7429,   408,  3282],
      dtype=int32)>, <tf.Tensor: shape=(8,), dtype=int32, numpy=
array([  250,  1288, 15933,  3199,  7429,   408,  3282,    48],
      dtype=int32)>)


In [15]:
ds = ds.batch(8)
print_data(1)

(<tf.Tensor: shape=(8, 8), dtype=int32, numpy=
array([[   48,   250,  1288, 15933,  3199,  7429,   408,  3282],
       [ 2471, 15977,     0,   671,   250,  1288, 15933,  3199],
       [  281, 15914,     0,  2149,     0,  7429,   408,   119],
       [  281, 15914,     0, 15967,    11,  1156,    19,   250],
       [15933,  3199,    29,     8,  2132,  1825,   233,   118],
       [ 5037,  3353,    37,    92,   250,  1288, 15933,  3199],
       [  408,  3282,  2524,  1150,    11,   120,     6, 12694],
       [   83,  2014,  1296,   481,  1699,   113, 10082,    36]],
      dtype=int32)>, <tf.Tensor: shape=(8, 8), dtype=int32, numpy=
array([[  250,  1288, 15933,  3199,  7429,   408,  3282,    48],
       [15977,     0,   671,   250,  1288, 15933,  3199,   233],
       [15914,     0,  2149,     0,  7429,   408,   119,  1764],
       [15914,     0, 15967,    11,  1156,    19,   250,  1288],
       [ 3199,    29,     8,  2132,  1825,   233,   118,    11],
       [ 3353,    37,    92,   250,  128

In [16]:
sp = load_sp_tokenizer()
loader = LMDatasetLoader(sp, seq_len= 8, batch_size= 8)

In [17]:
train_ds = loader.create('kaggle/input/wikitext-103/wikitext-103/wiki.train.tokens', training= False)

In [18]:
for item in train_ds.take(1):
    print(item)

(<tf.Tensor: shape=(8, 8), dtype=int32, numpy=
array([[   48,   250,  1288, 15933,  3199,  7429,   408,  3282],
       [   48,  2471, 15977,     0,   671,   250,  1288, 15933],
       [ 3199,   233,   281, 15914,     0,  2149,     0,  7429],
       [  408,   119,  1764,   281, 15914,     0, 15967,    11],
       [ 1156,    19,   250,  1288, 15933,  3199,    29,     8],
       [ 2132,  1825,   233,   118,    11,  5037,  3353,    37],
       [   92,   250,  1288, 15933,  3199,  7429,   408,  3282],
       [ 2524,  1150,    11,   120,     6, 12694,  1308,    83]],
      dtype=int32)>, <tf.Tensor: shape=(8, 8), dtype=int32, numpy=
array([[  250,  1288, 15933,  3199,  7429,   408,  3282,    48],
       [ 2471, 15977,     0,   671,   250,  1288, 15933,  3199],
       [  233,   281, 15914,     0,  2149,     0,  7429,   408],
       [  119,  1764,   281, 15914,     0, 15967,    11,  1156],
       [   19,   250,  1288, 15933,  3199,    29,     8,  2132],
       [ 1825,   233,   118,    11,  503

In [19]:
for item1, item2 in zip(ds.take(1), train_ds.take(1)):
    print(item1[0] == item2[0])

tf.Tensor(
[[ True  True  True  True  True  True  True  True]
 [False False False False False False False False]
 [False False False False  True False False False]
 [False False False False False False False False]
 [False False False False False False False False]
 [False False False False False False False False]
 [False False False False False False False False]
 [False False False False False False False False]], shape=(8, 8), dtype=bool)


In [20]:
item1[0]

<tf.Tensor: shape=(8, 8), dtype=int32, numpy=
array([[   48,   250,  1288, 15933,  3199,  7429,   408,  3282],
       [ 2471, 15977,     0,   671,   250,  1288, 15933,  3199],
       [  281, 15914,     0,  2149,     0,  7429,   408,   119],
       [  281, 15914,     0, 15967,    11,  1156,    19,   250],
       [15933,  3199,    29,     8,  2132,  1825,   233,   118],
       [ 5037,  3353,    37,    92,   250,  1288, 15933,  3199],
       [  408,  3282,  2524,  1150,    11,   120,     6, 12694],
       [   83,  2014,  1296,   481,  1699,   113, 10082,    36]],
      dtype=int32)>

In [21]:
item2[0]

<tf.Tensor: shape=(8, 8), dtype=int32, numpy=
array([[   48,   250,  1288, 15933,  3199,  7429,   408,  3282],
       [   48,  2471, 15977,     0,   671,   250,  1288, 15933],
       [ 3199,   233,   281, 15914,     0,  2149,     0,  7429],
       [  408,   119,  1764,   281, 15914,     0, 15967,    11],
       [ 1156,    19,   250,  1288, 15933,  3199,    29,     8],
       [ 2132,  1825,   233,   118,    11,  5037,  3353,    37],
       [   92,   250,  1288, 15933,  3199,  7429,   408,  3282],
       [ 2524,  1150,    11,   120,     6, 12694,  1308,    83]],
      dtype=int32)>