# Multi-Head Attention from Scratch Using CuPy

In this notebook, we will implement the **Multi-Head Attention** mechanism from scratch using **CuPy**, a GPU-accelerated library similar to NumPy. The multi-head attention mechanism is a key component of the Transformer model, enabling it to attend to different parts of the input simultaneously.

---

## 1. Overview of Multi-Head Attention

Multi-head attention allows the model to have multiple attention "heads," each of which focuses on different parts of the input sequence. Each head computes its own attention values, and the results are concatenated and transformed into the final output.

The steps involved in multi-head attention:
- **Linear transformations**: Apply learned weight matrices to the queries (Q), keys (K), and values (V).
- **Scaled Dot-Product Attention**: For each head, compute the attention scores and apply them to the values.
- **Concatenation**: Concatenate the outputs from all heads.
- **Final Linear Transformation**: Apply a final linear transformation to the concatenated output.

Mathematically, the output of the multi-head attention mechanism can be written as:

$$
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O
$$
where each attention head is computed as:

$$
\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
$$
---

## 2. CuPy Setup

Before we begin, make sure you have **CuPy** installed. You can install it via:

```bash
!pip install cupy


In [161]:
import sys 
import numpy as np
import re
import cupy as cp
import pickle
import time
import numpy as np 
import jax.numpy as jnp
import pandas as pd
import numpy as np
import jax
from tqdm import tqdm
from pathlib import Path
np.set_printoptions(edgeitems=30, linewidth=100000, formatter=dict(float=lambda x: "%.3g" % x)) 
def softmax(x, axis=-1):
    # Subtract the max value for numerical stability
    e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return e_x / np.sum(e_x, axis=axis, keepdims=True)

num_classes=2
word2vec_len = 10
num_phrases = 3
words_per_phrase = 7 
dk = dv = 10
 
num_heads=5
 
 
inputs = np.random.rand(num_phrases,words_per_phrase, word2vec_len)
target = softmax(np.random.rand(num_phrases,num_classes))

Q = np.random.rand(word2vec_len, dk) / jnp.sqrt(word2vec_len)
K = np.random.rand(word2vec_len, dk) / jnp.sqrt(word2vec_len)
V = np.random.rand(word2vec_len, dv) / jnp.sqrt(word2vec_len)
inputs.shape,Q.shape,K.shape,V.shape

((3, 7, 10), (10, 10), (10, 10), (10, 10))

### Consideriamo input formato da 3 frasi composte da 7 parole ciasuna ed ogni parola avente rappresentazione vettoriale di dimensione 10. Mentre vogliamo dopo l'attentione che ogni parola abbia rappresentazione vettoriale di 8 

In [162]:
inputs# each input phrase is made by 13 words having lenght 15

array([[[0.218, 0.563, 0.196, 0.275, 0.568, 0.155, 0.426, 0.0049, 0.0325, 0.463],
        [0.762, 0.183, 0.711, 0.974, 0.514, 0.945, 0.548, 0.534, 0.942, 0.313],
        [0.999, 0.963, 0.784, 0.742, 0.302, 0.408, 0.382, 0.0891, 0.324, 0.414],
        [0.35, 0.131, 0.818, 0.348, 0.508, 0.908, 0.41, 0.463, 0.907, 0.862],
        [0.224, 0.312, 0.908, 0.26, 0.129, 0.0262, 0.0881, 0.778, 0.285, 0.946],
        [0.611, 0.243, 0.21, 0.748, 0.684, 0.506, 0.42, 0.107, 0.7, 0.41],
        [0.552, 0.159, 0.151, 0.51, 0.923, 0.934, 0.609, 0.809, 0.26, 0.0536]],

       [[0.993, 0.179, 0.283, 0.982, 0.213, 0.898, 0.028, 0.295, 0.166, 0.795],
        [0.297, 0.493, 0.0927, 0.986, 0.307, 0.0541, 0.404, 0.766, 0.0191, 0.981],
        [0.24, 0.603, 0.928, 0.829, 0.719, 0.502, 0.742, 0.161, 0.611, 0.717],
        [0.26, 0.547, 0.817, 0.995, 0.393, 0.628, 0.731, 0.446, 0.959, 0.918],
        [0.641, 0.0695, 0.112, 0.784, 0.725, 0.718, 0.537, 0.0113, 0.916, 0.151],
        [0.885, 0.336, 0.727, 0.314, 0.

In [3]:
inputs.shape,Q.shape

((3, 7, 10), (10, 10))

In [4]:
jnp.matmul(inputs, Q) 

Array([[[0.742, 0.839, 1.04, 0.927, 0.653, 0.579, 0.895, 0.808, 0.76, 0.755],
        [0.661, 0.85, 1.2, 1.07, 0.849, 0.989, 0.633, 0.935, 0.788, 0.848],
        [0.767, 0.645, 0.993, 0.99, 0.681, 0.824, 0.681, 0.756, 0.646, 0.9],
        [1.02, 0.98, 1.34, 1.23, 1.12, 0.881, 1.03, 1.11, 0.822, 0.981],
        [0.44, 0.514, 0.836, 0.715, 0.603, 0.705, 0.245, 0.771, 0.615, 0.659],
        [0.563, 0.576, 0.978, 0.86, 0.807, 0.795, 0.698, 0.909, 0.614, 0.734],
        [0.683, 0.659, 0.841, 0.931, 0.723, 0.691, 0.644, 0.795, 0.641, 0.546]],

       [[0.664, 0.694, 0.943, 0.828, 0.772, 0.796, 0.384, 0.777, 0.553, 0.596],
        [0.724, 0.785, 1.18, 1.05, 0.86, 0.949, 0.803, 1.07, 0.771, 0.818],
        [0.928, 0.805, 1.13, 0.974, 0.959, 0.852, 0.681, 0.969, 0.696, 0.776],
        [0.943, 0.821, 1.2, 1.01, 0.958, 0.905, 0.731, 1.07, 0.738, 0.881],
        [0.987, 0.949, 1.42, 1.24, 0.987, 1.03, 0.98, 1.14, 0.911, 1.09],
        [0.464, 0.583, 0.907, 0.601, 0.468, 0.562, 0.527, 0.733, 0.618,

### Avendo fissato il numero di teste per il attezione ogni matrice Qval, Kval, Vval viene suddivisa in 4 parti uguali di 2 colonne. Otteniamo un array di 4 elementi che per ogni frase riportano le prime due colonne come di seguito.

In [5]:
len(jnp.array_split(jnp.matmul(inputs, Q),num_heads,axis=2))

5

In [6]:
jnp.array_split(jnp.matmul(inputs, Q),num_heads,axis=2)[0]# so i have basically num_heads chuncks of the Qval this is a list not array structure

Array([[[0.742, 0.839],
        [0.661, 0.85],
        [0.767, 0.645],
        [1.02, 0.98],
        [0.44, 0.514],
        [0.563, 0.576],
        [0.683, 0.659]],

       [[0.664, 0.694],
        [0.724, 0.785],
        [0.928, 0.805],
        [0.943, 0.821],
        [0.987, 0.949],
        [0.464, 0.583],
        [0.818, 0.956]],

       [[0.626, 0.639],
        [0.374, 0.508],
        [0.775, 0.743],
        [0.871, 0.815],
        [0.456, 0.459],
        [0.792, 0.782],
        [0.753, 0.926]]], dtype=float32)

### Ridimensioniamo l'array in modo che ogni frase contenga la lista dei rispettivi attention heads, ottenendo 3 frasi contententi 4 attention heads che hanno dimensione 7 (come il numero di parole per ogni frase) per 2 (fetta di embedding assegnata ad ogni head)

In [7]:
jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs, Q),num_heads,axis=2)), 0, 1).shape#  # here i actually transform it to a structure Qval


(3, 5, 7, 2)

In [8]:
jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs, Q),num_heads,axis=2)), 0, 1)[0]# refer to cell jnp.matmul(inputs, Q) 

Array([[[0.742, 0.839],
        [0.661, 0.85],
        [0.767, 0.645],
        [1.02, 0.98],
        [0.44, 0.514],
        [0.563, 0.576],
        [0.683, 0.659]],

       [[1.04, 0.927],
        [1.2, 1.07],
        [0.993, 0.99],
        [1.34, 1.23],
        [0.836, 0.715],
        [0.978, 0.86],
        [0.841, 0.931]],

       [[0.653, 0.579],
        [0.849, 0.989],
        [0.681, 0.824],
        [1.12, 0.881],
        [0.603, 0.705],
        [0.807, 0.795],
        [0.723, 0.691]],

       [[0.895, 0.808],
        [0.633, 0.935],
        [0.681, 0.756],
        [1.03, 1.11],
        [0.245, 0.771],
        [0.698, 0.909],
        [0.644, 0.795]],

       [[0.76, 0.755],
        [0.788, 0.848],
        [0.646, 0.9],
        [0.822, 0.981],
        [0.615, 0.659],
        [0.614, 0.734],
        [0.641, 0.546]]], dtype=float32)

In [9]:
Qval = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs, Q),num_heads,axis=2)), 0, 1)
print("Qval.shape: ",Qval.shape)

Kval = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs, K),num_heads,axis=2)), 0, 1)
print("Kval.shape: ",Kval.shape)


Vval = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs,V),num_heads,axis=2)), 0, 1)
print("Vval.shape: ",Vval.shape)

Qval.shape:  (3, 5, 7, 2)
Kval.shape:  (3, 5, 7, 2)
Vval.shape:  (3, 5, 7, 2)


### Per calcolare ora i pesi dell'attenzione applichiamo la formua 

$$
  \frac{QK^T}{\sqrt{d_k}}
  $$

In [10]:
Qval[0][0],Kval[0][0]

(Array([[0.742, 0.839],
        [0.661, 0.85],
        [0.767, 0.645],
        [1.02, 0.98],
        [0.44, 0.514],
        [0.563, 0.576],
        [0.683, 0.659]], dtype=float32),
 Array([[0.847, 0.728],
        [1.05, 0.979],
        [0.868, 0.849],
        [1.23, 1.22],
        [0.807, 0.543],
        [1, 0.869],
        [0.731, 0.734]], dtype=float32))

In [11]:
Qval[0][0],np.transpose(Kval, (0, 1, 3, 2))[0][0]

(Array([[0.742, 0.839],
        [0.661, 0.85],
        [0.767, 0.645],
        [1.02, 0.98],
        [0.44, 0.514],
        [0.563, 0.576],
        [0.683, 0.659]], dtype=float32),
 Array([[0.847, 1.05, 0.868, 1.23, 0.807, 1, 0.731],
        [0.728, 0.979, 0.849, 1.22, 0.543, 0.869, 0.734]], dtype=float32))

In [12]:
Qval[0][0]@jnp.transpose(Kval, (0, 1, 3, 2))[0][0]/ jnp.sqrt(dk)

Array([[0.392, 0.506, 0.429, 0.611, 0.333, 0.466, 0.366],
       [0.373, 0.483, 0.41, 0.584, 0.315, 0.443, 0.35],
       [0.354, 0.455, 0.384, 0.546, 0.307, 0.421, 0.327],
       [0.498, 0.642, 0.542, 0.772, 0.428, 0.592, 0.463],
       [0.236, 0.306, 0.259, 0.369, 0.201, 0.281, 0.221],
       [0.283, 0.365, 0.309, 0.44, 0.242, 0.337, 0.264],
       [0.335, 0.431, 0.364, 0.519, 0.288, 0.398, 0.311]], dtype=float32)

In [13]:
QKscaled = jnp.matmul(Qval, jnp.transpose(Kval, (0, 1, 3, 2))) / jnp.sqrt(dk)
 
QKscaled.shape

(3, 5, 7, 7)

In [14]:
QKscaled[0][0]

Array([[0.392, 0.506, 0.429, 0.611, 0.333, 0.466, 0.366],
       [0.373, 0.483, 0.41, 0.584, 0.315, 0.443, 0.35],
       [0.354, 0.455, 0.384, 0.546, 0.307, 0.421, 0.327],
       [0.498, 0.642, 0.542, 0.772, 0.428, 0.592, 0.463],
       [0.236, 0.306, 0.259, 0.369, 0.201, 0.281, 0.221],
       [0.283, 0.365, 0.309, 0.44, 0.242, 0.337, 0.264],
       [0.335, 0.431, 0.364, 0.519, 0.288, 0.398, 0.311]], dtype=float32)

In [15]:
Attention_weights = softmax(QKscaled)
print("Attention_weights shape:",Attention_weights.shape)

Attention_weights shape: (3, 5, 7, 7)


In [16]:
Attention = jnp.matmul(Attention_weights, Vval)
print("Attention shape:",Attention.shape)

Attention shape: (3, 5, 7, 2)


In [17]:
Attention[0]

Array([[[0.959, 0.954],
        [0.958, 0.954],
        [0.956, 0.953],
        [0.963, 0.958],
        [0.952, 0.949],
        [0.954, 0.95],
        [0.956, 0.952]],

       [[0.736, 0.842],
        [0.738, 0.845],
        [0.736, 0.843],
        [0.739, 0.847],
        [0.734, 0.839],
        [0.736, 0.842],
        [0.735, 0.841]],

       [[0.996, 0.671],
        [0.998, 0.672],
        [0.997, 0.672],
        [0.999, 0.673],
        [0.996, 0.671],
        [0.997, 0.672],
        [0.997, 0.671]],

       [[0.586, 0.927],
        [0.585, 0.927],
        [0.585, 0.926],
        [0.588, 0.93],
        [0.583, 0.923],
        [0.585, 0.927],
        [0.585, 0.926]],

       [[0.678, 0.897],
        [0.679, 0.898],
        [0.678, 0.896],
        [0.679, 0.899],
        [0.678, 0.895],
        [0.678, 0.895],
        [0.677, 0.894]]], dtype=float32)

### Now for tetrieving the attention correct size we need to horizontaly concatenate the attention output

In [18]:
Attention=jnp.array([jnp.concatenate(Attention[i], axis=1) for i in range(num_phrases)])
Attention,Attention.shape,inputs.shape

(Array([[[0.959, 0.954, 0.736, 0.842, 0.996, 0.671, 0.586, 0.927, 0.678, 0.897],
         [0.958, 0.954, 0.738, 0.845, 0.998, 0.672, 0.585, 0.927, 0.679, 0.898],
         [0.956, 0.953, 0.736, 0.843, 0.997, 0.672, 0.585, 0.926, 0.678, 0.896],
         [0.963, 0.958, 0.739, 0.847, 0.999, 0.673, 0.588, 0.93, 0.679, 0.899],
         [0.952, 0.949, 0.734, 0.839, 0.996, 0.671, 0.583, 0.923, 0.678, 0.895],
         [0.954, 0.95, 0.736, 0.842, 0.997, 0.672, 0.585, 0.927, 0.678, 0.895],
         [0.956, 0.952, 0.735, 0.841, 0.997, 0.671, 0.585, 0.926, 0.677, 0.894]],
 
        [[1.02, 1.06, 0.787, 0.83, 1.09, 0.707, 0.661, 0.974, 0.779, 0.996],
         [1.02, 1.06, 0.789, 0.834, 1.09, 0.708, 0.664, 0.978, 0.781, 0.998],
         [1.02, 1.06, 0.789, 0.833, 1.09, 0.708, 0.663, 0.977, 0.78, 0.998],
         [1.02, 1.06, 0.789, 0.834, 1.09, 0.708, 0.663, 0.978, 0.781, 0.998],
         [1.02, 1.07, 0.791, 0.837, 1.09, 0.708, 0.664, 0.979, 0.782, 1],
         [1.02, 1.06, 0.786, 0.828, 1.08, 0.706,

In [19]:
linearlayer= np.random.rand(num_phrases,dv, word2vec_len)   
linear_bias = np.random.rand(num_phrases,1,word2vec_len)
linearlayer.shape,linear_bias.shape

((3, 10, 10), (3, 1, 10))

In [20]:
linear_bias

array([[[0.149, 0.965, 0.688, 0.605, 0.158, 0.303, 0.201, 0.959, 0.635, 0.277]],

       [[0.0505, 0.2, 0.491, 0.996, 0.123, 0.332, 0.941, 0.714, 0.573, 0.0514]],

       [[0.49, 0.273, 0.311, 0.749, 0.808, 0.119, 0.587, 0.472, 0.206, 0.71]]])

In [21]:
def layer_norm(x, epsilon=1e-6):
    # Calculate the mean and variance
        mean = jnp.mean(x, axis=-1, keepdims=True)
        var = jnp.var(x, axis=-1, keepdims=True) 
        # Normalize the output
        x_norm = (x - mean) / jnp.sqrt(var + epsilon) 
        return x_norm


#output_sublayer_one=layer_norm((Attention@linearlayer +linear_bias)+inputs)
output_sublayer_one=layer_norm(Attention+inputs)
output_sublayer_one

Array([[[1.37, -1.24, -0.374, 0.015, -0.713, 0.942, -1.28, 0.597, -0.849, 1.53],
        [1.66, 1.28, 1.02, -0.0108, -0.612, -1.48, -0.235, -1.21, -0.647, 0.237],
        [0.201, -0.88, 1.31, 1.53, -0.0492, -0.848, 0.789, -0.75, -1.74, 0.437],
        [0.406, -0.243, -1.89, 0.863, 0.165, -1.44, -0.0345, 1.41, -0.353, 1.12],
        [-0.502, 0.561, 1.41, -0.722, 1.15, -1, -1.54, 0.0923, 1.29, -0.735],
        [0.313, 1.08, -0.648, 1.51, 0.494, -1.31, -1.89, -0.363, 0.255, 0.561],
        [0.537, 0.639, 0.987, -0.541, -0.1, -1.06, -1.14, 1.39, -1.66, 0.955]],

       [[0.842, 1.43, 0.318, -1.15, 0.988, -1.57, 0.821, 0.0974, -1.14, -0.643],
        [1.23, 1.21, 0.241, 0.842, 0.565, -1.21, -1.94, -0.278, -0.769, 0.116],
        [-0.277, 1.09, -0.493, -0.871, 1.69, -0.454, 0.702, 0.704, -1.91, -0.187],
        [-0.864, 0.762, -0.412, 0.307, 1.79, -0.417, 0.0304, 1.34, -1.24, -1.3],
        [1.35, -0.467, 0.158, 0.316, 0.889, -0.273, -0.296, -1.86, -1.22, 1.4],
        [0.283, -0.315, -0.144

In [22]:
output_sublayer_one.shape

(3, 7, 10)

## Decoder

In [23]:
decoder_input_word_embedding_size=10
decoder_input_number_of_words_per_phrase=9
num_heads_decoder=5# dv=10 
dv_decoder=10
inputs_decoder = np.random.rand(num_phrases,decoder_input_number_of_words_per_phrase, decoder_input_word_embedding_size)# for the target language suppose the 
target_decoder = inputs_decoder

In [24]:
inputs_decoder 

array([[[0.519, 0.655, 0.631, 0.843, 0.515, 0.233, 0.721, 0.719, 0.754, 0.287],
        [0.298, 0.412, 0.777, 0.535, 0.0521, 0.413, 0.804, 0.803, 0.689, 0.6],
        [0.705, 0.191, 0.392, 0.681, 0.416, 0.42, 0.734, 0.14, 0.159, 0.259],
        [0.763, 0.423, 0.091, 0.569, 0.328, 0.198, 0.768, 0.434, 0.678, 0.899],
        [0.165, 0.115, 0.582, 0.133, 0.693, 0.453, 0.243, 0.962, 0.457, 0.185],
        [0.267, 0.0362, 0.547, 0.0591, 0.605, 0.792, 0.415, 0.299, 0.986, 0.6],
        [0.425, 0.335, 0.154, 0.462, 0.905, 0.116, 0.595, 0.893, 0.7, 0.437],
        [0.476, 0.295, 0.0113, 0.51, 0.949, 0.671, 0.707, 0.229, 0.652, 0.348],
        [0.31, 0.0342, 0.732, 0.503, 0.824, 0.907, 0.793, 0.277, 0.686, 0.418]],

       [[0.706, 0.76, 0.217, 0.297, 0.938, 0.124, 0.461, 0.0958, 0.584, 0.281],
        [0.842, 0.261, 0.798, 0.453, 0.259, 0.844, 0.217, 0.543, 0.392, 0.149],
        [0.922, 0.227, 0.287, 0.778, 0.393, 0.94, 0.199, 0.305, 0.82, 0.842],
        [0.818, 0.673, 0.232, 0.867, 0.196, 0

In [25]:
 
input_translation=[]
def pad_sequence(seq, max_len, pad_value=0):
    """Pad a sequence with a given value up to max_len."""
    current_len = seq.shape[0]
    pad_width = max_len - current_len
    if pad_width > 0:
        # Pad sequence with zeros (or any pad_value you provide)
        seq = jnp.pad(seq, ((0, pad_width), (0, 0)), mode='constant', constant_values=pad_value)
    return seq

# Example usage:
max_len = decoder_input_number_of_words_per_phrase  # Max sequence length for decoder input
 
for j in range(inputs_decoder.shape[0]):
    # Create padded sequences
    padded_sequences = [pad_sequence(inputs_decoder[j][0:i], max_len) for i in range(1, inputs_decoder.shape[1] + 1)]
    input_translation.append(padded_sequences)


# Convert to an array for batching
input_translation = jnp.array(input_translation)
 

In [26]:
input_translation.shape

(3, 9, 9, 10)

In [27]:
input_translation[0]

Array([[[0.519, 0.655, 0.631, 0.843, 0.515, 0.233, 0.721, 0.719, 0.754, 0.287],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

       [[0.519, 0.655, 0.631, 0.843, 0.515, 0.233, 0.721, 0.719, 0.754, 0.287],
        [0.298, 0.412, 0.777, 0.535, 0.0521, 0.413, 0.804, 0.803, 0.689, 0.6],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

       [[0.519, 0.655, 0.631, 0.843, 0.515, 0.233, 0.721, 0.719, 0.754, 0.287],
        [0.298, 0.412, 0.777, 0.535, 0.0521, 0.413, 0.804, 0.803, 0.689, 0.6]

In [28]:
input_decoder=input_translation[0]

Q_decoder = np.random.rand(decoder_input_word_embedding_size, dv_decoder) / jnp.sqrt(decoder_input_word_embedding_size)
K_decoder = np.random.rand(decoder_input_word_embedding_size, dv_decoder) / jnp.sqrt(decoder_input_word_embedding_size)
V_decoder = np.random.rand(decoder_input_word_embedding_size, dv_decoder) / jnp.sqrt(decoder_input_word_embedding_size)

In [29]:
Q_decoder.shape

(10, 10)

In [30]:
Qval_decoder=input_decoder@Q_decoder
print("Qval.shape: ",Qval_decoder.shape)

Qval.shape:  (9, 9, 10)


In [31]:
Qval_decoder  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(input_decoder, Q_decoder),num_heads_decoder,axis=2)), 0, 1)
print("Qval.shape: ",Qval_decoder.shape)

Kval_decoder  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(input_decoder, K_decoder),num_heads_decoder,axis=2)), 0, 1)
print("Kval.shape: ",Kval_decoder.shape)


Vval_decoder  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(input_decoder,V_decoder),num_heads_decoder,axis=2)), 0, 1)
print("Vval.shape: ",Vval_decoder.shape)

Qval.shape:  (9, 5, 9, 2)
Kval.shape:  (9, 5, 9, 2)
Vval.shape:  (9, 5, 9, 2)


In [32]:
QKscaled_decoder  = jnp.matmul(Qval_decoder, jnp.transpose(Kval_decoder, (0, 1, 3, 2))) / jnp.sqrt(dk) + jnp.triu(jnp.ones((9, 9)))* -1e9 
# Step 1: Create a causal mask of shape (1, 1, 9, 9) to broadcast across heads and batch
mask = jnp.tril(jnp.ones((max_len, max_len)))  # (9, 9) lower triangular matrix
mask = mask.at[mask == 0].set(-jnp.inf)  # Set future tokens to -inf
mask = mask.at[mask == 1].set(0)  # Set allowed tokens to 0
mask = mask.reshape(1, 1, max_len, max_len)  # Reshape to (1, 1, 9, 9)

# Step 2: Apply mask to QKscaled_decoder (it will broadcast across batch and heads)
QKscaled_decoder = QKscaled_decoder + mask 
QKscaled_decoder[3]

Array([[[-1e+09, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0.445, -1e+09, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0.304, 0.282, -1e+09, -inf, -inf, -inf, -inf, -inf, -inf],
        [0.455, 0.418, 0.299, -1e+09, -inf, -inf, -inf, -inf, -inf],
        [0, 0, 0, 0, -1e+09, -inf, -inf, -inf, -inf],
        [0, 0, 0, 0, 0, -1e+09, -inf, -inf, -inf],
        [0, 0, 0, 0, 0, 0, -1e+09, -inf, -inf],
        [0, 0, 0, 0, 0, 0, 0, -1e+09, -inf],
        [0, 0, 0, 0, 0, 0, 0, 0, -1e+09]],

       [[-1e+09, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0.533, -1e+09, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0.352, 0.336, -1e+09, -inf, -inf, -inf, -inf, -inf, -inf],
        [0.399, 0.383, 0.323, -1e+09, -inf, -inf, -inf, -inf, -inf],
        [0, 0, 0, 0, -1e+09, -inf, -inf, -inf, -inf],
        [0, 0, 0, 0, 0, -1e+09, -inf, -inf, -inf],
        [0, 0, 0, 0, 0, 0, -1e+09, -inf, -inf],
        [0, 0, 0, 0, 0, 0, 0, -1e+09, -inf],
        [0, 0, 0, 0,

In [33]:
Attention_weights = softmax(QKscaled_decoder)
Attention_weights.shape

(9, 5, 9, 9)

In [34]:
Vval_decoder

Array([[[[0.939, 1.07],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0]],

        [[0.929, 1.18],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0]],

        [[1.08, 0.805],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0]],

        [[0.825, 0.794],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0]],

        [[0.681, 1.27],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0]]],


       [[[0.939, 1.07],
         [0.859, 1.08],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0],
         [0, 0]],

        [[0.929, 1

In [35]:
Attention = jnp.matmul(Attention_weights, Vval_decoder)

In [36]:
Attention.shape

(9, 5, 9, 2)

In [37]:
Attention[0]

Array([[[0.939, 1.07],
        [0.939, 1.07],
        [0.47, 0.534],
        [0.313, 0.356],
        [0.235, 0.267],
        [0.188, 0.213],
        [0.157, 0.178],
        [0.134, 0.152],
        [0.117, 0.133]],

       [[0.929, 1.18],
        [0.929, 1.18],
        [0.464, 0.59],
        [0.31, 0.393],
        [0.232, 0.295],
        [0.186, 0.236],
        [0.155, 0.197],
        [0.133, 0.169],
        [0.116, 0.147]],

       [[1.08, 0.805],
        [1.08, 0.805],
        [0.541, 0.402],
        [0.361, 0.268],
        [0.27, 0.201],
        [0.216, 0.161],
        [0.18, 0.134],
        [0.155, 0.115],
        [0.135, 0.101]],

       [[0.825, 0.794],
        [0.825, 0.794],
        [0.412, 0.397],
        [0.275, 0.265],
        [0.206, 0.199],
        [0.165, 0.159],
        [0.137, 0.132],
        [0.118, 0.113],
        [0.103, 0.0993]],

       [[0.681, 1.27],
        [0.681, 1.27],
        [0.34, 0.634],
        [0.227, 0.423],
        [0.17, 0.317],
        [0.136, 0.254]

In [38]:
Attention=jnp.array([jnp.concatenate(Attention[i], axis=1) for i in range(9)])
Attention.shape,input_decoder.shape

((9, 9, 10), (9, 9, 10))

In [39]:
Attention[1]

Array([[0.939, 1.07, 0.929, 1.18, 1.08, 0.805, 0.825, 0.794, 0.681, 1.27],
       [0.939, 1.07, 0.929, 1.18, 1.08, 0.805, 0.825, 0.794, 0.681, 1.27],
       [0.899, 1.07, 0.912, 1.09, 1.07, 0.861, 0.785, 0.703, 0.645, 1.25],
       [0.599, 0.716, 0.608, 0.729, 0.715, 0.574, 0.523, 0.468, 0.43, 0.833],
       [0.45, 0.537, 0.456, 0.547, 0.536, 0.431, 0.392, 0.351, 0.322, 0.625],
       [0.36, 0.43, 0.365, 0.437, 0.429, 0.345, 0.314, 0.281, 0.258, 0.5],
       [0.3, 0.358, 0.304, 0.365, 0.358, 0.287, 0.262, 0.234, 0.215, 0.416],
       [0.257, 0.307, 0.261, 0.312, 0.307, 0.246, 0.224, 0.201, 0.184, 0.357],
       [0.225, 0.269, 0.228, 0.273, 0.268, 0.215, 0.196, 0.176, 0.161, 0.312]], dtype=float32)

In [40]:
def layer_norm(x, epsilon=1e-6):
    # Calculate the mean and variance
        mean = jnp.mean(x, axis=-1, keepdims=True)
        var = jnp.var(x, axis=-1, keepdims=True) 
        # Normalize the output
        x_norm = (x - mean) / jnp.sqrt(var + epsilon) 
        return x_norm
residual_output = layer_norm(input_decoder + Attention)
residual_output.shape

(9, 9, 10)

In [178]:
phlenght=6
input_t=[34,55,67,27,45,78]
embedding_l=5
A=np.random.rand(len(input_t), len(input_t),embedding_l)
ins=input_t@A
ins

array([[198, 83.1, 140, 155, 193],
       [122, 140, 196, 197, 120],
       [156, 215, 188, 129, 183],
       [133, 180, 203, 148, 146],
       [143, 203, 119, 201, 205],
       [204, 89.6, 88.3, 170, 116]])

In [179]:
ins.shape

(6, 5)

# Summary Encoder

In [22]:
import numpy as np
import jax.numpy as jnp
def softmax(x, axis=-1):
    # Subtract the max value for numerical stability
    e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return e_x / np.sum(e_x, axis=axis, keepdims=True)
def layer_norm(x, epsilon=1e-6):
    # Calculate the mean and variance
        mean = jnp.mean(x, axis=-1, keepdims=True)
        var = jnp.var(x, axis=-1, keepdims=True) 
        
        # Normalize the output
        x_norm = (x - mean) / jnp.sqrt(var + epsilon) 
        #print(x)
        #print(mean)
        #print("mean",mean.shape)
        #print("x_norm.shape",x_norm.shape)
        return x_norm,mean,var,x.shape[-1]
def relu(x):
    return np.maximum(0, x)
 

num_phrases = 1
words_per_phrase = 3 
dk = dv = word2vec_len = 4 # constrain of transformer
 
num_heads=2
 
 
inputs_encoder = np.random.rand(num_phrases,words_per_phrase, word2vec_len)
print("inputs.shape: ",inputs_encoder.shape)

Qe = np.random.rand(word2vec_len, dk) / jnp.sqrt(word2vec_len)
Ke = np.random.rand(word2vec_len, dk) / jnp.sqrt(word2vec_len)
Ve = np.random.rand(word2vec_len, dv) / jnp.sqrt(word2vec_len)

Q_E= jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_encoder, Qe),num_heads,axis=2)), 0, 1)
print("Qval.shape: ",Q_E.shape)

K_E = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_encoder, Ke),num_heads,axis=2)), 0, 1)
print("Kval.shape: ",K_E.shape)


V_E = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_encoder,Ve),num_heads,axis=2)), 0, 1)
print("Vval.shape: ",V_E.shape)


QKscaled = jnp.matmul(Q_E, jnp.transpose(K_E, (0, 1, 3, 2))) / jnp.sqrt(dk)

Attention_weights = softmax(QKscaled)
print("Attention_weights shape:",Attention_weights.shape)


Attention_E = jnp.matmul(Attention_weights, V_E)
print("Attention shape:",Attention_E.shape)


Attention_E=jnp.array([jnp.concatenate(Attention_E[i], axis=1) for i in range(num_phrases)])
print("Attention shape concat:",Attention_E.shape)


 


Xe=Attention_E+inputs_encoder
Ect1,mu_e,var_e,Ne=layer_norm(Xe)
print("Ect1.shape",Ect1.shape,Ne)

fl1_size=100
Wfl1e=np.random.rand(num_phrases,dv, dv)   
bfl1e=np.random.rand(num_phrases,1,dv)
Xe1=jnp.matmul(Ect1,Wfl1e)+bfl1e
print("Xe1.shape",Xe1.shape)

FLe1=relu(Xe1)
print("FLe1.shape",FLe1.shape)


fl2_size=50
Wfl2e=np.random.rand(num_phrases,FLe1.shape[2], dv)   
bfl2e=np.random.rand(num_phrases,1,dv)
FLe2=jnp.matmul(FLe1,Wfl2e)+bfl2e
print("FLe2.shape",FLe2.shape)

Xe2=FLe2+Xe1
Ecout,mu_e2,var_e2,Ne2=layer_norm(Xe2)
print("Ecout.shape",Ecout.shape,Ne2)

inputs.shape:  (1, 3, 4)
Qval.shape:  (1, 2, 3, 2)
Kval.shape:  (1, 2, 3, 2)
Vval.shape:  (1, 2, 3, 2)
Attention_weights shape: (1, 2, 3, 3)
Attention shape: (1, 2, 3, 2)
Attention shape concat: (1, 3, 4)
Ect1.shape (1, 3, 4) 4
Xe1.shape (1, 3, 4)
FLe1.shape (1, 3, 4)
FLe2.shape (1, 3, 4)
Ecout.shape (1, 3, 4) 4


In [20]:
np.mean([0.78880733, 0.3270942 , 0.27765104 ,0.9192637 ])

0.5782040675

# Cross Attention

In [6]:
dv_cross=dv
num_heads_cross=num_heads
#Qc = np.random.rand(Xe2.shape[-1], dv_cross) / jnp.sqrt(Xe2.shape[-1])
Kc = np.random.rand(Ecout.shape[-1], dv_cross) / jnp.sqrt(Ecout.shape[-1])
Vc = np.random.rand(Ecout.shape[-1], dv_cross) / jnp.sqrt(Ecout.shape[-1])



K_C  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(Xe2, Kc),num_heads_cross,axis=2)), 0, 1)
print("K_C.shape: ",K_C.shape)


V_C  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(Xe2,Vc),num_heads_cross,axis=2)), 0, 1)
print("K_C.shape: ",K_C.shape)

K_C = K_C[0]  # Use the first phrase from the encoder output
V_C = V_C[0]  # Use the first phrase from the encoder output



Kval.shape:  (1, 2, 3, 2)
Vval.shape:  (1, 2, 3, 2)


# Summary Decoder

In [24]:
decoder_input_word_embedding_size=word2vec_len
decoder_input_number_of_words_per_phrase=7
num_heads_decoder=2# dv=10 
dv_decoder=dv     
input_d = np.random.rand(num_phrases,decoder_input_number_of_words_per_phrase, decoder_input_word_embedding_size)# for the target language suppose the 
target_decoder = input_d

input_translation=[]
def pad_sequence(seq, max_len, pad_value=0):
    """Pad a sequence with a given value up to max_len."""
    current_len = seq.shape[0]
    pad_width = max_len - current_len
    if pad_width > 0:
        # Pad sequence with zeros (or any pad_value you provide)
        seq = jnp.pad(seq, ((0, pad_width), (0, 0)), mode='constant', constant_values=pad_value)
    return seq

# Example usage:
max_len = decoder_input_number_of_words_per_phrase  # Max sequence length for decoder input
 
for j in range(input_d.shape[0]):
    # Create padded sequences
    padded_sequences = [pad_sequence(input_d[j][0:i], max_len) for i in range(1, input_d.shape[1] + 1)]
    input_translation.append(padded_sequences)


# Convert to an array for batching
input_translation = jnp.array(input_translation)

inputs_d=input_translation[0]

Qd = np.random.rand(decoder_input_word_embedding_size, dv_decoder) / jnp.sqrt(decoder_input_word_embedding_size)
Kd = np.random.rand(decoder_input_word_embedding_size, dv_decoder) / jnp.sqrt(decoder_input_word_embedding_size)
Vd = np.random.rand(decoder_input_word_embedding_size, dv_decoder) / jnp.sqrt(decoder_input_word_embedding_size)


Q_D  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_d, Qd),num_heads_decoder,axis=2)), 0, 1)
print("Qval.shape: ",Q_D.shape)

K_D  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_d, Kd),num_heads_decoder,axis=2)), 0, 1)
print("Kval.shape: ",K_D.shape)


V_D  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(inputs_d,Vd),num_heads_decoder,axis=2)), 0, 1)
print("Vval.shape: ",V_D.shape)


QKscaled_decoder  = jnp.matmul(Q_D, jnp.transpose(K_D, (0, 1, 3, 2))) / jnp.sqrt(dk) + jnp.triu(jnp.ones((decoder_input_number_of_words_per_phrase, decoder_input_number_of_words_per_phrase)))* -1e9 
# Step 1: Create a causal mask of shape (1, 1, 9, 9) to broadcast across heads and batch
mask = jnp.tril(jnp.ones((max_len, max_len)))  # (9, 9) lower triangular matrix
mask = mask.at[mask == 0].set(-jnp.inf)  # Set future tokens to -inf
mask = mask.at[mask == 1].set(0)  # Set allowed tokens to 0
mask = mask.reshape(1, 1, max_len, max_len)  # Reshape to (1, 1, 9, 9)

# Step 2: Apply mask to QKscaled_decoder (it will broadcast across batch and heads)
QKscaled_decoder = QKscaled_decoder + mask 

Attention_weights = softmax(QKscaled_decoder)


A_mask = jnp.matmul(Attention_weights, V_D)


A_mask=jnp.array([jnp.concatenate(A_mask[i], axis=1) for i in range(num_phrases)])



Xd = inputs_d + A_mask
Dt1,mu_d,var_d,N_d = layer_norm(Xd)
print("Dt1.shape",Dt1.shape)

Qc = np.random.rand(Dt1.shape[-1], dv_cross) / jnp.sqrt(Dt1.shape[-1])
Q_C  = jnp.swapaxes(jnp.array(jnp.array_split(jnp.matmul(Dt1, Qc),num_heads_decoder,axis=2)), 0, 1)
print("Qval.shape: ",Q_C.shape)
 

Qval.shape:  (7, 2, 7, 2)
Kval.shape:  (7, 2, 7, 2)
Vval.shape:  (7, 2, 7, 2)
Dt1.shape (7, 7, 4)
Qval.shape:  (7, 2, 7, 2)


 # Cross Attention

In [33]:
QKscaled_cross_attention  = jnp.matmul(Q_C, jnp.transpose(jnp.expand_dims(K_C, axis=0) , (0, 1, 3, 2))) / jnp.sqrt(dv_decoder)
Attention_weights_cross = softmax(QKscaled_cross_attention)
Acr = jnp.matmul(Attention_weights_cross, jnp.expand_dims(V_C, axis=0))
Acr=jnp.array([jnp.concatenate(Acr[i], axis=1) for i in range(decoder_input_number_of_words_per_phrase)]) 
Res=Acr + Dt1
Dt2, mu_res,var_res,N_res = layer_norm(Res)  # residual_output is (9, 9, 10)
print("Dt2 shape:", Dt2.shape)

 
Wfl1d=np.random.rand(num_phrases,dv, dv)   
bfl1d=np.random.rand(num_phrases,1,dv)
Xd1=jnp.matmul(Dt2,Wfl1d)+bfl1d
print("Xd1.shape",Xd1.shape)

 
FLd1=relu(Xd1)
print("FLe1.shape",FLd1.shape)


 
Wfl2d=np.random.rand(num_phrases,FLd1.shape[2], dv)   
bfl2d=np.random.rand(num_phrases,1,dv)
FLd2=jnp.matmul(FLd1,Wfl2d)+bfl2d
print("FLd2.shape",FLd2.shape)

Xd2=FLd2+Dt2
Dout,mu_d2,var_d2,N_d2=layer_norm(Xd2)
Dout.shape
print("Xd2.shape",Dout.shape)


W0=np.random.rand(num_phrases,dv, dv)   
b0=np.random.rand(num_phrases,1,dv)
Zout=jnp.matmul(Dout,W0)+b0
print("Zout.shape",Zout.shape)
SigmaZout = softmax(Zout) 
print("SigmaZout.shape",SigmaZout.shape)
SigmaZout[0]

Dt2 shape: (7, 7, 4)
Xd1.shape (7, 7, 4)
FLe1.shape (7, 7, 4)
FLd2.shape (7, 7, 4)
Xd2.shape (7, 7, 4)
Zout.shape (7, 7, 4)
SigmaZout.shape (7, 7, 4)


array([[0.369694  , 0.12853794, 0.27173984, 0.23002832],
       [0.34361538, 0.20036109, 0.30039287, 0.15563065],
       [0.343619  , 0.20036043, 0.30038974, 0.1556309 ],
       [0.34362486, 0.20035928, 0.3003845 , 0.15563135],
       [0.34363323, 0.20035768, 0.30037713, 0.15563194],
       [0.3436439 , 0.20035566, 0.30036774, 0.1556327 ],
       [0.3436569 , 0.20035319, 0.30035624, 0.15563364]], dtype=float32)

In [26]:
def cross_entropy_loss(predictions, target):
    # Cross-entropy loss for a batch of predictions and targets
    batch_loss = -jnp.sum(target * jnp.log(predictions + 1e-9), axis=1)
    return jnp.mean(batch_loss)

In [27]:
cross_entropy_loss(SigmaZout, target_decoder)

Array(6.692228, dtype=float32)

In [28]:
Dout.shape,W0.shape,b0.shape

((7, 7, 4), (1, 4, 4), (1, 1, 4))

In [29]:
dLoss_dZout=SigmaZout-target_decoder
print("dLoss_dZout.shape",dLoss_dZout.shape)
dLoss_W0=jnp.transpose(dLoss_dZout,(0,2,1))@Dout
print("dLoss_W0.shape",dLoss_W0.shape)
dLoss_b0=dLoss_dZout
print("dLoss_b0.shape",dLoss_b0.shape)

dLoss_dZout.shape (7, 7, 4)
dLoss_W0.shape (7, 4, 4)
dLoss_b0.shape (7, 7, 4)


In [34]:
epsilon=1e-6
dLoss_Dout=dLoss_dZout@W0
dLoss_FLd2=dLoss_Dout@((1-(1/N_d2))*(1/(jnp.sqrt(var_d2+epsilon)))-(1/N_d2)*(((Xd2-mu_d2)**2)/((var_d2+epsilon)**(3/2))))

TypeError: dot_general requires contracting dimensions to have the same shape, got (4,) and (7,).