# Positional encoding

In [22]:
import numpy as np 
import jax.numpy as jnp
def get_positional_encoding(seq_len, d_model):
    """
    Returns a non-learnable (sinusoidal) positional encoding.
    
    seq_len: Length of the input sequence.
    d_model: Dimension of the embeddings.
    """
    pos = np.arange(seq_len)[:, np.newaxis]  # Shape: [seq_len, 1]
    i = np.arange(d_model)[np.newaxis, :]    # Shape: [1, d_model]

    angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))

    # Apply sine to even indices, cosine to odd indices
    pos_encoding = np.zeros((seq_len, d_model))
    pos_encoding[:, 0::2] = np.sin(pos * angle_rates[:, 0::2])  # sine on even indices
    pos_encoding[:, 1::2] = np.cos(pos * angle_rates[:, 1::2])  # cosine on odd indices

    return pos_encoding

def softmax(x, axis=-1):
    # Subtract the max value for numerical stability
    e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return e_x / np.sum(e_x, axis=axis, keepdims=True)
def layer_norm(x, epsilon=1e-6):
    # Calculate the mean and variance
        mean = jnp.mean(x, axis=-1, keepdims=True)
        var = jnp.var(x, axis=-1, keepdims=True) 
        
        # Normalize the output
        x_norm = (x - mean) / jnp.sqrt(var + epsilon) 
        #print(x)
        #print(mean)
        #print("mean",mean.shape)
        #print("x_norm.shape",x_norm.shape)
        return x_norm,mean,var,x.shape[-1]
def relu(x):
    return np.maximum(0, x)

def pad_sequence(seq, max_len, pad_value=0):
    """Pad a sequence with a given value up to max_len."""
    current_len = seq.shape[0]
    pad_width = max_len - current_len
    if pad_width > 0:
        # Pad sequence with zeros (or any pad_value you provide)
        seq = jnp.pad(seq, ((0, pad_width), (0, 0)), mode='constant', constant_values=pad_value)
    return seq

def create_timestaped_input(input_d,words_per_phrase):
    input_translation=[]
    for j in range(input_d.shape[0]):
    # Create padded sequences
        padded_sequences = [pad_sequence(input_d[j][0:i], words_per_phrase) for i in range(1, input_d.shape[1] + 1)]
        input_translation.append(padded_sequences)
    return jnp.array(input_translation)

def cross_entropy_loss(predictions, target):
    # Cross-entropy loss for a batch of predictions and targets
    batch_loss = -jnp.sum(target * jnp.log(predictions + 1e-9), axis=1)
    return jnp.mean(batch_loss)

def diff_norm(X,var,mu,N):
    epsilon=1e-6
    AA=((1-(1/N))*(1/(jnp.sqrt(var+epsilon))))
    BB=(1/N)*((X-mu)**2)
    CC=((var+epsilon)**(3/2))
    result=(AA-(BB/CC)) 
    return result

def redimension(X):
    return jnp.concatenate(jnp.swapaxes(X,0,1),axis=-1) 

def diffQKV(dAttention,Attention_weights,X1,X2,X3,dk):
    dAttention_weights=Attention_weights*(1-Attention_weights)
    V1=redimension(dAttention_weights@X1/jnp.sqrt(dk)) 
    
    V2=redimension(X2)
    
    V3=V1*V2*X3
    dLoss_dX=jnp.sum(jnp.transpose(dAttention,(0,2,1))@V3,axis=0)
    return dLoss_dX

# Summary Encoder

In [51]:

 
vocab_size=7 
num_phrases = 5
words_per_phrase = 15 
dk = dv = embedding_size = 4 # constrain of transformer all embedding size both input embedding and attention embedding are the same encoder
num_heads=2
 
pos_encoding=get_positional_encoding(words_per_phrase,embedding_size)
 
inputs_e = np.random.rand(num_phrases,words_per_phrase, embedding_size)
inputs_e=pos_encoding+inputs_e
print("inputs.shape: ",inputs_e.shape)

Qe = np.random.rand(embedding_size, dk) / jnp.sqrt(embedding_size)
Ke = np.random.rand(embedding_size, dk) / jnp.sqrt(embedding_size)
Ve = np.random.rand(embedding_size, dv) / jnp.sqrt(embedding_size)

Q_E= jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_e, Qe),num_heads,axis=2)), 0, 1)
print("Qval.shape: ",Q_E.shape)

K_E = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_e, Ke),num_heads,axis=2)), 0, 1)
print("Kval.shape: ",K_E.shape)


V_E = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_e,Ve),num_heads,axis=2)), 0, 1)
print("Vval.shape: ",V_E.shape)


QKscaled = jnp.matmul(Q_E, jnp.transpose(K_E, (0, 1, 3, 2))) / jnp.sqrt(dk)

Attention_weights_e = softmax(QKscaled)
print("Attention_weights shape:",Attention_weights_e.shape)


Ae = jnp.matmul(Attention_weights_e, V_E)
print("Attention shape:",Ae.shape)


Ae=jnp.array([jnp.concatenate(Ae[i], axis=1) for i in range(num_phrases)])
print("Attention shape concat:",Ae.shape)


 


Xe=Ae+inputs_e
Ect1,mu_e,var_e,Ne=layer_norm(Xe)
print("Ect1.shape",Ect1.shape,Ne)

fl1_size=100
Wfl1e=np.random.rand(num_phrases,dv, fl1_size)   
bfl1e=np.random.rand(num_phrases,1,fl1_size)
Xe1=jnp.matmul(Ect1,Wfl1e)+bfl1e
print("Xe1.shape",Xe1.shape)

FLe1=relu(Xe1)
print("FLe1.shape",FLe1.shape)


fl2_size=50
Wfl2e=np.random.rand(num_phrases,FLe1.shape[2], dv)   
bfl2e=np.random.rand(num_phrases,1,dv)
FLe2=jnp.matmul(FLe1,Wfl2e)+bfl2e
print("FLe2.shape",FLe2.shape)

Xe2=FLe2+Ect1
Ecout,mu_e2,var_e2,N_e2=layer_norm(Xe2)
print("Ecout.shape",Ecout.shape,N_e2)

Kc = np.random.rand(Ecout.shape[-1], dk) / jnp.sqrt(Ecout.shape[-1])
Vc = np.random.rand(Ecout.shape[-1], dv) / jnp.sqrt(Ecout.shape[-1])



K_C  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(Ecout, Kc),num_heads,axis=2)), 0, 1)
print("K_C.shape: ",K_C.shape)# shape is: num_phrases, numbheads, words_per_phrase, dv/num_heads


V_C  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(Ecout,Vc),num_heads,axis=2)), 0, 1)
print("V_C.shape: ",V_C.shape)



inputs.shape:  (5, 15, 4)
Qval.shape:  (5, 2, 15, 2)
Kval.shape:  (5, 2, 15, 2)
Vval.shape:  (5, 2, 15, 2)
Attention_weights shape: (5, 2, 15, 15)
Attention shape: (5, 2, 15, 2)
Attention shape concat: (5, 15, 4)
Ect1.shape (5, 15, 4) 4
Xe1.shape (5, 15, 100)
FLe1.shape (5, 15, 100)
FLe2.shape (5, 15, 4)
Ecout.shape (5, 15, 4) 4
K_C.shape:  (5, 2, 15, 2)
V_C.shape:  (5, 2, 15, 2)


In [52]:
pos_encoding=get_positional_encoding(words_per_phrase,dv)
 
input_d = np.random.rand(num_phrases,words_per_phrase, embedding_size)
inputs_d=input_d+pos_encoding
 
# Convert to an array for batching
inputs_d = jnp.swapaxes(create_timestaped_input(input_d,words_per_phrase),0,1)
target_d=jnp.swapaxes(create_timestaped_input(input_d,words_per_phrase),0,1)
print("inputs_d complete shape",inputs_d.shape)# shape is: words_per_phrase,num_phrases,words_per_phrase,embedding_size at each





 

Qd = np.random.rand(embedding_size, dk) / jnp.sqrt(embedding_size)
Kd = np.random.rand(embedding_size, dk) / jnp.sqrt(embedding_size)
Vd = np.random.rand(embedding_size, dv) / jnp.sqrt(embedding_size)


learning_rate=0.01
step=0

for step in range(inputs_d.shape[0]):
    #Q_D  = jnp.swapaxes(jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_d[step], Qd),num_heads,axis=3)), 0, 1),1,2)
    Q_D  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_d[step], Qd),num_heads,axis=2)), 0, 1)
    #print("Qval.shape: ",Q_D.shape)# numwords, num_phrases, numheads, num_words, dv/num_heads

    #K_D  = jnp.swapaxes(jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_d[step], Kd),num_heads,axis=3)), 0, 1),1,2)
    K_D  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_d[step], Kd),num_heads,axis=2)), 0, 1)
    #print("Kval.shape: ",K_D.shape)


    #V_D  = jnp.swapaxes(jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_d[step], Vd),num_heads,axis=3)), 0, 1),1,2)
    V_D  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_d[step], Vd),num_heads,axis=2)), 0, 1)
    #print("Vval.shape: ",V_D.shape)


    #QKscaled_decoder  = jnp.matmul(Q_D, jnp.transpose(K_D, (0, 1, 2, 4,3))) / jnp.sqrt(dv) #+ jnp.triu(jnp.ones((words_per_phrase, words_per_phrase)))* -1e9 
    QKscaled_decoder  = jnp.matmul(Q_D, jnp.transpose(K_D, (0, 1, 3, 2))) / jnp.sqrt(dv)
    # Step 1: Create a causal mask of shape (1, 1, 9, 9) to broadcast across heads and batch
    mask = jnp.tril(jnp.ones((words_per_phrase, words_per_phrase)))  # (9, 9) lower triangular matrix
    mask = mask.at[mask == 0].set(-jnp.inf)  # Set future tokens to -inf
    mask = mask.at[mask == 1].set(0)  # Set allowed tokens to 0
    mask = mask.reshape(1, 1, words_per_phrase, words_per_phrase)   

    # Step 2: Apply mask to QKscaled_decoder (it will broadcast across batch and heads)
    QKscaled_decoder = QKscaled_decoder + mask 

    Attention_weights_masked = softmax(QKscaled_decoder)


    A_mask = jnp.matmul(Attention_weights_masked, V_D)
    #print("A_mask.shape non concat: ",A_mask.shape)
    
    #A_mask=jnp.swapaxes(jnp.concatenate(jnp.swapaxes(A_mask,0,2),axis=-1),0,1)
    A_mask= jnp.concatenate(jnp.swapaxes(A_mask,0,1),axis=-1) 
    #print("A_mask.shape concat: ",A_mask.shape)
    #print("inputs_d.shape: ",inputs_d[step].shape)



    Xd = inputs_d[step] + A_mask
    Dt1,mu_d,var_d,N_d = layer_norm(Xd)
    #print("Dt1.shape",Dt1.shape)

    Qc = np.random.rand(Dt1.shape[-1], dv) / jnp.sqrt(Dt1.shape[-1])
    #print("Qc.shape",Qc.shape)

    #Q_C  = jnp.swapaxes(jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(Dt1, Qc),num_heads,axis=3)), 0, 1),1,2)
    Q_C  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(Dt1, Qc),num_heads,axis=2)), 0, 1)
    #print("Q_C.shape: ",Q_C.shape)# shape words_per_phrase,num_heads,words_per_phrase,dv/num_heads
    #print("K_C.shape: ",K_C.shape)# shape num_phrase,num_heads,words_per_phrase,dv/num_heads
    #print("V_C.shape: ",V_C.shape)# shape num_phrase,num_heads,words_per_phrase,dv/num_heads

    QKscaled_cross_attention  = jnp.matmul(Q_C, jnp.transpose(K_C , (0, 1, 3, 2)))/ jnp.sqrt(dv)
    Attention_weights_cross = softmax(QKscaled_cross_attention)
    Acr = jnp.matmul(Attention_weights_cross, V_C)
    #print("Acr.shape non concat",Acr.shape)
    Acr=jnp.concatenate(jnp.swapaxes(Acr,0,1),axis=-1)
    #print("Acr.shape concat",Acr.shape)
    Res=Acr + Dt1
    Dt2, mu_res,var_res,N_res = layer_norm(Res)  # residual_output is (9, 9, 10)
    #print("Dt2 shape:", Dt2.shape)


    fl1d_size=100
    Wfl1d=np.random.rand(Dt2.shape[-1], fl1d_size)   
    #print("Wfl1d.shape",Wfl1d.shape)
    bfl1d=np.random.rand(fl1d_size)
    #print("bfl1d.shape",bfl1d.shape)
    Xd1=jnp.matmul(Dt2,Wfl1d)+bfl1d
    #print("Xd1.shape",Xd1.shape)

    FLd1=relu(Xd1)
    #print("FLe1.shape",FLd1.shape)


    Wfl2d=np.random.rand(FLd1.shape[-1], dv)   
    #print("Wfl2d.shape",Wfl2d.shape)
    bfl2d=np.random.rand(dv)
    #print("bfl2d.shape",bfl2d.shape)
    FLd2=jnp.matmul(FLd1,Wfl2d)+bfl2d
    #print("FLd2.shape",FLd2.shape)



    Xd2=FLd2+Dt2
    Dout,mu_d2,var_d2,N_d2=layer_norm(Xd2)
    Dout.shape
    #print("Dout.shape",Dout.shape)
    Dout=Dout.reshape(num_phrases,Dout.shape[1]*Dout.shape[2])
    #print("Dout.shape concat",Dout.shape)



    W0=np.random.rand(Dout.shape[-1],vocab_size)   
    b0=np.random.rand(vocab_size)
    Zout=jnp.matmul(Dout,W0)+b0
    #print("Zout.shape",Zout.shape)
    SigmaZout = softmax(Zout) 
    #print("SigmaZout.shape",SigmaZout.shape)
    SigmaZout 

    target_d=np.random.rand(words_per_phrase,num_phrases,vocab_size)
    target_d.shape,target_d[step].shape
    print("Loss:",cross_entropy_loss(SigmaZout, target_d[step]))

 

    ##################################################################backpropagatation
    dLoss_dZout=SigmaZout-target_d[step]
    #print("dLoss_dZout.shape",dLoss_dZout.shape)
    dLoss_W0=jnp.transpose(dLoss_dZout,(1,0))@Dout
    #print("dLoss_W0.shape",dLoss_W0.shape,"W0.shape",W0.shape)
    dLoss_b0=jnp.sum(dLoss_dZout, axis=0)
    #print("dLoss_b0.shape",dLoss_b0.shape,"b0.shape",b0.shape)
    dLoss_Dout=dLoss_dZout@W0.T
    dLoss_Dout=dLoss_Dout.reshape(num_phrases,words_per_phrase,embedding_size)
    #print("dLoss_Dout.shape",dLoss_Dout.shape)
    #print("Dout.shape",Dout.shape)
    dLoss_FLd2=dLoss_Dout*diff_norm(Xd2,var_d2,mu_d2,N_d2)
    #print("dLoss_FLd2.shape",dLoss_FLd2.shape) 
    dLoss_Dt2_a=dLoss_FLd2
    #print("dLoss_Dt2_a.shape",dLoss_Dt2_a.shape) 
    #print("Dt2.shape",Dt2.shape) 
    dLoss_FLd1=dLoss_FLd2@jnp.transpose(Wfl2d,(1,0))
    #print("dLoss_FLd1.shape",dLoss_FLd1.shape) 
    #print("FLd1.shape",FLd1.shape) 
    dLoss_Wfl2d=jnp.sum(jnp.transpose(dLoss_FLd2,(0,2,1))@FLd1,axis=0)
    #print("dLoss_Wfl2d.shape",dLoss_Wfl2d.shape) # do the mean here over each phrase
    #print("Wfl2d.shape",Wfl2d.shape) 
    dLoss_bfl2d=jnp.sum(jnp.sum(dLoss_FLd2, axis=0),axis=0)
    #print("dLoss_bfl2d.shape",dLoss_bfl2d.shape) # do the mean here over each phrase
    #print("bfl2d.shape",bfl2d.shape) 
    if Xd1.all()>0:
        DLoss_Dt2_b=dLoss_FLd1@jnp.transpose(Wfl1d,(1,0))
    else:
        DLoss_Dt2_b=0
    DLoss_Dt2=dLoss_Dt2_a+DLoss_Dt2_b
    #print("DLoss_Dt2.shape",DLoss_Dt2.shape) # do the mean here over each phrase
    #print("Dt2.shape",Dt2.shape) 
    if Xd1.all()>0:
        dLoss_Wfl1d=jnp.sum(jnp.transpose(dLoss_FLd1,(0,2,1))@Dt2,axis=0)
    else:
        dLoss_Wfl1d=0
    #print("dLoss_Wfl1d.shape",dLoss_Wfl1d.shape) # do the mean here over each phrase
    #print("Wfl1d.shape",Wfl1d.shape) 
    if Xd1.all()>0:
        dLoss_bfl1d=jnp.sum(jnp.sum(dLoss_FLd1,axis=0),axis=0)
    else:
        dLoss_bfl1d=0
    #print("dLoss_bfl1d.shape",dLoss_bfl1d.shape) # do the mean here over each phrase
    #print("bfl1d.shape",bfl1d.shape) 
    dLoss_Dt2=dLoss_Dt2_a+DLoss_Dt2_b
    dLoss_Dt2.shape,diff_norm(Res,var_res,mu_res,N_res).shape 
    dLoss_Acr=dLoss_Dt2*diff_norm(Res,var_res,mu_res,N_res)
    #print("dLoss_Acr.shape",dLoss_Acr.shape) # do the mean here over each phrase
    #print("Acr.shape",Acr.shape) 
    dLoss_Dt1_a=dLoss_Dt2*diff_norm(Res,var_res,mu_res,N_res)
    #print("dLoss_Dt1.shape",dLoss_Dt1_a.shape) # do the mean here over each phrase
    #print("Dt1.shape",Dt1.shape) 
    dLoss_Qc=diffQKV(dLoss_Acr,Attention_weights_cross,K_C,V_C,Dt1,dk)
    #print("dLoss_dQc.shape",dLoss_Qc.shape) # do the mean here over each phrase
    #print("Qc.shape",Qc.shape) 
    dLoss_Kc=diffQKV(dLoss_Acr,Attention_weights_cross,Q_C,V_C,Ecout,dk)
    #print("dLoss_dKc.shape",dLoss_Kc.shape) # do the mean here over each phrase
    #print("Kc.shape",Kc.shape) 
    dLoss_Vc=f=np.sum(np.mean(np.transpose(np.expand_dims(dLoss_Acr, axis=1),(0,1,3,2))@(Attention_weights_cross@np.expand_dims(Ecout, axis=1)),axis=1),axis=0)
    #print("dLoss_dVc.shape",dLoss_Vc.shape) # do the mean here over each phrase
    #print("Vc.shape",Vc.shape) 

    dAttention_weights_cross=Attention_weights_cross*(1-Attention_weights_cross)
    V1=redimension(dAttention_weights_cross@K_C/jnp.sqrt(dk)) 

    V2=redimension(V_C)

    V3=V1*V2@Qc
    dLoss_Dt1_b=dLoss_Acr*V3
    #print("dLoss_Dt1_b.shape",dLoss_Dt1_b.shape) # do the mean here over each phrase
    #print("dLoss_Dt1_a.shape",dLoss_Dt1_a.shape) 
    dLoss_Dt1=dLoss_Dt1_a+dLoss_Dt1_b
    dLoss_Amask=dLoss_Dt1*diff_norm(Xd,var_d,mu_d,N_d)
    #print("dLoss_DAmask.shape",dLoss_Amask.shape)  
    dLoss_inputd_a=dLoss_Amask
    #print("dLoss_Dinputd_a.shape",dLoss_inputd_a.shape) 
    dLoss_Kd=diffQKV(dLoss_Amask,Attention_weights_masked,Q_D,V_D,input_d,dk) 
    #print("dLoss_Kd.shape",dLoss_Kd.shape) 
    dLoss_Qd=diffQKV(dLoss_Amask,Attention_weights_masked,K_D,V_D,input_d,dk) 
    #print("dLoss_Qd.shape",dLoss_Qd.shape) 
    dLoss_Vd=f=np.sum(np.mean(np.transpose(np.expand_dims(dLoss_Amask, axis=1),(0,1,3,2))@(Attention_weights_masked@np.expand_dims(input_d, axis=1)),axis=1),axis=0)
    #print("dLoss_Vd.shape",dLoss_Vd.shape) # do the mean here over each phrase
    #print("Vd.shape",Vd.shape) 
    dLoss_V_D=np.transpose(np.mean(np.transpose(np.expand_dims(dLoss_Amask, axis=1),(0,1,3,2))@Attention_weights_masked,axis=1),(0,2,1))
    dLoss_V_D.shape
    dLoss_inputd_v=dLoss_V_D@Vd

    # print("dLoss_inputd_v.shape",dLoss_inputd_v.shape) # do the mean here over each phrase
    # print("input_d.shape",input_d.shape) 

    dAttention_weights_masked=Attention_weights_masked*(1-Attention_weights_masked)
    V1=redimension(dAttention_weights_masked@K_D/jnp.sqrt(dk)) 
    V2=redimension(V_D)
    V3=V1*V2
    dLoss_Q_D=dLoss_Amask*V3
    dLoss_Q_D.shape
    dLoss_inputd_q=dLoss_Q_D@Qd
    #print("dLoss_inputd_q.shape",dLoss_inputd_q.shape)

    dAttention_weights_masked=Attention_weights_masked*(1-Attention_weights_masked)
    V1=redimension(dAttention_weights_masked@Q_D/jnp.sqrt(dk)) 
    V2=redimension(V_D)
    V3=V1*V2
    dLoss_K_D=dLoss_Amask*V3
    dLoss_K_D.shape
    dLoss_inputd_k=dLoss_K_D@Kd
    #print("dLoss_inputd_k.shape",dLoss_inputd_k.shape)
    dLoss_inputd=dLoss_inputd_a+dLoss_inputd_k+dLoss_inputd_q+dLoss_inputd_v

    dLoss_dWemb_decoder=dLoss_inputd*input_d




    dAttention_weights_cross=Attention_weights_cross*(1-Attention_weights_cross)
    V1=redimension(dAttention_weights_cross@Q_C/jnp.sqrt(dk)) 

    V2=redimension(V_C)

    V3=V1*V2

    
    dLoss_K_C=dLoss_Acr*V3
    dLoss_K_C.shape

    dLoss_Ecout_k=dLoss_K_C@Kc
    #print("dLoss_Ecout_k.shape",dLoss_Ecout_k.shape) 

    dLoss_V_C=np.transpose(np.mean(np.transpose(np.expand_dims(dLoss_Acr, axis=1),(0,1,3,2))@Attention_weights_cross,axis=1),(0,2,1))
    dLoss_V_C.shape
    dLoss_Ecout_v=dLoss_V_C@Vc

    #print("dLoss_Ecout_v.shape",dLoss_Ecout_v.shape) # do the mean here over each phrase
    dLoss_Ecout=dLoss_Ecout_k+dLoss_Ecout_v

    dLoss_Ecout_k=dLoss_Kc*Kc 
    dLoss_dFLe2=dLoss_Ecout*diff_norm(Xe2,var_e2,mu_e2,N_e2)
    dLoss_Ect1_a=dLoss_dFLe2
    dLoss_dFLe1=dLoss_dFLe2@jnp.transpose(Wfl2e,(0,2,1))
    dLoss_dWfl2e=jnp.transpose(dLoss_dFLe2,(0,2,1))@FLe1
    dLoss_dbfl2e=jnp.sum(dLoss_dFLe2,axis=1)
    dLoss_dbfl2e.shape,bfl2e.reshape(bfl2e.shape[0],bfl2e.shape[-1]).shape
    if Xe1.all()>0:
        dLoss_Ect1_b=dLoss_dFLe1@jnp.transpose(Wfl1e,(0,2,1))
    else:
        dLoss_Ect1_b=0

    dLoss_Ect1=dLoss_Ect1_b+dLoss_Ect1_a
    if Xe1.all()>0:
        dLoss_Wfl1e=jnp.transpose(dLoss_dFLe1,(0,2,1))@Ect1
    else:
        dLoss_Wfl1e=0

    dLoss_Wfl1e.shape,Wfl1e.shape
    if Xe1.all()>0:
        dLoss_bfl1e=jnp.transpose(dLoss_dFLe1,(0,2,1)) 
    else:
        dLoss_bfl1e=0

    dLoss_bfl1e.shape,bfl1e.shape


    #################### Encoder BP
    dLoss_Ae=dLoss_Ect1*diff_norm(Xe,var_e,mu_e,Ne)
    dLoss_Ae.shape
    dLoss_inpute_a=dLoss_Ae
    dLoss_inpute_a.shape
    dLoss_dQe=diffQKV(dLoss_Ae,Attention_weights_e,K_E,V_E,inputs_e,dk)
    #print("dLoss_dQe.shape",dLoss_dQe.shape) # do the mean here over each phrase
    #print("Qe.shape",Qe.shape) 
    dLoss_dKe=diffQKV(dLoss_Ae,Attention_weights_e,Q_E,V_E,inputs_e,dk)
    #print("dLoss_dKe.shape",dLoss_dKe.shape) # do the mean here over each phrase
    #print("Ke.shape",Ke.shape) 
    dLoss_dVe=f=np.sum(np.sum(np.transpose(np.expand_dims(dLoss_Ae, axis=1),(0,1,3,2))@(Attention_weights_e@np.expand_dims(inputs_e, axis=1)),axis=1),axis=0)
    #print("dLoss_dVe.shape",dLoss_dVe.shape) # do the mean here over each phrase
    #print("Ve.shape",Ke.shape) 
    dLoss_V_E=np.transpose(np.sum(np.transpose(np.expand_dims(dLoss_Ae, axis=1),(0,1,3,2))@Attention_weights_e,axis=1),(0,2,1))
    dLoss_V_E.shape
    dLoss_inpute_v=dLoss_V_E@Ve

    #print("dLoss_inpute_v.shape",dLoss_inpute_v.shape)  
    dAttention_weights_e=Attention_weights_e*(1-Attention_weights_e)
    V1=redimension(dAttention_weights_e@K_E/jnp.sqrt(dk)) 

    V2=redimension(V_E)

    V3=V1*V2

    
    dLoss_Q_E=dLoss_Ae*V3
    dLoss_Q_E.shape

    dLoss_inpute_q=dLoss_Q_E@Qe
    #print("dLoss_inpute_q.shape",dLoss_inpute_q.shape)
    dAttention_weights_e=Attention_weights_e*(1-Attention_weights_e)
    V1=redimension(dAttention_weights_e@Q_E/jnp.sqrt(dk)) 

    V2=redimension(V_E)

    V3=V1*V2

    
    dLoss_K_E=dLoss_Ae*V3
    dLoss_K_E.shape

    dLoss_inpute_k=dLoss_K_E@Ke
    #print("dLoss_inpute_k.shape",dLoss_inpute_k.shape)
    dLoss_inpute=dLoss_inpute_a+dLoss_inpute_k+dLoss_inpute_q+dLoss_inpute_v
    dLoss_dWemb_encoder=dLoss_inpute*inputs_e

    W0=W0-learning_rate*dLoss_W0.T
    b0=b0-learning_rate*dLoss_b0
    Wfl2d=Wfl2d-learning_rate*dLoss_Wfl2d.T
    bfl2d=bfl2d-learning_rate*dLoss_bfl2d
    Wfl1d=Wfl1d-learning_rate*dLoss_Wfl1d.T
    bfl1d=bfl1d-learning_rate*dLoss_bfl1d
    Qc=Qc-learning_rate*dLoss_Qc
    Kc=Kc-learning_rate*dLoss_Kc
    Vc=Vc-learning_rate*dLoss_Vc
    Qd=Qd-learning_rate*dLoss_Qd
    Kd=Kd-learning_rate*dLoss_Kd
    Vd=Vd-learning_rate*dLoss_Vd
    input_d=input_d-learning_rate*dLoss_dWemb_decoder

    Wfl2e=Wfl2e-learning_rate*jnp.transpose(dLoss_dWfl2e ,(0,2,1))
    bfl2e=bfl2e-learning_rate*bfl2e
    Wfl1e=Wfl1e-learning_rate*jnp.transpose(dLoss_Wfl1e ,(0,2,1))
    bfl1e=bfl1e-learning_rate*bfl1e
    Qe=Qe-learning_rate*dLoss_dQe
    Ke=Ke-learning_rate*dLoss_dKe
    Ve=Ve-learning_rate*dLoss_dVe
    inputs_e=inputs_e-learning_rate*dLoss_dWemb_encoder

























inputs_d complete shape (15, 5, 15, 4)
Loss: 9.036615
Loss: 10.489328
Loss: 13.923627
Loss: 11.954871
Loss: 7.488725
Loss: 13.54043
Loss: 9.03979
Loss: 14.729125
Loss: 8.649913
Loss: 8.228803
Loss: 11.801648
Loss: 17.237494
Loss: 18.08123
Loss: 13.9557085
Loss: 11.357043


  dLoss_dVe=f=np.sum(np.sum(np.transpose(np.expand_dims(dLoss_Ae, axis=1),(0,1,3,2))@(Attention_weights_e@np.expand_dims(inputs_e, axis=1)),axis=1),axis=0)


## Update

In [26]:
learning_rate=0.01


In [38]:
W0=W0-learning_rate*dLoss_W0.T
b0=b0-learning_rate*dLoss_b0
Wfl2d=Wfl2d-learning_rate*dLoss_Wfl2d.T
bfl2d=bfl2d-learning_rate*dLoss_bfl2d
Wfl1d=Wfl1d-learning_rate*dLoss_Wfl1d.T
bfl1d=bfl1d-learning_rate*dLoss_bfl1d
Qc=Qc-learning_rate*dLoss_Qc
Kc=Kc-learning_rate*dLoss_Kc
Vc=Vc-learning_rate*dLoss_Vc
Qd=Qd-learning_rate*dLoss_Qd
Kd=Kd-learning_rate*dLoss_Kd
Vd=Vd-learning_rate*dLoss_Vd
input_d=input_d-learning_rate*dLoss_dWemb_decoder

Wfl2e=Wfl2e-learning_rate*jnp.transpose(dLoss_dWfl2e ,(0,2,1))
bfl2e=bfl2e-learning_rate*bfl2e
Wfl1e=Wfl1e-learning_rate*jnp.transpose(dLoss_Wfl1e ,(0,2,1))
bfl1e=bfl1e-learning_rate*bfl1e
Qe=Qe-learning_rate*dLoss_dQe
Ke=Ke-learning_rate*dLoss_dKe
Ve=Ve-learning_rate*dLoss_dVe
inputs_e=inputs_e-learning_rate*dLoss_dWemb_encoder