In [1]:
import re
import numpy as np
import tensorflow as tf
import pandas as pd

from Attention import *

In [2]:
find_tokens = re.compile(r"[\w']+|[.,!?:;\[\]\(\)]")

def load_tokens(tokens_file):
    """
        Load the pre-trained word embeddings from the embedding file
        """
    # Load the word tokens
    tokens = pd.read_csv(tokens_file, index_col=0)
    tokens = tokens.index.values
    tokens_dict = {}
    tokens_dict['to_token'] = {token: i for i, token in enumerate(tokens)}
    tokens_dict['to_word'] = {i: token for i, token in enumerate(tokens)}
    return tokens_dict

def sim_tokenize(text, tokens_dict):
    """
        Tokenize the text and return the token indices
        """
    # Tokenize the text
    tokens = find_tokens.findall(text.lower())

    # Get the token indices
    tokens_indices = [tokens_dict['to_token'][token] for token in tokens]
    tokens_indices.insert(0, tokens_dict['to_token']['<start>'])
    tokens_indices.append(tokens_dict['to_token']['<end>'])
    [tokens_indices.append(tokens_dict['to_token']['<pad>']) for i in range(10)]
    return tokens_indices

def detokenize(tokens_indices, tokens_dict):
    """
        Convert the token indices to text
        """
    # Detokenize the text
    tokens = [tokens_dict['to_word'][i] for i in tokens_indices]
    text = ' '.join(tokens).replace('<pad>', '')
    return text


In [3]:
tok = load_tokens('tokens.csv')

In [4]:
tokenized = sim_tokenize('I am a test sentence.', tok)
rebuilt = detokenize(tokenized, tok)
print(f"Tokenized: {tokenized}")
print(f"Rebuilt: {rebuilt}")

Tokenized: [3, 11, 166, 12, 870, 4325, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Rebuilt: <start> i am a test sentence . <end>          


In [5]:
# q:  (128, 8, 96, 16) k:  (128, 8, 96, 16) v:  (128, 8, 96, 16)
x = tf.random.uniform((128, 96, 128))

mha = MultiHeadAttention(128, 8)

output, attn = mha(x, x, x, None)
print(output.shape, attn.shape)

(128, 96, 128) (128, 8, 96, 96)


In [6]:
# q:  (128, 8, 96, 16) k:  (128, 8, 96, 16) v:  (128, 8, 96, 16)
x = tf.random.uniform((128, 96, 128))

mha = Fastformer_MultiHeadAttention(128, 8, 96)

output = mha(x, x, x, None)
print(output.shape, attn.shape)


(128, 96, 128) (128, 8, 96, 96)
