# Positional encoding

In [1]:
import numpy as np 
import jax.numpy as jnp
np.set_printoptions(edgeitems=30, linewidth=100000, formatter=dict(float=lambda x: "%.3g" % x)) 
def get_positional_encoding(seq_len, d_model):
    """
    Returns a non-learnable (sinusoidal) positional encoding.
    
    seq_len: Length of the input sequence.
    d_model: Dimension of the embeddings.
    """
    pos = np.arange(seq_len)[:, np.newaxis]  # Shape: [seq_len, 1]
    i = np.arange(d_model)[np.newaxis, :]    # Shape: [1, d_model]

    angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))

    # Apply sine to even indices, cosine to odd indices
    pos_encoding = np.zeros((seq_len, d_model))
    pos_encoding[:, 0::2] = np.sin(pos * angle_rates[:, 0::2])  # sine on even indices
    pos_encoding[:, 1::2] = np.cos(pos * angle_rates[:, 1::2])  # cosine on odd indices

    return pos_encoding

def softmax(x, axis=-1):
    # Subtract the max value for numerical stability
    e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return e_x / np.sum(e_x, axis=axis, keepdims=True)
def layer_norm(x, epsilon=1e-6):
    # Calculate the mean and variance
        mean = jnp.mean(x, axis=-1, keepdims=True)
        var = jnp.var(x, axis=-1, keepdims=True) 
        
        # Normalize the output
        x_norm = (x - mean) / jnp.sqrt(var + epsilon) 
        #print(x)
        #print(mean)
        #print("mean",mean.shape)
        #print("x_norm.shape",x_norm.shape)
        return x_norm,mean,var,x.shape[-1]
def relu(x):
    return np.maximum(0, x)

def pad_sequence(seq, max_len, pad_value=0):
    """Pad a sequence with a given value up to max_len."""
    current_len = seq.shape[0]
    pad_width = max_len - current_len
    if pad_width > 0:
        # Pad sequence with zeros (or any pad_value you provide)
        seq = jnp.pad(seq, ((0, pad_width), (0, 0)), mode='constant', constant_values=pad_value)
    return seq

def create_timestaped_input(input_d,words_per_phrase):
    input_translation=[]
    for j in range(input_d.shape[0]):
    # Create padded sequences
        padded_sequences = [pad_sequence(input_d[j][0:i], words_per_phrase) for i in range(1, input_d.shape[1] + 1)]
        input_translation.append(padded_sequences)
    return jnp.array(input_translation)

def cross_entropy_loss(predictions, target):
    # Cross-entropy loss for a batch of predictions and targets
    batch_loss = -jnp.sum(target * jnp.log(predictions + 1e-9), axis=1)
    return jnp.mean(batch_loss)

def diff_norm(X,var,mu,N):
    epsilon=1e-6
    AA=((1-(1/N))*(1/(jnp.sqrt(var+epsilon))))
    BB=(1/N)*((X-mu)**2)
    CC=((var+epsilon)**(3/2))
    result=(AA-(BB/CC)) 
    return result

def redimension(X):
    return jnp.concatenate(jnp.swapaxes(X,0,1),axis=-1) 

def diffQKV(dAttention,Attention_weights,X1,X2,X3,dk):
    dAttention_weights=Attention_weights*(1-Attention_weights)
    V1=redimension(dAttention_weights@X1/jnp.sqrt(dk)) 
    
    V2=redimension(X2)
    
    V3=V1*V2*X3
    dLoss_dX=jnp.sum(jnp.transpose(dAttention,(0,2,1))@V3,axis=0)
    return dLoss_dX

## Input preparation

In [2]:
import re
import cupy as cp
import pickle
import time
import numpy as np 
import jax.numpy as jnp
import pandas as pd
import numpy as np
import jax
from tqdm import tqdm
from pathlib import Path

def create_vocabulary(complete_text,name):

        existing_vocab = Path(f"data/{name}.pkl")
        
        # Use re.findall to split considering punctuation
        text = re.findall(r'\[.*?\]|\w+|[^\w\s]', complete_text)
        words_list = list(set(text))
        vocabulary=dict()
        for i,j in enumerate(words_list):
            vocabulary[j]=(jax.random.uniform(jax.random.key(np.random.randint(10000)),embedding_size),i)
        
        print("Vocabulary size: ", len(vocabulary))
        with open(f"data/{name}.pkl", 'wb') as handle:
            pickle.dump(vocabulary, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
        return vocabulary




def pad_sequences(sentences,lenght=0, pad_token='[PAD]'):
        """
        Pads the input sentences to have the same length by adding [PAD] tokens at the end.
        """
        
        # Split each sentence into words
        tokenized_sentences = [["[START]"]+sentence.split()+["[END]"] for sentence in sentences]
        
        if lenght==0: 
            # Find the maximum sentence length
            max_len = max(len(sentence) for sentence in tokenized_sentences)
        else:
            max_len=lenght
        
        # Pad each sentence with the [PAD] token to make them of equal length
        padded_sentences = [" ".join(sentence + [pad_token] * (max_len - len(sentence))) for sentence in tokenized_sentences]
        
        return padded_sentences
    
def generate_input(x_batch,y_batch,vocabulary,max_len):

        #print("batch prases original:\n",x_batch)
         
        xi=[]
        yi=[]
        #y_batch=[" ".join(y) for y in y_batch]
        #print(y_batch)
        phrase_vectors_x = [re.findall(r'\[.*?\]|\w+|[^\w\s]', x) for x in x_batch]
        phrase_vectors_y = [re.findall(r'\[.*?\]|\w+|[^\w\s]', y) for y in y_batch] 

    
     
        phrase_vectors_x= [i[0:max_len] for i in phrase_vectors_x] 
        print("phrase_vectors_x:\n",phrase_vectors_x)
        phrase_vectors_y= [i[0:max_len] for i in phrase_vectors_y]
        print("phrase_vectors_y:\n",phrase_vectors_y)

        phrase_vectors_target= [i[1:max_len] for i in phrase_vectors_y]
        print("phrase_vectors_target:\n",phrase_vectors_target)
 
        xi=jnp.array([[vocabulary[word][0] for word in phrase_vector] for phrase_vector in phrase_vectors_x])
        yi=jnp.array([[vocabulary[word][0]  for word in phrase_vector] for phrase_vector in phrase_vectors_y])
        ti=jnp.array([[vocabulary[word][1]  for word in phrase_vector] for phrase_vector in phrase_vectors_target])
        return xi,yi,ti
    


 

X_train=["i love soy sauce!", 
         "my dog... is cute", 
         "you are crazy strong!",
         "the friend is good, you know"]
y_train=["amo la salsa di soia!",
        "il cane... è tenero",
        "sei pazzo potente!",
        "l'amico è buono, vero?"]    

num_phrases = 4
words_per_phrase = 10
dk = dv = embedding_size = 4 # constrain of transformer all embedding size both input embedding and attention embedding are the same encoder
num_heads=2

complete_text_origin = ' '.join(X_train)
complete_text_target = ' '.join(y_train)
complete_text=complete_text_origin+" "+complete_text_target+" [START] [PAD] [END] "
vocabulary=create_vocabulary(complete_text,"vocabulary_test")   
vocab_size=len(vocabulary) 


pos_encoding=get_positional_encoding(words_per_phrase,embedding_size)
 
#inputs_e = np.random.rand(num_phrases,words_per_phrase, embedding_size)
x_train=pad_sequences(X_train,words_per_phrase)
y_train=pad_sequences(y_train,words_per_phrase) 
inputs_e,y,tg=generate_input(x_train,y_train,vocabulary,words_per_phrase)
inputs_e=pos_encoding+inputs_e
print("inputs.shape: ",inputs_e.shape)
 
 
#input_d = np.random.rand(num_phrases,words_per_phrase, embedding_size)
inputs_d=y+pos_encoding

print(inputs_d.shape)
# Convert to an array for batching
inputs_d = jnp.swapaxes(create_timestaped_input(y,words_per_phrase),0,1)
 
print("inputs_d complete shape",inputs_d.shape)# shape is: words_per_phrase,num_phrases,words_per_phrase,embedding_size at each
#target_d=np.random.rand(words_per_phrase,num_phrases,vocab_size)
def get_one_hot(index,vocab_size): 
    one_hot_vector = np.zeros(vocab_size)
    one_hot_vector[index] = 1 
    return one_hot_vector
vocab_size=len(vocabulary)
target_d=[[get_one_hot(index, vocab_size) for index in phrase] for phrase in tg]
target_d=jnp.swapaxes(create_timestaped_input(jnp.array(target_d),words_per_phrase),0,1) 
targets_d=[]

for i in range(target_d.shape[0]):
    ff=[]
    #print(i,target_d[i].shape)
    for j in range(target_d[i].shape[0]):
        ff.append(target_d[i][j][i])
        #print(targets_d[i][j][i])
    targets_d.append(ff)
        #print(ff)

targets_d=jnp.array(targets_d)
inputs_d.shape


Vocabulary size:  40
phrase_vectors_x:
 [['[START]', 'i', 'love', 'soy', 'sauce', '!', '[END]', '[PAD]', '[PAD]', '[PAD]'], ['[START]', 'my', 'dog', '.', '.', '.', 'is', 'cute', '[END]', '[PAD]'], ['[START]', 'you', 'are', 'crazy', 'strong', '!', '[END]', '[PAD]', '[PAD]', '[PAD]'], ['[START]', 'the', 'friend', 'is', 'good', ',', 'you', 'know', '[END]', '[PAD]']]
phrase_vectors_y:
 [['[START]', 'amo', 'la', 'salsa', 'di', 'soia', '!', '[END]', '[PAD]', '[PAD]'], ['[START]', 'il', 'cane', '.', '.', '.', 'è', 'tenero', '[END]', '[PAD]'], ['[START]', 'sei', 'pazzo', 'potente', '!', '[END]', '[PAD]', '[PAD]', '[PAD]', '[PAD]'], ['[START]', 'l', "'", 'amico', 'è', 'buono', ',', 'vero', '?', '[END]']]
phrase_vectors_target:
 [['amo', 'la', 'salsa', 'di', 'soia', '!', '[END]', '[PAD]', '[PAD]'], ['il', 'cane', '.', '.', '.', 'è', 'tenero', '[END]', '[PAD]'], ['sei', 'pazzo', 'potente', '!', '[END]', '[PAD]', '[PAD]', '[PAD]', '[PAD]'], ['l', "'", 'amico', 'è', 'buono', ',', 'vero', '?', '[END

(10, 4, 10, 4)

In [3]:
 
inputs_d[1]

Array([[[0.245, 0.303, 0.524, 0.329],
        [0.476, 0.401, 0.445, 0.732],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]],

       [[0.245, 0.303, 0.524, 0.329],
        [0.953, 0.911, 0.0563, 0.625],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]],

       [[0.245, 0.303, 0.524, 0.329],
        [0.194, 0.0428, 0.0714, 0.728],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0]],

       [[0.245, 0.303, 0.524, 0.329],
        [0.612, 0.151, 0.198, 0.165],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
        [0, 0, 0, 0],
     

## Initialiazer

In [4]:
learning_rate=0.01



Qe = np.random.rand(embedding_size, embedding_size) / jnp.sqrt(embedding_size)
Ke = np.random.rand(embedding_size, embedding_size) / jnp.sqrt(embedding_size)
Ve = np.random.rand(embedding_size, embedding_size) / jnp.sqrt(embedding_size)
Qc = np.random.rand(embedding_size, embedding_size) / jnp.sqrt(embedding_size)
Kc = np.random.rand(embedding_size, embedding_size) / jnp.sqrt(embedding_size)
Vc = np.random.rand(embedding_size, embedding_size) / jnp.sqrt(embedding_size)
Qd = np.random.rand(embedding_size, embedding_size) / jnp.sqrt(embedding_size)
Kd = np.random.rand(embedding_size, embedding_size) / jnp.sqrt(embedding_size)
Vd = np.random.rand(embedding_size, embedding_size) / jnp.sqrt(embedding_size)

fl1_size=100
Wfl1e=np.random.rand(embedding_size, fl1_size)   
bfl1e=np.random.rand(fl1_size)

Wfl2e=np.random.rand(fl1_size, dv)   
bfl2e=np.random.rand(dv)


Wfl1d=np.random.rand(embedding_size, fl1_size)    
bfl1d=np.random.rand(fl1_size)

Wfl2d=np.random.rand(fl1_size, dv)    
bfl2d=np.random.rand(dv)

W0=np.random.rand(words_per_phrase*embedding_size,vocab_size)   
b0=np.random.rand(vocab_size)

# Summary Encoder

In [5]:
 
def forward_attention_encoder(inputs_e):
    global Qe,Ke,Ve

    Q_E= jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_e, Qe),num_heads,axis=2)), 0, 1)
    #print("Qval.shape: ",Q_E.shape)

    K_E = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_e, Ke),num_heads,axis=2)), 0, 1)
    #print("Kval.shape: ",K_E.shape)


    V_E = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_e,Ve),num_heads,axis=2)), 0, 1)
    #print("Vval.shape: ",V_E.shape)

    QKscaled = jnp.matmul(Q_E, jnp.transpose(K_E, (0, 1, 3, 2))) / jnp.sqrt(dk)

    Attention_weights_e = softmax(QKscaled)
    #print("Attention_weights shape:",Attention_weights_e.shape)


    Ae = jnp.matmul(Attention_weights_e, V_E)
    #print("Attention shape:",Ae.shape)


    Ae=jnp.array([jnp.concatenate(Ae[i], axis=1) for i in range(num_phrases)])
    #print("Attention shape concat:",Ae.shape)

    Xe=Ae+inputs_e
    Ect1,mu_e,var_e,Ne=layer_norm(Xe)
    #print("Ect1.shape",Ect1.shape,Ne)

    return Ae,Xe,Ect1,mu_e,var_e,Ne,Attention_weights_e,K_E,V_E,Q_E


def fully_connected_layers_encoder(Ect1):
    
    global Wfl1e,bfl1e,Wfl2e,bfl2e
    
    Xe1=jnp.matmul(Ect1,Wfl1e)+bfl1e
    FLe1=relu(Xe1)

    FLe2=jnp.matmul(FLe1,Wfl2e)+bfl2e

    Xe2=FLe2+Ect1
    Ecout,mu_e2,var_e2,N_e2=layer_norm(Xe2)
    #print("Ecout.shape",Ecout.shape)
    return Ecout,mu_e2,var_e2,N_e2,FLe1,Xe1,Xe2





In [6]:
def cross_attention_encoder(Ecout):
    global Kc,Vc 
    K_C  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(Ecout, Kc),num_heads,axis=2)), 0, 1)
    #print("K_C.shape: ",K_C.shape)# shape is: num_phrases, numbheads, words_per_phrase, dv/num_heads 
    V_C  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(Ecout,Vc),num_heads,axis=2)), 0, 1)
    #print("V_C.shape: ",V_C.shape)
    return K_C,V_C




## Decoder forward

In [7]:
def forward_attention_decoder(input_decoder):
    global Qd,Kd,Vd
    Q_D  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(input_decoder, Qd),num_heads,axis=2)), 0, 1)
    #print("Qval.shape: ",Q_D.shape)# numwords, num_phrases, numheads, num_words, dv/num_heads

    #K_D  = jnp.swapaxes(jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_d[step], Kd),num_heads,axis=3)), 0, 1),1,2)
    K_D  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(input_decoder, Kd),num_heads,axis=2)), 0, 1)
    #print("Kval.shape: ",K_D.shape)


    #V_D  = jnp.swapaxes(jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_d[step], Vd),num_heads,axis=3)), 0, 1),1,2)
    V_D  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(input_decoder, Vd),num_heads,axis=2)), 0, 1)

    QKscaled_decoder  = jnp.matmul(Q_D, jnp.transpose(K_D, (0, 1, 3, 2))) / jnp.sqrt(dv)
    # Step 1: Create a causal mask of shape (1, 1, 9, 9) to broadcast across heads and batch
    mask = jnp.tril(jnp.ones((words_per_phrase, words_per_phrase)))  # (9, 9) lower triangular matrix
    mask = mask.at[mask == 0].set(-jnp.inf)  # Set future tokens to -inf
    mask = mask.at[mask == 1].set(0)  # Set allowed tokens to 0
    mask = mask.reshape(1, 1, words_per_phrase, words_per_phrase)   

    # Step 2: Apply mask to QKscaled_decoder (it will broadcast across batch and heads)
    QKscaled_decoder = QKscaled_decoder + mask 

    Attention_weights_masked = softmax(QKscaled_decoder)


    A_mask = jnp.matmul(Attention_weights_masked, V_D)
    #print("A_mask.shape non concat: ",A_mask.shape)
    
    #A_mask=jnp.swapaxes(jnp.concatenate(jnp.swapaxes(A_mask,0,2),axis=-1),0,1)
    A_mask= jnp.concatenate(jnp.swapaxes(A_mask,0,1),axis=-1) 

    Xd = input_decoder + A_mask
    Dt1,mu_d,var_d,N_d = layer_norm(Xd)
    #print("A_mask.shape concat: ",A_mask.shape)
    #print("inputs_d.shape: ",input_decoder.shape)
    #print("Dt1.shape: ",Dt1.shape)
    return A_mask,Xd,Dt1,mu_d,var_d,N_d,Attention_weights_masked,Q_D,K_D,V_D
    
def cross_attention_decoder(Dt1):
    global Qc 
    Q_C  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(Dt1, Qc),num_heads,axis=2)), 0, 1)
    #print("Q_C.shape: ",Q_C.shape)
    return Q_C

def cross_attention(Q_C,K_C,V_C,Dt1):
    QKscaled_cross_attention  = jnp.matmul(Q_C, jnp.transpose(K_C , (0, 1, 3, 2)))/ jnp.sqrt(dv)
    Attention_weights_cross = softmax(QKscaled_cross_attention)
    Acr = jnp.matmul(Attention_weights_cross, V_C)
    #print("Acr.shape non concat",Acr.shape)
    Acr=jnp.concatenate(jnp.swapaxes(Acr,0,1),axis=-1)
    #print("Acr.shape concat",Acr.shape)
    Res=Acr + Dt1
    Dt2, mu_res,var_res,N_res = layer_norm(Res)
    return Dt2, mu_res,var_res,N_res,Res,Attention_weights_cross

def fully_connected_layers_decoder(Dt2):
    
    global Wfl1d,bfl1d,Wfl2d,bfl2d
     
    Xd1=jnp.matmul(Dt2,Wfl1d)+bfl1d
   
    FLd1=relu(Xd1)
     
    FLd2=jnp.matmul(FLd1,Wfl2d)+bfl2d
    #print("FLd2.shape",FLd2.shape)
  
    Xd2=FLd2+Dt2
    Dout,mu_d2,var_d2,N_d2=layer_norm(Xd2)
 
    #print("Dout.shape",Dout.shape)
    Dout=Dout.reshape(num_phrases,Dout.shape[1]*Dout.shape[2])
    #print("Dout.shape",Dout.shape)
    return Dout,mu_d2,var_d2,N_d2,Xd2,Xd1,FLd1

def output_layer(Dout):
    
    global W0,b0 
    #print(Dout.shape,W0.shape)  
    Zout=jnp.matmul(Dout,W0)+b0 
    SigmaZout = softmax(Zout) 
    print("SigmaZout.shape",SigmaZout.shape)
    
    return SigmaZout 


def loss_calculation(SigmaZout,target):
    #print("target.shape",jnp.array(target).shape)
    Loss=cross_entropy_loss(SigmaZout, target)
    print("Loss:",Loss)
    return Loss




## BackPropagation

In [8]:
def derivate_dout(SigmaZout,target,Dout):
    global W0
    dLoss_dZout=SigmaZout-target
    #print("dLoss_dZout.shape",dLoss_dZout.shape)
    dLoss_W0=jnp.transpose(dLoss_dZout,(1,0))@Dout
    #print("dLoss_W0.shape",dLoss_W0.shape,"W0.shape",W0.shape)
    dLoss_b0=jnp.sum(dLoss_dZout, axis=0)
    #print("dLoss_b0.shape",dLoss_b0.shape,"b0.shape",b0.shape)
    dLoss_Dout=dLoss_dZout@W0.T
    dLoss_Dout=dLoss_Dout.reshape(num_phrases,words_per_phrase,embedding_size)
    #print("dLoss_Dout.shape",dLoss_Dout.shape)
    return dLoss_Dout,dLoss_W0,dLoss_b0


def derivate_fully_connected_layers_decoder(dLoss_Dout,Dt2,Xd2,var_d2,mu_d2,N_d2,Wfl2d,FLd1,Xd1):
 
    dLoss_FLd2=dLoss_Dout*diff_norm(Xd2,var_d2,mu_d2,N_d2)
    #print("dLoss_FLd2.shape",dLoss_FLd2.shape) 
    dLoss_Dt2_a=dLoss_FLd2
    #print("dLoss_Dt2_a.shape",dLoss_Dt2_a.shape) 
    #print("Dt2.shape",Dt2.shape) 
    dLoss_FLd1=dLoss_FLd2@jnp.transpose(Wfl2d,(1,0))
    #print("dLoss_FLd1.shape",dLoss_FLd1.shape) 
    #print("FLd1.shape",FLd1.shape) 
    dLoss_Wfl2d=jnp.sum(jnp.transpose(dLoss_FLd2,(0,2,1))@FLd1,axis=0)
    #print("dLoss_Wfl2d.shape",dLoss_Wfl2d.shape) # do the mean here over each phrase
    #print("Wfl2d.shape",Wfl2d.shape) 
    dLoss_bfl2d=jnp.sum(jnp.sum(dLoss_FLd2, axis=0),axis=0)
    #print("dLoss_bfl2d.shape",dLoss_bfl2d.shape) # do the mean here over each phrase
    #print("bfl2d.shape",bfl2d.shape) 
    if Xd1.all()>0:
        DLoss_Dt2_b=dLoss_FLd1@jnp.transpose(Wfl1d,(1,0))
    else:
        DLoss_Dt2_b=0
    DLoss_Dt2=dLoss_Dt2_a+DLoss_Dt2_b
    #print("DLoss_Dt2.shape",DLoss_Dt2.shape) # do the mean here over each phrase
    #print("Dt2.shape",Dt2.shape) 
    if Xd1.all()>0:
        dLoss_Wfl1d=jnp.sum(jnp.transpose(dLoss_FLd1,(0,2,1))@Dt2,axis=0)
    else:
        dLoss_Wfl1d=0
    #print("dLoss_Wfl1d.shape",dLoss_Wfl1d.shape) # do the mean here over each phrase
    #print("Wfl1d.shape",Wfl1d.shape) 
    if Xd1.all()>0:
        dLoss_bfl1d=jnp.sum(jnp.sum(dLoss_FLd1,axis=0),axis=0)
    else:
        dLoss_bfl1d=0
    return dLoss_Wfl2d,dLoss_bfl2d,dLoss_Wfl1d,dLoss_bfl1d,DLoss_Dt2

def derivative_cross_attention(dLoss_Dt2,Res,var_res,mu_res,N_res,Attention_weights_cross,K_C,V_C,Q_C,Ecout,Dt1):
    #print("dLoss_bfl1d.shape",dLoss_bfl1d.shape) # do the mean here over each phrase
    #print("bfl1d.shape",bfl1d.shape) 
    
    dLoss_Acr=dLoss_Dt2*diff_norm(Res,var_res,mu_res,N_res)
    #print("dLoss_Acr.shape",dLoss_Acr.shape) # do the mean here over each phrase
    #print("Acr.shape",Acr.shape) 
    dLoss_Dt1_a=dLoss_Dt2*diff_norm(Res,var_res,mu_res,N_res)
    #print("dLoss_Dt1.shape",dLoss_Dt1_a.shape) # do the mean here over each phrase
    #print("Dt1.shape",Dt1.shape) 
    dLoss_Qc=diffQKV(dLoss_Acr,Attention_weights_cross,K_C,V_C,Dt1,dk)
    #print("dLoss_dQc.shape",dLoss_Qc.shape) # do the mean here over each phrase
    #print("Qc.shape",Qc.shape) 
    dLoss_Kc=diffQKV(dLoss_Acr,Attention_weights_cross,Q_C,V_C,Ecout,dk)
    #print("dLoss_dKc.shape",dLoss_Kc.shape) # do the mean here over each phrase
    #print("Kc.shape",Kc.shape) 
    dLoss_Vc=f=np.sum(np.mean(np.transpose(np.expand_dims(dLoss_Acr, axis=1),(0,1,3,2))@(Attention_weights_cross@np.expand_dims(Ecout, axis=1)),axis=1),axis=0)
    #print("dLoss_dVc.shape",dLoss_Vc.shape) # do the mean here over each phrase
    #print("Vc.shape",Vc.shape) 
    return dLoss_Qc,dLoss_Kc,dLoss_Vc,Attention_weights_cross,dLoss_Dt1_a,dLoss_Acr
    
def derivative_attention_decoder(dLoss_Acr,Attention_weights_cross,dLoss_Dt1_a,Attention_weights_masked,Q_D,V_D,K_D,K_C,V_C,Qc,Xd,var_d,mu_d,N_d,input_d):
    
    dAttention_weights_cross=Attention_weights_cross*(1-Attention_weights_cross)
    V1=redimension(dAttention_weights_cross@K_C/jnp.sqrt(dk)) 

    V2=redimension(V_C)

    V3=V1*V2@Qc
    dLoss_Dt1_b=dLoss_Acr*V3
    #print("dLoss_Dt1_b.shape",dLoss_Dt1_b.shape) # do the mean here over each phrase
    #print("dLoss_Dt1_a.shape",dLoss_Dt1_a.shape) 
    dLoss_Dt1=dLoss_Dt1_a+dLoss_Dt1_b
    dLoss_Amask=dLoss_Dt1*diff_norm(Xd,var_d,mu_d,N_d)
    #print("dLoss_DAmask.shape",dLoss_Amask.shape)  
    dLoss_inputd_a=dLoss_Amask
    #print("dLoss_Dinputd_a.shape",dLoss_inputd_a.shape) 
    dLoss_Kd=diffQKV(dLoss_Amask,Attention_weights_masked,Q_D,V_D,input_d,dk) 
    #print("dLoss_Kd.shape",dLoss_Kd.shape) 
    dLoss_Qd=diffQKV(dLoss_Amask,Attention_weights_masked,K_D,V_D,input_d,dk) 
    #print("dLoss_Qd.shape",dLoss_Qd.shape) 
    dLoss_Vd=f=np.sum(np.mean(np.transpose(np.expand_dims(dLoss_Amask, axis=1),(0,1,3,2))@(Attention_weights_masked@np.expand_dims(input_d, axis=1)),axis=1),axis=0)
    return dLoss_Kd,dLoss_Qd,dLoss_Vd,dLoss_inputd_a,dLoss_Amask


def derivative_input_decoder(dLoss_Amask,Attention_weights_masked,K_D,V_D,Q_D,dLoss_inputd_a,input_d):
    global Qd,Kd,Vd
    dLoss_V_D=np.transpose(np.mean(np.transpose(np.expand_dims(dLoss_Amask, axis=1),(0,1,3,2))@Attention_weights_masked,axis=1),(0,2,1))
    dLoss_V_D.shape
    dLoss_inputd_v=dLoss_V_D@Vd

    # print("dLoss_inputd_v.shape",dLoss_inputd_v.shape) # do the mean here over each phrase
    # print("input_d.shape",input_d.shape) 

    dAttention_weights_masked=Attention_weights_masked*(1-Attention_weights_masked)
    V1=redimension(dAttention_weights_masked@K_D/jnp.sqrt(dk)) 
    V2=redimension(V_D)
    V3=V1*V2
    dLoss_Q_D=dLoss_Amask*V3
    dLoss_Q_D.shape
    dLoss_inputd_q=dLoss_Q_D@Qd
    #print("dLoss_inputd_q.shape",dLoss_inputd_q.shape)
 
    V1=redimension(dAttention_weights_masked@Q_D/jnp.sqrt(dk)) 
    V2=redimension(V_D)
    V3=V1*V2
    dLoss_K_D=dLoss_Amask*V3 
    dLoss_inputd_k=dLoss_K_D@Kd
    #print("dLoss_inputd_k.shape",dLoss_inputd_k.shape)
    dLoss_inputd=dLoss_inputd_a+dLoss_inputd_k+dLoss_inputd_q+dLoss_inputd_v

    dLoss_dWemb_decoder=dLoss_inputd*input_d
    return dLoss_inputd,dLoss_dWemb_decoder
    




In [9]:
def derivative_Ecout(Attention_weights_cross,dLoss_Acr,Q_C,V_C):
    global Kc,Vc
    dAttention_weights_cross=Attention_weights_cross*(1-Attention_weights_cross)
    V1=redimension(dAttention_weights_cross@Q_C/jnp.sqrt(dk)) 

    V2=redimension(V_C)

    V3=V1*V2

    
    dLoss_K_C=dLoss_Acr*V3
    dLoss_K_C.shape

    dLoss_Ecout_k=dLoss_K_C@Kc
    #print("dLoss_Ecout_k.shape",dLoss_Ecout_k.shape) 

    dLoss_V_C=np.transpose(np.mean(np.transpose(np.expand_dims(dLoss_Acr, axis=1),(0,1,3,2))@Attention_weights_cross,axis=1),(0,2,1))
    dLoss_V_C.shape
    dLoss_Ecout_v=dLoss_V_C@Vc

    #print("dLoss_Ecout_v.shape",dLoss_Ecout_v.shape) # do the mean here over each phrase
    dLoss_Ecout=dLoss_Ecout_k+dLoss_Ecout_v
    return dLoss_Ecout

def derivate_fully_connected_layers_encoder(dLoss_Ecout,Ect1,Xe2,var_e2,mu_e2,N_e2,FLe1,Xe1):
    global Wfl2e,Wfl1e
    dLoss_dFLe2=dLoss_Ecout*diff_norm(Xe2,var_e2,mu_e2,N_e2)
    dLoss_Ect1_a=dLoss_dFLe2
    #print(Wfl2e.shape)
    dLoss_dFLe1=dLoss_dFLe2@jnp.transpose(Wfl2e,(1,0))
    dLoss_dWfl2e=jnp.transpose(dLoss_dFLe2,(0,2,1))@FLe1
    #print(dLoss_dWfl2e)
    dLoss_dbfl2e=jnp.sum(dLoss_dFLe2,axis=1) 
    if Xe1.all()>0:
        dLoss_Ect1_b=dLoss_dFLe1@jnp.transpose(Wfl1e,(1,0))
    else:
        dLoss_Ect1_b=0

    dLoss_Ect1=dLoss_Ect1_b+dLoss_Ect1_a
    if Xe1.all()>0:
        dLoss_Wfl1e=jnp.transpose(dLoss_dFLe1,(0,2,1))@Ect1
    else:
        dLoss_Wfl1e=0
 
    if Xe1.all()>0:
        dLoss_bfl1e=jnp.transpose(dLoss_dFLe1,(0,2,1)) 
    else:
        dLoss_bfl1e=0
 
    return dLoss_dWfl2e,dLoss_dbfl2e,dLoss_Wfl1e,dLoss_bfl1e,dLoss_Ect1

 
def derivative_attention_encoder(dLoss_Ect1,Xe,var_e,mu_e,Ne,Attention_weights_e,K_E,V_E,Q_E):
    global inputs_e
    dLoss_Ae=dLoss_Ect1*diff_norm(Xe,var_e,mu_e,Ne)
    
    dLoss_inpute_a=dLoss_Ae
   
    dLoss_dQe=diffQKV(dLoss_Ae,Attention_weights_e,K_E,V_E,inputs_e,dk)
    #print("dLoss_dQe.shape",dLoss_dQe.shape) # do the mean here over each phrase
    #print("Qe.shape",Qe.shape) 
    dLoss_dKe=diffQKV(dLoss_Ae,Attention_weights_e,Q_E,V_E,inputs_e,dk)
    #print("dLoss_dKe.shape",dLoss_dKe.shape) # do the mean here over each phrase
    #print("Ke.shape",Ke.shape) 
    dLoss_dVe=f=np.sum(np.sum(np.transpose(np.expand_dims(dLoss_Ae, axis=1),(0,1,3,2))@(Attention_weights_e@np.expand_dims(inputs_e, axis=1)),axis=1),axis=0)
    #print("dLoss_dVe.shape",dLoss_dVe.shape) # do the mean here over each phrase
    return dLoss_dQe,dLoss_dKe,dLoss_dVe,dLoss_inpute_a,dLoss_Ae


def derivative_input_encoder(dLoss_Ae,Attention_weights_e,K_E,V_E,Q_E,dLoss_inpute_a):
    global Ve,Qe,Ke,inputs_e
    dLoss_V_E=np.transpose(np.sum(np.transpose(np.expand_dims(dLoss_Ae, axis=1),(0,1,3,2))@Attention_weights_e,axis=1),(0,2,1))
    dLoss_inpute_v=dLoss_V_E@Ve

  
    dAttention_weights_e=Attention_weights_e*(1-Attention_weights_e)
    V1=redimension(dAttention_weights_e@K_E/jnp.sqrt(dk))  
    V2=redimension(V_E) 
    V3=V1*V2 
    dLoss_Q_E=dLoss_Ae*V3  
    dLoss_inpute_q=dLoss_Q_E@Qe
    #print("dLoss_inpute_q.shape",dLoss_inpute_q.shape)
    
    V1=redimension(dAttention_weights_e@Q_E/jnp.sqrt(dk))  
    V2=redimension(V_E) 
    V3=V1*V2 
    dLoss_K_E=dLoss_Ae*V3  
    dLoss_inpute_k=dLoss_K_E@Ke
    #print("dLoss_inpute_k.shape",dLoss_inpute_k.shape)
    dLoss_inpute=dLoss_inpute_a+dLoss_inpute_k+dLoss_inpute_q+dLoss_inpute_v
    dLoss_dWemb_encoder=dLoss_inpute*inputs_e
    return dLoss_inpute,dLoss_dWemb_encoder





## Update

In [12]:
def print_vocabs(ans): 
    for idx, values in enumerate(ans):
        max_index = np.argmax(values)
        
        # Step 2: Find the word in the vocabulary with the corresponding position
        matched_word = None
        for word, (_, position) in vocabulary.items():
            if position == max_index:
                matched_word = word
                break
        print(f"List {idx + 1}: Max value index: {max_index}, Matched word: {matched_word}")

In [13]:
Ae,Xe,Ect1,mu_e,var_e,Ne,Attention_weights_e,K_E,V_E,Q_E=forward_attention_encoder(inputs_e) 
Ecout,mu_e2,var_e2,N_e2,FLe1,Xe1,Xe2=fully_connected_layers_encoder(Ect1)
K_C,V_C=cross_attention_encoder(Ecout)

step=0
learning_rate=0.001
for epochs in range(0,50):
    for step in range(inputs_d.shape[0]):
        inputs_decoder=inputs_d[step]
        target=targets_d[step]
        A_mask,Xd,Dt1,mu_d,var_d,N_d,Attention_weights_masked,Q_D,K_D,V_D=forward_attention_decoder(inputs_decoder)
        Q_C=cross_attention_decoder(Dt1)
        Dt2, mu_res,var_res,N_res,Res,Attention_weights_cross=cross_attention(Q_C,K_C,V_C,Dt1)
        Dout,mu_d2,var_d2,N_d2,Xd2,Xd1,FLd1=fully_connected_layers_decoder(Dt2)
        SigmaZout=output_layer(Dout)
        #print(SigmaZout)
        #print_vocabs(SigmaZout)
        Loss=loss_calculation(SigmaZout,target)
        dLoss_Dout,dLoss_W0,dLoss_b0=derivate_dout(SigmaZout,target,Dout)
        dLoss_Wfl2d,dLoss_bfl2d,dLoss_Wfl1d,dLoss_bfl1d,DLoss_Dt2=derivate_fully_connected_layers_decoder(dLoss_Dout,Dt2,Xd2,var_d2,mu_d2,N_d2,Wfl2d,FLd1,Xd1)
        dLoss_Qc,dLoss_Kc,dLoss_Vc,Attention_weights_cross,dLoss_Dt1_a,dLoss_Acr=derivative_cross_attention(DLoss_Dt2,Res,var_res,mu_res,N_res,Attention_weights_cross,K_C,V_C,Q_C,Ecout,Dt1)
        dLoss_Kd,dLoss_Qd,dLoss_Vd,dLoss_inputd_a,dLoss_Amask=derivative_attention_decoder(dLoss_Acr,Attention_weights_cross,dLoss_Dt1_a,Attention_weights_masked,Q_D,V_D,K_D,K_C,V_C,Qc,Xd,var_d,mu_d,N_d,inputs_decoder)
        dLoss_inputd,dLoss_dWemb_decoder=derivative_input_decoder(dLoss_Amask,Attention_weights_masked,K_D,V_D,Q_D,dLoss_inputd_a,inputs_decoder)
        
        W0=W0-learning_rate*dLoss_W0.T
        b0=b0-learning_rate*dLoss_b0
        Wfl2d=Wfl2d-learning_rate*dLoss_Wfl2d.T
        bfl2d=bfl2d-learning_rate*dLoss_bfl2d
        Wfl1d=Wfl1d-learning_rate*dLoss_Wfl1d.T
        bfl1d=bfl1d-learning_rate*dLoss_bfl1d
        Qc=Qc-learning_rate*dLoss_Qc
        Kc=Kc-learning_rate*dLoss_Kc
        Vc=Vc-learning_rate*dLoss_Vc
        Qd=Qd-learning_rate*dLoss_Qd
        Kd=Kd-learning_rate*dLoss_Kd
        Vd=Vd-learning_rate*dLoss_Vd
        inputs_d=inputs_d-learning_rate*dLoss_dWemb_decoder
        #print(input_d)

    dLoss_Ecout=derivative_Ecout(Attention_weights_cross,dLoss_Acr,Q_C,V_C)
    dLoss_dWfl2e,dLoss_dbfl2e,dLoss_Wfl1e,dLoss_bfl1e,dLoss_Ect1=derivate_fully_connected_layers_encoder(dLoss_Ecout,Ect1,Xe2,var_e2,mu_e2,N_e2,FLe1,Xe1)
    dLoss_dQe,dLoss_dKe,dLoss_dVe,dLoss_inpute_a,dLoss_Ae=derivative_attention_encoder(dLoss_Ect1,Xe,var_e,mu_e,Ne,Attention_weights_e,K_E,V_E,Q_E)
    dLoss_inpute,dLoss_dWemb_encoder=derivative_input_encoder(dLoss_Ae,Attention_weights_e,K_E,V_E,Q_E,dLoss_inpute_a)
    Wfl2e=Wfl2e-learning_rate*jnp.sum(jnp.transpose(dLoss_dWfl2e ,(0,2,1)),axis=0) 
    bfl2e=bfl2e-learning_rate*bfl2e
    Wfl1e=Wfl1e-learning_rate*jnp.sum(jnp.transpose(dLoss_Wfl1e ,(0,2,1)),axis=0) 
    bfl1e=bfl1e-learning_rate*bfl1e
    Qe=Qe-learning_rate*dLoss_dQe
    Ke=Ke-learning_rate*dLoss_dKe
    Ve=Ve-learning_rate*dLoss_dVe
    inputs_e=inputs_e-learning_rate*dLoss_dWemb_encoder

SigmaZout.shape (4, 40)
List 1: Max value index: 4, Matched word: salsa
List 2: Max value index: 4, Matched word: salsa
List 3: Max value index: 4, Matched word: salsa
List 4: Max value index: 4, Matched word: salsa
Loss: 6.0514264
SigmaZout.shape (4, 40)
List 1: Max value index: 5, Matched word: my
List 2: Max value index: 4, Matched word: salsa
List 3: Max value index: 5, Matched word: my
List 4: Max value index: 4, Matched word: salsa
Loss: 4.1016655
SigmaZout.shape (4, 40)
List 1: Max value index: 5, Matched word: my
List 2: Max value index: 5, Matched word: my
List 3: Max value index: 5, Matched word: my
List 4: Max value index: 4, Matched word: salsa
Loss: 4.0675735
SigmaZout.shape (4, 40)
List 1: Max value index: 5, Matched word: my
List 2: Max value index: 5, Matched word: my
List 3: Max value index: 5, Matched word: my
List 4: Max value index: 4, Matched word: salsa
Loss: 4.849557
SigmaZout.shape (4, 40)
List 1: Max value index: 5, Matched word: my
List 2: Max value index: 4, 

  dLoss_dVe=f=np.sum(np.sum(np.transpose(np.expand_dims(dLoss_Ae, axis=1),(0,1,3,2))@(Attention_weights_e@np.expand_dims(inputs_e, axis=1)),axis=1),axis=0)


SigmaZout.shape (4, 40)
List 1: Max value index: 35, Matched word: [PAD]
List 2: Max value index: 35, Matched word: [PAD]
List 3: Max value index: 35, Matched word: [PAD]
List 4: Max value index: 35, Matched word: [PAD]
Loss: 4.5799375
SigmaZout.shape (4, 40)
List 1: Max value index: 35, Matched word: [PAD]
List 2: Max value index: 35, Matched word: [PAD]
List 3: Max value index: 35, Matched word: [PAD]
List 4: Max value index: 35, Matched word: [PAD]
Loss: 4.2859097
SigmaZout.shape (4, 40)
List 1: Max value index: 35, Matched word: [PAD]
List 2: Max value index: 35, Matched word: [PAD]
List 3: Max value index: 35, Matched word: [PAD]
List 4: Max value index: 35, Matched word: [PAD]
Loss: 3.135861
SigmaZout.shape (4, 40)
List 1: Max value index: 35, Matched word: [PAD]
List 2: Max value index: 35, Matched word: [PAD]
List 3: Max value index: 35, Matched word: [PAD]
List 4: Max value index: 7, Matched word: .
Loss: 3.2769718
SigmaZout.shape (4, 40)
List 1: Max value index: 11, Matched w

In [41]:
vocabulary["[PAD]"]

(Array([0.15, 0.0448, 0.895, 0.0183], dtype=float32), 14)

In [33]:
ans=[[0.0166,0.0516,0.0201,0.028,0.00164,0.0517,0.101,0.0243],
     [0.0155,0.0516,0.0924,0.0228,0.00332,0.0361,0.00214,0.00615],
     [0.000674,0.133,0.00307,0.0113,0.0202,0.00332,0.0361,0.00615]]
     # 0.0155 0.00615 0.00332 0.0301 0.00986 0.0259 0.0924 0.00706 0.00312 0.0361 0.00214 0.0168 0.0971 0.00729 0.00114 0.00959 0.00401 0.0203 0.0203 6.59e-05 0.00569 0.0273 0.0326 0.0402 0.00848 0.00147 0.0134 0.000674 0.133 0.00307 0.0113 0.0202]

In [32]:
np.argmax(ans)
for word, (_, position) in vocab.items():
    if position == max_index:
        matched_word = word
        break

6

In [31]:
vocabulary

{'è': (Array([0.973, 0.511, 0.271, 0.473], dtype=float32), 0),
 'salsa': (Array([0.359, 0.784, 0.146, 0.976], dtype=float32), 1),
 'l': (Array([0.277, 0.0115, 0.537, 0.443], dtype=float32), 2),
 '[END]': (Array([0.599, 0.786, 0.261, 0.259], dtype=float32), 3),
 'cute': (Array([0.663, 0.166, 0.368, 0.155], dtype=float32), 4),
 '?': (Array([0.768, 0.767, 0.636, 0.119], dtype=float32), 5),
 'amo': (Array([0.654, 0.609, 0.00815, 0.674], dtype=float32), 6),
 'di': (Array([0.821, 0.0597, 0.312, 0.0862], dtype=float32), 7),
 'la': (Array([0.871, 0.638, 0.294, 0.107], dtype=float32), 8),
 'you': (Array([0.612, 0.358, 0.977, 0.175], dtype=float32), 9),
 'know': (Array([0.211, 0.299, 0.456, 0.19], dtype=float32), 10),
 'dog': (Array([0.538, 0.667, 0.276, 0.191], dtype=float32), 11),
 'soia': (Array([0.714, 0.932, 0.781, 0.0637], dtype=float32), 12),
 'buono': (Array([0.675, 0.268, 0.23, 0.118], dtype=float32), 13),
 '[PAD]': (Array([0.15, 0.0448, 0.895, 0.0183], dtype=float32), 14),
 'are': (Arr

In [15]:
targets_d[0]

Array([[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
       [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=float32)