In [None]:
import tensorflow as tf
import numpy as np

from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import (TextVectorization, Dense, MultiHeadAttention, LayerNormalization, 
                                     Layer, Embedding, Input, Dropout)
from tensorflow.keras.callbacks import EarlyStopping

# Build dataset

In [None]:
FULL_VOCAB = 'abcdefghijklmnopqrstuvwxyz'
SEQ_LEN = 10

In [None]:
def dataset1(vocab_size=2, dataset_size=10000, seq_len=10, full_vocab=FULL_VOCAB):
  """
  Inserts a space after a the first character in the vocabulary (and nowhere else)
  """
  assert vocab_size > 1
  vocab = list(full_vocab[:vocab_size])
  inputs = np.random.choice(vocab, size=(dataset_size, seq_len))
  outputs = np.where(inputs == vocab[0], 1., 0.).astype(np.float32)  # 1 = space, 0 = no space
  concatenated_inputs = np.array([''.join(row) for row in inputs])
  return concatenated_inputs, outputs

def dataset2(vocab_size=2, dataset_size=10000, seq_len=10, full_vocab=FULL_VOCAB):
  """
  Inserts a space after the combination of 1st->2nd character in the vocabulary (and nowhere else)
  """
  assert vocab_size > 1
  vocab = list(full_vocab[:vocab_size])
  inputs = np.random.choice(vocab, size=(dataset_size, seq_len))
  outputs = np.zeros_like(inputs, dtype=np.float32)
  for i, example in enumerate(inputs):
    previous_char = example[0]
    for j, char in enumerate(example[1:]):
      if (previous_char == vocab[0]) and (char == vocab[1]):  # 1 = space, 0 = no space
        outputs[i, j+1] = 1.
      previous_char = char
  concatenated_inputs = np.array([''.join(row) for row in inputs])
  return concatenated_inputs, outputs

def dataset3(vocab_size=2, dataset_size=10000, seq_len=10, insert_space_every=3, full_vocab=FULL_VOCAB):
  """
  Inserts a space after a certain number of characters, no matter what the characters
  """
  assert vocab_size > 1
  vocab = list(full_vocab[:vocab_size])
  inputs = np.random.choice(vocab, size=(dataset_size, seq_len))
  outputs = np.zeros_like(inputs, dtype=np.float32)
  outputs[:, np.arange(insert_space_every-1, outputs.shape[1], insert_space_every)] = 1.
  concatenated_inputs = np.array([''.join(row) for row in inputs])
  return concatenated_inputs, outputs

In [None]:
DATASET_FN = dataset2

train_ds = tf.data.Dataset.from_tensor_slices(DATASET_FN(vocab_size=2, seq_len=SEQ_LEN))
valid_ds = tf.data.Dataset.from_tensor_slices(DATASET_FN(vocab_size=2, seq_len=SEQ_LEN))
test_ds = tf.data.Dataset.from_tensor_slices(DATASET_FN(vocab_size=2, seq_len=SEQ_LEN))
train_ds.element_spec

(TensorSpec(shape=(), dtype=tf.string, name=None),
 TensorSpec(shape=(10,), dtype=tf.float32, name=None))

In [None]:
train_ds = train_ds.shuffle(1000).batch(128)
valid_ds = valid_ds.batch(128)
test_ds = test_ds.batch(128)

In [None]:
for test_inputs, test_outputs in train_ds.take(1):
  pass
print(test_inputs)
tf.print(test_outputs,summarize=-1)

tf.Tensor(
[b'babaaabbaa' b'bbaabbaabb' b'babbbaaaba' b'abbaabbbaa' b'baabbaaaab'
 b'aaabbabbaa' b'aaaaaabaaa' b'ababbaabab' b'abbabababa' b'bbbbaaabab'
 b'bbbbbabbba' b'babbbbaaba' b'bbbbbaaaaa' b'bbbbaabbaa' b'bbbbaaaaba'
 b'abbaababbb' b'baaabbabab' b'aabbaaaaaa' b'abaabaaaab' b'aababaaaaa'
 b'bbbbbaabbb' b'ababbaabbb' b'aababbaaaa' b'bbbabbbbba' b'abbabbbaaa'
 b'baaaabaabb' b'bababaaaab' b'babbbbbbbb' b'aabbaaaabb' b'bbaaaababa'
 b'baaaabaaab' b'ababaaaaba' b'bababbabaa' b'babbababaa' b'baaaaabaaa'
 b'baaaabaaba' b'bbaaaabaaa' b'bababbbbaa' b'bbaaaabaab' b'aaaabbabaa'
 b'aabbabbaaa' b'baabaabbba' b'bbababbbba' b'baaabbabab' b'aabababbaa'
 b'bbbaabbbbb' b'aaaaabbabb' b'bbaabbaaba' b'aabbaaabbb' b'baaabbbbab'
 b'abbaaabbaa' b'aaaabbbbba' b'babbaaabbb' b'aaabbaaabb' b'baaaaababa'
 b'aabbbbaaaa' b'aabbabaaba' b'baaabbaaab' b'baabaabaaa' b'bbbaababbb'
 b'babbaaaaab' b'babbbbabab' b'ababbaabba' b'baaaaabbab' b'ababbbaaab'
 b'bbabbabbbb' b'baababbaaa' b'bbbbbabbab' b'aabbaabaab' b'abaabaa

# Build layers

In [None]:
textvectorization = TextVectorization(split='character')
textvectorization.adapt(train_ds.map(lambda x, y: x))

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


In [None]:
textvectorization.get_vocabulary()

['', '[UNK]', 'a', 'b']

In [None]:
# Test

tv_out = textvectorization(test_inputs)
tv_out

<tf.Tensor: shape=(128, 10), dtype=int64, numpy=
array([[3, 2, 3, ..., 3, 2, 2],
       [3, 3, 2, ..., 2, 3, 3],
       [3, 2, 3, ..., 2, 3, 2],
       ...,
       [3, 2, 2, ..., 2, 2, 3],
       [3, 2, 3, ..., 2, 3, 2],
       [2, 2, 3, ..., 2, 2, 2]])>

In [None]:
def positional_encodings(seq_len, d_model):
    max_wavelength = 10000.

    pos = np.arange(seq_len)
    inx = np.arange(d_model)

    I, P = np.meshgrid(inx, pos)
    pe_even = np.sin(P / max_wavelength**(I/d_model))
    pe_odd = np.cos(P / max_wavelength**(I/d_model))
        
    pe = np.zeros((seq_len, d_model))
    pe[:, ::2] = pe_even[:, ::2]
    pe[:, 1::2] = pe_odd[:, ::2]
    return tf.constant(pe, dtype=tf.float32)

In [None]:
D_MODEL = 32
MAX_TOKENS = textvectorization.vocabulary_size()  # includes padding and UNK tokens

In [None]:
class InputEmbeddings(Layer):
    
    def __init__(self, d_model, pos_encodings, max_tokens, name='input_embeddings', **kwargs):
        super().__init__(name=name, **kwargs)
        self.pos_encodings = pos_encodings
        self.embedding = Embedding(max_tokens, d_model, mask_zero=True)
        
    def compute_mask(self, inputs, mask=None):
        return self.embedding.compute_mask(inputs)
        
    def call(self, inputs):
        n = tf.shape(inputs)[-1]
        pos_encodings = self.pos_encodings[:n, :]
        h = self.embedding(inputs)
        return h + pos_encodings

In [None]:
# Test

input_embeddings = InputEmbeddings(D_MODEL, positional_encodings(SEQ_LEN, D_MODEL), MAX_TOKENS)
emb_out = input_embeddings(tv_out)
emb_out.shape

TensorShape([128, 10, 32])

In [None]:
def get_attention_mask(mask=None):
    if mask is None:
        return None
    mask1 = mask[:, :, None]
    mask2 = mask[:, None, :]
    return mask1 & mask2

In [None]:
class EncoderBlock(Layer):
    
    def __init__(self, num_heads, key_dim, d_model, ff_dim, name='encoder_block', **kwargs):
        super().__init__(name=name, **kwargs)
        self.supports_masking = True  # This will pass on any incoming mask
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.d_model = d_model
        self.ff_dim = ff_dim
        self.multihead_attention = MultiHeadAttention(num_heads, key_dim)
        self.ff = Sequential([
            Dense(ff_dim, activation='relu'),
            Dense(d_model)
        ])
        self.layernorm1 = LayerNormalization()
        self.layernorm2 = LayerNormalization()
        
    def call(self, inputs, mask=None):
        attention_mask = get_attention_mask(mask)
        h = self.multihead_attention(inputs, inputs, attention_mask=attention_mask)
        h = self.layernorm1(inputs + h)
        
        h_ff = self.ff(h)
        return self.layernorm2(h + h_ff)

In [None]:
# Test

encoder_block = EncoderBlock(num_heads=2, key_dim=16, d_model=D_MODEL, ff_dim=32)
enc_block_out = encoder_block(emb_out)
enc_block_out.shape

TensorShape([128, 10, 32])

In [None]:
enc_block_out._keras_mask

<tf.Tensor: shape=(128, 10), dtype=bool, numpy=
array([[ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       ...,
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True]])>

In [None]:
class ClassifierHead(Layer):

  def __init__(self, d_model, dropout_rate, units, name='classifier_head', **kwargs):
    super().__init__(name=name, **kwargs)
    self.supports_masking = True
    self.d_model = d_model
    self.dropout_rate = dropout_rate
    self.units = units
    self.dense1 = Dense(units, activation='relu')
    self.dropout = Dropout(dropout_rate)
    self.dense2 = Dense(1)

  def call(self, inputs):
    batch_size = tf.shape(inputs)[0]
    seq_len = tf.shape(inputs)[1]
    h = self.dense1(inputs)
    h = self.dropout(h)
    h = self.dense2(h)
    return tf.reshape(h, (batch_size, seq_len))

In [None]:
# Test

classifier_head = ClassifierHead(D_MODEL, dropout_rate=0.1, units=32)
head_out = classifier_head(enc_block_out)
print(head_out._keras_mask)
head_out.shape

tf.Tensor(
[[ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 ...
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]], shape=(128, 10), dtype=bool)


TensorShape([128, 10])

# Transformer Model

In [None]:
class Transformer(Model):

  def __init__(self, d_model, seq_len, max_tokens, num_heads, key_dim, ff_dim, dropout_rate, units,
               textvectorization, name='transformer', **kwargs):
    super().__init__(name=name, **kwargs)
    self.d_model = d_model
    self.seq_len = seq_len
    self.max_tokens = max_tokens
    self.num_heads = num_heads
    self.key_dim = key_dim
    self.ff_dim = ff_dim
    self.dropout_rate = dropout_rate
    self.units = units
    self.textvectorization = textvectorization
    self.input_embeddings = InputEmbeddings(d_model, positional_encodings(seq_len, d_model),
                                            max_tokens)
    self.encoder_block = EncoderBlock(num_heads=num_heads, key_dim=key_dim, d_model=d_model, ff_dim=ff_dim)
    self.classifier_head = ClassifierHead(d_model, dropout_rate=dropout_rate, units=units)

  def train_step(self, data):
    inputs, y_true = data
    with tf.GradientTape() as tape:
      y_pred = self(inputs)
      loss = self.compiled_loss(y_true, y_pred, regularization_losses=self.losses)
    grads = tape.gradient(loss, self.trainable_variables)
    self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
    self.compiled_metrics.update_state(y_true, y_pred)
    return {m.name: m.result() for m in self.metrics} 

  def test_step(self, data):
    inputs, y_true = data
    y_pred = self(inputs)
    loss = self.compiled_loss(y_true, y_pred, regularization_losses=self.losses)
    self.compiled_metrics.update_state(y_true, y_pred)
    return {m.name: m.result() for m in self.metrics} 

  def call(self, inputs):
    h = self.textvectorization(inputs)
    h = self.input_embeddings(h)
    h = self.encoder_block(h)
    h = self.classifier_head(h)
    return h

In [None]:
NUM_HEADS = 2
KEY_DIM = 16
FF_DIM = 32
DROPOUT_RATE = 0.1
UNITS = 20

# transformer = Sequential([
#     textvectorization,
#     InputEmbeddings(D_MODEL, positional_encodings(SEQ_LEN, D_MODEL), MAX_TOKENS, input_shape=(SEQ_LEN,)),
#     EncoderBlock(num_heads=2, key_dim=16, d_model=D_MODEL, ff_dim=32),
#     ClassifierHead(D_MODEL, dropout_rate=0.1, units=20)
# ])
transformer = Transformer(D_MODEL, SEQ_LEN, MAX_TOKENS, NUM_HEADS, KEY_DIM, FF_DIM,
                          DROPOUT_RATE, UNITS, textvectorization)
_ = transformer(test_inputs)
transformer.summary()

Model: "transformer"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 text_vectorization (TextVec  multiple                 0         
 torization)                                                     
                                                                 
 input_embeddings (InputEmbe  multiple                 128       
 ddings)                                                         
                                                                 
 encoder_block (EncoderBlock  multiple                 6464      
 )                                                               
                                                                 
 classifier_head (Classifier  multiple                 681       
 Head)                                                           
                                                                 
Total params: 7,273
Trainable params: 7,273
Non-trainab

In [None]:
# Test

transformer_out = transformer(test_inputs)
transformer_out.shape

TensorShape([128, 10])

In [None]:
transformer_out._keras_mask

<tf.Tensor: shape=(128, 10), dtype=bool, numpy=
array([[ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       ...,
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True]])>

In [None]:
def masked_binary_crossentropy(y_true, y_pred):
  """
  y_true: shape (batch_size, seq_len). 1. = no space, 2. = space
  y_pred: shape (batch_size, seq_len, 1). Logits
  """
  labels = y_true  # 0 = no space, 1 = space

  # Deal with Keras 'feature' that squeezes out the last dimension silently (WTF)
  # if tf.shape(y_pred)[-1] == 1:  
  #   logits = tf.squeeze(y_pred, axis=-1)  # (batch_size, seq_len)
  # else:
  logits = y_pred
  probs = tf.nn.sigmoid(logits)
  bce = - labels * tf.math.log(probs) - ((1 - labels) * tf.math.log(1 - probs))

  return tf.reduce_mean(bce)

In [None]:
# Test

masked_binary_crossentropy(test_outputs, transformer_out)

<tf.Tensor: shape=(), dtype=float32, numpy=0.57050836>

In [None]:
for test_inputs, test_outputs in train_ds.take(1):
  y_pred = transformer(test_inputs)
  loss = masked_binary_crossentropy(test_outputs, y_pred)

loss

<tf.Tensor: shape=(), dtype=float32, numpy=0.57452714>

In [None]:
from tensorflow.keras.metrics import BinaryAccuracy, Precision, Recall, AUC

earlystopping = EarlyStopping(patience=2, monitor="val_binary_accuracy")
transformer.compile(loss=masked_binary_crossentropy, optimizer='adam',
                    metrics=[BinaryAccuracy(), Precision(), Recall(), AUC(curve='PR')])

history = transformer.fit(train_ds, validation_data=valid_ds, epochs=20, callbacks=[earlystopping])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20


In [None]:
transformer.evaluate(test_ds, return_dict=True)



{'loss': 0.0005387350684031844,
 'binary_accuracy': 1.0,
 'precision_2': 1.0,
 'recall_2': 1.0,
 'auc_2': 1.0}

In [None]:
transformer(['ababab'+'b'*(SEQ_LEN-6)])

<tf.Tensor: shape=(1, 10), dtype=float32, numpy=
array([[-20.69848 ,  12.559976, -21.522467,   9.896942, -20.546288,
          8.217764, -11.243217, -14.479351, -14.61993 , -12.18022 ]],
      dtype=float32)>