## Enabling and testing the TPU

In [None]:
import datetime
import functools
import math
import os

import tensorflow as tf
import tensorflow_datasets as tfds

# %load_ext tensorboard

print("Tensorflow version " + tf.__version__)

Tensorflow version 2.15.0


In [None]:
tf.debugging.set_log_device_placement(False)

GPU = tf.config.list_logical_devices('GPU')
GPU_STRATEGY = tf.distribute.MirroredStrategy(GPU)

print(GPU)

[LogicalDevice(name='/device:GPU:0', device_type='GPU')]


## Defining The Metadata

In [None]:
# META ########################################################################

N_DEPTH = 2 # D
N_TOKEN_DIM = 4 # G
N_ENCODING_DIM = 256 # U
N_EMBEDDING_DIM = N_ENCODING_DIM # E
N_LATENT_DIM = N_EMBEDDING_DIM # L

N_EPOCHS = 16
N_EPOCHS_RAMPUP = 0
N_EPOCHS_SUSTAIN = 0

N_BATCH = 128 # number of samples per batch
N_SAMPLE = 128 # number of characters per sample (=> N_TOKEN_DIM * N_SAMPLE int per sample)

R_MIN = 0.0001
R_MAX = 0.001
R_EXP = .9

VERSION = 'tokun-4-keras-1M200K'

## Loading The Data

In [None]:
# DATA ########################################################################

LANG = ['ar', 'de', 'en', 'es', 'hi', 'vi', 'zh']
TRAIN = {__l: tfds.load('mlqa/' + __l, split='test', as_supervised=False, shuffle_files=True, data_dir='~/.cache/tensorflow/', batch_size=N_BATCH) for __l in LANG}
TEST = {__l: tfds.load('mlqa/' + __l, split='validation', as_supervised=False, shuffle_files=True, data_dir='~/.cache/tensorflow/', batch_size=N_BATCH) for __l in LANG}

## LAYERS

In [None]:
# EMBEDDING ###################################################################

class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(
        self,
        input_axis: int=1, # axis of the sequence
        output_axis: int=-1, # axis of the embedding
        **kwargs
    ):
        super(PositionalEmbedding, self).__init__(**kwargs)
        self._input_axis = input_axis
        self._output_axis = output_axis
        self._kernel = None

    def build(self, input_shape: tuple):
        # shape
        __axes = [self._input_axis % len(input_shape), self._output_axis % len(input_shape)]
        __shape = [(__d if __i in __axes else 1) for __i, __d in enumerate(list(input_shape))]
        # init values
        __kernel_init = tf.keras.initializers.GlorotNormal()
        # register the weights
        self._kernel = self.add_weight(name="kernel", shape=__shape, initializer=__kernel_init)
        # notify the model
        self.built = True

    def call(self, inputs: tf.Tensor):
        return inputs + self._kernel # each index in the sequence axis has a dedicated bias (different from dense bias)

# RESHAPING ###################################################################

def _normalize_shape(shape: list) -> list:
    return [-1 if __d is None else __d for __d in shape]

def _normalize_dim(dim: int) -> int:
    return -1 if (dim is None or dim < 0) else dim

def _multiply_dim(dim_l: int, dim_r: int) -> int:
    return -1 if (dim_l == -1 or dim_r == -1) else dim_l * dim_r

def _divide_dim(dim_l: int, dim_r: int) -> int:
    return -1 if (dim_l == -1 or dim_r == -1) else dim_l // dim_r

class Divide(tf.keras.layers.Layer):
    def __init__(
        self,
        input_axis: int, # relative to the NEW shape / rank
        output_axis: int, # same
        factor: int,
        insert: bool=False,
        **kwargs
    ) -> None:
        super(Divide, self).__init__(**kwargs)
        self._input_axis = input_axis
        self._output_axis = output_axis
        self._factor = factor
        self._insert = insert

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        # infer the dimension of the symbolic axis
        __shape = _normalize_shape(list(inputs.shape))
        # rank, according to the new shape
        __rank = len(__shape) + int(self._insert)
        # axes, taken from the new shape
        __axis0 = self._input_axis % __rank
        __axis1 = self._output_axis % __rank
        # option to group data on a new axistho i do it with other
        if self._insert: __shape.insert(__axis1, 1)
        # move data from axis 0 to axis 1
        __shape[__axis0] = _divide_dim(__shape[__axis0], self._factor)
        __shape[__axis1] = _multiply_dim(__shape[__axis1], self._factor)
        return tf.reshape(tensor=inputs, shape=__shape)

class Merge(tf.keras.layers.Layer):
    def __init__(
        self,
        left_axis: int=-2,
        right_axis: int=-1,
        left: bool=True,
        **kwargs
    ) -> None:
        super(Merge, self).__init__(**kwargs)
        self._left_axis = left_axis
        self._right_axis = right_axis
        self._left = left

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        # infer the dimension of the symbolic axis
        __shape = _normalize_shape(list(inputs.shape))
        __rank = len(__shape)
        # target axes
        __axis_l = self._left_axis % __rank
        __axis_r = self._right_axis % __rank
        # new axis
        __dim = _multiply_dim(__shape[__axis_l], __shape[__axis_r])
        __axis_k = __axis_l if self._left else __axis_r # kept axis
        __axis_d = __axis_r if self._left else __axis_l # deleted axis
        # new shape
        __shape[__axis_k] = __dim
        __shape.pop(__axis_d)
        # actually merge the two axes
        return tf.reshape(tensor=inputs, shape=__shape)

## Blocks

In [None]:
# ENCODING BLOCKS #############################################################

class TokenizeBlock(tf.keras.layers.Layer):
    def __init__(
        self,
        left_axis: int=-2,
        right_axis: int=-1,
        token_dim: int=4,
        latent_dim: int=256,
        **kwargs
    ) -> None:
        super(TokenizeBlock, self).__init__(**kwargs)
        # layers
        self._divide = Divide(input_axis=0, output_axis=1, factor=token_dim, insert=True, name='group') # (B * G, E) => (B, G, E)
        self._embedding = PositionalEmbedding(input_axis=left_axis, output_axis=right_axis, name='position-embeddings') # (B, G, E) + (1, G, E)
        self._merge = Merge(left_axis=left_axis, right_axis=right_axis, left=True, name='merge-embeddings') # (B, G, E) => (B, G * E)
        self._dense = tf.keras.layers.Dense(units=latent_dim, activation='relu', use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', name='compress-embeddings') # (B, G * E) => (B, L), typically L = E

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        return self._dense(self._merge(self._embedding(self._divide(inputs))))

# DECODING BLOCKS #############################################################

class DetokenizeBlock(tf.keras.layers.Layer):
    def __init__(
        self,
        token_dim: int=4,
        embedding_dim: int=256,
        **kwargs
    ) -> None:
        super(DetokenizeBlock, self).__init__(**kwargs)
        # layers
        self._dense = tf.keras.layers.Dense(units=token_dim * embedding_dim, activation='relu', use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', name='decompress-embeddings') # (B, L) => (B, G * E), typically L = E
        self._divide = Divide(input_axis=-2, output_axis=-1, insert=True, factor=embedding_dim, name='divide-embeddings') # (B, G * E) => (B, G, E)
        self._embedding = PositionalEmbedding(input_axis=-2, output_axis=-1, name='position-embeddings') # (B, G, E) + (1, G, E)
        self._merge = Merge(left_axis=0, right_axis=1, left=True) # (B, G, E) => (B * G, E)

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        return self._merge(self._embedding(self._divide(self._dense(inputs))))

# HEAD BLOCK ##################################################################

class HeadBlock(tf.keras.layers.Layer):
    def __init__(
        self,
        encoding_dim: int=256,
        **kwargs
    ) -> None:
        super(HeadBlock, self).__init__(**kwargs)
        # layers
        self._dense = tf.keras.layers.Dense(units=encoding_dim, activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', name='project-head') # (..., G, E) => (..., G, U), typically U = E
        self._softmax = tf.keras.layers.Softmax(axis=-1, name='softmax') # (..., G, U)

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        return self._softmax(self._dense(inputs))

## Model

In [None]:
# ENCODER #####################################################################

class Encoder(tf.keras.models.Model):
    def __init__(self, token_dim: int, encoding_dim: int, embedding_dim: int, latent_dim: int, batch_dim: int=None, **kwargs) -> None:
        super(Encoder, self).__init__(**kwargs)
        self._encoder = tf.keras.Sequential([
            tf.keras.Input(shape=(encoding_dim,), batch_size=batch_dim, name='input'), # (B * G * G, U)
            tf.keras.layers.Dense(units=embedding_dim, activation=None, use_bias=False, kernel_initializer='glorot_uniform', bias_initializer=None, name='embed-1'), # (B * G * G, U) => (B * G * G, E)
            TokenizeBlock(left_axis=-2, right_axis=-1, token_dim=token_dim, latent_dim=latent_dim, name='tokenize-4'), # (B * G * G, E) => (B * G, E)
            TokenizeBlock(left_axis=-2, right_axis=-1, token_dim=token_dim, latent_dim=latent_dim, name='tokenize-4-4'),]) # (B * G, E) => (B, E)

    def call(self, x: tf.Tensor) -> tf.Tensor:
        return self._encoder(x)

# DECODER #####################################################################

class Decoder(tf.keras.models.Model):
    def __init__(self, token_dim: int, encoding_dim: int, embedding_dim: int, latent_dim: int, batch_dim: int=None, **kwargs) -> None:
        super(Decoder, self).__init__(**kwargs)
        self._decoder = tf.keras.Sequential([
            tf.keras.Input(shape=(latent_dim,), batch_size=batch_dim, name='input'), # (B, E)
            DetokenizeBlock(token_dim=token_dim, embedding_dim=embedding_dim, name='detokenize-4-4'), # (B, E) => (B * G, E)
            DetokenizeBlock(token_dim=token_dim, embedding_dim=embedding_dim, name='detokenize-4'), # (B * G, E) => (B * G * G, E)
            HeadBlock(encoding_dim=encoding_dim, name='project-head')]) # (B * G, E) => (B * G, U)

    def call(self, x: tf.Tensor) -> tf.Tensor:
        return self._decoder(x)

# VAE #########################################################################

class AutoEncoder(tf.keras.models.Model):
    def __init__(self, token_dim: int, encoding_dim: int, embedding_dim: int, latent_dim: int, batch_dim: int=None, **kwargs) -> None:
        super(AutoEncoder, self).__init__(**kwargs)
        self._encoder = Encoder(token_dim=token_dim, encoding_dim=encoding_dim, embedding_dim=embedding_dim, latent_dim=latent_dim, batch_dim=batch_dim)
        self._decoder = Decoder(token_dim=token_dim, encoding_dim=encoding_dim, embedding_dim=embedding_dim, latent_dim=latent_dim, batch_dim=batch_dim)

    def call(self, x: tf.Tensor) -> tf.Tensor:
        return self._decoder(self._encoder(x))

In [None]:
with GPU_STRATEGY.scope():
  MODEL = AutoEncoder(token_dim=N_TOKEN_DIM, encoding_dim=N_ENCODING_DIM, embedding_dim=N_EMBEDDING_DIM, latent_dim=N_LATENT_DIM, batch_dim=None)
  MODEL.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=R_MAX),
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False, label_smoothing=0., axis=-1, reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, name='loss'),
    metrics=['accuracy'])


## Train

In [None]:
# CONTROL #####################################################################

def learning_rate_hokusai(epoch: int, lr_min: float, lr_max: float, lr_exp: float, rampup: int, sustain: int) -> float:
    __lr = lr_min
    if epoch < rampup:
        __lr = lr_min + (epoch * (lr_max - lr_min) / rampup)
    elif epoch < rampup + sustain:
        __lr = lr_max
    else:
        __lr = lr_min + (lr_max - lr_min) * lr_exp ** (epoch - rampup - sustain)
    return __lr

In [None]:
# PREPROCESS ##################################################################

def shape(layer_count: int, group_size: int, flatten: bool=False) -> list:
    return [-1] + (1 - int(flatten)) * layer_count * [group_size]

def _tokenize_scalar(text: str, layer_count: int=1, group_size: int=4, flatten: bool=False) -> tf.Tensor:
    __mod = group_size ** layer_count
    __bytes = list(text.encode('utf-32-be'))
    __shape = shape(layer_count=layer_count, group_size=group_size, flatten=flatten)
    __padding = (-len(__bytes) % __mod) * [0]
    __tensor = tf.convert_to_tensor(value=__bytes + __padding, dtype=tf.dtypes.int32) # uint8 is not allowed
    return tf.reshape(tensor=__tensor, shape=__shape)

def tokenize(data: tf.Tensor, layer_count: int=1, group_size: int=4, sample_size: int=64, flatten: bool=False) -> tf.Tensor:
    # make sure each sample has a length multiple of G ** L = T, the token dim
    __mod = group_size ** layer_count
    __dim = math.ceil(4 * sample_size / __mod) * __mod # factor 4 because of the UTF-32 encoding
    # output shape
    __shape = shape(layer_count=layer_count, group_size=group_size, flatten=flatten)
    # Decode bytes from UTF-8
    __bytes = tf.strings.unicode_transcode(input=data, input_encoding='UTF-8', output_encoding='UTF-32-BE') # (B,)
    # Decode byte strings to arrays of integers
    __ints = tf.io.decode_raw(__bytes, out_type=tf.uint8, fixed_length=__dim) # (B, 4 * S)
    # group the characters into tokens
    return tf.reshape(tensor=__ints, shape=__shape) # for example (-1, G, G, G) the first dimension is not B

def preprocess(dataset: tf.data.Dataset, key: str='context', layer_count: int=1, group_size: int=4, sample_size: int=64, flatten: bool=False) -> tf.data.Dataset:
    # from UTF-8 bytes scalar to UTF-32-BE int tensor
    __dataset = dataset.map(lambda x: tokenize(data=x[key], layer_count=layer_count, group_size=group_size, sample_size=sample_size, flatten=flatten))
    # one-hot encoding of UTF-32 bytes
    __dataset = __dataset.map(lambda x: tf.one_hot(indices=x, depth=256, axis=-1))
    # produce (input, target) tuples for supervised training, instead of a single tensor X
    return __dataset.map(lambda x: (x,x))

In [None]:
TRAIN = {__l: preprocess(dataset=__d, key='context', layer_count=N_DEPTH, group_size=N_TOKEN_DIM, sample_size=N_SAMPLE, flatten=True) for __l, __d in TRAIN.items()}
TEST = {__l: preprocess(dataset=__d, key='context', layer_count=N_DEPTH, group_size=N_TOKEN_DIM, sample_size=N_SAMPLE, flatten=True) for __l, __d in TEST.items()}

In [None]:
# SAVE ########################################################################

# log path
LOGPATH = os.path.join('.logs/', VERSION, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
SUMMARY = tf.summary.create_file_writer(LOGPATH)

In [None]:
# TRAIN #######################################################################

# called during training
tb_callback = tf.keras.callbacks.TensorBoard(log_dir=LOGPATH)
lr_callback = tf.keras.callbacks.LearningRateScheduler(functools.partial(learning_rate_hokusai, lr_min=R_MIN, lr_max=0.4 * R_MAX, lr_exp=R_EXP, rampup=N_EPOCHS_RAMPUP, sustain=N_EPOCHS_SUSTAIN), verbose=True)

TRAINING_HISTORY = MODEL.fit(
    x=TRAIN['ar'].concatenate(TRAIN['en']).concatenate(TRAIN['es']).concatenate(TRAIN['de']).concatenate(TRAIN['hi']).concatenate(TRAIN['vi']).concatenate(TRAIN['zh']),
    batch_size=N_BATCH,
    epochs=N_EPOCHS,
    validation_split=None,
    validation_data=TEST['ar'], # full of glyphs
    validation_freq=list(range(1, N_EPOCHS + 1, N_EPOCHS // 8)),
    verbose=2,
    callbacks=[lr_callback, tb_callback])


Epoch 1: LearningRateScheduler setting learning rate to 0.0004.
Epoch 1/16
334/334 - 23s - loss: 0.0054 - accuracy: 0.9987 - val_loss: 0.0011 - val_accuracy: 0.9998 - lr: 4.0000e-04 - 23s/epoch - 70ms/step

Epoch 2: LearningRateScheduler setting learning rate to 0.00037000000000000005.
Epoch 2/16
334/334 - 23s - loss: 0.0037 - accuracy: 0.9992 - lr: 3.7000e-04 - 23s/epoch - 70ms/step

Epoch 3: LearningRateScheduler setting learning rate to 0.00034300000000000004.
Epoch 3/16
334/334 - 24s - loss: 0.0026 - accuracy: 0.9995 - val_loss: 8.0122e-04 - val_accuracy: 0.9999 - lr: 3.4300e-04 - 24s/epoch - 71ms/step

Epoch 4: LearningRateScheduler setting learning rate to 0.00031870000000000005.
Epoch 4/16
334/334 - 23s - loss: 0.0019 - accuracy: 0.9996 - lr: 3.1870e-04 - 23s/epoch - 70ms/step

Epoch 5: LearningRateScheduler setting learning rate to 0.00029683000000000004.
Epoch 5/16
334/334 - 24s - loss: 0.0014 - accuracy: 0.9998 - val_loss: 6.0164e-04 - val_accuracy: 0.9999 - lr: 2.9683e-04 -

In [None]:
MODEL.summary()

Model: "auto_encoder_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder_2 (Encoder)         multiple                  592384    
                                                                 
 decoder_2 (Decoder)         multiple                  594176    
                                                                 
Total params: 1186560 (4.53 MB)
Trainable params: 1186560 (4.53 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


## Dataviz

In [None]:
# GENERIC #####################################################################

def _label(c: str) -> str:
    return '#{}'.format(c.encode('utf-32-be').hex())

def label(token: str) -> str:
    return ' '.join(_label(__c) for __c in token)

def compare(left: str, right: str) -> float:
    return sum(__l == __r for __l, __r in zip(left, right)) / max(1, len(left))

def chunk(seq: list, size: int, repeats: bool=True) -> list:
    __chunks = (seq[__i:__i+size] for __i in range(0, len(seq), size))
    return list(__chunks if repeats else set(__chunks))

In [None]:
# POSTPROCESS #################################################################

def interpret(output: tf.Tensor) -> tf.Tensor:
    return tf.argmax(input=output, axis=-1, output_type=tf.dtypes.int32) # uint8 is not allowed

def detokenize(tokens: tf.Tensor) -> str:
    __b = tf.reshape(tensor=tokens, shape=(-1,)).numpy().tolist()
    return bytes(__b).decode(encoding='utf-32-be', errors='replace')

def postprocess(output: tf.Tensor) -> tf.Tensor:
    # from one-hot to indices
    __output = interpret(output=output)
    # flatten
    return detokenize(tokens=__output)

In [None]:
# SAVE ########################################################################

def write(data: any, path: str, tsv: bool=True) -> None:
    with open(path, 'w') as __f:
      for __row in data:
        __line = '\t'.join(str(__v) for __v in __row) if tsv else str(__row)
        __f.write(__line + '\n')

In [None]:
# SAMPLES #####################################################################

SAMPLES = {}
TOKENS = {1: {}, 4: {}, 16: {}}
EMBEDDINGS = {1: {}, 4: {}, 16: {}}

for __l in TEST:
    # compute predictions
    __i = iter(TEST[__l]) # iterate over batches of samples
    __x = next(__i)[0] # take input only
    __o = MODEL(__x)
    # sample predictions (inputs, outputs)
    SAMPLES[__l] = (__x, __o)
    # unique 1-tokens (characters)
    TOKENS[1][__l] = chunk(seq=postprocess(__x), size=1, repeats=False)
    # unique 4-tokens
    TOKENS[4][__l] = chunk(seq=postprocess(__x), size=4, repeats=False)

TOKENS[1]['all'] = list(set(__t for _, __s in TOKENS[1].items() for __t in __s))
TOKENS[4]['all'] = list(set(__t for _, __s in TOKENS[4].items() for __t in __s))

In [None]:
# EMBEDDINGS ##################################################################

for __l, __s in TOKENS[1].items():
    # re-encode without token repeats
    __token_x = tf.one_hot(indices=_tokenize_scalar(text=''.join(__s), layer_count=N_DEPTH, group_size=4, flatten=True), depth=256, axis=-1)
    # embed
    EMBEDDINGS[1][__l] = MODEL._encoder._encoder.layers[1](MODEL._encoder._encoder.layers[0](__token_x))[:len(__s)]

for __l, __s in TOKENS[4].items():
    # re-encode without token repeats
    __token_x = tf.one_hot(indices=_tokenize_scalar(text=''.join(__s), layer_count=N_DEPTH, group_size=4, flatten=True), depth=256, axis=-1)
    # embed
    EMBEDDINGS[4][__l] = MODEL._encoder(__token_x)[:len(__s)]

In [None]:
# SAVE ########################################################################

write(data=[__c + ' ' + label(__c) for __c in TOKENS[1]['all']], path='./metadata.1.tsv', tsv=False)
write(data=EMBEDDINGS[1]['all'].numpy(), path='./embeddings.1.tsv', tsv=True)

write(data=[__c + ' ' + label(__c) for __c in TOKENS[4]['all']], path='./metadata.4.tsv', tsv=False)
write(data=EMBEDDINGS[4]['all'].numpy(), path='./embeddings.4.tsv', tsv=True)

In [None]:
MODEL.save('model.keras', save_format='keras')

In [None]:
# TEST ########################################################################

__s = """class Encoder(tf.keras.models.Model):\n    def __init__(self, depth: int, token_dim: int, encoding_dim: int, embedding_dim: int, latent_dim: int, batch_dim: int=None, attention: bool=False, **kwargs) -> None:\n        super(Encoder, self).__init__(**kwargs)\n        self._encoder = tf.keras.Sequential([\n            tf.keras.Input(shape=(encoding_dim,), batch_size=batch_dim, name='input'), # (B * G ^ D, U)\n            tf.keras.layers.Dense(units=embedding_dim, activation=None, use_bias=False, kernel_initializer='glorot_uniform', bias_initializer=None, name='embed-1'),] # (B * G ^ D, U) => (B * G ^ D, E)\n            + [_mmtl.TokenizeBlock(left_axis=-2, right_axis=-1, token_dim=token_dim, latent_dim=latent_dim, attention=attention, name='tokenize' + (__i + 1) * '-4') for __i in range(depth)]) # (B * G ^ i, E) => (B * G ^ (i-1), E)\n\n    def call(self, x: tf.Tensor) -> tf.Tensor:\n        return self._encoder(x)\n"""

__x = tf.one_hot(indices=_tokenize_scalar(text=__s, layer_count=N_DEPTH, group_size=4, flatten=True), depth=256, axis=-1)
__e = MODEL._encoder(__x)
__p = MODEL(__x)
__y = postprocess(__p)

print(__s)
print(__y)
print(compare(__s, __y))

class Encoder(tf.keras.models.Model):
    def __init__(self, depth: int, token_dim: int, encoding_dim: int, embedding_dim: int, latent_dim: int, batch_dim: int=None, attention: bool=False, **kwargs) -> None:
        super(Encoder, self).__init__(**kwargs)
        self._encoder = tf.keras.Sequential([
            tf.keras.Input(shape=(encoding_dim,), batch_size=batch_dim, name='input'), # (B * G ^ D, U)
            tf.keras.layers.Dense(units=embedding_dim, activation=None, use_bias=False, kernel_initializer='glorot_uniform', bias_initializer=None, name='embed-1'),] # (B * G ^ D, U) => (B * G ^ D, E)
            + [_mmtl.TokenizeBlock(left_axis=-2, right_axis=-1, token_dim=token_dim, latent_dim=latent_dim, attention=attention, name='tokenize' + (__i + 1) * '-4') for __i in range(depth)]) # (B * G ^ i, E) => (B * G ^ (i-1), E)

    def call(self, x: tf.Tensor) -> tf.Tensor:
        return self._encoder(x)

class Encoder(tf.keras.models.Model):G    def __init__(self, depth: int, token_dim

In [None]:
__l = postprocess(SAMPLES['de'][0])
__r = postprocess(SAMPLES['de'][1])

print(__l)
print(__r)
print(compare(__l, __r))

1.0
