In [19]:
 
import jax.numpy as jnp 
import re
import cupy as cp
import pickle
import time
import numpy as np 
import jax.numpy as jnp
import pandas as pd
import numpy as np
import jax
from tqdm import tqdm
from pathlib import Path

def log_time(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()  # Record start time
        result = func(*args, **kwargs)  # Execute the wrapped function
        end_time = time.time()  # Record end time
        elapsed_time = end_time - start_time
        print(f"Function '{func.__name__}' executed in {elapsed_time:.4f} seconds")
        return result

    return wrapper


class Transformer:

    def __init__(self,embedding_size,complete_text,epochs,batch_size,flush_vocab=True):
        self.embedding_size=embedding_size
        self.defaultkey=jax.random.key(55)
        self.epochs=epochs
        self.batch_size= batch_size
        self.flush_vocab=flush_vocab
        self.vocabulary=self.create_vocabulary(complete_text)
        
        pass

    def cross_entropy_loss(self, predictions, target):
        # Cross-entropy loss for a batch of predictions and targets
        batch_loss = -jnp.sum(target * jnp.log(predictions + 1e-9), axis=1)
        return jnp.mean(batch_loss)

    def softmax(self, x, axis=-1):
        x = jnp.clip(x, -1e4, 1e4)  # Clip for numerical stability
        e_x = jnp.exp(x - jnp.max(x, axis=axis, keepdims=True))
        return e_x / jnp.sum(e_x, axis=axis, keepdims=True)
    
    @log_time
    def create_vocabulary(self,complete_text):

        existing_vocab = Path("data/vocabulary.pkl")
        if existing_vocab.is_file() and self.flush_vocab==False:
            with open('data/vocabulary.pkl', 'rb') as f:
                vocabulary=pickle.load(f)

        else: 
            text = re.sub(r'[^\w\s]', ' ', complete_text).split()
            words_list = list(set(text))
            vocabulary=dict()
            for i in words_list:
                vocabulary[i]=jax.random.uniform(jax.random.key(np.random.randint(10000)),self.embedding_size)
            
            print("Vocabulary size: ", len(vocabulary))
            with open('data/vocabulary.pkl', 'wb') as handle:
                pickle.dump(vocabulary, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
        return vocabulary
    
  
 
    @log_time
    def generate_input(self, batch_phrases):

        print(batch_phrases)
        phrase_vectors=[]   
        xi=[]
        
        phrase_vectors=[x.split() for x in batch_phrases]
        
        xi=jnp.array([[self.vocabulary[word] for word in phrase_vector] for phrase_vector in phrase_vectors])
        
        yield xi
 

    
    
    def train(self,X_train,yi):

         for epoch in range(self.epochs):
            self.iterations=0
            total_loss = 0 
            num_batches_per_epoch = len(X_train) // self.batch_size

            for i in tqdm(range(num_batches_per_epoch), desc=f"Epoch {epoch + 1}/{self.epochs}"):
                
                start = i * self.batch_size
                end = start + self.batch_size
                X_batch_phrases = X_train[start:end]

                for xi in self.generate_input(X_batch_phrases): 
                    print(xi) 

                
                # start = i * self.batch_size
                # end = start + self.batch_size
                
                # X_batch = X_train[start:end]
                # y_batch = y_train[start:end] 
            # if self.validation_split > 0:
            #     self.training_validation(X_train_validation, y_train_validation, Loss)
            # print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {(total_loss / num_batches_per_epoch):.4f}, learning rate: {self.learning_rate}")


        
 
df = pd.read_csv("data/bbc-text.csv")
# complete_text = ' '.join(df['text'].str.split()) 
complete_text = ' '.join(df["text"].tolist())
embedding_size = 512 

model = Transformer(embedding_size,complete_text,epochs=1,batch_size=2,flush_vocab=True)
 
X_train=["i love soy sauce", 
         "my dog is cute", 
         "you are crazy strong",
         "the friend is good"]
yi=[]
model.train(X_train,yi) 

Vocabulary size:  29457
Function 'create_vocabulary' executed in 3.2156 seconds


Epoch 1/1: 100%|██████████| 2/2 [00:00<?, ?it/s]

Function 'generate_input' executed in 0.0000 seconds
['i love soy sauce', 'my dog is cute']
[[[0.46195328 0.13579082 0.6548934  ... 0.16695213 0.07322681 0.21672976]
  [0.6734519  0.4317391  0.8660369  ... 0.21976137 0.05613589 0.0024066 ]
  [0.6274011  0.67449486 0.631884   ... 0.4727906  0.3639574  0.31044948]
  [0.42594814 0.50421154 0.39506936 ... 0.5835451  0.7005005  0.9929346 ]]

 [[0.57852244 0.3697554  0.23043132 ... 0.9848201  0.38312972 0.8222574 ]
  [0.30123627 0.5854081  0.2083242  ... 0.79928184 0.8889991  0.37048984]
  [0.5242338  0.14272118 0.10307431 ... 0.4563253  0.7924125  0.87941873]
  [0.96102893 0.40895152 0.5783787  ... 0.05882668 0.53607345 0.41859365]]]
Function 'generate_input' executed in 0.0000 seconds
['you are crazy strong', 'the friend is good']
[[[0.68645215 0.6083869  0.36374485 ... 0.9921442  0.42366004 0.5363716 ]
  [0.8448074  0.74833643 0.297593   ... 0.7307147  0.47371364 0.82315874]
  [0.8186954  0.10007882 0.42771804 ... 0.6315285  0.27904475 0.




# Prepare input lookup table

Step-by-Step Process:
Create a Vocabulary:
Map each word to an index (token) and initialize a random vector for each word.
Initialize Embedding Vectors:
For each word in the vocabulary, initialize a random embedding vector (say of dimension 3 or 512).
For Each Input Sequence:
Convert the words to their corresponding vectors using the vocabulary.
Stack the vectors to form an input matrix.

In [11]:
import jax.random   
seed=55

@log_time
def create_vocabulary(complete_text):

    existing_vocab = Path("data/vocabulary.pkl")
    if existing_vocab.is_file() and self.flush_vocab==False:
        with open('data/vocabulary.pkl', 'rb') as f:
            vocabulary=pickle.load(f)

    else: 
        text = re.sub(r'[^\w\s]', ' ', complete_text).split()
        words_list = list(set(text))
        vocabulary=dict()
        for i in words_list:
            vocabulary[i]=jax.random.uniform(jax.random.key(np.random.randint(10000)),self.embedding_size)
        
        print("Vocabulary size: ", len(vocabulary))
        with open('data/vocabulary.pkl', 'wb') as handle:
            pickle.dump(vocabulary, handle, protocol=pickle.HIGHEST_PROTOCOL)

    return vocabulary

@log_time
def generate_input(batch_phrases):

    phrase_vectors=[]   
    xi=[]
    
    phrase_vectors=[x.split() for x in batch_phrases]
    
    xi=jnp.array([[self.vocabulary[word] for word in phrase_vector] for phrase_vector in phrase_vectors])
    
    return xi


embedding_size=512
key=jax.random.key(seed)
x = jax.random.uniform(key,(5,5))
y = jnp.arange(5)
x
 

Array([[0.3952378 , 0.45534575, 0.91380334, 0.29122305, 0.82850766],
       [0.58672273, 0.6880075 , 0.23995149, 0.81122804, 0.08536363],
       [0.34470356, 0.04894471, 0.00256085, 0.6435065 , 0.50082767],
       [0.22316742, 0.05539286, 0.23274505, 0.45073962, 0.51079834],
       [0.01448357, 0.23985529, 0.0051235 , 0.70521474, 0.73882663]],      dtype=float32)

In [12]:
x+=y

In [13]:
#x.at[1:3].set([50,40,6])