In [28]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import random

In [29]:
# https://github.com/karpathy/minGPT/blob/master/play_math.ipynb

def make_dataset():
  ret = []
  for i in range(100):
    for j in range(100):
      s = i+j
      ret.append([i//10, i%10, j//10, j%10, s//100, (s//10)%10, s%10])
  return ret
ds = make_dataset()
random.shuffle(ds)
ds = np.array(ds)
ds_X = ds[:, 0:6]
ds_Y = np.copy(ds[:, 1:])
ds_X_train, ds_X_test = ds_X[0:8000], ds_X[8000:]
ds_Y_train, ds_Y_test = ds_Y[0:8000], ds_Y[8000:]

In [30]:
# https://keras.io/examples/nlp/text_classification_with_transformer/

class MultiHeadSelfAttention(layers.Layer):
  def __init__(self, embed_dim, num_heads=8):
    super(MultiHeadSelfAttention, self).__init__()
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    assert embed_dim % num_heads == 0
    self.projection_dim = embed_dim // num_heads
    
    # 4 * embed_dim**2 weights
    self.query_dense = layers.Dense(embed_dim)
    self.key_dense = layers.Dense(embed_dim)
    self.value_dense = layers.Dense(embed_dim)
    self.combine_heads = layers.Dense(embed_dim)
    
  def separate_heads(self, x, batch_size):
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim))
    return tf.transpose(x, perm=[0, 2, 1, 3])
  
  def call(self, inputs):
    batch_size = tf.shape(inputs)[0]
    
    # get QKV
    query = self.separate_heads(self.query_dense(inputs), batch_size)
    key = self.separate_heads(self.key_dense(inputs), batch_size)
    value = self.separate_heads(self.value_dense(inputs), batch_size)
    
    # self attention
    score = tf.matmul(query, key, transpose_b=True)
    dim_key = tf.cast(tf.shape(key)[-1], tf.float32)
    scaled_score = score / tf.math.sqrt(dim_key)
    weights = tf.nn.softmax(scaled_score, axis=-1)
    attention = tf.matmul(weights, value)
    
    # output
    attention = tf.transpose(attention, perm=[0, 2, 1, 3])
    concat_attention = tf.reshape(attention, (batch_size, -1, self.embed_dim))
    return self.combine_heads(concat_attention)

class TransformerBlock(layers.Layer):
  def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
    super(TransformerBlock, self).__init__()
    self.att = MultiHeadSelfAttention(embed_dim, num_heads)
    self.ffn = keras.Sequential(
      [layers.Dense(ff_dim, activation="relu"),
       layers.Dense(embed_dim),]
    )
    self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
    self.dropout1 = layers.Dropout(rate)
    self.dropout2 = layers.Dropout(rate)

  def call(self, inputs, training):
    attn_output = self.att(inputs)
    attn_output = self.dropout1(attn_output, training=training)
    out1 = self.layernorm1(inputs + attn_output)
    ffn_output = self.ffn(out1)
    ffn_output = self.dropout2(ffn_output, training=training)
    return self.layernorm2(out1 + ffn_output)
  
class TokenAndPositionEmbedding(layers.Layer):
  def __init__(self, maxlen, vocab_size, embed_dim):
    super(TokenAndPositionEmbedding, self).__init__()
    self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
    self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

  def call(self, x):
    positions = tf.range(start=0, limit=tf.shape(x)[-1], delta=1)
    positions = self.pos_emb(positions)
    x = self.token_emb(x)
    
    return x + positions

In [31]:
maxlen = 10
x = in1 = layers.Input(shape=(6,))
x = TokenAndPositionEmbedding(maxlen, 10, 128)(x)
x = TransformerBlock(128, 4, 32)(x)
x = TransformerBlock(128, 4, 32)(x)
x = layers.Dense(10)(x)
x = layers.Softmax()(x)
m = keras.Model(in1, x)
m.compile('adam', 'sparse_categorical_crossentropy')

In [32]:
m.summary()

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_6 (InputLayer)        [(None, 6)]               0         
                                                                 
 token_and_position_embeddin  (None, 6, 128)           2560      
 g_5 (TokenAndPositionEmbedd                                     
 ing)                                                            
                                                                 
 transformer_block_2 (Transf  (None, 6, 128)           74912     
 ormerBlock)                                                     
                                                                 
 transformer_block_3 (Transf  (None, 6, 128)           74912     
 ormerBlock)                                                     
                                                                 
 dense_37 (Dense)            (None, 6, 10)             1290

In [33]:
m.fit(ds_X_train, ds_Y_train, epochs=10, verbose=1)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x22284954a30>

In [35]:
aa = m.predict(ds_X_test)
correct = ds_Y_test[:, -1] == np.argmax(aa, axis=2)[:, -1]
sum(correct), correct.shape

(2000, (2000,))