In [1]:
 
import re
import cupy as cp
import pickle
import time
import numpy as np 
import jax.numpy as jnp
import pandas as pd
import numpy as np
import jax
from tqdm import tqdm
from pathlib import Path

def log_time(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()  # Record start time
        result = func(*args, **kwargs)  # Execute the wrapped function
        end_time = time.time()  # Record end time
        elapsed_time = end_time - start_time
        print(f"Function '{func.__name__}' executed in {elapsed_time:.4f} seconds")
        return result

    return wrapper


class Transformer:

    def __init__(self,embedding_size,
                 complete_text_origin,
                 complete_text_target,
                 max_lenght_phrase,
                 epochs,
                 batch_size,
                 dv,
                 num_heads,
                 flush_vocab=True):
        
        self.embedding_size=embedding_size
        self.max_lenght_phrase=max_lenght_phrase
        self.defaultkey=jax.random.key(55)
        self.epochs=epochs
        self.batch_size= batch_size
        self.flush_vocab=flush_vocab
        self.dv=dv
        self.dk=dv
        self.num_heads=num_heads
        if complete_text_target!="":
            complete_text=complete_text_origin+" "+complete_text_target+" [START] [PAD] [END] "
            self.vocabulary=self.create_vocabulary(complete_text,"vocabulary") 
        else:
            complete_text=complete_text_origin+" [START] [PAD] [END] "
            self.vocabulary=self.create_vocabulary(complete_text,"vocabulary")

        # Initialize weights with Xavier/Glorot initialization
        self.Q_Encoder = np.random.randn(self.embedding_size, self.dv) / np.sqrt(self.embedding_size)  # * 0.01
        self.K_Encoder = np.random.randn(self.embedding_size, self.dv) / np.sqrt(self.embedding_size)  # * 0.01
        self.V_Encoder = np.random.randn(self.embedding_size, self.dv) / np.sqrt(self.embedding_size)  # * 0.01
        self.linearlayerAttentionEncoder= np.random.rand(self.batch_size,dv, embedding_size)   
        self.linear_biasAttentionEncoder = np.random.rand(self.batch_size,1,embedding_size)

    def cross_entropy_loss(self, predictions, target):
        # Cross-entropy loss for a batch of predictions and targets
        batch_loss = -jnp.sum(target * jnp.log(predictions + 1e-9), axis=1)
        return jnp.mean(batch_loss)

    def softmax(self, x, axis=-1):
        x = jnp.clip(x, -1e4, 1e4)  # Clip for numerical stability
        e_x = jnp.exp(x - jnp.max(x, axis=axis, keepdims=True))
        return e_x / jnp.sum(e_x, axis=axis, keepdims=True)
    
    #@log_time
    def create_vocabulary(self,complete_text,name):

        existing_vocab = Path(f"data/{name}.pkl")
        if existing_vocab.is_file() and self.flush_vocab==False:
            with open('data/vocabulary.pkl', 'rb') as f:
                vocabulary=pickle.load(f)

        else: 
            # Use re.findall to split considering punctuation
            text = re.findall(r'\[.*?\]|\w+|[^\w\s]', complete_text)
            words_list = list(set(text))
            vocabulary=dict()
            for i in words_list:
                vocabulary[i]=jax.random.uniform(jax.random.key(np.random.randint(10000)),self.embedding_size)
            
            print("Vocabulary size: ", len(vocabulary))
            with open(f"data/{name}.pkl", 'wb') as handle:
                pickle.dump(vocabulary, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
        return vocabulary
    
  
 
    #@log_time
    def generate_input(self, x_batch,y_batch):

        #print("batch prases original:\n",x_batch)
         
        xi=[]
        y_batch=[" ".join(y) for y in y_batch]
        #print(y_batch)
        phrase_vectors_x = [re.findall(r'\[.*?\]|\w+|[^\w\s]', x) for x in x_batch]
        phrase_vectors_y = [re.findall(r'\[.*?\]|\w+|[^\w\s]', y) for y in y_batch]
        #print("phrase_vectors:\n",phrase_vectors)

        xi=jnp.array([[self.vocabulary[word] for word in phrase_vector] for phrase_vector in phrase_vectors_x])
        yi=jnp.array([[self.vocabulary[word] for word in phrase_vector] for phrase_vector in phrase_vectors_y])
        
        yield xi,yi
 
    #@log_time
    def pad_sequences(self,sentences, pad_token='[PAD]'):
        """
        Pads the input sentences to have the same length by adding [PAD] tokens at the end.
        """
        
        # Split each sentence into words
        tokenized_sentences = [re.findall(r'\[.*?\]|\w+|[^\w\s]', sentence) for sentence in sentences]
        
        max_lenght=max(len(sentence) for sentence in tokenized_sentences)

        if self.max_lenght_phrase==0: 
            # Find the maximum sentence length
            self.max_lenght_phrase = max_lenght
         
        if self.max_lenght_phrase>max_lenght:
            #print("self.max_lenght_phrase>max_lenght")
            # Pad each sentence with the [PAD] token to make them of equal length
            padded_sentences = [" ".join(map(str, sentence + [pad_token] * (self.max_lenght_phrase - len(sentence)))) for sentence in tokenized_sentences]
        else: 
            padded_sentences=[" ".join(map(str, (sentence + [pad_token] * (self.max_lenght_phrase - len(sentence)))[0:self.max_lenght_phrase])) for sentence in tokenized_sentences]
         


        # print("-----------------")
        # self.print_matrix(padded_sentences)
        # print("-----------------\n\n")
        return padded_sentences
    
    
    # add <start> <end> tokens
    #@log_time
    def padding_start_end_tokens_target(self,yi):
        return [f'[START] {sentence} [END]' for sentence in yi]
    

    #@log_time     
    def preprocess_target(self,yi):
        y=[]
        yi=sorted(yi, key=lambda x: len(x.split()), reverse=True)
        for j in yi:
            phrase=j.split()
            for i in range(1,len(phrase)):
                y.append(phrase[0:i]+["[PAD]" for x in range(len(phrase)-i-1)]+["[END]"])
        return y
     
    def print_matrix(self,x):
        for i in x:
            print(i)

    def layer_norm(self,x, epsilon=1e-6):
    # Calculate the mean and variance
        mean = jnp.mean(x, axis=-1, keepdims=True)
        var = jnp.var(x, axis=-1, keepdims=True) 
        # Normalize the output
        x_norm = (x - mean) / jnp.sqrt(var + epsilon) 
        return x_norm

    #@log_time 
    def MultiHeadsAttentionEncoder(self, Inputs):
        #print("cupy.matmul(Inputs, self.Q)", cupy.matmul(Inputs, self.Q).shape)
        #print("cupy.array_split:",cupy.array(cupy.array_split(cupy.matmul(Inputs, self.Q), self.num_heads, axis=2)).shape)
        # print("Kval shape:", self.Kval.shape)
        # print("Vval shape:", self.Vval.shape)
        # print("Q  shape:", self.Q_Encoder.shape)
        # print("K  shape:", self.K_Encoder.shape)
        # print("V  shape:", self.V_Encoder.shape)
        self.Qval_Encoder = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(Inputs, self.Q_Encoder), self.num_heads, axis=2)), 0, 1)
        self.Kval_Encoder = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(Inputs, self.K_Encoder), self.num_heads, axis=2)), 0, 1)
        self.Vval_Encoder = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(Inputs, self.V_Encoder), self.num_heads, axis=2)), 0, 1)
        # print("Qval shape:", self.Qval_Encoder.shape)
        # print("Kval shape:", self.Kval_Encoder.shape)
        # print("Vval shape:", self.Vval_Encoder.shape)

        QKscaled_Encoder = jnp.matmul(self.Qval_Encoder, jnp.transpose(self.Kval_Encoder, (0, 1, 3, 2))) / jnp.sqrt(self.dk)
        # print("QKscaled shape:", QKscaled.shape)

        self.Attention_weights_Encoder = self.softmax(QKscaled_Encoder)
        # print("Attention_weights shape:", self.Attention_weights.shape)

        Attention_output_Encoder = jnp.matmul(self.Attention_weights_Encoder, self.Vval_Encoder)
        # print("Attention output shape:", Attention_output.shape)
        Attention_output_Encoder=jnp.array([jnp.concatenate(Attention_output_Encoder[i], axis=1) for i in range(self.batch_size)])

        return Attention_output_Encoder
    
    def linear_layer_attention_encoder(self,Encoder_attention_output): 
        return jnp.matmul(Encoder_attention_output,self.linearlayerAttentionEncoder)+self.linear_biasAttentionEncoder

    def inputs_add_and_norm_encoder(self,Encoder_attention_output, Inputs): 
        input_dimension_remappedEncoder=self.linear_layer_attention_encoder(Encoder_attention_output)+Inputs
        return self.layer_norm(input_dimension_remappedEncoder)
    
    def forward_step_encoder(self,Inputs):
        Encoder_attention_output=self.MultiHeadsAttentionEncoder(Inputs)
 
        output_sublayer_one=self.inputs_add_and_norm_encoder(Encoder_attention_output, Inputs)

        
        
        print("output_sublayer_one shape",output_sublayer_one.shape)
        
        pass

    def train(self,X_train,y_train):
         y_train=self.padding_start_end_tokens_target(self.pad_sequences(y_train))
         
         X_train=self.pad_sequences(X_train)
         
         for epoch in range(self.epochs):
            self.iterations=0
            total_loss = 0 
            num_batches_per_epoch = len(X_train) // self.batch_size
         
            for i in tqdm(range(num_batches_per_epoch), desc=f"Epoch {epoch + 1}/{self.epochs}"):

                start = i * self.batch_size
                end = start + self.batch_size
                X_batch_phrases = X_train[start:end]
                y_batch_phrases= y_train[start:end]
                y_batch_phrases=self.preprocess_target(y_batch_phrases)
                #self.print_matrix(y_batch_phrases)

                for xi,yi in self.generate_input(X_batch_phrases,y_batch_phrases): 
                    #print("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
                    #self.print_matrix(xi)
                    print("input shape ",xi.shape,type(xi)) 
                    #print("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
                    #self.print_matrix(yi)
                    #print("target shape",yi.shape,type(yi)) 
                    self.forward_step_encoder(xi)
                     


X_train=["i love soy sauce!", 
         "my dog... is cute", 
         "you are crazy strong!",
         "the friend is good, you know"]
y_train=["amo la salsa di soia!",
        "il cane... è tenero",
        "sei pazzo potente!",
        "l'amico è buono, vero?"]        
 
df = pd.read_csv("data/bbc-text.csv")
# complete_text = ' '.join(df['text'].str.split()) 
complete_text_origin = ' '.join(df["text"].tolist())
complete_text_target = ' '.join(y_train)
embedding_size = 15 #this is the initial size for the word embedding
max_words_per_phrase=10# consider +2 then adding start and end token this fix max lengh
batch_size=2 # cosidering 2 phrase per time

model = Transformer(embedding_size=embedding_size,
                    complete_text_origin=complete_text_origin,
                    complete_text_target=complete_text_target,
                    max_lenght_phrase=max_words_per_phrase,
                    epochs=1,
                    batch_size=2,
                    dv=8,
                    num_heads=4,
                    flush_vocab=True)
 

model.train(X_train,y_train) 

Vocabulary size:  29582


Epoch 1/1:   0%|          | 0/2 [00:00<?, ?it/s]

input shape  (2, 10, 15) <class 'jaxlib.xla_extension.ArrayImpl'>


Epoch 1/1: 100%|██████████| 2/2 [00:00<00:00,  4.36it/s]

output_sublayer_one shape (2, 10, 15)
input shape  (2, 10, 15) <class 'jaxlib.xla_extension.ArrayImpl'>
output_sublayer_one shape (2, 10, 15)





In [2]:
y_train=["amo la salsa di soia!",
        "il cane... è tenero",
        "sei pazzo potente!",
        "l'amico è buono, gia lo sai o no puo essere"]    

In [3]:
sorted(y_train, key=lambda x: len(x.split()), reverse=True)

["l'amico è buono, gia lo sai o no puo essere",
 'amo la salsa di soia!',
 'il cane... è tenero',
 'sei pazzo potente!']

In [4]:
sorted(y_train, key=lambda x: len(x.split()), reverse=True)

["l'amico è buono, gia lo sai o no puo essere",
 'amo la salsa di soia!',
 'il cane... è tenero',
 'sei pazzo potente!']

In [5]:
phrase='[START] amo la salsa di soia ! [PAD] [PAD] [PAD] [PAD] [END]'
phrase=phrase.split()
for i in range(1,len(phrase)):
    print(phrase[0:i]+["[PAD]" for x in range(len(phrase)-i-1)]+["[END]"])

['[START]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[END]']
['[START]', 'amo', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[END]']
['[START]', 'amo', 'la', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[END]']
['[START]', 'amo', 'la', 'salsa', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[END]']
['[START]', 'amo', 'la', 'salsa', 'di', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[END]']
['[START]', 'amo', 'la', 'salsa', 'di', 'soia', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[END]']
['[START]', 'amo', 'la', 'salsa', 'di', 'soia', '!', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[END]']
['[START]', 'amo', 'la', 'salsa', 'di', 'soia', '!', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[END]']
['[START]', 'amo', 'la', 'salsa', 'di', 'soia', '!', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[END]']
['[START]', 'amo', 'la', 'salsa', 'di', 'soia', '!', '[PAD]', '[PAD]'

In [6]:
batch_phrases = ["This is a test [PAD]!", "[START] Hello, how are you? [CLOSE]"]

# Modified regex to capture words like [PAD], [START], and [CLOSE]
phrase_vectors = [re.findall(r'\[.*?\]|\w+|[^\w\s]', x) for x in batch_phrases]
phrase_vectors

[['This', 'is', 'a', 'test', '[PAD]', '!'],
 ['[START]', 'Hello', ',', 'how', 'are', 'you', '?', '[CLOSE]']]

# Prepare input lookup table

Step-by-Step Process:
Create a Vocabulary:
Map each word to an index (token) and initialize a random vector for each word.
Initialize Embedding Vectors:
For each word in the vocabulary, initialize a random embedding vector (say of dimension 3 or 512).
For Each Input Sequence:
Convert the words to their corresponding vectors using the vocabulary.
Stack the vectors to form an input matrix.

In [7]:
import jax.random   
seed=55

@log_time
def create_vocabulary(complete_text):

    existing_vocab = Path("data/vocabulary.pkl")
    if existing_vocab.is_file() and flush_vocab==False:
        with open('data/vocabulary.pkl', 'rb') as f:
            vocabulary=pickle.load(f)

    else: 
       # Use re.findall to split considering punctuation
        text = re.findall(r'\w+|[^\w\s]', complete_text)
        words_list = list(set(text))
        vocabulary=dict()
        for i in words_list:
            vocabulary[i]=jax.random.uniform(jax.random.key(np.random.randint(10000)),self.embedding_size)
        
        print("Vocabulary size: ", len(vocabulary))
        with open('data/vocabulary.pkl', 'wb') as handle:
            pickle.dump(vocabulary, handle, protocol=pickle.HIGHEST_PROTOCOL)

    return vocabulary

@log_time
def generate_input(batch_phrases):

    phrase_vectors=[]   
    xi=[]
    
    phrase_vectors=[x.split() for x in batch_phrases]
    
    xi=jnp.array([[vocabulary[word] for word in phrase_vector] for phrase_vector in phrase_vectors])
    
    return xi


embedding_size=512
key=jax.random.key(seed)
x = jax.random.uniform(key,(5,5))
y = jnp.arange(5)
x
 

Array([[0.3952378 , 0.45534575, 0.91380334, 0.29122305, 0.82850766],
       [0.58672273, 0.6880075 , 0.23995149, 0.81122804, 0.08536363],
       [0.34470356, 0.04894471, 0.00256085, 0.6435065 , 0.50082767],
       [0.22316742, 0.05539286, 0.23274505, 0.45073962, 0.51079834],
       [0.01448357, 0.23985529, 0.0051235 , 0.70521474, 0.73882663]],      dtype=float32)

# Prepare input for decoder

In [8]:
def pad_sequences(sentences,lenght=0, pad_token='[PAD]'):
        """
        Pads the input sentences to have the same length by adding [PAD] tokens at the end.
        """
        # Split each sentence into words
        tokenized_sentences = [sentence.split() for sentence in sentences]
        
        if lenght==0: 
            # Find the maximum sentence length
            max_len = max(len(sentence) for sentence in tokenized_sentences)
        else:
            max_len=lenght
        
        # Pad each sentence with the [PAD] token to make them of equal length
        padded_sentences = [" ".join(sentence + [pad_token] * (max_len - len(sentence))) for sentence in tokenized_sentences]
        
        return padded_sentences
    
    
    # add <start> <end> tokens
def prepropcess_target(yi):
    return [f'[START] {sentence} [END]' for sentence in yi]

In [9]:
text = "i love soy sauce!"

# Use re.findall to split considering punctuation
words = re.findall(r'\w+|[^\w\s]', text)

print(words)

['i', 'love', 'soy', 'sauce', '!']


In [10]:
prepropcess_target(pad_sequences(yi,8))

NameError: name 'yi' is not defined

In [14]:
pad_sequences(X_train,8)

['i love soy sauce [PAD] [PAD] [PAD] [PAD]',
 'my dog is cute [PAD] [PAD] [PAD] [PAD]',
 'you are crazy strong [PAD] [PAD] [PAD] [PAD]',
 'the friend is good [PAD] [PAD] [PAD] [PAD]']